import cx from "classnames";
import { ReactNode, useEffect, useMemo, useRef } from "react";
import Selecto from "react-selecto";
import { DatasetId, UnlabeledCellSampleMetadata } from "src/types";
import { getCellMaskPropsForCrop } from "src/util/get-mask-props-for-crop";
import PredictionWarning from "../assets/predictionWarning.svg";
import { DEFAULT_CROP_SIZE_PX } from "../env";
import { ImageSet } from "../imaging/types";
import { toNumericField } from "../imaging/util";
import MultiChannelView from "../immunofluorescence/MultiChannelView";
import { PredictionsMap } from "./Context";
import { MIN_PREDICTION_CONFIDENCE } from "./constants";
import { SampleState } from "./types";
import { withoutSampleState } from "./util";

export function SelectableExampleGridViewV2<
  T extends UnlabeledCellSampleMetadata,
>({
  cropSize,
  dataset,
  imageSet,
  disabled,
  forwardRef,
  metadata,
  minColumns,
  placeholders,
  predictionMap,
  size,
  onClick,
  onSetSelection,
}: {
  cropSize?: number;
  dataset: DatasetId;
  imageSet: ImageSet | null;
  disabled?: boolean;
  forwardRef?: string | { [K: string]: any } | null;
  metadata: (T & SampleState)[];
  minColumns?: number;
  placeholders?: ReactNode[];
  predictionMap: PredictionsMap | null;
  size: number;
  onClick?: (type: T, e: MouseEvent) => void;
  onSetSelection: (selected: T[]) => void;
}) {
  const refSelecto = useRef<Selecto>(null);
  const refGridViewContainer = useRef<HTMLDivElement>(null);

  const getMaxPredictionDisplay = (
    metadataId: string,
  ): {
    name: string | null;
    value: string | null;
  } => {
    if (!predictionMap || !predictionMap.has(metadataId)) {
      return { name: null, value: null };
    }

    const prediction = predictionMap.get(metadataId)!;
    const predictedClass = prediction.predictedClass;
    const predictionValue: number = prediction.predictions.get(predictedClass)!;
    const belowThreshold: boolean = predictionValue < MIN_PREDICTION_CONFIDENCE;
    return {
      name: belowThreshold ? null : predictedClass,
      value: belowThreshold ? null : Math.round(100 * predictionValue) + "%",
    };
  };

  const sampleElementIds = useMemo(
    // These are used by selecto to query for the selectable elements.
    // We may have some IDs that contain characters that would break document queries (for example,
    // a "." would attempt to access a class). To prevent that, we need to pass escaped IDs.
    () => metadata.map((sample) => `#${CSS.escape(sample.id)}`),
    [metadata],
  );

  const sampleIdToMetadata: Record<string, T & SampleState> = useMemo(
    () =>
      metadata.reduce(
        (sampleMap, sample) => ({
          ...sampleMap,
          [sample.id]: sample,
        }),
        {},
      ),
    [metadata],
  );

  // Use metadata to determine selections
  useEffect(() => {
    const selectedSamples = metadata
      .filter((sample) => sample.selected)
      .map(
        (sample) =>
          document.querySelector(
            `#${CSS.escape(sample.id)}`,
          ) as HTMLDivElement | null,
      )
      .filter((element) => element !== null) as HTMLDivElement[];

    refSelecto.current?.setSelectedTargets(selectedSamples);
  }, [metadata]);

  // If there are no selections yet, we select the one right-clicked on
  // to use for the context menu
  const handleRightClick = (sampleId: string) => {
    if (refSelecto.current === null) {
      return;
    }

    const selectedTargets = refSelecto.current.getSelectedTargets();
    if (selectedTargets.length !== 0) {
      return;
    }

    const sample = withoutSampleState(sampleIdToMetadata[sampleId]);
    onSetSelection([sample]);
  };

  return (
    <>
      <Selecto
        ref={refSelecto}
        container={refGridViewContainer.current}
        selectableTargets={sampleElementIds}
        // Allow multiselect
        continueSelect={true}
        keyContainer={window}
        // The rate at which the target overlaps the drag area to be selected
        hitRate={1}
        onSelect={(e) => {
          if (disabled) {
            return;
          }

          const selectedMetadata = e.selected.map((selected) =>
            withoutSampleState(sampleIdToMetadata[selected.id]),
          );
          onSetSelection(selectedMetadata);
        }}
      />

      <div
        ref={refGridViewContainer}
        className={cx("tw-grid tw-flex-1 tw-gap-[2px] tw-justify-center")}
        style={{
          gridTemplateColumns: `repeat(auto-fit, ${size + 8}px)`,
          gridTemplateRows: `repeat(auto-fit, ${size + 8}px)`,
          ...(minColumns
            ? {
                minWidth:
                  // Cells with the 4px border
                  (size + 8) * minColumns +
                  // 2px gap
                  2 * (minColumns - 1),
              }
            : {}),
        }}
      >
        {metadata.map((metadata) => {
          const maxPrediction = getMaxPredictionDisplay(metadata.id);
          return (
            <div
              className={cx(
                "tw-group",
                "tw-cursor-pointer",
                "tw-rounded tw-border-4",
                metadata.selected
                  ? "tw-border-purple-500 tw-shadow-md"
                  : "tw-border-transparent",
                !disabled &&
                  (metadata.selected
                    ? "tw-opacity-70 hover:tw-opacity-100"
                    : "hover:tw-shadow-md hover:tw-shadow-purple"),
              )}
              key={metadata.id}
            >
              <div
                onClick={(e: any) => onClick?.(metadata, e)}
                onContextMenu={() => handleRightClick(metadata.id)}
                id={metadata.id}
                className={cx(
                  "tw-relative tw-flex tw-flex-row tw-z-20 tw-text-white tw-text-xs tw-h-full",
                )}
              >
                <div className="tw-absolute tw-inset-x-[2px] tw-flex tw-items-center tw-justify-between tw-mt-2 tw-flex-wrap">
                  <div
                    className={cx(
                      "tw-z-20 tw-px-2 tw-py-1 tw-rounded tw-truncate tw-max-w-[65%] tw-mb-1",
                      maxPrediction.name &&
                        "tw-bg-black tw-bg-opacity-50 tw-truncate",
                    )}
                  >
                    {maxPrediction.name ? (
                      maxPrediction.name
                    ) : predictionMap === null ? null : (
                      <PredictionWarning className={"tw-scale-125"} />
                    )}
                  </div>
                  <div
                    className={cx(
                      "tw-max-w-[65%] tw-mb-1",
                      "tw-opacity-0 group-hover:tw-opacity-100 tw-transition-opacity",
                      maxPrediction.name
                        ? "tw-bg-black tw-bg-opacity-50 tw-z-20 tw-rounded " +
                            "tw-border-white tw-px-2 tw-py-1"
                        : "",
                    )}
                  >
                    {maxPrediction.name !== null ? maxPrediction.value : ""}
                  </div>
                </div>
                <MultiChannelView
                  key={metadata.id}
                  index={{
                    dataset,
                    plate: metadata.plate,
                    well: metadata.well,
                    field: toNumericField(metadata.field),
                    t: 0,
                    z: 0,
                  }}
                  size={size}
                  imageSet={imageSet}
                  crop={{
                    size: cropSize ?? DEFAULT_CROP_SIZE_PX,
                    location: { x: metadata.row, y: metadata.column },
                  }}
                  forwardRef={forwardRef}
                  {...getCellMaskPropsForCrop(cropSize ?? DEFAULT_CROP_SIZE_PX)}
                />
              </div>
            </div>
          );
        })}
        {placeholders}
      </div>
    </>
  );
}
