import cx from "classnames";
import { Fragment, ReactNode, useMemo } from "react";

export default function ConfusionMatrix({
  classes,
  counts,
  predictionLabels,
  selectedCell,
  onSelectCell,
  includeRowSummaries = false,
  cellSize = 40,
}: {
  classes: string[];
  counts: number[][];
  predictionLabels?: string[];
  includeRowSummaries?: boolean;
  selectedCell?: { predicted: string; actual: string } | null;
  onSelectCell?: (selection: { predicted: string; actual: string }) => void;
  // TODO(benkomalo): maybe also accept an 'auto' so it calculates based on available
  // width? Either way, we need an explicit variable for this, since we can't
  // just rely on CSS grid just to lay things out completely if we want other elements
  // (like labels outside of the grid) to be proportional to the cell grids themselves.
  // (Also -- we want the cells to be square).
  cellSize?: number;
}) {
  predictionLabels = predictionLabels ?? classes;
  const accuracies = useMemo(() => {
    const classSums = counts.map((row) => row.reduce((a, b) => a + b, 0));
    return counts.map((row, i) => {
      return row.map((count) => count / classSums[i]);
    });
  }, [counts]);

  // TODO(benkomalo): parameterize this or dynamically infer based on content length?
  const labelSize = 120;

  return (
    <div className={"tw-flex tw-w-full"}>
      <div
        className={
          "tw-self-stretch tw-flex tw-justify-center tw-items-center tw-w-[2rem]"
        }
      >
        <div
          className={
            "-tw-rotate-90 tw-origin-center tw-whitespace-nowrap tw-flex tw-text-sm tw-text-gray-500"
          }
        >
          <div className={"tw-w-[120px]"}></div>
          Actual classes
        </div>
      </div>
      <div
        className={"tw-grid tw-flex-1"}
        style={{
          gridTemplateColumns: includeRowSummaries
            ? `${labelSize}px repeat(${classes.length}, minmax(0, ${cellSize}px)) ${labelSize}px`
            : `${labelSize}px repeat(${classes.length}, minmax(0, ${cellSize}px))`,
        }}
      >
        {counts.map((row, i) => {
          const n = row.reduce((a, b) => a + b, 0);
          return (
            <Fragment key={classes[i]}>
              <div
                className={"tw-px-4 tw-h-full tw-flex tw-items-center"}
                title={classes[i]}
              >
                <div className={"tw-truncate"}>{classes[i]}</div>
              </div>
              {row.map((count, j) => {
                return (
                  <Cell
                    key={`${i}-${j}`}
                    label={count}
                    accuracy={accuracies[i][j]}
                    size={cellSize}
                    selected={
                      selectedCell !== null &&
                      selectedCell !== undefined &&
                      classes[i] === selectedCell.actual &&
                      classes[j] === selectedCell.predicted
                    }
                    onClick={
                      onSelectCell && count > 0
                        ? () => {
                            onSelectCell({
                              actual: classes[i],
                              predicted: classes[j],
                            });
                          }
                        : undefined
                    }
                  />
                );
              })}
              {includeRowSummaries && (
                <div
                  className={
                    "tw-text-sm tw-mx-2 tw-text-right tw-whitespace-nowrap tw-flex tw-items-center"
                  }
                >
                  {row[i]} / {n} ({((100 * row[i]) / n).toFixed(1)}%)
                </div>
              )}
            </Fragment>
          );
        })}
        <Fragment key={"bottom-row"}>
          <div
            key={"leading-spacer"}
            style={{
              height: 16 + Math.round(labelSize * Math.cos(Math.PI / 4)),
            }}
          ></div>
          {predictionLabels.map((label) => {
            return (
              <div
                key={label}
                className={
                  "-tw-translate-x-1/2 -tw-rotate-45 tw-origin-top-right"
                }
              >
                <div
                  className="tw-truncate tw-text-right tw-absolute tw-right-0"
                  style={{ width: labelSize }}
                >
                  {label}
                </div>
              </div>
            );
          })}
          {includeRowSummaries && <div key={"trailing-spacer"}></div>}
        </Fragment>
        <Fragment key={"axis-label"}>
          <div key={"leading-spacer"}></div>
          <div
            className={
              "tw-text-sm tw-text-gray-500 tw-col-start-2 tw-col-end-[-1] tw-flex tw-justify-center tw-whitespace-nowrap"
            }
          >
            Predicted classes
          </div>
        </Fragment>
      </div>
    </div>
  );
}

function Cell({
  label,
  accuracy,
  size,
  selected,
  onClick,
}: {
  label: ReactNode;
  accuracy: number;
  size: number;
  selected?: boolean;
  onClick?: () => void;
}) {
  return (
    <div
      className={cx(
        "tw-truncate tw-flex tw-justify-center tw-items-center tw-border",
        "tw-bg-purple",
        accuracy > 0.5 ? "tw-text-white" : "tw-text-black",
        selected && "tw-border-red-500 tw-border-2",
        onClick && !selected && "hover:tw-border-red-300 hover:tw-border-2",
        !onClick && "tw-cursor-default",
      )}
      role={onClick ? "button" : "none"}
      style={
        {
          "--tw-bg-opacity": accuracy,
          height: size,
        } as React.CSSProperties
      }
      onClick={onClick}
    >
      {label}
    </div>
  );
}
