import { AsyncDuckDB } from "@duckdb/duckdb-wasm";
import * as Dialog from "@radix-ui/react-dialog";
import useResizeObserver from "@react-hook/resize-observer";
import cx from "classnames";
import {
  memo,
  useCallback,
  useLayoutEffect,
  useMemo,
  useRef,
  useState,
} from "react";
import { Sliders, X } from "react-feather";
import { VariableSizeList as List, areEqual } from "react-window";
import { Button } from "src/Common/Button";
import { FilterSqlClause } from "src/Control/FilterSelector/types";
import { SelectionBreadcrumbs } from "src/ImageViewer/ImageViewer";
import { useAutoImageSet, useFields } from "src/hooks/immunofluorescence";
import MultiChannelView from "src/immunofluorescence/MultiChannelView";
import { DatasetId } from "src/types";
import { useSize } from "src/util/hooks";
import {
  sanitizedColumn,
  sanitizedTextValue,
  sql,
  sqlAnd,
  useQueryAsRecords,
} from "src/util/sql";
import {
  DatasetPlate,
  DatasetPlateWellField,
  Field,
  ImageSet,
} from "../imaging/types";
import { toNumericField } from "../imaging/util";

const THUMBNAIL_SIZE = 64;
const THUMBNAIL_GAP = 16;
const THUMBNAIL_BORDER = 8;
const THUMBNAIL_FULL_SIZE = THUMBNAIL_SIZE + THUMBNAIL_BORDER * 2;
const HEADING_SIZE = 48;
const WELL_ROW_PADDING = 8;
const WELL_ROW_HEIGHT =
  THUMBNAIL_FULL_SIZE + HEADING_SIZE + WELL_ROW_PADDING * 2;

const IMAGE_SIZE_MAGNIFIED = 800;

interface Props {
  className?: string;
  dataset: DatasetId;
  predicted: string;
  actual: string;
  classes: string[];
  db: AsyncDuckDB;
  tableName: string;
  filter?: FilterSqlClause;
  onToggleVisualizationControls: () => void;
}

interface PlateImagesProps {
  index: DatasetPlate;
  imageSet: ImageSet | null;
  entries: { well: string; prediction: number }[];
  onClick?: (plate: string, well: string, field: Field) => void;
}

function PlateImages({ index, imageSet, entries, onClick }: PlateImagesProps) {
  const { dataset, plate } = index;
  const fetchFields = useFields({ dataset, acquisition: plate });

  return (
    <>
      {entries.map(({ well, prediction }) => (
        <div
          key={[plate, well].join(">")}
          className="tw-flex tw-flex-col"
          style={{
            height: WELL_ROW_HEIGHT,
            padding: WELL_ROW_PADDING,
          }}
        >
          <div
            className="tw-flex tw-flex-row tw-items-start"
            style={{ height: HEADING_SIZE }}
          >
            <div className={"tw-flex tw-items-center"}>
              <SelectionBreadcrumbs plate={plate} well={well} field={null} />
              <div className={"tw-text-sm tw-text-slate-500 tw-ml-md"}>
                Prediction: {prediction.toFixed(3)}
              </div>
            </div>
          </div>
          <div
            className="tw-flex tw-flex-row tw-max-w-full tw-overflow-x-auto tw-overflow-y-hidden"
            style={{ height: THUMBNAIL_FULL_SIZE, gap: THUMBNAIL_GAP }}
          >
            {fetchFields?.successful &&
              fetchFields.value.map((field) => (
                <div
                  key={field}
                  className="tw-border"
                  style={{ width: THUMBNAIL_SIZE, height: THUMBNAIL_SIZE }}
                >
                  <MultiChannelView
                    index={{
                      ...index,
                      well,
                      field: toNumericField(field),
                      t: 0,
                      z: 0,
                    }}
                    key={field}
                    size={THUMBNAIL_SIZE}
                    onClick={onClick && (() => onClick(plate, well, field))}
                    imageSet={imageSet}
                    crop={null}
                  />
                </div>
              ))}
          </div>
        </div>
      ))}
    </>
  );
}

