import { checkRequiredFields } from "modulus-interop/utils";
import { v4 as uuidv4 } from "uuid";
import slugify from "slugify";
import { AdvectionDiffusionType, AdvectionDiffusionCreator } from "./types";

function checkRequiredPhysicsFields<T>(requiredFields: string[], settings: T) {
  if (requiredFields.includes("T") && !settings.hasOwnProperty("T")) {
    throw Error("The dependent variable (T) not set!");
  }

  if (requiredFields.includes("D") && !settings.hasOwnProperty("D")) {
    throw Error("Diffusivity (D) not set!");
  }

  if (requiredFields.includes("Q") && !settings.hasOwnProperty("Q")) {
    throw Error("The source term (Q) not set!");
  }

  if (requiredFields.includes("rho") && !settings.hasOwnProperty("rho")) {
    throw Error("The density (rho) not set!");
  }
}

export const defaultAdvectionDiffusionSettings: Partial<AdvectionDiffusionType> = {
  T: "T",
  D: { symbol: "D_T", value: "D", parameterized: false },
  Q: { symbol: "Q_T", value: 0, parameterized: false },
  rho: { symbol: "rho_AdvectionDiffusion", value: 1.0, parameterized: false },
  dim: 3,
  time: false,
  mixed_form: false,
};

export function AdvectionDiffusion(settings?: AdvectionDiffusionType): AdvectionDiffusionCreator {
  return {
    id: uuidv4(),
    mode: Object.freeze("AdvectionDiffusion"),
    slug: Object.freeze("advection_diffusion"),
    settings: {
      ...defaultAdvectionDiffusionSettings,
      ...settings,
    },
    set(settings: Partial<AdvectionDiffusionType>) {
      Object.assign(this.settings, settings);
      return this.settings;
    },
    validate() {
      checkRequiredPhysicsFields(["T", "D", "Q", "rho"], this.settings);
      checkRequiredFields(["dim"], this.settings);
    },
    generateCode() {
      this.validate();
      const { T, D, Q, rho, dim, time, mixed_form } = this.settings;

      return `
    from modulus.eq.pdes.advection_diffusion import AdvectionDiffusion
    ${this.slug} = AdvectionDiffusion(
        T=${T},
        D=${D.parameterized ? `Symbol("${D.symbol}")` : D.value || D}},
        Q=${Q.parameterized ? `Symbol("${Q.symbol}")` : Q.value || Q}},
        rho=${Q.parameterized ? `Symbol("${rho.symbol}")` : rho.value || rho}},
        time=${time ? "True" : "False"},
        mixed_form=${mixed_form},
        dim=${dim})
    nodes = nodes + ${this.slug}.make_nodes()
`;
    },
  };
}
