import { AsyncDuckDB } from "@duckdb/duckdb-wasm";
import cx from "classnames";
import isEqual from "lodash.isequal";
import pluralize from "pluralize";
import {
  Dispatch,
  Fragment,
  SetStateAction,
  useCallback,
  useEffect,
  useMemo,
  useRef,
  useState,
} from "react";
import { Check } from "react-feather";
import AutoSizer, { Size } from "react-virtualized-auto-sizer";
import {
  FixedSizeGrid as Grid,
  GridChildComponentProps,
  FixedSizeList as List,
  ListChildComponentProps,
} from "react-window";
import { Button } from "src/Common/Button";
import { Checkbox } from "@spring/ui/Checkbox";
import {
  colorSchemeByWellMetadata,
  colorValuesByScheme,
} from "../Control/ColorSchemeSelector";
import { FilterSqlClause } from "../Control/FilterSelector/types";
import { inferPlateKind } from "../ImageViewer/FullPlateView";
import PlateMap, { PlateKind } from "../ImageViewer/PlateMap";
import { usePrevious } from "../hooks/utils";
import { DatasetId, MetadataColumnValue } from "../types";
import { inferInterestingColumnsDB } from "../util/dataset-util";
import { useAsyncValue } from "../util/hooks";
import { columnMatchesValueClause } from "../util/sql";
import { DataClass, EditorStepCommonProps, ModelTrainingConfig } from "./types";
import {
  findControlEntries,
  inferClassColumnAndValuesFromConfig,
  queryCandidateValuesForColumn,
  queryPlates,
} from "./utils";

const PLATES_PER_ROW = 2;