const Row = memo(
  ({
    index,
    style,
    data,
  }: {
    index: number;
    style: React.CSSProperties;
    data: PlateImagesProps[];
  }) => {
    const rowProps = data[index];

    return (
      <div
        data-key={rowProps.index.plate}
        key={rowProps.index.plate}
        style={style}
      >
        <PlateImages {...rowProps} />
      </div>
    );
  },
  areEqual,
);
Row.displayName = "Row";

export function ImagesForPrediction({
  className,
  filter,
  dataset,
  predicted,
  actual,
  classes,
  db,
  tableName,
  onToggleVisualizationControls,
}: Props) {
  const otherClasses = classes.filter((className) => className !== predicted);
  const refDisplayArea = useRef<HTMLDivElement>(null);

  const displayAreaSize = useSize(refDisplayArea, "border-box");
  const [selected, setSelected] = useState<{
    plate: string;
    well: string;
    field: Field;
  } | null>(null);
  const handleClickToViewZoomedImage = useCallback(
    (plate: string, well: string, field: Field) => {
      setSelected({ plate, well, field });
    },
    [setSelected],
  );

  const index = useMemo(
    () =>
      (selected && {
        dataset,
        plate: selected.plate,
        well: selected.well,
        field: toNumericField(selected.field),
      }) ??
      null,
    // eslint-disable-next-line react-hooks/exhaustive-deps
    [dataset, selected?.plate, selected?.well, selected?.field],
  );

  const thumbnailImageSet = useAutoImageSet({
    dataset,
    params: {
      imageSize: THUMBNAIL_SIZE,
      processingMode: "illumination-corrected",
    },
  });

  const fullSizeImageSet =
    useAutoImageSet({
      dataset,
      params: {
        imageSize: IMAGE_SIZE_MAGNIFIED,
        processingMode: "illumination-corrected",
      },
    }) ?? thumbnailImageSet;

  const records = useQueryAsRecords<{
    plate: string;
    well: string;
    prediction: number;
  }>(
    db,
    sql`SELECT 
        ${tableName}.plate as plate, 
        ${tableName}.well as well,
        ${tableName}.${sanitizedColumn(predicted)} as prediction
      FROM ${tableName} 
      INNER JOIN (
        SELECT *
        FROM sample_metadata
        WHERE ${filter || "TRUE"}
      ) sample_metadata
      ON ${tableName}.plate = sample_metadata.plate
      AND ${tableName}.well = sample_metadata.well
    WHERE ${sqlAnd([
      sql`_class=${sanitizedTextValue(actual)}`,
      ...otherClasses.map(
        (className) =>
          sql`${sanitizedColumn(predicted)} >= ${sanitizedColumn(className)}`,
      ),
    ])} ORDER BY plate, prediction DESC, well`,
  );

  const plateEntries: PlateImagesProps[] | undefined = useMemo(() => {
    if (!records || !records.successful) {
      return undefined;
    }

    const byPlate: Map<string, { well: string; prediction: number }[]> =
      new Map();

    records.value.forEach((record) => {
      const entry = byPlate.get(record.plate);
      if (entry) {
        entry.push({ well: record.well, prediction: record.prediction });
      } else {
        byPlate.set(record.plate, [
          { well: record.well, prediction: record.prediction },
        ]);
      }
    });

    return Array.from(byPlate.entries()).flatMap(([plate, entries]) => ({
      index: { dataset, plate },
      imageSet: thumbnailImageSet,
      entries,
      onClick: handleClickToViewZoomedImage,
    }));
  }, [dataset, thumbnailImageSet, records, handleClickToViewZoomedImage]);

  return (
    <div className={cx("tw-flex tw-flex-col tw-gap-md", className)}>
      <div className="tw-flex tw-flex-row tw-items-center tw-gap-lg tw-px-md tw-pt-md">
        <div className="tw-flex-1 tw-flex tw-flex-wrap tw-gap-lg tw-gap-y-sm">
          <div className="tw-inline-flex tw-flex-row tw-gap-sm tw-items-center">
            {/* TODO: Use the color associated with the class */}
            <div className="tw-w-md tw-h-md tw-bg-red-500 tw-rounded"></div>
            <div className="tw-truncate">
              {predicted} <span className={"tw-text-sm"}>(Predicted)</span>
            </div>
          </div>
          <div className="tw-text-slate-500">vs</div>
          <div className="tw-inline-flex tw-flex-row tw-gap-sm tw-items-center">
            {/* TODO: Use the color associated with the class */}
            <div className="tw-w-md tw-h-md tw-bg-blue tw-rounded"></div>
            <div className="tw-truncate">
              {actual} <span className={"tw-text-sm"}>(Actual)</span>
            </div>
          </div>
          <div></div>
          <div className="tw-flex-1" />
        </div>
        <div className="tw-self-start">
          <Button
            icon={Sliders}
            onClick={() => onToggleVisualizationControls()}
          >
            Image
          </Button>
        </div>
      </div>
      <div
        ref={refDisplayArea}
        className={cx("tw-flex-1 tw-flex tw-flex-col tw-gap-lg")}
      >
        {plateEntries && displayAreaSize && (
          <List
            itemData={plateEntries}
            itemCount={plateEntries.length}
            itemSize={(index) =>
              plateEntries[index].entries.length * WELL_ROW_HEIGHT
            }
            width={displayAreaSize.width}
            height={displayAreaSize.height}
            itemKey={(index, data) => data[index].index.plate}
          >
            {Row}
          </List>
        )}
      </div>
      {index && (
        <FullSizeImageModal
          index={index}
          imageSet={fullSizeImageSet}
          targetSize={IMAGE_SIZE_MAGNIFIED}
          onClose={() => setSelected(null)}
        />
      )}
    </div>
  );
}

