import { Activation } from "modulus-interop/activation-functions";
import { checkRequiredFields } from "modulus-interop/utils";
import { v4 as uuidv4 } from "uuid";
import type { FNOArch, FNOCreator } from "./types";
import { Parameter, SimulatorComponentSettings } from "modulus-interop/types";

export const defaultFNOArchSettings: Partial<FNOArch> = {
  input_keys: ["x", "y", "z"],
  parameterized_inputs: {},
  detach_keys: [],
  dimension: 3,
  decoder_net: "",
  nr_fno_layers: 4,
  fno_modes: [16],
  padding: 8,
  padding_type: "constant",
  activation_fn: Activation.GELU,
  coord_features: true,
};

export function FNOArch(
  settings?: FNOArch,
  variableParameters?: SimulatorComponentSettings["variable_parameters"]
): FNOCreator {
  return {
    id: uuidv4(),
    mode: Object.freeze("FNO"),
    slug: Object.freeze("fno"),
    settings: {
      ...defaultFNOArchSettings,
      ...settings,
    },
    set(settings: Partial<FNOArch>) {
      Object.assign(this.settings, settings);
      return this.settings;
    },
    validate() {
      checkRequiredFields(["dimension", "input_keys", "decoder_net"], this.settings);
    },
    generateCode() {
      this.validate();
      const {
        dimension,
        input_keys,
        decoder_net,
        nr_fno_layers,
        fno_modes,
        detach_keys,
        padding,
        padding_type,
        activation_fn,
        coord_features,
      } = this.settings;

      const parameterizedEquationsKeys = Object.values(variableParameters.equations)?.map(
        (input: Parameter) => `Key("${input?.symbol.toString()}")`
      );
      const parameterizedInputsKeys = Object.values(variableParameters.inputs)?.map(
        (input: Parameter) => `Key("${input?.symbol.toString()}")`
      );
      const allInputs = [...input_keys, ...parameterizedEquationsKeys, ...parameterizedInputsKeys];

      return `
    architecture_${this.slug} = modulus_models["fno"](
        input_keys=[${allInputs?.map((key: string) => `Key("${key.toString()}")`).join(", ")}],
        dimension=${dimension},
        decoder_net="${decoder_net}",
        detach_keys=[${detach_keys ? detach_keys.map((key: string) => `"${key.toString()}"`).join(", ") : "[]"}],
        nr_fno_layers=${nr_fno_layers},
        fno_modes=[${fno_modes?.map((mode: number) => `${mode}`).join(", ")}],
        padding=${padding},
        padding_type=${padding_type},
        coord_features=${coord_features ? "True" : "False"})

    nodes = nodes + [architecture_${this.slug}.make_node(name="architecture_${this.slug}")]`;
    },
  };
}
