import {
  CorrelationLabel,
  MAPELabel,
  MSELabel,
  MaxDiffLabel,
  MinDiffLabel,
  RMSELabel,
  SMAPELabel,
} from "constants/chartConstants";
import { EUnit, EValFormat, IFormatObject } from "../src/types/format";

export interface Metric {
  label: string;
  value: number;
  format: IFormatObject;
}

export function getMetrics(sanitizedArrays: number[][], overwritePercentage = false): Metric[] {
  let mse_value = mse(sanitizedArrays);
  const rmse_value = rmse(sanitizedArrays);
  const corr_value = correlation(sanitizedArrays);
  const mape_value = mape(sanitizedArrays);
  const smape_value = smape(sanitizedArrays);
  const minDiff_value = minDiff(sanitizedArrays);
  const maxDiff_value = maxDiff(sanitizedArrays);

  const arrayToReturn: Metric[] = [];

  let mse_metric = null;
  if (overwritePercentage) {
    mse_value *= 100;
    mse_metric = createMetric(MSELabel, mse_value, true);
  } else {
    mse_metric = createMetric(MSELabel, mse_value, false);
  }
  if (mse_metric) arrayToReturn.push(mse_metric);

  let rmse_metric = null;
  if (overwritePercentage) {
    rmse_metric = createMetric(RMSELabel, rmse_value, true);
  } else {
    rmse_metric = createMetric(RMSELabel, rmse_value, false);
  }
  if (rmse_metric) arrayToReturn.push(rmse_metric);

  const corr_metric = createMetric(CorrelationLabel, corr_value, false, 3);
  if (corr_metric) arrayToReturn.push(corr_metric);

  const mape_metric = createMetric(MAPELabel, mape_value, true);
  if (mape_metric) arrayToReturn.push(mape_metric);

  const smape_metric = createMetric(SMAPELabel, smape_value, true);
  if (smape_metric) arrayToReturn.push(smape_metric);

  let mindiff_metric = null;
  if (overwritePercentage) {
    mindiff_metric = createMetric(MinDiffLabel, minDiff_value, true);
  } else {
    mindiff_metric = createMetric(MinDiffLabel, minDiff_value, false);
  }
  if (mindiff_metric) arrayToReturn.push(mindiff_metric);

  let maxdiff_metric = null;
  if (overwritePercentage) {
    maxdiff_metric = createMetric(MaxDiffLabel, maxDiff_value, true);
  } else {
    maxdiff_metric = createMetric(MaxDiffLabel, maxDiff_value, false);
  }
  if (maxdiff_metric) arrayToReturn.push(maxdiff_metric);

  return arrayToReturn;
}

export function createMetric(
  label: string,
  value: number,
  isPercentage: boolean,
  decimalPlaces = 1
): Metric | undefined {
  if (isNaN(value)) return;

  const format: IFormatObject = { decimalPlaces: decimalPlaces, hasSuffix: false, unit: EUnit.RAW };

  if (Math.abs(value) > 1e12) {
    format.unit = EUnit.TRILLIONS;
    format.hasSuffix = true;
  } else if (Math.abs(value) > 1e9) {
    format.unit = EUnit.BILLIONS;
    format.hasSuffix = true;
  } else if (Math.abs(value) > 1e6) {
    format.unit = EUnit.MILLIONS;
    format.hasSuffix = true;
  } else if (Math.abs(value) > 1e3) {
    format.unit = EUnit.THOUSANDS;
    format.hasSuffix = true;
  }

  if (isPercentage) {
    format.valFormat = EValFormat.PERCENTAGE;
  } else {
    format.valFormat = EValFormat.NUMERIC;
  }

  return { label: label, value: value, format: format };
}

// this function checks that the input arrays are valid for performing mathematical operations
// including ensuring equal, non-zero lengths, and finite elements.
export function checkArrayValidityForMathsOps(arrs: number[][], arrNum: number): boolean {
  if (arrs.length !== arrNum) {
    return false;
  }
  for (let i = 1; i < arrNum; i++) {
    if (arrs[i].length !== arrs[0].length) {
      return false;
    }
  }

  if (arrs[0].length === 0) {
    return false;
  }

  for (let i = 0; i < arrNum; i++) {
    arrs[i].forEach((element) => {
      if (!isFinite(element)) return false;
    });
  }

  return true;
}