export default function ClassesConfigurator({
  dataset,
  metadata,
  config,
  setConfig,
  onReadyToAdvanceChanged,
}: EditorStepCommonProps & {
  dataset: DatasetId;
  metadata: AsyncDuckDB;
  config: Partial<ModelTrainingConfig>;
  setConfig: Dispatch<SetStateAction<Partial<ModelTrainingConfig>>>;
}) {
  const candidateColumns = useAsyncValue(
    () => inferInterestingColumnsDB(metadata, config.data?.filter ?? "TRUE"),
    [metadata],
  );
  const [hoveredColumn, _setHoveredColumn] = useState<string | undefined>(
    undefined,
  );
  const [selectedColumn, _setSelectedColumn] = useState<string | undefined>(
    undefined,
  );
  const [candidateValues, setCandidateValues] = useState<MetadataColumnValue[]>(
    [],
  );
  const [selectedValues, setSelectedValues] = useState<MetadataColumnValue[]>(
    [],
  );
  const [hoveredValue, setHoveredValue] = useState<
    MetadataColumnValue | undefined
  >(undefined);

  // Search filtering for the column values.
  const [searchValue, setSearchValue] = useState("");
  const filteredCandidateValues = useMemo(() => {
    if (!searchValue) {
      return candidateValues;
    }
    const searchValueLower = searchValue.toLowerCase();
    return candidateValues.filter((v) =>
      String(v).toLowerCase().includes(searchValueLower),
    );
  }, [candidateValues, searchValue]);
  const areAllValuesSelected = useMemo(() => {
    return (
      filteredCandidateValues.length > 0 &&
      filteredCandidateValues.every((value) => selectedValues.includes(value))
    );
  }, [selectedValues, filteredCandidateValues]);

  const requeryCandidateControlValues = useCallback(
    (column: string) => {
      if (!column) {
        setCandidateValues([]);
      } else {
        queryCandidateValuesForColumn(
          metadata,
          column,
          config.data?.filter ?? "TRUE",
        ).then(setCandidateValues);
      }
    },
    [metadata, config],
  );

  const setSelectedColumn = useCallback(
    (column: string | undefined) => {
      _setSelectedColumn(column);
      if (column) {
        setConfig({
          ...config,
          split: config.split
            ? {
                ...config.split,
                stratifySamplingColumn: column,
              }
            : undefined,
        });
        requeryCandidateControlValues(column);
      } else {
        setConfig({
          ...config,
          data: {
            dataset,
            filter: config.data?.filter ?? "TRUE",
            classes: [],
          },
        });
      }
    },
    [dataset, config, setConfig, requeryCandidateControlValues],
  );

  const setHoveredColumn = useCallback(
    (column: string | undefined) => {
      _setHoveredColumn(column);
      if (column) {
        requeryCandidateControlValues(column);
      } else {
        setCandidateValues([]);
      }
    },
    [_setHoveredColumn, requeryCandidateControlValues, setCandidateValues],
  );

  useEffect(() => {
    if (!selectedColumn && candidateColumns && candidateColumns.length > 0) {
      inferClassColumnAndValuesFromConfig(
        // TODO(you): Fix this no-unnecessary-condition rule violation
        // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
        config?.data?.classes ?? [],
        candidateColumns,
        metadata,
        config.data?.filter ?? "TRUE",
        false,
      ).then((result) => {
        if (result) {
          const [column, values] = result;
          setSelectedColumn(column);
          setSelectedValues(values);
        }
      });
    }
  }, [metadata, selectedColumn, setSelectedColumn, candidateColumns, config]);

  const previousValues = usePrevious(selectedValues);
  useEffect(() => {
    if (selectedColumn && !isEqual(previousValues, selectedValues)) {
      onReadyToAdvanceChanged(!!selectedColumn && selectedValues.length > 1);
      setConfig({
        ...config,
        data: {
          dataset,
          filter: config.data?.filter ?? "TRUE",
          classes: selectedValues.map<DataClass>((value) => {
            return {
              name: String(value),
              filter: columnMatchesValueClause(selectedColumn, value),
            };
          }),
        },
      });
    }
  }, [
    dataset,
    selectedColumn,
    selectedValues,
    previousValues,
    config,
    setConfig,
    onReadyToAdvanceChanged,
  ]);

  const colorMap = useMemo(() => {
    const colorScheme = colorSchemeByWellMetadata(candidateValues);
    return colorValuesByScheme(candidateValues, colorScheme);
  }, [candidateValues]);

  const wellCountPerClass = useAsyncValue(async () => {
    const allWells = selectedColumn
      ? await findControlEntries(
          metadata,
          config.data?.filter ?? "TRUE",
          selectedColumn,
          undefined,
        )
      : [];
    const counts = new Map();
    for (const { value } of allWells) {
      counts.set(value, (counts.get(value) ?? 0) + 1);
    }
    return counts;
  }, [selectedColumn, metadata]);

  return (
    <div className={"tw-flex tw-bg-slate-100 tw-h-full tw-overflow-hidden"}>
      <div className={"tw-flex-1 tw-p-8 tw-bg-white tw-flex tw-flex-col"}>
        <div className={"tw-text-lg tw-mb-4"}>
          What are the classes you want the model to identify?
        </div>
        {candidateColumns && !selectedColumn && (
          <>
            <div
              className={
                "tw-flex tw-flex-col tw-flex-1 tw-overflow-y-auto tw-p-2 -tw-m-2"
              }
            >
              {candidateColumns.map((column) => (
                <ColumnOption
                  key={column}
                  column={column}
                  prefilter={config.data?.filter ?? "TRUE"}
                  metadata={metadata}
                  onMouseEnter={() => setHoveredColumn(column)}
                  onMouseLeave={() => setHoveredColumn(undefined)}
                  onSelect={() => setSelectedColumn(column)}
                />
              ))}
            </div>
          </>
        )}
        {selectedColumn && (
          <div
            className={
              "tw-flex tw-flex-col tw-overflow-y-auto tw-overflow-x-hidden tw-my-sm tw-h-full"
            }
          >
            <div className={"tw-flex tw-mb-sm"}>
              <button
                className={"tw-text-sm tw-text-gray-500"}
                onClick={() => setSelectedColumn(undefined)}
              >
                Based on{" "}
                <span className={"tw-text-black tw-underline"}>
                  {selectedColumn}
                </span>
              </button>
              <div className={"tw-flex-1"} />
              <label
                className={
                  "tw-flex tw-items-center tw-cursor-pointer tw-text-sm"
                }
              >
                Select all
                <Checkbox
                  className={"tw-ml-sm"}
                  checked={areAllValuesSelected}
                  onClick={() => {
                    if (areAllValuesSelected) {
                      setSelectedValues([]);
                    } else {
                      setSelectedValues(filteredCandidateValues);
                    }
                  }}
                />
              </label>
            </div>
            {candidateValues.length > 20 && (
              <input
                className={"tw-border tw-rounded tw-py-sm tw-px-md"}
                placeholder={`Search for a ${selectedColumn} value`}
                value={searchValue}
                onChange={(e) => setSearchValue(e.target.value)}
              />
            )}
            <div className={"tw-flex-auto tw-h-full"}>
              <AutoSizer>
                {({ height, width }: Size) => (
                  <List<ItemData>
                    height={height}
                    itemCount={filteredCandidateValues.length}
                    itemSize={45}
                    width={width}
                    itemData={{
                      candidateValues: filteredCandidateValues,
                      colorMap: colorMap,
                      wellCountPerClass: wellCountPerClass,
                      selectedValues: selectedValues,
                      setSelectedValues: setSelectedValues,
                      setHoveredValue: setHoveredValue,
                    }}
                  >
                    {ClassOptionRow}
                  </List>
                )}
              </AutoSizer>
            </div>
          </div>
        )}
      </div>
      <div className={"tw-flex-1 tw-p-8 tw-flex tw-flex-col"}>
        {(selectedColumn ?? hoveredColumn) && (
          <>
            <div className={"tw-uppercase tw-text-sm tw-text-gray-500 tw-mb-4"}>
              Plate Preview
            </div>
            <div className={"tw-w-full tw-flex-1 tw-overflow-hidden tw-border"}>
              <ModelControlsVisualizer
                metadata={metadata}
                prefilter={config.data?.filter ?? "TRUE"}
                selectedColumn={(selectedColumn ?? hoveredColumn) as string}
                selectedValues={selectedColumn ? selectedValues : []}
                colorMap={colorMap}
                hoveredValue={hoveredValue}
              />
            </div>
          </>
        )}
      </div>
    </div>
  );
}

