import React, { useState, useEffect } from 'react';
import { useDispatch } from 'react-redux';

import { makeStyles } from "@material-ui/core/styles";
import Button         from "@material-ui/core/Button";
import TextField      from "@material-ui/core/TextField";

import Plotly from "plotly.js"
import createPlotlyComponent from 'react-plotly.js/factory';

import { showError, showWarning, useTranslation } from '../../../core/utils';

const Plot = createPlotlyComponent(Plotly);

const MAX_SIZE = 30;

const useStyles = makeStyles(theme => ({
  sizeInput: {
    width: 35,
    marginLeft: 12,
    verticalAlign: 'middle',

    '& .MuiInputBase-root:before': {
      borderBottom: 0,
    },
  },
}));

const getMatrixData = (cm, maxSize) => {
  const trimMatrix = (matrix, len) => {
    const newMatrix = matrix.slice(0, len);
    newMatrix.forEach((row, idx) => (newMatrix[idx] = row.slice(0, len)));
    return newMatrix;
  };

  const { labels, matrix, pmatrix, size } = cm.is_json // new fmt TODO remove old after a while
    ? { matrix: JSON.parse(cm.matrix), labels: JSON.parse(cm.labels), pmatrix: JSON.parse(cm.pmatrix), size: cm.size }
    : { ...cm, size: cm.matrix.length };

  return {
    labels: labels.slice(0, maxSize),
    matrix: trimMatrix(matrix, maxSize),
    pmatrix: trimMatrix(pmatrix, maxSize),
    size,
  };
};