function FullSizeImageModal({
  index,
  imageSet,
  targetSize,
  onClose,
}: {
  index: DatasetPlateWellField;
  imageSet: ImageSet | null;
  targetSize: number;
  onClose?: () => void;
}) {
  const [viewportDimensions, setViewportDimensions] = useState<
    [number, number] | null
  >(null);
  useLayoutEffect(() => {
    setViewportDimensions([window.innerWidth, window.innerHeight]);
  }, []);
  useResizeObserver(document.body, () =>
    setViewportDimensions([window.innerWidth, window.innerHeight]),
  );
  const PADDING = 16;
  const actualSize = viewportDimensions
    ? Math.max(
        1,
        Math.min(
          viewportDimensions[0] - PADDING * 2,
          viewportDimensions[1] - PADDING * 2,
          targetSize,
        ),
      )
    : targetSize;

  return (
    <Dialog.Root
      open={true}
      onOpenChange={(open) => {
        if (!open) {
          onClose?.();
        }
      }}
    >
      <Dialog.Portal>
        <div
          className={cx(
            "tw-fixed tw-left-0 tw-top-0 tw-w-full tw-h-full tw-z-popup-overlay",
            "tw-bg-slate-400 tw-opacity-70",
          )}
        />
        <Dialog.Overlay
          className={cx(
            "tw-fixed tw-left-0 tw-top-0 tw-w-full tw-h-full tw-z-popup",
            "tw-flex tw-flex-col tw-justify-center",
          )}
        >
          <Dialog.Content
            className={cx(
              "tw-relative tw-m-auto tw-overflow-hidden",
              "tw-rounded-lg tw-bg-white tw-shadow-lg",
              "tw-max-w-[100vw] tw-max-h-[100vh]",
            )}
          >
            {viewportDimensions && (
              <div className={"tw-relative"}>
                <MultiChannelView
                  key={JSON.stringify(index)}
                  index={{ ...index, t: 0, z: 0 }}
                  imageSet={imageSet}
                  size={actualSize}
                  crop={null}
                  showDownloadControls={false}
                  showCursorAnnotations={false}
                />
                <div className={"tw-absolute tw-text-white tw-left-2 tw-top-2"}>
                  {index.plate} {index.well} {index.field}
                </div>
                <Dialog.Close
                  asChild
                  className={
                    "tw-absolute tw-right-2 tw-top-2 tw-text-white tw-cursor-pointer"
                  }
                >
                  <X />
                </Dialog.Close>
              </div>
            )}
          </Dialog.Content>
        </Dialog.Overlay>
      </Dialog.Portal>
    </Dialog.Root>
  );
}
