import chroma from "chroma-js";
import cx from "classnames";
import Select from "react-select";
import type { TopLevelSpec } from "vega-lite";
import { colorValuesByScheme } from "../../Control/ColorSchemeSelector";
import VegaLite from "../../Vega/VegaLite";
import { fonts } from "../../util/styles";
import { ConfusionMatrix } from "../ConfusionMatrix";
import { Classification, useLabeledSetContext } from "../Context";
import { DemographicColumnOption, DemographicData } from "../types";

type ClassMetricsProps = {
  classification: Classification;
  className?: string;
  demographicColumnOptions: Array<DemographicColumnOption>;
  selectedDemographicColumn: DemographicColumnOption | null;
  onChangeDemographicColumn: (option: DemographicColumnOption | null) => void;
  demographicData: DemographicData | null;
};

export function ClassMetrics({
  classification,
  className,
  demographicColumnOptions,
  selectedDemographicColumn,
  onChangeDemographicColumn,
  demographicData,
}: ClassMetricsProps) {
  const { state: labeledSetState } = useLabeledSetContext();
  const { classifications } = labeledSetState;
  const { precision, recall } = classification;
  const classIsEmpty = classification.examples.length === 0;
  const hasModelMetrics = precision !== undefined && recall !== undefined;

  const classificationNames = classifications.map(
    (otherClassification) => otherClassification.name,
  );
  const fontFamily = fonts.apercuRegular.fontFamily;
  const colorScheme = classifications.length <= 10 ? "tableau10" : "tableau20";
  const colorByClassificationName = colorValuesByScheme(
    classificationNames,
    colorScheme,
  );
  const colors: string[] = classificationNames.map((name) => {
    // TODO: Shouldn't need this fallback; we should have better typing on the returned map
    const color = colorByClassificationName.get(name) ?? "#ccc";
    return name === classification.name
      ? color
      : chroma(color).alpha(0.5).brighten(1).toString();
  });

  const vegaSpec: TopLevelSpec = {
    $schema: "https://vega.github.io/schema/vega-lite/v4.json",
    data: {
      values: demographicData ?? [],
    },
    mark: "bar",
    width: "container",
    encoding: {
      x: {
        field: "count",
        aggregate: "sum",
        stack: "normalize",
        title: "Fraction of cells in each class",
        axis: { format: "%" },
      },
      y: {
        field: "value",
        title: selectedDemographicColumn?.value,
      },
      color: {
        field: "classificationName",
        title: null,
        scale: { range: colors },
        sort: classificationNames,
      },
    },
    config: {
      view: {
        stroke: "transparent",
      },
      axis: { labelFont: fontFamily, titleFont: fontFamily },
      axisY: { titlePadding: 10 },
      legend: { labelFont: fontFamily, titleFont: fontFamily },
      header: { labelFont: fontFamily, titleFont: fontFamily },
      mark: { font: fontFamily },
      title: { font: fontFamily, subtitleFont: fontFamily },
    },
  };

  const demographicsSection =
    !classIsEmpty && demographicData ? (
      <>
        <h3 className="tw-font-normal tw-font-sans tw-text-xl tw-mb-4">
          Demographics
        </h3>
        <label className="tw-inline-block tw-w-full">
          <span className="tw-font-bold tw-text-sm tw-uppercase tw-text-purple-700">
            Stratify By
          </span>
          <Select
            placeholder="Search for a group..."
            options={demographicColumnOptions}
            onChange={onChangeDemographicColumn}
            value={selectedDemographicColumn}
            menuPortalTarget={document.body}
            // We need a very high z-index to get on top of BeB
            styles={{ menuPortal: (base) => ({ ...base, zIndex: 100000 }) }}
          />
        </label>

        <VegaLite className="tw-w-full tw-mt-8" spec={vegaSpec} />
      </>
    ) : null;

  const accuracySection = hasModelMetrics ? (
    <>
      <h3 className="tw-font-normal tw-font-sans tw-text-xl">Accuracy</h3>
      <p className="tw-my-0.5 tw-text-gray-500 tw-text-sm tw-font-bold">
        {(100 * recall).toFixed(1)}% recall
      </p>
      <p className="tw-my-0.5 tw-text-gray-500 tw-text-sm">
        {(100 * precision).toFixed(1)}% precision
      </p>
      <ConfusionMatrix highlight={classification.name} />
      <hr className="tw-flex-none tw-my-8 tw-border-gray-300" />
    </>
  ) : null;

  return (
    <div className={cx(className, "tw-px-8")}>
      {classIsEmpty ? (
        <div className="tw-mx-auto">Label a cell to enable metrics.</div>
      ) : null}
      {accuracySection}
      {demographicsSection}
    </div>
  );
}
