import { checkRequiredFields } from "modulus-interop/utils";
import { v4 as uuidv4 } from "uuid";
import {
  CompressibleIntegralContinuityCreator,
  CompressibleIntegralContinuityType,
  CurlCreator,
  CurlType,
  FluxContinuityCreator,
  FluxContinuityType,
  GradNormalCreator,
  GradNormalType,
  NavierStokesCreator,
  NavierStokesType,
} from "./types";

function checkRequiredPhysicsFields<T>(requiredFields: string[], settings: T) {
  if (requiredFields.includes("nu") && !settings.hasOwnProperty("nu")) {
    throw Error("Kinematic viscosity (nu) not set!");
  }

  if (requiredFields.includes("rho") && !settings.hasOwnProperty("rho")) {
    throw Error("Density (rho) not set!");
  }
}

export const defaultNavierStokesSettings: Partial<NavierStokesType> = {
  nu: { symbol: "nu_NavierStokes", value: 0.025, parameterized: false },
  rho: { symbol: "rho_NavierStokes", value: 1.0, parameterized: false },
  dim: 3,
  time: false,
};

export function NavierStokes(settings?: NavierStokesType): NavierStokesCreator {
  return {
    id: uuidv4(),
    mode: Object.freeze("NavierStokes"),
    slug: Object.freeze("navier_stokes"),
    settings: {
      ...defaultNavierStokesSettings,
      ...settings,
    },
    set(settings: Partial<NavierStokesType>) {
      Object.assign(this.settings, settings);
      return this.settings;
    },
    validate() {
      checkRequiredPhysicsFields(["nu", "rho"], this.settings);
    },
    generateCode() {
      this.validate();
      const { nu, rho, dim, time, mixed_form } = this.settings;

      return `
    from modulus.eq.pdes.navier_stokes import NavierStokes
    ${this.slug} = NavierStokes(
        nu=${nu.parameterized ? `Symbol("${nu.symbol}")` : nu.value || nu}},
        rho=${rho.parameterized ? `Symbol("${rho.symbol}")` : rho.value || rho}},
        dim=${dim},
        time=${time ? "True" : "False"},
        mixed_form=${mixed_form ? "True" : "False"})
    nodes = nodes + ${this.slug}.make_nodes()
`;
    },
  };
}

export const defaultCompressibleIntegralContinuitySettings: Partial<CompressibleIntegralContinuityType> = {
  rho: { symbol: "rho_CompressibleIntegralContinuity", value: 1.0, parameterized: false },
  dim: 3,
  vec: ["u", "v", "w"],
};

export function CompressibleIntegralContinuity(
  settings?: CompressibleIntegralContinuityType
): CompressibleIntegralContinuityCreator {
  return {
    id: uuidv4(),
    mode: Object.freeze("CompressibleIntegralContinuity"),
    slug: Object.freeze("compressible_integral_continuity"),
    settings: {
      ...defaultCompressibleIntegralContinuitySettings,
      ...settings,
    },
    set(settings: Partial<CompressibleIntegralContinuityType>) {
      Object.assign(this.settings, settings);
      return this.settings;
    },
    validate() {
      checkRequiredPhysicsFields(["rho"], this.settings);
    },
    generateCode() {
      this.validate();
      const { rho, dim, vec } = this.settings;

      return `
    from modulus.eq.pdes.navier_stokes import CompressibleIntegralContinuity
    ${this.slug} = CompressibleIntegralContinuity(
        rho=${rho.parameterized ? `Symbol("${rho.symbol}")` : rho.value || rho}},
        dim=${dim},
        vec=[${vec.map((v) => `"${v}"`).join(", ")}])
    nodes = nodes + ${this.slug}.make_nodes()
`;
    },
  };
}

export const defaultCurlSettings: Partial<CurlType> = {
  curl_name: ["u", "v", "w"],
};

export function Curl(settings?: CurlType): CurlCreator {
  return {
    id: uuidv4(),
    mode: Object.freeze("Curl"),
    slug: Object.freeze("curl"),
    settings: {
      ...defaultCurlSettings,
      ...settings,
    },
    set(settings: Partial<CurlType>) {
      Object.assign(this.settings, settings);
      return this.settings;
    },
    validate() {
      checkRequiredFields(["vector"], this.settings);
    },
    generateCode() {
      this.validate();
      const { vector, curl_name } = this.settings;

      return `
    from modulus.eq.pdes.navier_stokes import Curl
    ${this.slug} = Curl((${vector.map((v) => (typeof v === "number" ? v : `"${v}"`)).join(", ")}), ${curl_name
        .map((v) => `"${v}"`)
        .join(", ")}))
    nodes = nodes + ${this.slug}.make_nodes()
`;
    },
  };
}

export const defaultFluxContinuitySettings: Partial<FluxContinuityType> = {
  T: "T",
  rho: { symbol: "rho_T", value: 1.0, parameterized: false },
  dim: 3,
  vec: ["u", "v", "w"],
};

export function FluxContinuity(settings?: FluxContinuityType): FluxContinuityCreator {
  return {
    id: uuidv4(),
    mode: Object.freeze("FluxContinuity"),
    slug: Object.freeze("flux_continuity"),
    settings: {
      ...defaultFluxContinuitySettings,
      ...settings,
    },
    set(settings: Partial<FluxContinuityType>) {
      Object.assign(this.settings, settings);
      return this.settings;
    },
    validate() {
      checkRequiredFields(["T", "rho", "dim"], this.settings);
    },
    generateCode() {
      this.validate();
      const { T, rho, dim, vec } = this.settings;

      return `
    from modulus.eq.pdes.navier_stokes import FluxContinuity
    ${this.slug} = FluxContinuity(
        T=${T},
        rho=${rho.parameterized ? `Symbol("${rho.symbol}")` : rho.value || rho}},
        dim=${dim},
        vec=[${vec.map((v) => `"${v}"`).join(", ")}])
    nodes = nodes + ${this.slug}.make_nodes()
`;
    },
  };
}

export const defaultGradNormalSettings: Partial<GradNormalType> = {
  T: "T",
  dim: 3,
  time: false,
};

export function GradNormal(settings?: GradNormalType): GradNormalCreator {
  return {
    id: uuidv4(),
    mode: Object.freeze("GradNormal"),
    slug: Object.freeze("grad_normal"),
    settings: {
      ...defaultGradNormalSettings,
      ...settings,
    },
    set(settings: Partial<GradNormalType>) {
      Object.assign(this.settings, settings);
      return this.settings;
    },
    validate() {
      checkRequiredFields(["T", "dim"], this.settings);
    },
    generateCode() {
      this.validate();
      const { T, dim, time } = this.settings;

      return `
    from modulus.eq.pdes.navier_stokes import GradNormal
    ${this.slug} = GradNormal(
        T=${T},
        dim=${dim},
        time=${time ? "True" : "False"})
    nodes = nodes + ${this.slug}.make_nodes()
`;
    },
  };
}
