import { AsyncDuckDB } from "@duckdb/duckdb-wasm";
import * as Sentry from "@sentry/react";
import cx from "classnames";
import {
  MouseEvent,
  forwardRef,
  useCallback,
  useEffect,
  useMemo,
  useRef,
  useState,
} from "react";
import { ChevronLeft, ChevronRight } from "react-feather";
import MultiChannelView from "src/immunofluorescence/MultiChannelView";
import { Fetchable } from "@spring/core/result";
import { Label } from "@spring/ui/typography";
import { useDialog } from "../../Common/useDialog";
import ControlHeader from "../../Control/ControlHeader";
import FilterSelector from "../../Control/FilterSelector";
import { TypedColumn } from "../../Control/FilterSelector/backend-types";
import { Operator } from "../../Control/FilterSelector/operations/filter-by";
import { Filter, FilterSet } from "../../Control/FilterSelector/types";
import {
  serializeToSqlClause,
  updateFilters,
  updateOperator,
  validateFilterFetchable,
} from "../../Control/FilterSelector/utils";
import { useAutoImageSet } from "../../hooks/immunofluorescence";
import useWindowDimensions from "../../hooks/utils";
import Close from "../../icons/Close.svg";
import {
  Color,
  DatasetPlateWell,
  DatasetPlateWellField,
  Field,
  ImageSet,
} from "../../imaging/types";
import { metadataToKey, toNumericField } from "../../imaging/util";
import { DatasetId } from "../../types";
import { useSize } from "../../util/hooks";
import { columnComparator, defaultComparator } from "../../util/sorting";
import { sql, useCacheableQueryAsRecords } from "../../util/sql";
import { IMAGE_SIZES } from "../constants";
import { FieldAggMetadataRow } from "../types";

const ROW_SAMPLING_SEED = 42;