function ColumnOption({
  column,
  metadata,
  prefilter,
  onSelect,
  onMouseEnter,
  onMouseLeave,
}: {
  column: string;
  metadata: AsyncDuckDB;
  prefilter: FilterSqlClause;
  onSelect: () => void;
  onMouseEnter?: () => void;
  onMouseLeave?: () => void;
}) {
  const candidateValues = useAsyncValue(async () => {
    return queryCandidateValuesForColumn(metadata, column, prefilter);
  }, [column, metadata]);

  // TODO(benkomalo): just hardcoding a preview of the first 3 items but we should
  //  probably be smarter and determine the length of the values in text and truncate
  //  when we're out of space. Right now really long values get awkwardly pushed out
  //  even at 3 values, and some short ones don't make good use of the whitespace.
  return (
    <Button
      onMouseEnter={onMouseEnter}
      onMouseLeave={onMouseLeave}
      className={
        "tw-border tw-rounded tw-w-full tw-my-1 tw-p-4 tw-flex tw-items-center"
      }
      onClick={() => onSelect()}
      disableTracking={true}
    >
      {column}
      {candidateValues && (
        <div className={"tw-ml-4 tw-flex tw-text-sm tw-text-gray-400"}>
          {/* TODO(you): Fix this no-unnecessary-condition rule violation */}
          {/* eslint-disable-next-line @typescript-eslint/no-unnecessary-condition */}
          {candidateValues?.slice(0, 3)?.map((value, i) => (
            <Fragment key={String(value)}>
              {i > 0 && <span className={"tw-mr-1"}>,</span>}
              <div className={"tw-truncate tw-max-w-[64px]"}>
                {value === null ? "<null>" : value}
              </div>
            </Fragment>
          ))}
          {/* TODO(you): Fix this no-unnecessary-condition rule violation */}
          {/* eslint-disable-next-line @typescript-eslint/no-unnecessary-condition */}
          {candidateValues?.length > 3 && (
            <span className={"tw-ml-1 tw-whitespace-nowrap"}>
              {/* TODO(you): Fix this no-unnecessary-condition rule violation */}
              {/* eslint-disable-next-line @typescript-eslint/no-unnecessary-condition */}
              , +{candidateValues?.length - 3} more
            </span>
          )}
        </div>
      )}
    </Button>
  );
}

type PlateData = {
  plates: string[];
  colorMap: Map<MetadataColumnValue, string>;
  allWells: WellValue[];
  selectedWells: WellValue[];
  hoveredWells: WellValue[];
  plateKind: PlateKind;
};

type ItemData = {
  candidateValues: MetadataColumnValue[];
  colorMap: Map<MetadataColumnValue, string>;
  wellCountPerClass: Map<MetadataColumnValue, number> | undefined;
  selectedValues: MetadataColumnValue[];
  setSelectedValues: Dispatch<SetStateAction<MetadataColumnValue[]>>;
  setHoveredValue: (value: MetadataColumnValue | undefined) => void;
};