export function ShowMatrix(props) {
  const { tm, exportHandler } = props;
  const { t } = useTranslation();
  const classes = useStyles();
  const dispatch = useDispatch();

  const [maxSize, setMaxSize] = useState(MAX_SIZE);
  const [error, setError] = useState(false);

  const { labels: lab, pmatrix, matrix, size: CM_SIZE } = getMatrixData(tm.confusionMatrix || tm.data.cm, maxSize);
  const { report: rep } = tm.data?.int || { report: [] };

  const LABEL_INDEX = 0;
  const PRECISION_INDEX = 1;
  const RECALL_INDEX = 2;
  const FAKE_HEADER_INDEX = 1;
  const FOOTER_DIVIDER_INDEX = rep.length - 3;

  const setErr = (message, ...args) => {
    showError(dispatch, t)(message, ...args);
    setError(true);
  };

  const fix_empty = label => label.replace(/^$/,'0'); // TODO: remove after some time
  const fix_numbers = label => label.replace(/^-?\d+$/, '_'+label) // labels-numbers confuse Plotly
  const labels = lab.map(fix_empty).map(fix_numbers);
  const cutReport = rep.slice(FAKE_HEADER_INDEX, FOOTER_DIVIDER_INDEX);

  useEffect(() => {
    if (!tm.confusionMatrix)
      showWarning(dispatch, t)('tests.cm_warn_deprecated_data_src');

    if (!matrix.length || !pmatrix.length)
      return (setErr('tests.cm_error_matrix_empty'), undefined);

    const cutReportLabels = cutReport.map(r => r[LABEL_INDEX]).map(fix_numbers);
    const missingMatrixLabels = labels.filter(label => !cutReportLabels.includes(label));

    !cutReport.length && setErr('tests.cm_error_report_empty');
    missingMatrixLabels.length && setErr('tests.cm_error_report_mismatch', missingMatrixLabels.join(', '));

    if (maxSize >= CM_SIZE) {
      const missingReportLabels = cutReportLabels.filter(r => !labels.includes(r));
      missingReportLabels.length && setErr('tests.cm_error_matrix_mismatch', missingReportLabels.join(', '));
    }
  }, []);

  const report = labels.map(label => cutReport.find(item => label === item[LABEL_INDEX]));
  const getValues = (i) => report.map(item => ((item?.[i] || 0)*100).toFixed(0));
  const precisionValues = getValues(PRECISION_INDEX);
  const recallValues = getValues(RECALL_INDEX).reverse();

  const xValues = labels;
  const yValues = [...labels].reverse();
  const matrixValues = [...matrix].reverse();
  const zValues = [...pmatrix].reverse().map(row => row.map(v => Math.round(v)));
  const text = zValues.map(row => row.map(v => v +'%'));

  const call2 = f => (x,y) => f(x,y);
  const min = a => a.reduce(call2(Math.min),0);
  const max = a => a.reduce(call2(Math.max),0);
  const ct = a => min(a) + 0.3*(max(a) - min(a)); // color switch threshold
  const [ct1, ct2, ct3] = [precisionValues, zValues.flat(), recallValues].map(ct);

  const colorscaleValue = [
    ['0.0', 'rgb(255, 255, 255)'],
    ['0.1', 'rgb(230, 231, 241)'],
    ['0.2', 'rgb(204, 207, 227)'],
    ['0.3', 'rgb(179, 183, 213)'],
    ['0.4', 'rgb(153, 159, 199)'],
    ['0.5', 'rgb(128, 135, 185)'],
    ['0.6', 'rgb(103, 110, 170)'],
    ['0.7', 'rgb(77, 86, 156)'],
    ['0.8', 'rgb(52, 62, 142)'],
    ['0.9', 'rgb(26, 38, 128)'],
    ['1', 'rgb(1, 14, 114)'],
  ];

  const trace1 = {
    z: [precisionValues],
    xaxis: 'x1',
    yaxis: 'y1',
    y: 0,
    hovertemplate: 
    '<extra></extra>',
    type: 'heatmap',
    colorscale: colorscaleValue,
    showscale: false,
  };

  const trace2 = {
    x: xValues,
    y: yValues,
    z: zValues,
    xaxis: 'x2',
    yaxis: 'y2',
    text: text,
    colorscale: colorscaleValue,
    hovertemplate: 
    'predicted: %{x}<br>' +
    'true: %{y}<br>' +
    'share: %{text}' +
    '<extra></extra>',
    type: 'heatmap',
    showscale: false,
  };

  const trace3 = {
    z: recallValues.map(r => [r]),
    xaxis: 'x3',
    yaxis: 'y3',
    hovertemplate: 
    '<extra></extra>',
    type: 'heatmap',
    colorscale: colorscaleValue,
    showscale: false,
  };

  const data = [trace1, trace2, trace3];

  const layout = {
    font: {
      size: 11,
    },
    xaxis1: {
      overlaying: 'x1',
      position: 0.95,
      side: 'bottom',
      title: {
        text: 'Precision',
        font: {
          family: 'Courier New, monospace',
          size: 15,
          color: 'black'
        },
      },
      domain: [0.1, 0.9],
      showgrid: false,
      zeroline: false,
      showline: false,
      showticklabels: false,
      ticks: '',
    },
    yaxis1: {
      domain: [0.95, 1],
      showgrid: false,
      zeroline: false,
      showline: false,
      showticklabels: false,
      ticks: '',
    },
    xaxis2: {
      domain: [0.1, 0.9],
      position: 0.1,
      overlaying: 'x2',
      side: 'bottom',
      title: {
        text: 'Predicted Class',
        font: {
          family: 'Arial, sans-serif',
          size: 20,
          color: 'lightgrey',
        },
      },
    },
    yaxis2: {
      domain: [0.1, 0.85],
      title: {
        text: 'True Class',
        font: {
          family: 'Arial, sans-serif',
          size: 20,
          color: 'lightgrey',
        },
        side: 'right',
      },
    },
    xaxis3: {
      domain: [0.95, 1],
      showgrid: false,
      zeroline: false,
      showline: false,
      showticklabels: false,
      ticks: ''
    },
    yaxis3: {
      domain: [0.1, 0.85],
      overlaying: 'y3',
      position: 0.95,
      side: 'left',
      title: {
        text: 'Recall',
        font: {
          family: 'Courier New, monospace',
          size: 15,
          color: 'black',
        },
      },
      showgrid: false,
      zeroline: false,
      showline: false,
      showticklabels: false,
      ticks: '',
    },
    annotations: [],
    autosize: false,
    width: labels.length > 20? (labels.length * 40 + 400): 1000,
    height: labels.length > 20? labels.length * 35 + 200: 700,
  };

  const layoutUpdate = (xref, yref, x, y, text, i, ct) => {
    layout.annotations.push({
        xref: xref,
        yref: yref,
        x: x,
        y: y,
        text: text,
        font: {
          family: 'Arial',
          size: 11,
          color: i >= ct? 'white': 'black',
        },
        showarrow: false,
    })
  };

  precisionValues.forEach((item, index) => 
    layoutUpdate('x1', 'y1', index, 0, `${item}%`, item, ct1));

  yValues.forEach((y, yIndex) => xValues.forEach((x, xIndex) => 
    layoutUpdate('x2', 'y2', x, y, matrixValues[yIndex][xIndex], zValues[yIndex][xIndex], ct2)));

  recallValues.map(r => [r]).forEach((item, index) =>
    layoutUpdate('x3', 'y3', 0, index, `${item[0]}%`, item, ct3));

  return data && layout ? (
    <div style={{ width: '100%' }}>
      <div>
      <div style={{float: 'left', marginRight: 20, paddingTop: 2}}>
        {t('tests.cm_show')}
        <TextField
          type="number"
          size="small"
          className={classes.sizeInput}
          defaultValue={Math.min(maxSize, CM_SIZE)}
          inputProps={{min: 1, max: CM_SIZE || MAX_SIZE}}
          onKeyDown={event => {
            if (event.key != 'Enter')
              return;
            const { min: _min, value: _value, max: _max } = event.target;
            const [ min, v, max ] = [ _min, _value, _max ].map(Number);
            const value = v < min ? min : v > max ? max : v;
            setMaxSize(value);
            event.target.value = value;
          }}
        />
        {t('tests.cm_of')}
        <span style={{fontSize: 16, marginLeft: 5}}>{CM_SIZE}</span>
      </div>
      <Button
        variant="outlined"
        size="small"
        style={{float: 'left'}}
        onClick={exportHandler}
      >
        {t("common.export_as_csv")}
      </Button>
        </div>
      <div>
      <Plot data={data} layout={layout}/>
        </div>
    </div>
  ): 'plotting...';
}