export default function ImageGroup({
  groupIndex,
  label,
  metadataDB,
  metadataColumns,
  dataset,
  filterSet,
  onChangeFilterSet,
  showSegmentation,
  showTimepoints,
  selectedMaskColor,
  imageSize,
  onRemoveGroup,
  isLastGroup,
}: {
  groupIndex: number;
  label: string | undefined;
  metadataDB: Fetchable<AsyncDuckDB>;
  metadataColumns: TypedColumn[];
  dataset: DatasetId;
  filterSet: FilterSet;
  onChangeFilterSet: (filterSet: FilterSet) => void;
  showSegmentation: boolean;
  showTimepoints: boolean;
  selectedMaskColor: Color;
  imageSize: number;
  onRemoveGroup: (() => void) | null;
  isLastGroup: boolean;
}) {
  const groupLabel = label || `Group ${groupIndex + 1}`;
  const windowSize = useWindowDimensions();

  // Image loading/row count strategy:
  // 1. Load N images by default across Y rows
  // 2. When the container is initially resized, calculate if more images will fit in the default rows
  // 3. Load the additional images if applicable

  const defaultImageCount = 6;
  const initialRows = 3;
  const [imageCount, setImageCount] = useState<number>(defaultImageCount);

  const [ignoredFilters, setIgnoredFilters] = useState<Filter[]>([]);

  const containerRef = useRef<HTMLDivElement | null>(null);
  const containerSize = useSize(containerRef, "content-box");
  const containerWidth = containerSize?.width;

  useEffect(() => {
    const gutterWidth = 15;

    const calculatedImagesPerRow = containerWidth
      ? Math.floor(containerWidth / (imageSize + gutterWidth / 2))
      : null;

    const newDefaultImageCount = calculatedImagesPerRow
      ? calculatedImagesPerRow * initialRows
      : null;

    if (newDefaultImageCount && newDefaultImageCount > imageCount) {
      setImageCount(newDefaultImageCount);
    }
  }, [containerWidth, imageCount, imageSize]);

  const handleUpdateImageCount = () => {
    setImageCount((prevImageCount) => prevImageCount + defaultImageCount);
  };

  const validateAndSetFilter = (filter: FilterSet) => {
    validateFilterFetchable(metadataDB, filter).then((filter) => {
      onChangeFilterSet(filter);
      setIgnoredFilters(
        filter.operator === Operator.AND ? filter.ignoredFilters : [],
      );
    });
  };

  const filterSerialized = useMemo(
    () => serializeToSqlClause(filterSet),
    [filterSet],
  );

  const filterSetWithIgnoredFilters = useMemo(() => {
    return {
      ...filterSet,
      ignoredFilters: ignoredFilters,
    };
  }, [filterSet, ignoredFilters]);

  // Shuffle the rows to get a random sample of images. This is particularly
  // relevant since we are returning fields and ideally don't want to show
  // clusters of fields for the same well. Further, we increment the default
  // seed so that each group gets a new order.
  const sampleSeed = ROW_SAMPLING_SEED + groupIndex;

  const fieldAggMetadata = useCacheableQueryAsRecords<
    FieldAggMetadataRow & { hash_value: string }
  >(
    metadataDB?.successful ? metadataDB.value : null,
    // 1) Filter the data
    // 2) Randomly sort the fields (seeded)
    // 3) Aggregate the fields for each plate+well group
    // 4) Join and select
    // 5) Limit the output
    sql`
      WITH filtered_data AS (
        SELECT *,      
        MD5(CAST(${sampleSeed} AS VARCHAR) || plate || well || field${
          showTimepoints ? "|| timepoint" : ""
        }) AS hash_value
        FROM sample_metadata
        WHERE ${filterSerialized}
        ORDER BY hash_value
      ),
      aggregated_fields AS (
        SELECT
          plate,
          well,
          ARRAY_AGG(field) AS allFields
        FROM filtered_data
        GROUP BY plate, well
      )
      SELECT 
        filtered_data.*, 
        aggregated_fields.allFields
      FROM 
        filtered_data
      JOIN 
        aggregated_fields
      ON 
        filtered_data.plate = aggregated_fields.plate AND filtered_data.well = aggregated_fields.well
      LIMIT ${imageCount};
  `,
  );

  // Note(davidsharff):
  // 1. DuckDB returns some sort of vector object for ARRAY_AGG calls.
  // 2. Cleanup from query shuffle: reorder the fields list for the carousel and drop
  //    the synthesized hash_value.
  const filteredWells: FieldAggMetadataRow[] = useMemo(
    () =>
      fieldAggMetadata.result?.successful
        ? fieldAggMetadata.result.value.map((row) => {
            const { hash_value, ...rest } = row;
            return {
              ...rest,
              allFields: Array.from(new Set(row.allFields)).sort(
                defaultComparator,
              ),
            };
          })
        : [],
    [fieldAggMetadata.result],
  );

  // Used to support inner scrolling beneath the filter toolbar
  const viewableContainerHeight = containerRef.current
    ? windowSize.height - containerRef.current.offsetTop
    : null;

  return (
    <>
      {metadataDB?.successful ? (
        <div
          ref={containerRef}
          className={cx(
            "tw-flex tw-flex-col tw-flex-1 tw-overflow-hidden",
            "tw-shadow-lg tw-rounded-md",
            // Override default box sizing so that padding/border are independent of width
            "tw-box-content tw-p-4 tw-border tw-border-slate-100",
            { "tw-mr-8": !isLastGroup },
          )}
          style={
            viewableContainerHeight
              ? {
                  height: viewableContainerHeight,
                  minWidth: imageSize,
                }
              : {}
          }
        >
          <div>
            <ControlHeader extraClasses="tw-flex tw-justify-between tw-items-center">
              <span>{groupLabel}</span>
              {onRemoveGroup && (
                <Close
                  className="hover:tw-text-black tw-cursor-pointer"
                  onClick={onRemoveGroup}
                />
              )}
            </ControlHeader>
            <FilterSelector
              columns={metadataColumns}
              filterSet={filterSetWithIgnoredFilters}
              onChangeFilters={(filters) =>
                validateAndSetFilter(
                  updateFilters(filterSetWithIgnoredFilters, filters),
                )
              }
              onChangeOperator={(operator, newFilters) =>
                validateAndSetFilter(
                  updateOperator(
                    filterSetWithIgnoredFilters,
                    operator,
                    newFilters,
                  ),
                )
              }
              metadata={metadataDB.value}
            />
          </div>
          <ImageViewport
            metadataRows={filteredWells}
            dataset={dataset}
            imageSize={imageSize}
            onLoadMoreImages={handleUpdateImageCount}
            showSegmentation={showSegmentation}
            showTimepoints={showTimepoints}
            selectedMaskColor={selectedMaskColor}
          />
        </div>
      ) : (
        <div>There was an error loading the images.</div>
      )}
    </>
  );
}