function ClassOptionRow({
  index,
  style,
  data,
}: ListChildComponentProps<ItemData>) {
  const {
    candidateValues,
    colorMap,
    wellCountPerClass,
    selectedValues,
    setSelectedValues,
    setHoveredValue,
  } = data;
  const value = candidateValues[index];
  return (
    <ClassOption
      key={String(value)}
      style={style}
      value={value}
      count={wellCountPerClass ? wellCountPerClass.get(value) : undefined}
      color={colorMap.get(value)}
      onSelect={(selected) =>
        selected
          ? setSelectedValues((oldSelectedValues) => [
              ...oldSelectedValues,
              value,
            ])
          : setSelectedValues((oldSelectedValues) =>
              oldSelectedValues.filter((v) => v !== value),
            )
      }
      selected={selectedValues.includes(value)}
      onMouseEnter={() => setHoveredValue(value)}
      onMouseLeave={() => setHoveredValue(undefined)}
    />
  );
}

function ClassOption({
  style,
  value,
  count,
  color,
  onSelect,
  selected,
  onMouseEnter,
  onMouseLeave,
}: {
  style: React.CSSProperties;
  value: MetadataColumnValue;
  count: number | undefined;
  color: string | undefined;
  onSelect: (selected: boolean) => void;
  selected: boolean;
  onMouseEnter: () => void;
  onMouseLeave: () => void;
}) {
  return (
    <div style={style}>
      <Button
        className={
          "tw-border tw-rounded tw-w-full tw-my-1 tw-p-4 tw-flex tw-items-center"
        }
        onClick={() => onSelect(!selected)}
        onMouseEnter={onMouseEnter}
        onMouseLeave={onMouseLeave}
        disableTracking={true}
      >
        <div className="tw-flex-1 tw-truncate">
          <div
            key={String(value)}
            className={cx("tw-flex tw-flex-row tw-items-center")}
          >
            <div
              style={{ backgroundColor: color }}
              className={"tw-w-[20px] tw-h-[20px] tw-rounded"}
            />
            <span className={cx("tw-truncate tw-ml-4 tw-font-mono")}>
              {value === null ? "<null>" : String(value)}
              {count && (
                <p className={"tw-text-xs tw-inline"}>
                  {" "}
                  ({pluralize("well", count, true)})
                </p>
              )}
            </span>
          </div>
        </div>
        {selected && (
          <div>
            <Check size={16} />
          </div>
        )}
      </Button>
    </div>
  );
}

type WellValue = { plate: string; well: string; value: string };

function ModelControlsVisualizer({
  metadata,
  prefilter,
  selectedColumn,
  selectedValues,
  colorMap,
  hoveredValue,
}: {
  metadata: AsyncDuckDB;
  prefilter: FilterSqlClause;
  selectedColumn: string;
  selectedValues: MetadataColumnValue[];
  colorMap: Map<MetadataColumnValue, string>;
  hoveredValue: MetadataColumnValue | undefined;
}) {
  const [datasetPlates, setDatasetPlates] = useState<string[]>([]);

  useEffect(() => {
    queryPlates(metadata, prefilter).then((result) => {
      setDatasetPlates(result);
    });
  }, [prefilter, metadata]);

  const data = useAsyncValue(async () => {
    const allWells = await findControlEntries(
      metadata,
      prefilter,
      selectedColumn,
      undefined,
    );
    const selectedWells =
      selectedValues.length > 0
        ? await findControlEntries(
            metadata,
            prefilter,
            selectedColumn,
            selectedValues,
          )
        : [];
    const hoveredWells =
      hoveredValue !== undefined
        ? await findControlEntries(metadata, prefilter, selectedColumn, [
            hoveredValue,
          ])
        : [];

    return [allWells, selectedWells, hoveredWells];
  }, [prefilter, selectedColumn, hoveredValue, metadata, selectedValues]);

  // Keep track of the last data so we can render it even as the next data is loading.
  // Since the queries to duckdb are async, there is always an intermediate
  // emission where data is undefined. And even though it's really fast, it flickers
  // unless we keep the previous data rendered.
  const previousData = useRef(data);
  useEffect(() => {
    if (data) {
      previousData.current = data;
    }
  }, [data]);

  const plateKind = useMemo(() => {
    if (data) {
      return inferPlateKind(data[0]);
    } else if (previousData.current) {
      return inferPlateKind(previousData.current[0]);
    } else {
      return "384-well";
    }
  }, [data]);

  // TODO(you): Fix this no-unnecessary-condition rule violation
  // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
  if (!datasetPlates) {
    // TODO(benkomalo): maybe a loading state, but this should be really fast.
    return null;
  }

  if (!data && !previousData.current) {
    return null;
  }

  const [allWells, selectedWells, hoveredWells] = (data ??
    previousData.current) as [WellValue[], WellValue[], WellValue[]];

  return (
    <div className="tw-bg-white tw-h-full">
      <AutoSizer>
        {({ height, width }: Size) => {
          const paddedWidth = width - 16;
          return (
            <Grid<PlateData>
              height={height}
              width={width}
              columnWidth={paddedWidth / 2}
              columnCount={PLATES_PER_ROW}
              rowHeight={(paddedWidth / 2) * 0.75}
              rowCount={Math.ceil(datasetPlates.length / 2)}
              itemData={{
                plates: datasetPlates,
                colorMap: colorMap,
                allWells: allWells,
                selectedWells: selectedWells,
                hoveredWells: hoveredWells,
                plateKind: plateKind,
              }}
            >
              {PlateMapItem}
            </Grid>
          );
        }}
      </AutoSizer>
    </div>
  );
}

