import { atom, selector } from 'recoil';
import { ProjectType, projectType } from '~/src/Project/store';
import { productionFileStatsState } from '~/src/TableFilters/store';
import { variablesState } from '~/src/TrainingFile/store';
import { algorithmStatsState, dependentVariableSelector } from './prediction-table.store';
import { confusionMatrixState, modelNameMapper, modelSelector, predictionStatsState } from './results.store';

const UNCKNOWN_VALUE = 'N/A';

export const summaryTargetColumnStatsState = atom<{ [key: string]: number }>({
  key: 'Summary/summaryTargetColumnStatsState',
  default: {},
});

export const pendingSummaryTargetColumnStatsState = atom({
  key: 'Summary/pendingSummaryTargetColumnStatsState',
  default: false,
});

export const tsAutomodelerState = atom<{ [key: string]: any }>({
  key: 'Summary/tsAutomodelerState',
  default: null,
});

export const pendingTSAutomodelerState = atom({
  key: 'Summary/pendingTSAutomodelerState',
  default: false,
});

export const confusionMatrixBinarySummaryStatsState = selector({
  key: 'Summary/confusionMatrixBinarySummaryStatsState',
  get: ({ get }) => {
    const matrix = get(confusionMatrixState);
    if (!matrix) return { precision: UNCKNOWN_VALUE, sensitivity: UNCKNOWN_VALUE, specificity: UNCKNOWN_VALUE };

    const truePositive = +matrix[2][1];
    const falsePositive = +matrix[1][1];
    const trueNagative = +matrix[1][0];
    const falseNegative = +matrix[2][0];

    return {
      precision: (truePositive / (truePositive + falsePositive)).toFixed(2) || UNCKNOWN_VALUE,
      sensitivity: (truePositive / (truePositive + falseNegative)).toFixed(2) || UNCKNOWN_VALUE,
      specificity: (trueNagative / (trueNagative + falseNegative)).toFixed(2) || UNCKNOWN_VALUE,
    };
  },
});

interface ClassPositiveNegativeData {
  tp: number;
  tn: number;
  fp: number;
  fn: number;
}
interface MultiClassConfusionMatrixCoeficients {
  truePositive: number;
  classes: {
    [key: string]: ClassPositiveNegativeData;
  };
}

const getPositiveNegativeCoefs = (matrix: string[][]) => {
  const positiveNegativeCoefs = matrix.reduce(
    (acc, item, i) => {
      if (!i || i === matrix.length - 1) return acc;
      const className = matrix[0][i - 1];
      acc.classes[className] = {} as ClassPositiveNegativeData;
      if (i < matrix.length - 1) {
        acc.truePositive += +item[i - 1];
        acc.classes[className].tp = +item[i - 1];
      }

      const fn = item.reduce((sum, elem, index) => (index >= item.length - 2 || index === i - 1 ? sum : +elem + sum), 0);

      acc.classes[className].fn = fn;

      return acc;
    },
    { truePositive: 0, classes: {} } as MultiClassConfusionMatrixCoeficients,
  );

  matrix[0].forEach((item) => {
    if (!positiveNegativeCoefs.classes[item]) return;
    positiveNegativeCoefs.classes[item].tn = positiveNegativeCoefs.truePositive - positiveNegativeCoefs.classes[item].tp;
  });

  matrix[matrix.length - 1].forEach((item, i) => {
    if (i >= matrix.length - 2) return;
    const className = matrix[0][i];
    positiveNegativeCoefs.classes[className].fp = +item - positiveNegativeCoefs.classes[className].tp;
  });

  return positiveNegativeCoefs;
};

export const confusionMatrixMultiClassSummaryStatsState = selector({
  key: 'Summary/confusionMatrixMultiClassSummaryStatsState',
  get: ({ get }) => {
    const matrix = get(confusionMatrixState);
    if (!matrix) return { precision: UNCKNOWN_VALUE, sensitivity: UNCKNOWN_VALUE, specificity: UNCKNOWN_VALUE };

    const positiveNegativeCoefs = getPositiveNegativeCoefs(matrix);
    const numberOfClasses = matrix.length - 2; // all rows minus two rows: headers and totals values

    const classesCoefs = Object.values(positiveNegativeCoefs.classes).reduce(
      (acc, { tp, fp, tn, fn }) => {
        acc.precisions.push(tp / (tp + fp));
        acc.sensitivities.push(tp / (tp + fn));
        acc.specificities.push(tn / (tn + fp));
        return acc;
      },
      { precisions: [], sensitivities: [], specificities: [] },
    );

    return {
      precision: (classesCoefs.precisions.reduce((acc, item) => acc + item, 0) / numberOfClasses).toFixed(2) || UNCKNOWN_VALUE,
      sensitivity: (classesCoefs.sensitivities.reduce((acc, item) => acc + item, 0) / numberOfClasses).toFixed(2) || UNCKNOWN_VALUE,
      specificity: (classesCoefs.specificities.reduce((acc, item) => acc + item, 0) / numberOfClasses).toFixed(2) || UNCKNOWN_VALUE,
    };
  },
});

export const algorithmSummaryStatsState = selector({
  key: 'Summary/algorithmSummaryStatsState',
  get: ({ get }) => {
    const type = get(projectType);
    const algoritmStats = get(algorithmStatsState);
    const matrix = get(confusionMatrixState);
    const predictivePower = algoritmStats?.auc_on_cross_val
      ? (((algoritmStats.auc_on_cross_val - 0.5) / 0.5) * 100).toFixed(2)
      : UNCKNOWN_VALUE;

    let accuracy = UNCKNOWN_VALUE;
    if (type === ProjectType.Binary) {
      accuracy = algoritmStats?.accuracy && (algoritmStats.accuracy[0][1] * 100).toFixed(2);
    } else if (type === ProjectType.MultiClass && matrix) {
      accuracy = matrix && (+matrix[matrix.length - 1][matrix.length - 2] * 100).toFixed(2); // take the penultimate value from last row
    }

    return {
      accuracy,
      predictivePower,
      maxF1: algoritmStats?.F1 ? (algoritmStats.F1[0][1] * 100).toFixed(1) : UNCKNOWN_VALUE,
      r2: algoritmStats?.R2 ? (+algoritmStats.R2).toFixed(4) : UNCKNOWN_VALUE,
    };
  },
});