export function mean(arr: number[]): number {
  if (arr.length === 0) return NaN;
  let sum = 0;
  arr.forEach((element) => {
    sum += element;
  });

  return sum / arr.length;
}

export function squareMean(arr: number[]): number {
  if (arr.length === 0) {
    return 0;
  }
  let sumSquares = 0;
  arr.forEach((element) => {
    sumSquares += Math.pow(element, 2);
  });
  const mse = sumSquares / arr.length;
  return mse;
}

export function arrayDifferences(arrs: number[][]): number[] {
  const diffs: number[] = [];
  const valid = checkArrayValidityForMathsOps(arrs, 2);
  if (!valid) {
    return diffs;
  }
  for (let i = 0; i < arrs[0].length; i++) {
    diffs.push(arrs[1][i] - arrs[0][i]);
  }
  return diffs;
}

export function correlation(arrs: number[][]): number {
  const valid = checkArrayValidityForMathsOps(arrs, 2);
  if (!valid) {
    return NaN;
  }
  const xBar = mean(arrs[0]);
  if (!xBar) {
    return NaN;
  }
  const yBar = mean(arrs[1]);
  if (!yBar) {
    return NaN;
  }

  let sumProducts = 0;

  for (let i = 0; i < arrs[0].length; i++) {
    sumProducts += (arrs[0][i] - xBar) * (arrs[1][i] - yBar);
  }

  let sumXSquaredDevs = 0;
  for (let i = 0; i < arrs[0].length; i++) {
    sumXSquaredDevs += Math.pow(arrs[0][i] - xBar, 2);
  }
  const rootSumXSquaredDevs = Math.sqrt(sumXSquaredDevs);

  let sumYSquaredDevs = 0;
  for (let i = 0; i < arrs[1].length; i++) {
    sumYSquaredDevs += Math.pow(arrs[1][i] - yBar, 2);
  }
  const rootSumYSquaredDevs = Math.sqrt(sumYSquaredDevs);

  if (rootSumXSquaredDevs === 0 || rootSumYSquaredDevs === 0) {
    return NaN;
  }
  const correlation = sumProducts / (rootSumXSquaredDevs * rootSumYSquaredDevs);

  return correlation;
}

export function mape(arrs: number[][]): number {
  const valid = checkArrayValidityForMathsOps(arrs, 2);
  if (!valid) {
    return NaN;
  }

  const diffs = arrayDifferences(arrs);
  if (diffs.length === 0) {
    return NaN;
  }

  for (let i = 0; i < arrs[0].length; i++) {
    if (arrs[0][i] === 0) {
      return NaN;
    }
    diffs[i] = Math.abs(diffs[i] / arrs[0][i]);
  }
  return mean(diffs);
}

export function smape(arrs: number[][]): number {
  const valid = checkArrayValidityForMathsOps(arrs, 2);
  if (!valid) {
    return NaN;
  }

  const diffs = arrayDifferences(arrs);
  if (diffs.length === 0) {
    return NaN;
  }

  for (let i = 0; i < diffs.length; i++) {
    const denom = Math.abs(arrs[0][i]) + Math.abs(arrs[1][i]);
    if (denom === 0) {
      return NaN;
    }
    diffs[i] = Math.abs(diffs[i]) / denom;
  }

  return mean(diffs);
}

export function mse(arrs: number[][]): number {
  const valid = checkArrayValidityForMathsOps(arrs, 2);
  if (!valid) {
    return NaN;
  }
  const diffs = arrayDifferences(arrs);

  return squareMean(diffs);
}

export function rmse(arrs: number[][]): number {
  const MSE = mse(arrs);
  if (MSE) {
    return Math.sqrt(+MSE);
  }
  return NaN;
}

export function maxDiff(arrs: number[][]): number {
  const isValid = checkArrayValidityForMathsOps(arrs, 2);

  if (!isValid) {
    return NaN;
  }
  const diffs = arrayDifferences(arrs);

  return Math.max(...diffs);
}

export function minDiff(arrs: number[][]): number {
  const isValid = checkArrayValidityForMathsOps(arrs, 2);
  if (!isValid) {
    return NaN;
  }
  const diffs = arrayDifferences(arrs);
  return Math.min(...diffs);
}
