import { Row, Col, Button } from "react-bootstrap";
import yaml from "js-yaml";
import { saveAs } from "file-saver";
import { Simulator } from "types";
import set from "lodash/set";
import { ModulusHydraConfig } from "modulus-interop/types";
import { useUpdateSimulator } from "mutations";
import Select from "react-select";
import { useState } from "react";

const defaultConfig: ModulusHydraConfig = {
  defaults: {
    scheduler: "tf_exponential_lr",
    optimizer: "adam",
    loss: "sum",
  },
  scheduler: {
    decay_rate: 0.1,
    decay_steps: 10000,
  },
  training: {
    rec_results_freq: 5000,
    rec_constraint_freq: 10_000, // not changeable by user
    max_steps: 1_000_000,
    rec_validation_freq: 5_000,
    rec_inference_freq: 5_000,
    rec_monitor_freq: 5_000,
    save_network_freq: 1_000,
    print_stats_freq: 100, // not changeable by user
    summary_freq: 1_000, // not changeable by user
    batch_size: 10_000,
    amp: false,
  },
  cuda_graphs: true,
  cuda_graph_warmup: 20,
  jit: false,
};

type TSelectOption = {
  value: string;
  label: string;
};

type SimulatorConfiguratorProps = {
  simulator: Simulator;
};

