import { checkRequiredFields } from "modulus-interop/utils";
import { v4 as uuidv4 } from "uuid";
import { HelmholtzEquationCreator, HelmholtzEquationType, WaveEquationCreator, WaveEquationType } from "./types";

export const defaultWaveEquationSettings: Partial<WaveEquationType> = {
  c: { symbol: "c_WaveEquation", value: "c", parameterized: false },
  u: "u",
  dim: 3,
  time: false,
  mixed_form: false,
};

export function WaveEquation(settings?: WaveEquationType): WaveEquationCreator {
  return {
    id: uuidv4(),
    mode: Object.freeze("WaveEquation"),
    slug: Object.freeze("wave_equation"),
    settings: {
      ...defaultWaveEquationSettings,
      ...settings,
    },
    set(settings: Partial<WaveEquationType>) {
      Object.assign(this.settings, settings);
      return this.settings;
    },
    validate() {
      checkRequiredFields(["c", "u", "max_distance", "dim", "time"], this.settings);
    },
    generateCode() {
      this.validate();
      const { u, c, dim, time, mixed_form } = this.settings;

      return `
    from modulus.eq.pdes.wave_equation import WaveEquation
    ${this.slug} = WaveEquation(
        u=${u},
        c=${c.parameterized ? `Symbol("${c.symbol}")` : c.value || c}},
        dim=${dim},
        time=${time ? "True" : "False"},
        mixed_form=${mixed_form ? "True" : "False"})
    nodes = nodes + ${this.slug}.make_nodes()
`;
    },
  };
}

export const defaultHelmholtzEquationSettings: Partial<HelmholtzEquationType> = {
  k: { symbol: "k_HelmholtzEquation", value: "k", parameterized: false },
  u: "u",
  dim: 3,
  mixed_form: false,
};

export function HelmholtzEquation(settings?: HelmholtzEquationType): HelmholtzEquationCreator {
  return {
    id: uuidv4(),
    mode: Object.freeze("HelmholtzEquation"),
    slug: Object.freeze("helmholtz_equation"),
    settings: {
      ...defaultHelmholtzEquationSettings,
      ...settings,
    },
    set(settings: Partial<HelmholtzEquationType>) {
      Object.assign(this.settings, settings);
      return this.settings;
    },
    validate() {
      checkRequiredFields(["u", "k", "max_distance", "dim"], this.settings);
    },
    generateCode() {
      this.validate();
      const { u, k, dim, mixed_form } = this.settings;

      return `
    from modulus.eq.pdes.wave_equation import HelmholtzEquation
    ${this.slug} = HelmholtzEquation(
        u=${u},
        k=${k.parameterized ? `Symbol("${k.symbol}")` : k.value || k}},
        dim=${dim},
        mixed_form=${mixed_form ? "True" : "False"})
    nodes = nodes + ${this.slug}.make_nodes()
`;
    },
  };
}