function ImageViewport({
  metadataRows,
  dataset,
  imageSize,
  onLoadMoreImages,
  showSegmentation,
  showTimepoints,
  selectedMaskColor,
}: {
  metadataRows: FieldAggMetadataRow[];
  dataset: DatasetId;
  imageSize: number;
  onLoadMoreImages: () => void;
  showSegmentation: boolean;
  showTimepoints: boolean;
  selectedMaskColor: Color;
}) {
  const [isLoading, setIsLoading] = useState<boolean>(false);
  const [prevScrollTop, setPrevScrollTop] = useState<number>(0);

  const viewportRef = useRef<HTMLDivElement | null>(null);
  const finalImageRef = useRef<HTMLDivElement | null>(null);

  const { height: windowHeight } = useWindowDimensions();

  // TODO(davidsharff): there is probably a better way to do this using an npm windowing package
  // but this wasn't hard to pull off, but we should explore standard options if it is a problem.
  const handleScoll = (e: React.UIEvent<HTMLDivElement, UIEvent>) => {
    // TODO(you): Fix this no-unnecessary-condition rule violation
    // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
    if (finalImageRef.current && e.currentTarget) {
      const currentScrollTop = e.currentTarget.scrollTop;

      // Only load more images if we are scrolling down and the final image is in view.
      if (!isLoading && currentScrollTop > prevScrollTop) {
        const { bottom } = finalImageRef.current.getBoundingClientRect();
        // Subtract 2 to account for the border (container is content-box)
        if (bottom - 2 < windowHeight) {
          setIsLoading(true);

          // TODO(davidsharff): hack to prevent multiple load calls as the scroll bar reaches the bottom.
          // Ideally, we could tie into the actual loading state buried within the MultiChannelView.
          setTimeout(() => {
            setIsLoading(false);
          }, 100);

          onLoadMoreImages();
        }
      }
      setPrevScrollTop(currentScrollTop);
    }
  };

  // TODO(davidsharff): it could be nice to fallback on overlays in the future or even just prevent refetches if
  // we know the mask path is bad.
  const handleFetchMaskError = useCallback((e: Error) => {
    Sentry.captureException(e);
  }, []);

  // TODO(davidsharff): I'm not sure we need the colorMatrix or binarize filter anymore if we stick with a single color border.
  // However, I'm leaving a backdoor available for the color picker for now.
  return (
    <div
      ref={viewportRef}
      className="tw-flex tw-flex-wrap tw-gap-4 tw-overflow-y-auto"
      onScroll={handleScoll}
    >
      {metadataRows.map((metadata, i) => (
        <Image
          key={metadataToKey(metadata)}
          index={{
            dataset: dataset,
            plate: metadata.plate,
            well: metadata.well,
          }}
          metadata={metadata}
          imageRef={i === metadataRows.length - 1 ? finalImageRef : undefined}
          imageSize={imageSize}
          showSegmentation={showSegmentation}
          showTimepoint={showTimepoints}
          selectedMaskColor={selectedMaskColor}
          onFetchError={handleFetchMaskError}
        />
      ))}
    </div>
  );
}

