import { checkRequiredFields } from "modulus-interop/utils";
import { v4 as uuidv4 } from "uuid";
import type { AFNOArch, AFNOCreator } from "./types";
import { Parameter, SimulatorComponentSettings } from "modulus-interop/types";

export const defaultAFNOArchSettings: Partial<AFNOArch> = {
  input_keys: ["x", "y", "z"],
  parameterized_inputs: {},
  output_keys: ["u", "v", "w", "p"],
  detach_keys: [],
  patch_size: 16,
  embed_dim: 256,
  depth: 4,
  num_blocks: 4,
};

export function AFNOArch(
  settings?: AFNOArch,
  variableParameters?: SimulatorComponentSettings["variable_parameters"]
): AFNOCreator {
  return {
    id: uuidv4(),
    mode: Object.freeze("AFNO"),
    slug: Object.freeze("afno"),
    settings: {
      ...defaultAFNOArchSettings,
      ...settings,
    },
    set(settings: Partial<AFNOArch>) {
      Object.assign(this.settings, settings);
      return this.settings;
    },
    validate() {
      checkRequiredFields(["input_keys", "output_keys", "img_height", "img_width"], this.settings);
    },
    generateCode() {
      this.validate();
      const { input_keys, output_keys, img_height, img_width, detach_keys, patch_size, embed_dim, depth, num_blocks } =
        this.settings;

      const parameterizedEquationsKeys = Object.values(variableParameters.equations)?.map(
        (input: Parameter) => `Key("${input?.symbol.toString()}")`
      );
      const parameterizedInputsKeys = Object.values(variableParameters.inputs)?.map(
        (input: Parameter) => `Key("${input?.symbol.toString()}")`
      );
      const allInputs = [...input_keys, ...parameterizedEquationsKeys, ...parameterizedInputsKeys];

      return `
    architecture_${this.slug} = modulus_models["afno"](
        input_keys=[${allInputs?.map((key: string) => `Key("${key.toString()}")`).join(", ")}],
        output_keys=[${output_keys?.map((key: string) => `Key("${key.toString()}")`).join(", ")}],
        img_shape=(${img_height}, ${img_width}),
        detach_keys=[${detach_keys ? detach_keys.map((key: string) => `"${key.toString()}"`).join(", ") : "[]"}],
        patch_size=${patch_size},
        embed_dim=${embed_dim},
        depth=${depth},
        num_blocks=${num_blocks})

    nodes = nodes + [architecture_${this.slug}.make_node(name="architecture_${this.slug}")]`;
    },
  };
}
