import { TrainingTelemetry } from "types/trainingtelemetry";

const groupByFields = (array: Array<any>, f: any): Array<any> => {
  /*
    params description :
        f : function which returns the array of fields
        e.g. :  (item) => {
            return [itemField1, itemField2];
        }
        array : array of data to group e.g. : [{...}, {...}]
    */
  var groups = {};
  // TODO: fix `any` type, specify obj. type
  array.forEach((o: any) => {
    var group = JSON.stringify(f(o));
    (groups as any)[group] = (groups as any)[group] || [];
    (groups as any)[group].push(o);
  });
  return Object.keys(groups).map((group) => {
    return (groups as any)[group];
  });
};

function filterOutliers(someArray: Array<number>) {
  if (someArray.length < 4) return someArray;

  let values: Array<number>, q1: number, q3: number, iqr: number, maxValue: number, minValue: number;

  values = someArray.slice().sort((a, b) => a - b); //copy array fast and sort

  if ((values.length / 4) % 1 === 0) {
    //find quartiles
    q1 = (1 / 2) * (values[values.length / 4] + values[values.length / 4 + 1]);
    q3 = (1 / 2) * (values[values.length * (3 / 4)] + values[values.length * (3 / 4) + 1]);
  } else {
    q1 = values[Math.floor(values.length / 4 + 1)];
    q3 = values[Math.ceil(values.length * (3 / 4) + 1)];
  }

  iqr = q3 - q1;
  maxValue = q3 + iqr * 1.5;
  minValue = q1 - iqr * 1.5;

  return values.filter((x) => x >= minValue && x <= maxValue);
}

// this function is specific for train_loss_aggregated chart
// outliers in data points caused distortion in chart
// filterOutliers function cannot be used because it returns array of numbers and we need array of array of step, train_loss_aggregated
// this is not the prettiest solution but it works
function getFilteredDataForTrainLossAggregatedChartWithoutOutliers(telemetryArray: TrainingTelemetry[]) {
  type DataForChart = {
    step: number;
    train_loss_aggregated: number;
  }
  const someArray = telemetryArray.map((telemetry) => telemetry.train_loss_aggregated);
  if (someArray.length < 4) {
    const objectArr: DataForChart[] = telemetryArray.map((telemetry) => {
      return {
        step: telemetry.step,
        train_loss_aggregated: telemetry.train_loss_aggregated
      }
    });
    return objectArr;
  };

  let values: Array<number>, q1: number, q3: number, iqr: number, maxValue: number, minValue: number;

  values = someArray.slice().sort((a, b) => a - b); //copy array fast and sort

  if ((values.length / 4) % 1 === 0) {
    //find quartiles
    q1 = (1 / 2) * (values[values.length / 4] + values[values.length / 4 + 1]);
    q3 = (1 / 2) * (values[values.length * (3 / 4)] + values[values.length * (3 / 4) + 1]);
  } else {
    q1 = values[Math.floor(values.length / 4 + 1)];
    q3 = values[Math.ceil(values.length * (3 / 4) + 1)];
  }

  iqr = q3 - q1;
  maxValue = q3 + iqr * 1.5;
  minValue = q1 - iqr * 1.5;

  const telemetryWithoutOutliers: TrainingTelemetry[] = telemetryArray.filter((telemetry) => telemetry.train_loss_aggregated >= minValue && telemetry.train_loss_aggregated <= maxValue);

  const objectArr: DataForChart[] = telemetryWithoutOutliers.map((telemetry) => {
    return {
      step: telemetry.step,
      train_loss_aggregated: telemetry.train_loss_aggregated
    }
  });

  return objectArr;
}

export type Parameter = {
  key: "Train/loss_p" | "Train/loss_u" | "Train/loss_v" | "Train/loss_w" | "Train/loss_continuity" | "Train/loss_momentum_x" | "Train/loss_momentum_y" | "Train/loss_momentum_z";
}

function getFilteredIndividualLossesDataWithoutOutliers (telemetryArray: TrainingTelemetry[], whichParameter: Parameter['key']) {
  type DataForChart = {
    step: number;
    parameterBasedMetric: number;
  }
  const recordArray = telemetryArray.map((telemetry) => telemetry.train_losses);
  const array = recordArray.map((record) => record).map((record) => record[whichParameter]);
  if (telemetryArray.length < 4) {
    const objectArr: DataForChart[] = telemetryArray.map((telemetry) => {
      return {
        step: telemetry.step,
        parameterBasedMetric: telemetry.train_losses[whichParameter]
      }
    });
    return objectArr;
  };
  let values: Array<number>, q1: number, q3: number, iqr: number, maxValue: number, minValue: number;

  values = array.slice().sort((a, b) => a - b); //copy array fast and sort
  if ((values.length / 4) % 1 === 0) {
    //find quartiles
    q1 = (1 / 2) * (values[values.length / 4] + values[values.length / 4 + 1]);
    q3 = (1 / 2) * (values[values.length * (3 / 4)] + values[values.length * (3 / 4) + 1]);
  } else {
    q1 = values[Math.floor(values.length / 4 + 1)];
    q3 = values[Math.ceil(values.length * (3 / 4) + 1)];
  }

  iqr = q3 - q1;
  maxValue = q3 + iqr * 1.5;
  minValue = q1 - iqr * 1.5;

  const telemetryWithoutOutliers: TrainingTelemetry[] = telemetryArray.filter((telemetry) => telemetry.train_losses[whichParameter] >= minValue && telemetry.train_losses[whichParameter] <= maxValue);
  const objectArr: DataForChart[] = telemetryWithoutOutliers.map((telemetry) => {
    return {
      step: telemetry.step,
      parameterBasedMetric: telemetry.train_losses[whichParameter]
    }
  });

  return objectArr;
}

export { groupByFields, filterOutliers, getFilteredDataForTrainLossAggregatedChartWithoutOutliers, getFilteredIndividualLossesDataWithoutOutliers };