function PlateMapItem({
  columnIndex,
  rowIndex,
  style,
  data,
}: GridChildComponentProps<PlateData>) {
  const { plates, colorMap, allWells, selectedWells, hoveredWells, plateKind } =
    data;
  const plate = plates[rowIndex * PLATES_PER_ROW + columnIndex];
  const ref = useRef<HTMLDivElement>(null);
  return (
    <>
      {plate && (
        <div
          key={plate}
          style={style}
          className={"tw-border tw-p-sm tw-flex tw-flex-col tw-items-center"}
          ref={ref}
        >
          <div className={"tw-text-sm tw-text-gray-500 tw-text-center"}>
            {plate}
          </div>
          <PlateMapWrapper
            allWells={
              selectedWells.length > 0
                ? undefined
                : allWells.filter((x) => x.plate === plate)
            }
            containerWidth={ref.current?.clientWidth ?? 0}
            controlWells={selectedWells.filter((x) => x.plate === plate)}
            hoveredWells={hoveredWells.filter((x) => x.plate === plate)}
            colorMap={colorMap}
            plateKind={plateKind}
          />
        </div>
      )}
    </>
  );
}

/**
 * Visualize the wells on a plate based on selection and hover state.
 */
function PlateMapWrapper({
  allWells,
  controlWells,
  hoveredWells,
  colorMap,
  plateKind,
  containerWidth,
}: {
  // Showing all wells in a faded state is only done if nothing is selected.
  allWells: WellValue[] | undefined;
  controlWells: WellValue[];
  hoveredWells: WellValue[];
  colorMap: Map<MetadataColumnValue, string>;
  plateKind: PlateKind;
  containerWidth?: number;
}) {
  const controlWellsMap = new Map(controlWells.map((x) => [x.well, x.value]));

  const allWellsMap = allWells
    ? new Map(allWells.map((x) => [x.well, x.value]))
    : new Map();
  const hoveredWellsMap = new Map(hoveredWells.map((x) => [x.well, x.value]));

  const widthMinusFudgeForLabels = containerWidth ? containerWidth - 24 : 0;
  let size: number;
  const gapBetweenCells = 2;
  switch (plateKind) {
    case "96-well":
      size = containerWidth
        ? Math.max(
            18,
            Math.floor(widthMinusFudgeForLabels / 12) - gapBetweenCells,
          )
        : 18;
      break;
    case "384-well":
      size = containerWidth
        ? Math.max(
            8,
            Math.floor(widthMinusFudgeForLabels / 24) - gapBetweenCells,
          )
        : 8;
      break;
  }

  return (
    <PlateMap
      size={size}
      plateKind={plateKind}
      labelKind={plateKind === "96-well" ? "normal" : "extra-small"}
    >
      {(well) =>
        controlWellsMap.has(well) ? (
          <div
            key={well}
            style={{
              backgroundColor: colorMap.get(controlWellsMap.get(well)!),
              transform: hoveredWellsMap.has(well) ? "scale(1.2)" : "",
            }}
          />
        ) : hoveredWellsMap.has(well) ? (
          <div
            key={well}
            style={{
              backgroundColor: colorMap.get(hoveredWellsMap.get(well)!),
              opacity: 0.5,
              transform: "scale(1.2)",
            }}
          />
        ) : allWellsMap.has(well) ? (
          <div
            key={well}
            style={{
              backgroundColor: colorMap.get(allWellsMap.get(well)!),
              opacity: 0.2,
            }}
          />
        ) : (
          <div key={well} className={"tw-bg-gray-100"} />
        )
      }
    </PlateMap>
  );
}
