import { v4 as uuidv4 } from "uuid";
import slugify from "slugify";
import type { PointwiseInteriorConstraintType, PointwiseInteriorConstraintCreator } from "./types";
import { checkRequiredFields } from "modulus-interop/utils";

export const defaultPointwiseInteriorConstraintSettings = {
  batch_size: 1024,
  batch_per_epoch: 1000,
  fixed_dataset: true,
  compute_sdf_derivatives: false,
  quasirandom: false,
  shuffle: true,
  used_for_training: true,
}

export function PointwiseInteriorConstraint(
  settings: PointwiseInteriorConstraintType
): PointwiseInteriorConstraintCreator {
  return {
    id: uuidv4(),
    mode: Object.freeze("PointwiseInteriorConstraint"),
    slug: Object.freeze("pointwise_interior_constraint"),
    settings: {
      ...defaultPointwiseInteriorConstraintSettings,
      ...settings,
    },
    set(settings: Partial<PointwiseInteriorConstraintType>) {
      Object.assign(this.settings, settings);
      return this.settings;
    },
    validate() {
      checkRequiredFields(["label", "geometry", "outvar"], this.settings);
    },
    generateCode() {
      this.validate();
      const {
        batch_size,
        geometry,
        outvar,
        criteria,
        lambda_weighting,
        quasirandom,
        shuffle,
        compute_sdf_derivatives,
        fixed_dataset,
        parameterization,
        geometry_label,
      } = this.settings;
      const label = slugify(settings.label, { lower: true, replacement: "_" });
      const parameterizationAttr =
        parameterization && parameterization.length > 0
          ? `{${parameterization.map(({ symbol, min, max }) => `${symbol}: (${min}, ${max})`).join(", ")}}`
          : "None";
      const lambda_weightingAttr = lambda_weighting
        ? `{${Object.entries(lambda_weighting)
            .map(([key, value]) => `"${key}": ${value}`)
            .join(", ")}}`
        : "None";

      return `
    ${label}_interior = modulus.domain.constraint.PointwiseInteriorConstraint(
        nodes=nodes,
        geometry=${geometry_label}_mesh,
        outvar={${Object.entries(outvar)
          .map(([key, value]) => `"${key.toString()}": ${value.toString()}`)
          .join(", ")}},
        batch_size=${batch_size},
        lambda_weighting=${lambda_weightingAttr},
        compute_sdf_derivatives=${compute_sdf_derivatives ? "True" : "False"},
        fixed_dataset=${fixed_dataset ? "True" : "False"},
        quasirandom=${quasirandom ? "True" : "False"},
        shuffle=${shuffle ? "True" : "False"},
        criteria=${criteria ? criteria : "None"},
        parameterization=${parameterizationAttr})
        
    domain.add_constraint(${label}_interior, "${label}")
`;
    },
  };
}
