import cx from "classnames";
import lodashMax from "lodash.max";
import lodashMin from "lodash.min";
import { useCallback, useMemo } from "react";
import { useLabeledSetContext } from "./Context";

interface Props {
  className?: string;
  highlight: string;
}

export function ConfusionMatrix({ className, highlight }: Props) {
  const {
    state: { confusionMatrix, classifications },
  } = useLabeledSetContext();

  const expandedConfusionMatrix = useMemo(() => {
    // It's possible that the user added classes after we loaded the confusion matrix
    // In this case we add some empty rows/cols so we have a confusion matrix where the
    // number of rows and cols is equal to the number of classes we have currently
    const expanded: (number | undefined)[][] = [];

    for (let row = 0; row < classifications.length; row += 1) {
      const values: (number | undefined)[] = [];
      for (let col = 0; col < classifications.length; col += 1) {
        if (
          confusionMatrix &&
          row < confusionMatrix.length &&
          col < confusionMatrix[row].length
        ) {
          values.push(confusionMatrix[row][col]);
        } else {
          values.push(undefined);
        }
      }
      expanded.push(values);
    }

    return expanded;
  }, [classifications.length, confusionMatrix]);

  const highlightIndex = useMemo(
    () =>
      classifications.findIndex(
        (classification) => classification.name === highlight,
      ),
    [classifications, highlight],
  );

  const relevantValues = useMemo(
    () =>
      confusionMatrix && highlightIndex >= 0
        ? [
            ...confusionMatrix[highlightIndex],
            ...confusionMatrix.map((row) => row[highlightIndex]),
          ]
        : [],
    [confusionMatrix, highlightIndex],
  );
  const minValue: number = useMemo(
    () => lodashMin(relevantValues) ?? 0,
    [relevantValues],
  );
  const maxValue: number = useMemo(
    () => lodashMax(relevantValues) ?? 0,
    [relevantValues],
  );
  const bgStyle = useCallback(
    (value) => {
      const opacity = (value - minValue) / (maxValue - minValue);
      return {
        textColor: opacity > 0.5 ? "#fff" : "#000",
        "--tw-bg-opacity": opacity,
      } as React.CSSProperties;
    },
    [maxValue, minValue],
  );

  const shortNames = useMemo(
    () =>
      classifications.every(
        (classification) => classification.name.length <= 8,
      ),
    [classifications],
  );
  const smallText = useMemo(
    () => classifications.length > 3,
    [classifications],
  );
  const overflow = useMemo(() => classifications.length > 7, [classifications]);

  return (
    <div
      className={cx(
        "tw-flex tw-items-center tw-justify-center",
        smallText ? "tw-text-xs" : "tw-text-base",
        className,
      )}
    >
      <div
        className={cx(
          "tw-grid tw-max-w-[400px]",
          overflow ? "tw-overflow-auto" : "tw-overflow-hidden",
        )}
        style={{ gridTemplateColumns: "40px 1fr" }}
      >
        <div></div>
        <div className="tw-flex tw-items-center tw-justify-center">
          Predicted
        </div>
        <div className="tw-flex tw-items-center tw-justify-center">
          <div
            className={cx(
              "-tw-rotate-90",
              shortNames ? "tw-pr-[64px]" : "tw-pr-[128px]",
            )}
          >
            Actual
          </div>
        </div>
        <div>
          <div
            className="tw-grid tw-gap-1 tw-overflow-hidden"
            style={{
              gridTemplateColumns: `minmax(0, 1fr) repeat(${
                classifications.length
              }, ${smallText ? "32px" : "42px"}) minmax(0, 1fr)`,
            }}
          >
            {[
              <div key="corner"></div>,
              classifications.map((classification, colIndex) => (
                <div
                  key={["col", colIndex].join("-")}
                  className={cx(
                    "tw-flex tw-items-end tw-justify-center",
                    shortNames ? "tw-h-[64px]" : "tw-h-[128px]",
                    "tw-relative",
                  )}
                >
                  <div className="-tw-rotate-45 tw-origin-left tw-absolute -tw-bottom-[8px] tw-left-1/2">
                    <span
                      className={cx(
                        "tw-max-w-[150px] tw-inline-block",
                        "tw-truncate",
                        colIndex === highlightIndex
                          ? "tw-text-gray-700"
                          : "tw-text-gray-300",
                      )}
                      title={classification.name}
                    >
                      {classification.name}
                    </span>
                  </div>
                </div>
              )),
              <div key={["extra", "top"].join("-")}></div>,
              ...expandedConfusionMatrix.map((row, rowIndex) => [
                <div
                  key={["row", rowIndex].join("-")}
                  className="tw-flex tw-items-center tw-justify-end"
                >
                  <span
                    className={cx(
                      "tw-truncate",
                      rowIndex === highlightIndex
                        ? "tw-text-gray-700"
                        : "tw-text-gray-300",
                    )}
                    title={classifications[rowIndex].name}
                  >
                    {classifications[rowIndex].name}
                  </span>
                </div>,
                ...row.map((cell, colIndex) => (
                  <div
                    key={[rowIndex, colIndex].join(",")}
                    className={cx(
                      "tw-flex tw-items-center tw-justify-center",
                      smallText
                        ? "tw-w-[32px] tw-h-[32px]"
                        : "tw-w-[42px] tw-h-[42px]",
                      "tw-border",
                      cell === undefined
                        ? "tw-border-gray-100 tw-text-gray-100"
                        : rowIndex === highlightIndex ||
                            colIndex === highlightIndex
                          ? "tw-border-gray-700 tw-text-black tw-bg-purple-500"
                          : "tw-border-gray-300 tw-text-gray-300",
                    )}
                    style={{
                      ...((cell !== undefined && rowIndex === highlightIndex) ||
                      colIndex === highlightIndex
                        ? bgStyle(cell)
                        : {}),
                    }}
                  >
                    {cell !== undefined ? cell : "N/A"}
                  </div>
                )),
                <div key={["extra", rowIndex].join("-")}></div>,
              ]),
            ]}
          </div>
        </div>
      </div>
    </div>
  );
}
