import { Activation } from "modulus-interop/activation-functions";
import { checkRequiredFields } from "modulus-interop/utils";
import { v4 as uuidv4 } from "uuid";
import type { FourierNetArch, FourierNetCreator, ModifiedFourierNetArch, ModifiedFourierNetCreator } from "./types";
import { Parameter, SimulatorComponentSettings } from "modulus-interop/types";

export const defaultFourierNetArchSettings: Partial<FourierNetArch> = {
  input_keys: ["x", "y", "z"],
  parameterized_inputs: {},
  output_keys: ["u", "v", "w", "p"],
  detach_keys: [],
  frequencies: { frequency_type: "axis", freqs: "[i for i in range(10)]" },
  frequencies_params: { frequency_type: "axis", freqs: "[i for i in range(10)]" },
  layer_size: 512,
  nr_layers: 6,
  skip_connections: false,
  weight_norm: true,
  adaptive_activations: false,
  activation_fn: Activation.SILU,
};

export function FourierNetArch(
  settings?: FourierNetArch,
  variableParameters?: SimulatorComponentSettings["variable_parameters"]
): FourierNetCreator {
  return {
    id: uuidv4(),
    mode: Object.freeze("FourierNetArch"),
    slug: Object.freeze("fourier"),
    settings: {
      ...defaultFourierNetArchSettings,
      ...settings,
    },
    set(settings: Partial<FourierNetArch>) {
      Object.assign(this.settings, settings);
      return this.settings;
    },
    validate() {
      checkRequiredFields(["input_keys", "output_keys"], this.settings);
    },
    generateCode() {
      this.validate();
      const {
        input_keys,
        output_keys,
        detach_keys,
        frequencies,
        frequencies_params,
        layer_size,
        nr_layers,
        activation_fn,
        skip_connections,
        weight_norm,
        adaptive_activations,
      } = this.settings;

      const parameterizedEquationsKeys = Object.values(variableParameters.equations)?.map(
        (input: Parameter) => `Key("${input?.symbol.toString()}")`
      );
      const parameterizedInputsKeys = Object.values(variableParameters.inputs)?.map(
        (input: Parameter) => `Key("${input?.symbol.toString()}")`
      );
      const allInputs = [...input_keys, ...parameterizedEquationsKeys, ...parameterizedInputsKeys];

      const freqs =
        typeof frequencies?.freqs === "string"
          ? frequencies?.freqs
          : `[${frequencies?.freqs.map((f: number) => f).join(", ")}]`;

      const freq_params =
        typeof frequencies_params?.freqs === "string"
          ? frequencies_params?.freqs
          : `[${frequencies_params?.freqs.map((f: number) => f).join(", ")}]`;

      return `
    architecture_${this.slug} = modulus_models["fourier"](
        input_keys=[${allInputs?.map((key: string) => `Key("${key.toString()}")`).join(", ")}],
        output_keys=[${output_keys?.map((key: string) => `Key("${key.toString()}")`).join(", ")}],
        detach_keys=[${detach_keys ? detach_keys.map((key: string) => `"${key.toString()}"`).join(", ") : "[]"}],
        frequencies=("${frequencies?.frequency_type}", ${freqs}),
        frequencies_params=["${frequencies_params?.frequency_type}", ${freq_params}],
        layer_size=${layer_size},
        nr_layers=${nr_layers},
        weight_norm=${weight_norm ? "True" : "False"},
        skip_connections=${skip_connections ? "True" : "False"},
        adaptive_activations=${adaptive_activations ? "True" : "False"})

    nodes = nodes + [architecture_${this.slug}.make_node(name="architecture_${this.slug}")]`;
    },
  };
}

export const defaultModifiedFourierNetArchSettings: Partial<ModifiedFourierNetArch> = {
  input_keys: ["x", "y", "z"],
  parameterized_inputs: {},
  output_keys: ["u", "v", "w", "p"],
  detach_keys: [],
  frequencies: { frequency_type: "axis", freqs: "[i for i in range(10)]" },
  frequencies_params: { frequency_type: "axis", freqs: "[i for i in range(10)]" },
  layer_size: 512,
  nr_layers: 6,
  skip_connections: false,
  weight_norm: true,
  adaptive_activations: false,
  activation_fn: Activation.SILU,
};

export function ModifiedFourierNetArch(
  settings?: ModifiedFourierNetArch,
  variableParameters?: SimulatorComponentSettings["variable_parameters"]
): ModifiedFourierNetCreator {
  return {
    id: uuidv4(),
    mode: Object.freeze("ModifiedFourierNetArch"),
    slug: Object.freeze("modified_fourier"),
    settings: {
      ...defaultModifiedFourierNetArchSettings,
      ...settings,
    },
    set(settings: Partial<ModifiedFourierNetArch>) {
      Object.assign(this.settings, settings);
      return this.settings;
    },
    validate() {
      checkRequiredFields(["input_keys", "output_keys"], this.settings);
    },
    generateCode() {
      this.validate();
      const {
        input_keys,
        output_keys,
        detach_keys,
        frequencies,
        frequencies_params,
        layer_size,
        nr_layers,
        activation_fn,
        skip_connections,
        weight_norm,
        adaptive_activations,
      } = this.settings;

      const parameterizedEquationsKeys = Object.values(variableParameters.equations)?.map(
        (input: Parameter) => `Key("${input?.symbol.toString()}")`
      );
      const parameterizedInputsKeys = Object.values(variableParameters.inputs)?.map(
        (input: Parameter) => `Key("${input?.symbol.toString()}")`
      );
      const allInputs = [...input_keys, ...parameterizedEquationsKeys, ...parameterizedInputsKeys];

      const freqs =
        typeof frequencies?.freqs === "string"
          ? frequencies?.freqs
          : `[${frequencies?.freqs.map((f: number) => f).join(", ")}]`;

      const freq_params =
        typeof frequencies_params?.freqs === "string"
          ? frequencies_params?.freqs
          : `[${frequencies_params?.freqs.map((f: number) => f).join(", ")}]`;

      return `
    architecture_${this.slug} = modulus_models["modified_fourier"](
        input_keys=[${allInputs?.map((key: string) => `Key("${key.toString()}")`).join(", ")}],
        output_keys=[${output_keys?.map((key: string) => `Key("${key.toString()}")`).join(", ")}],
        detach_keys=[${detach_keys ? detach_keys.map((key: string) => `"${key.toString()}"`).join(", ") : "[]"}],
        frequencies=("${frequencies?.frequency_type}", ${freqs}),
        frequencies_params=["${frequencies_params?.frequency_type}", ${freq_params}],
        layer_size=${layer_size},
        nr_layers=${nr_layers},
        weight_norm=${weight_norm ? "True" : "False"},
        skip_connections=${skip_connections ? "True" : "False"},
        adaptive_activations=${adaptive_activations ? "True" : "False"})

    nodes = nodes + [architecture_${this.slug}.make_node(name="architecture_${this.slug}")]`;
    },
  };
}