type WrongInput = {
  [key: string]: boolean;
}
export const SimulatorConfigurator = ({ simulator }: SimulatorConfiguratorProps) => {
  const [config, setConfig] = useState<ModulusHydraConfig>(simulator.modulus_components?.conf ?? defaultConfig);
  const [isWrongInput, toggleWrongInput] = useState<WrongInput>({});

  const updateConfig = (key: string, value: any) => {
    setConfig({ ...config, ...set(config, key, value) });
  };

  const updateMutation = useUpdateSimulator();

  const buildConfig = (): ModulusHydraConfig =>
    Object.assign(simulator.modulus_components?.conf, {
      defaults: {
        scheduler: config?.defaults?.scheduler ?? defaultConfig.defaults.scheduler,
        optimizer: config?.defaults?.optimizer ?? defaultConfig.defaults.optimizer,
        loss: config?.defaults?.loss ?? defaultConfig.defaults.loss,
      },
      scheduler: {
        decay_rate: config?.scheduler?.decay_rate ?? defaultConfig.scheduler.decay_rate,
        decay_steps: config?.scheduler?.decay_steps ?? defaultConfig.scheduler.decay_steps,
      },
      training: {
        rec_results_freq: config?.training?.rec_results_freq ?? defaultConfig.training.rec_results_freq,
        rec_constraint_freq: config?.training?.rec_constraint_freq ?? defaultConfig.training.rec_constraint_freq,
        max_steps: config?.training?.max_steps ?? defaultConfig.training.max_steps,
        rec_validation_freq: config?.training?.rec_validation_freq ?? defaultConfig.training.rec_validation_freq,
        rec_inference_freq: config?.training?.rec_inference_freq ?? defaultConfig.training.rec_inference_freq,
        rec_monitor_freq: config?.training?.rec_monitor_freq ?? defaultConfig.training.rec_monitor_freq,
        save_network_freq: config?.training?.save_network_freq ?? defaultConfig.training.save_network_freq,
        print_stats_freq: config?.training?.print_stats_freq ?? defaultConfig.training.print_stats_freq,
        summary_freq: config?.training?.summary_freq ?? defaultConfig.training.summary_freq,
        batch_size: config?.training?.batch_size ?? defaultConfig?.training?.batch_size,
        amp: config?.training?.amp ?? defaultConfig.training.amp,
      },
      cuda_graphs: config?.cuda_graphs ?? defaultConfig.cuda_graphs,
      cuda_graph_warmup: config?.cuda_graph_warmup ?? defaultConfig.cuda_graph_warmup,
      jit: config?.jit ?? defaultConfig.jit,
    });

  const handleSave = (e: any) => {
    e.preventDefault();

    if (simulator) {
      const conf: ModulusHydraConfig = buildConfig();

      const updatedModulusComponents = {
        ...simulator.modulus_components,
        conf,
      };

      updateMutation.mutate({
        id: simulator.id,
        modulus_components: updatedModulusComponents,
      });
    }
  };

  const handleYamlGeneration = () => {
    const conf: ModulusHydraConfig = buildConfig();
    const yamlDump = yaml.dump(conf, {
      styles: {
        "!!null": "canonical", // dump null as ~
      },
    });

    // eslint-disable-next-line no-restricted-globals
    if (confirm("Download generated config.yaml?")) {
      saveAs(
        new File([yamlDump], "config.yaml", {
          type: "text/yaml;charset=utf-8",
        })
      );
    }
  };

  const schedulerOptions = [
    { value: "constant", label: "Constant" },
    { value: "reflect", label: "Reflect" },
    { value: "replicate", label: "Replicate" },
    { value: "circular", label: "Circular" },
  ];

  const optmizerOptions = [{ value: "tf_exponential_lr", label: "Exponential learning rate" }];

  const lossOptions = [{ value: "sum", label: "Sum" }];

  return (
    <div>
      <Row className="mt-3 mb-5">
        <Col sm="3">
          <h3>Defaults</h3>
          <hr />
          <Row>
            <Col>
              <h5>Scheduler</h5>
              <Select<TSelectOption>
                className="react-select settings-select"
                options={schedulerOptions}
                defaultValue={
                  schedulerOptions.find((o) => o.value === config?.defaults?.scheduler) ?? {
                    value: "constant",
                    label: "Constant",
                  }
                }
                onChange={(option) => updateConfig("defaults.scheduler", option.value)}
              />
            </Col>
          </Row>
          <Row>
            <Col>
              <h5>Optimizer</h5>
              <Select<TSelectOption>
                className="react-select settings-select"
                options={optmizerOptions}
                defaultValue={
                  optmizerOptions.find((o) => o.value === config?.defaults?.optimizer) ?? {
                    value: "tf_exponential_lr",
                    label: "Exponential learning rate",
                  }
                }
                onChange={(option) => updateConfig("defaults.optimizer", option.value)}
              />
            </Col>
          </Row>
          <Row>
            <Col>
              <h5>Loss</h5>
              <Select<TSelectOption>
                className="react-select settings-select"
                options={lossOptions}
                defaultValue={
                  lossOptions.find((o) => o.value === config?.defaults?.loss) ?? {
                    value: "sum",
                    label: "Sum",
                  }
                }
                onChange={(option) => updateConfig("defaults.loss", option.value)}
              />
            </Col>
          </Row>
        </Col>
        <Col sm="3">
          <h3>Scheduler</h3>
          <hr />

          <Row>
            <Col>
              <h5>Decay rate</h5>
              <input
                type="number"
                className="form-control"
                defaultValue={config?.scheduler?.decay_rate ?? defaultConfig.scheduler.decay_rate}
                onChange={(e) => updateConfig("scheduler.decay_rate", parseFloat(e.target.value))}
              />
            </Col>
          </Row>

          <Row>
            <Col>
              <h5>Decay steps</h5>
              <input
                type="number"
                className="form-control"
                defaultValue={config?.scheduler?.decay_steps ?? defaultConfig.scheduler.decay_steps}
                onChange={(e) => updateConfig("scheduler.decay_steps", parseInt(e.target.value))}
              />
            </Col>
          </Row>
        </Col>
      </Row>

      <Row>
        <Col sm="5">
          <h3>Training</h3>
          <hr />

          <Row>
            <Col>
              <h5>Maximum steps before automatically stopping the training {isWrongInput.max_steps && <small style={{color: "red"}}>minimum 1000</small>}</h5>
              <input
                type="number"
                className="form-control"
                defaultValue={config?.training?.max_steps ?? defaultConfig.training.max_steps}
                style={{border: isWrongInput.max_steps ? "1px solid red" : "", color: isWrongInput.max_steps ? "red" : ""}}
                step={1000}
                onChange={(e) => {
                  const number = parseInt(e.target.value);
                  if (number < 1000) {
                    toggleWrongInput({...isWrongInput, max_steps: true});
                    return;
                  };
                  toggleWrongInput({...isWrongInput, max_steps: false});
                  updateConfig("training.max_steps", parseInt(e.target.value))
                }}
              />
            </Col>
          </Row>

          <Row>
            <Col>
              <h5>
                Checkpoint saving frequency <small>(each X steps)</small> {isWrongInput.save_network_freq && <small style={{color: "red"}}>minimum 1000</small>}
              </h5>
              <input
                type="number"
                className="form-control"
                style={{border: isWrongInput.save_network_freq ? "1px solid red" : "", color: isWrongInput.save_network_freq ? "red" : ""}}
                defaultValue={config?.training?.save_network_freq ?? defaultConfig.training.save_network_freq}
                step={1000}
                onChange={(e) => {
                  const number = parseInt(e.target.value);
                  if (number < 1000) {
                    toggleWrongInput({...isWrongInput, save_network_freq: true});
                    return;
                  };
                  toggleWrongInput({...isWrongInput, save_network_freq: false});
                  updateConfig("training.save_network_freq", parseInt(e.target.value))
                }}
              />
            </Col>
          </Row>

          <Row>
            <Col>
              <h5>
                Results recording frequency <small>(each X steps)</small>{isWrongInput.rec_results_freq && <small style={{color: "red"}}> minimum 1000</small>}
              </h5>
              <input
                type="number"
                className="form-control"
                defaultValue={config?.training?.rec_results_freq ?? defaultConfig.training.rec_results_freq}
                style={{border: isWrongInput.rec_results_freq ? "1px solid red" : "", color: isWrongInput.rec_results_freq ? "red" : ""}}
                step={1000}
                onChange={(e) => {
                  const number = parseInt(e.target.value);
                  if (number < 1000) {
                    toggleWrongInput({...isWrongInput, rec_results_freq: true});
                    return;
                  };
                  toggleWrongInput({...isWrongInput, rec_results_freq: false});
                  updateConfig("training.rec_results_freq", parseInt(e.target.value))
                }}
              />
            </Col>
          </Row>

          <Row>
            <Col>
              <h5>
                Validation recording frequency <small>(each X steps)</small>{isWrongInput.rec_validation_freq && <small style={{color: "red"}}> minimum 1000</small>}
              </h5>
              <input
                type="number"
                className="form-control"
                step={1000}
                style={{border: isWrongInput.rec_validation_freq ? "1px solid red" : "", color: isWrongInput.rec_validation_freq ? "red" : ""}}
                defaultValue={config?.training?.rec_validation_freq ?? defaultConfig.training.rec_validation_freq}
                onChange={(e) => {
                  const number = parseInt(e.target.value);
                  if (number < 1000) {
                    toggleWrongInput({...isWrongInput, rec_validation_freq: true});
                    return;
                  };
                  toggleWrongInput({...isWrongInput, rec_validation_freq: false});
                  updateConfig("training.rec_validation_freq", parseInt(e.target.value))
                }}
              />
            </Col>
          </Row>

          <Row>
            <Col>
              <h5>
                Value monitoring recording frequency <small>(each X steps)</small>{isWrongInput.rec_monitor_freq && <small style={{color: "red"}}> minimum 1000</small>}
              </h5>
              <input
                type="number"
                className="form-control"
                step={1000}
                style={{border: isWrongInput.rec_monitor_freq ? "1px solid red" : "", color: isWrongInput.rec_monitor_freq ? "red" : ""}}
                defaultValue={config?.training?.rec_monitor_freq ?? defaultConfig.training.rec_monitor_freq}
                onChange={(e) => {
                  const number = parseInt(e.target.value);
                  if (number < 1000) {
                    toggleWrongInput({...isWrongInput, rec_monitor_freq: true});
                    return;
                  };
                  toggleWrongInput({...isWrongInput, rec_monitor_freq: false});
                  updateConfig("training.rec_monitor_freq", parseInt(e.target.value))
                }}
              />
            </Col>
          </Row>

          <Row>
            <Col>
              <h5>
                Inference recording frequency <small>(each X steps)</small>{isWrongInput.rec_inference_freq && <small style={{color: "red"}}> minimum 1000</small>}
              </h5>
              <input
                type="number"
                className="form-control"
                step={1000}
                style={{border: isWrongInput.rec_inference_freq ? "1px solid red" : "", color: isWrongInput.rec_inference_freq ? "red" : ""}}
                defaultValue={config?.training?.rec_inference_freq ?? defaultConfig.training.rec_inference_freq}
                onChange={(e) => {
                  const number = parseInt(e.target.value);
                  if (number < 1000) {
                    toggleWrongInput({...isWrongInput, rec_inference_freq: true});
                    return;
                  };
                  toggleWrongInput({...isWrongInput, rec_inference_freq: false});
                  updateConfig("training.rec_inference_freq", parseInt(e.target.value))
                }}
              />
            </Col>
          </Row>

          <Row>
            <Col>
              <h5>
                Batch size
              </h5>
              <input
                type="number"
                className="form-control"
                step={1000}
                defaultValue={config?.training?.batch_size ?? defaultConfig.training.batch_size}
                onChange={(e) => {
                  updateConfig("training.batch_size", parseInt(e.target.value))
                }}
              />
            </Col>
          </Row>
        </Col>
        <Col sm="4">
          <h3>Advanced settings</h3>
          <hr />

          <Row>
            <Col>
              <h5>Use Just-in-Time compilation of neural network nodes?</h5>
              <input
                type="checkbox"
                checked={config?.jit ?? defaultConfig.jit}
                onChange={(e) => updateConfig("jit", e.target.checked)}
              />
            </Col>
          </Row>

          <Row>
            <Col>
              <h5>Use CUDA graphs?</h5>
              <input
                type="checkbox"
                checked={config?.cuda_graphs}
                onChange={(e) => updateConfig("cuda_graphs", e.target.checked)}
              />
            </Col>
          </Row>

          <Row>
            <Col>
              <h5>CUDA graph warmup steps</h5>
              <input
                type="number"
                className="form-control"
                defaultValue={config?.cuda_graph_warmup ?? defaultConfig.cuda_graph_warmup}
                onChange={(e) => updateConfig("cuda_graph_warmup", parseInt(e.target.value))}
              />
            </Col>
          </Row>
        </Col>
      </Row>

      <Row className="mt-4">
        <Col>
          <Button variant="success" onClick={handleSave} disabled={Object.keys(isWrongInput || {}).some((key) => isWrongInput[key])}> {/* if any value is wrong, disable btn */}
            Save
          </Button>
          <Button className="ms-1" variant="warning" onClick={handleYamlGeneration}>
            Download <strong>config.yaml</strong>
          </Button>
        </Col>
      </Row>
    </div>
  );
};