export const summaryBinaryDataState = selector({
  key: 'Summary/summaryBinaryDataState',
  get: ({ get }) => {
    const model = get(modelSelector);
    const algoritmStats = get(algorithmSummaryStatsState);
    const matrixStats = get(confusionMatrixBinarySummaryStatsState) || {};
    const productionFileStats = get(productionFileStatsState);
    const dependentVariable = get(dependentVariableSelector);
    const summaryTargetColumnStats = get(summaryTargetColumnStatsState);
    const predictionStats = get(predictionStatsState);
    const targetColumnName = dependentVariable?.displayName || predictionStats?.target || '';
    const prodTotalRows = productionFileStats
      ? productionFileStats.column_stats[Object.keys(productionFileStats.column_stats)[0]][0]
      : UNCKNOWN_VALUE;
    const classes = (targetColumnName && productionFileStats && productionFileStats.column_levels?.[targetColumnName]) || [];

    const targetDistributions = Object.keys(summaryTargetColumnStats).length && classes.length && prodTotalRows !== UNCKNOWN_VALUE
      ? classes.map((className) => ((summaryTargetColumnStats[className] / prodTotalRows) * 100).toFixed(2))
      : [];
    return {
      model: modelNameMapper[model?.model_id] || model?.model_id || UNCKNOWN_VALUE,
      ...algoritmStats,
      ...matrixStats,
      prodTotalRows,
      targetColumn: targetColumnName,
      classes,
      targetDistributions,
    };
  },
});

export const summaryMultiClassDataState = selector({
  key: 'Summary/summaryMultiClassDataState',
  get: ({ get }) => {
    const model = get(modelSelector);
    const algoritmStats = get(algorithmSummaryStatsState);
    const matrixStats = get(confusionMatrixMultiClassSummaryStatsState) || {};
    const productionFileStats = get(productionFileStatsState);
    const dependentVariable = get(dependentVariableSelector);
    const summaryTargetColumnStats = get(summaryTargetColumnStatsState);
    const predictionStats = get(predictionStatsState);
    const targetColumnName = dependentVariable?.displayName || predictionStats?.target || '';

    const prodTotalRows = productionFileStats
      ? productionFileStats.column_stats[Object.keys(productionFileStats.column_stats)[0]][0]
      : UNCKNOWN_VALUE;
    const classes = (targetColumnName && productionFileStats && productionFileStats.column_levels?.[targetColumnName]) || [];

    return {
      model: modelNameMapper[model?.model_id] || model?.model_id || UNCKNOWN_VALUE,
      ...algoritmStats,
      ...matrixStats,
      prodTotalRows,
      targetColumn: targetColumnName,
      classes,
      targetDistributions:
        Object.keys(summaryTargetColumnStats).length && classes.length && prodTotalRows !== UNCKNOWN_VALUE
          ? classes.map((className) => ((summaryTargetColumnStats[className] / prodTotalRows) * 100).toFixed(2))
          : [],
    };
  },
});

export const summaryRegressionDataState = selector({
  key: 'Summary/summaryRegressionDataState',
  get: ({ get }) => {
    const algoritmStats = get(algorithmStatsState);
    const productionFileStats = get(productionFileStatsState);
    const model = get(modelSelector) || {};
    const mae = model[model.mae ? 'mae' : 'MAE'];

    return {
      prodTotalRows: productionFileStats
        ? productionFileStats.column_stats[Object.keys(productionFileStats.column_stats)[0]][0]
        : UNCKNOWN_VALUE,
      r2: algoritmStats?.R2 ? (+algoritmStats.R2).toFixed(4) : UNCKNOWN_VALUE,
      mae: mae ? (+mae).toFixed(4) : UNCKNOWN_VALUE,
    };
  },
});

export const summaryTimeSeriesDataState = selector({
  key: 'Summary/summaryTimeSeriesDataState',
  get: ({ get }) => {
    const tsAutomodeler = get(tsAutomodelerState);
    const model = get(modelSelector);

    return {
      rmse: model ? (+model.RMSE).toFixed(2) : UNCKNOWN_VALUE,
      aic: model ? (+model.AIC).toFixed(2) : UNCKNOWN_VALUE,
      predictionSteps: tsAutomodeler ? tsAutomodeler.predictionSteps : UNCKNOWN_VALUE,
      bic: model ? (+model.BIC).toFixed(2) : UNCKNOWN_VALUE,
    };
  },
});

export const summaryDataState = selector({
  key: 'Summary/summaryDataState',
  get: ({ get }) => {
    const type = get(projectType);
    const variables = get(variablesState);

    let data;
    switch (type) {
      case ProjectType.Binary:
        data = get(summaryBinaryDataState);
        break;
      case ProjectType.MultiClass:
        data = get(summaryMultiClassDataState);
        break;
      case ProjectType.Regression:
        data = get(summaryRegressionDataState);
        break;
      case ProjectType.TimeSeriesForecast:
        data = get(summaryTimeSeriesDataState);
        break;
      default:
        data = {};
    }

    return {
      type,
      totalRows: variables?.headers[0]?.totalValues || UNCKNOWN_VALUE,
      ...data,
    };
  },
});