function Image({
  index,
  metadata,
  imageRef,
  imageSize,
  showSegmentation,
  showTimepoint,
  selectedMaskColor,
}: {
  index: DatasetPlateWell;
  metadata: FieldAggMetadataRow;
  imageRef?: React.Ref<HTMLDivElement>;
  imageSize: number;
  showSegmentation: boolean;
  showTimepoint: boolean;
  selectedMaskColor: Color;
  onFetchError: (e: Error) => void;
}) {
  const [selectedField, setSelectedField] = useState<Field>(
    metadata.field as Field,
  );

  useEffect(() => {
    // The group filter could have dropped the selected field.
    if (!metadata.allFields.includes(selectedField)) {
      setSelectedField(metadata.allFields[0]);
    }
  }, [metadata.allFields, selectedField]);

  const updatedMetadata: FieldAggMetadataRow = useMemo(
    () => ({
      ...metadata,
      field: selectedField,
    }),
    [selectedField, metadata],
  );

  const fullIndex = {
    ...index,
    field: toNumericField(selectedField),
    t: showTimepoint ? metadata.timepoint : 0,
    z: 0,
  };

  const { DialogComponent, setOpenDialog } = useDialog();

  const { height: windowHeight } = useWindowDimensions();

  const maximizedSize = Math.min(IMAGE_SIZES.max, windowHeight - 20);

  const imageSet = useAutoImageSet({
    dataset: index.dataset,
    params: {
      imageSize,
      processingMode: "illumination-corrected",
    },
  });

  const handleChangeField = (
    e: MouseEvent<HTMLButtonElement>,
    direction: "forward" | "backward",
  ) => {
    const { allFields } = metadata;

    const step = direction === "forward" ? 1 : -1;
    const currentIndex = allFields.indexOf(selectedField);
    const nextIndex =
      (currentIndex + step + allFields.length) % allFields.length;
    const nextField = allFields[nextIndex];

    setSelectedField(nextField);
  };

  const carouselControlClasses =
    "tw-absolute tw-top-[calc(50%-11px)] tw-text-white tw-outline-none tw-invisible group-hover:tw-visible";

  return (
    <div className="tw-mb-4 tw-border-b" ref={imageRef}>
      <div className="tw-text-slate-500 tw-font-mono">
        <span>
          {showTimepoint
            ? `${index.well} > ${selectedField} > t${metadata.timepoint}`
            : `${index.well} > ${selectedField}`}
        </span>
      </div>
      <div className="tw-mb-4 tw-relative tw-group">
        <div className="tw-cursor-pointer" onClick={() => setOpenDialog(true)}>
          <ImageDisplay
            ref={imageRef}
            size={imageSize}
            index={fullIndex}
            imageSet={imageSet}
            showSegmentation={showSegmentation && !!imageSet}
            selectedMaskColor={selectedMaskColor}
          />
        </div>
        {metadata.allFields.length > 1 && (
          <>
            <button
              className={cx(carouselControlClasses, "tw-left-0")}
              onClick={(e) => handleChangeField(e, "backward")}
            >
              <ChevronLeft />
            </button>

            <button
              className={cx(carouselControlClasses, "tw-right-0")}
              onClick={(e) => handleChangeField(e, "forward")}
            >
              <ChevronRight />
            </button>
          </>
        )}
      </div>
      <DialogComponent>
        <div className="tw-flex">
          <ImageDisplay
            size={maximizedSize}
            index={fullIndex}
            imageSet={imageSet}
            showSegmentation={showSegmentation && !!imageSet}
            selectedMaskColor={selectedMaskColor}
          />
          <div
            className="tw-w-[400px] tw-bg-slate-700 tw-p-8 tw-overflow-y-auto"
            style={{
              height: maximizedSize,
            }}
          >
            <Label className="tw-text-white tw-mb-4">Metadata</Label>
            <div className="tw-grow tw-text-sm tw-text-slate-300">
              {Object.keys(updatedMetadata)
                .filter((key) => key !== "allFields" && key !== "well_id")
                .sort(columnComparator)
                .map((key) => (
                  <div key={key} className={"tw-mb-4 tw-flex"}>
                    <div
                      title={key}
                      className="tw-capitalize tw-truncate tw-mr-1"
                    >
                      {key}
                    </div>
                    <div
                      title={`${updatedMetadata[key]}`}
                      className="tw-truncate tw-font-mono tw-text-right tw-flex-1"
                    >
                      {updatedMetadata[key] === null
                        ? "<null>"
                        : updatedMetadata[key]}
                    </div>
                  </div>
                ))}
            </div>
          </div>
        </div>
      </DialogComponent>
    </div>
  );
}

const ImageDisplay = forwardRef(function ImageDisplay(
  {
    size,
    index,
    imageSet,
    showSegmentation,
    selectedMaskColor,
  }: {
    size: number;
    index: DatasetPlateWellField;
    imageSet: ImageSet | null;
    showSegmentation: boolean;
    selectedMaskColor: Color;
  },
  ref: React.Ref<HTMLDivElement>,
) {
  const fullIndex = { ...index, t: 0, z: 0 };

  return (
    <div ref={ref}>
      <div
        style={{
          width: size,
          height: size,
        }}
      >
        <MultiChannelView
          index={fullIndex}
          imageSet={imageSet}
          size={size}
          crop={null}
          {...(showSegmentation
            ? {
                showMasks: true,
                maskOptions: { maskColor: selectedMaskColor },
              }
            : { showMasks: false })}
        />
      </div>
    </div>
  );
});
