import * as RadixSwitch from "@radix-ui/react-switch";
import cx from "classnames";
import type { Dayjs } from "dayjs";
import dayjs from "dayjs";
import React, {
  ReactElement,
  ReactNode,
  useCallback,
  useEffect,
  useMemo,
  useRef,
  useState,
} from "react";
import { usePopper } from "react-popper";
import { useActiveWorkspaceId } from "src/Workspace/hooks";
import { metadataToKey } from "src/imaging/util";
import { DEFAULT_TIMEPOINT } from "src/timeseries/constants";
import { useDebouncedCallback } from "use-debounce";
import { Fetchable, Success } from "@spring/core/result";
import {
  colorSchemeByWellMetadata,
  colorValuesByScheme,
} from "../Control/ColorSchemeSelector";
import { FilterSqlClause } from "../Control/FilterSelector/types";
import { NormalizationConfig } from "../SupervisedLearning/types";
import { useToastContext } from "../Toast/context";
import { useFeatureFlag } from "../Workspace/feature-flags";
import { getDevOverride } from "../dev-options";
import { ENV } from "../env";
import { UmapResult, useUmap } from "../hooks/features";
import {
  DatasetId,
  MetadataColumnValue,
  UntypedSampleMetadataRow,
  UntypedWellSampleMetadataRow,
} from "../types";
import { FULL_GRADIENT, colorForValue } from "../util/colors";
import { inferInterestingColumns } from "../util/dataset-util";
import { defaultComparator } from "../util/sorting";
import { sql, useQueryAsRecords } from "../util/sql";
import { useComponentSpan } from "../util/tracing";
import { MULTI_FEATURE_TOAST_CONTAINER_ID } from "./FeatureSetManagementPage.constants";
import HoverPopup from "./HoverPopup";
import { LabMate } from "./MultiFeature/LabMate";
import LeftNavSectionTitle from "./MultiFeature/LeftNavSectionTitle";
import MetadataSelector from "./MultiFeature/MetadataSelector";
import NormalizationOptions from "./MultiFeature/NormalizationOptions";
import PlotLegend, { DiscreteLegend, Legend } from "./MultiFeature/PlotLegend";
import { PlotLoadingMessage } from "./MultiFeature/PlotLoadingMessage";
import { PointHoverPopup } from "./MultiFeature/PointHoverPopup";
import SidebarSelectionView from "./MultiFeature/SidebarSelectionView";
import { useSimilaritiesViewState } from "./MultiFeature/SimilaritiesViewStateContext";
import { setUnion } from "./MultiFeature/set-utils";
import {
  MetadataSelection,
  ScatterplotSelection,
  SelectionKind,
} from "./MultiFeature/types";
import {
  createFilterForSamplingAsync,
  getDedupedMetadata,
  getGroupedMetadata,
  handleDownloadData,
} from "./MultiFeature/utils";
import ScatterPlot from "./ScatterPlot";
import { ScatterPlotPoint, ScatterPlotPointState } from "./ScatterPlot/types";
import { useFeatureSetManagementContext } from "./context";
import { MetadataColors, PointsByKey, UmapRow } from "./types";

// Max estimated runtime of 60s provided by the estimation below
const MAX_WELLS_BEFORE_REQUIRING_BATCH_INFRA = 384 * 20;
// On GPU infra, the calculation is much faster (10x-100x, depending on case).
// We run many fewer instances, though, so we want to target a lower max run
// time of 30s. In our most conservative 10x improvement, targeting 30s, that
// translates to 300s normal runtime, or about 100 plates.
const MAX_WELLS_BEFORE_REQUIRING_BATCH_INFRA_WITH_GPUS = 384 * 100;

const getEstimatedUmapFetchTimeSecs = (
  numWells: number,
  usingGPU: boolean = false,
) => {
  // Linear model provided by fitting to a benchmark (R^2 of 0.96):
  // https://springdiscovery.slack.com/archives/C04LZ2A72VA/p1685025730103519
  const baseline = 0.007906 * numWells - 0.46;
  // With a T4 GPU, our normal estimate is that it's about 10x faster (though
  // this can be much greater depending on parameters).
  if (usingGPU) {
    return baseline / 10;
  } else {
    return baseline;
  }
};

type SampledResult = {
  isSampled: boolean;
  /**
   * The "active" filter.
   *
   * If sampling is triggered, will be the consolidated filter incorporating both
   * user specified and forced/sampling filters. Otherwise, it's just the user specified
   * filter.
   */
  filterSerialized: FilterSqlClause;

  /**
   * The "active" metadata.
   *
   * If sampling is triggered, will be the sampled metadata. Otherwise, it will
   * be the metadata from the analysis context (which incorporates the user filters).
   */
  metadata: UntypedWellSampleMetadataRow[];
  estimatedFinishTime?: Dayjs;
  result: Fetchable<UmapResult>;
};

type UmapFetchConfig = {
  dataset: DatasetId;
  features: string[];
  plates: string[];
  columns?: string[];
  groupByColumns?: string[];
  distanceMetric?: "euclidean" | "cosine";
  zScore: boolean;
  normalizationConfig: NormalizationConfig;
};

function useUmapRequestOptions(
  config: UmapFetchConfig | null,
  isSampled: boolean,
):
  | {
      isSampled: boolean;
      useBatch: boolean;
      requestFilter: string;
      requestedMetadata: UntypedWellSampleMetadataRow[];
      useGPU: boolean;
    }
  | undefined {
  const context = useFeatureSetManagementContext();
  const { filter, filterSerialized, filteredMetadata, metadataDB } = context;

  const [requestFilter, setRequestFilter] = useState<string | undefined>();
  const [requestedMetadata, setRequestedMetadata] = useState<
    UntypedWellSampleMetadataRow[] | undefined
  >();

  const useGPU = !!getDevOverride("hostUrl") || ENV === "production";

  const maxWellsBeforeBatch = useGPU
    ? MAX_WELLS_BEFORE_REQUIRING_BATCH_INFRA_WITH_GPUS
    : MAX_WELLS_BEFORE_REQUIRING_BATCH_INFRA;

  const updateSamplingFilter = useCallback(async () => {
    const [sampleFilter, sampleMetadata] = await createFilterForSamplingAsync(
      metadataDB,
      config?.plates ?? [],
      filter,
      maxWellsBeforeBatch,
    );
    setRequestFilter(sampleFilter);
    setRequestedMetadata(sampleMetadata);
  }, [config?.plates, filter, metadataDB, maxWellsBeforeBatch]);

  useEffect(() => {
    // Clear to prevent a stale filter while waiting for the async sample call
    setRequestFilter(undefined);
    setRequestedMetadata(undefined);

    if (isSampled) {
      updateSamplingFilter();
    } else {
      setRequestFilter(filterSerialized);
      setRequestedMetadata(filteredMetadata);
    }
  }, [isSampled, updateSamplingFilter, filterSerialized, filteredMetadata]);

  return useMemo(() => {
    return requestFilter === undefined || requestedMetadata === undefined
      ? undefined
      : {
          isSampled,
          requestFilter,
          requestedMetadata,
          useBatch: requestedMetadata.length > maxWellsBeforeBatch,
          useGPU,
        };
  }, [
    requestFilter,
    requestedMetadata,
    isSampled,
    useGPU,
    maxWellsBeforeBatch,
  ]);
}

function useMaybeSampledUmap(
  config: UmapFetchConfig | null,
  isSampled: boolean,
): Fetchable<SampledResult> {
  const { dataset, features, plates, columns, groupByColumns, distanceMetric } =
    config || {};
  const umapOptions = useUmapRequestOptions(config, isSampled);
  const umapResult = useUmap(
    // This is a hack to force a no-op until the filter is ready, since we can't
    // conditionally call this hook
    umapOptions?.requestFilter
      ? {
          dataset,
          sqlFilter: umapOptions.requestFilter,
          features: features ?? [],
          plates: plates ?? [],
          columns,
          groupByColumns,
          useBatch: umapOptions.useBatch,
          distanceMetric,
          zScore: config?.zScore,
          useGPU: umapOptions.useGPU,
          normalizationConfig: config?.normalizationConfig,
        }
      : { skip: true },
  );

  // eslint-disable-next-line react-hooks/exhaustive-deps
  const timeRequestStarted = useMemo(() => dayjs(), [config, isSampled]);

  return useMemo(() => {
    // We won't make the call until we define a filter
    if (umapOptions === undefined) {
      return undefined;
    }

    const estimatedDurationSecs = Math.max(
      getEstimatedUmapFetchTimeSecs(umapOptions.requestedMetadata.length),
      5,
    );

    return Success.of({
      isSampled: umapOptions.isSampled,
      metadata: umapOptions.requestedMetadata,
      filterSerialized: umapOptions.requestFilter,
      estimatedFinishTime: timeRequestStarted.add(estimatedDurationSecs, "s"),
      result: umapResult,
    });
  }, [timeRequestStarted, umapOptions, umapResult]);
}

export const UmapScatterPlot = React.forwardRef<
  HTMLDivElement,
  {
    dataset: DatasetId;
    plates: string[];
    features: string[];
    columns?: string[];
    onOpenFilterSelector: () => void;
    onSetFilterHeader: (header: ReactNode) => void;
    normalizationColumns: string[];
    onChangeNormalizationColumns: (columns: string[]) => void;
  }
>(function UmapScatterPlot(
  {
    dataset,
    plates,
    features,
    columns,
    onOpenFilterSelector,
    onSetFilterHeader,
    normalizationColumns,
    onChangeNormalizationColumns,
  },
  ref,
) {
  useComponentSpan("UmapScatterPlot", [dataset, plates, features]);

  const workspaceId = useActiveWorkspaceId();
  const { filterSerialized, filteredMetadata, setOnDownload } =
    useFeatureSetManagementContext();

  const { colorBy, groupBy, selections, update } = useSimilaritiesViewState();

  const setSelections = useCallback(
    (selections: SelectionKind[]) => {
      update({ selections });
    },
    [update],
  );

  const setColorBy = useCallback(
    (colorBy: string | null) => {
      update({ colorBy });
    },
    [update],
  );

  // The fetch config is defined by the "base parameters" provided to us by the
  // parent/caller (i.e. the general features selected), as well as other "internal"
  // config parameters (like aggregation columns).
  const config: UmapFetchConfig = useMemo(() => {
    return {
      dataset,
      features,
      plates,
      columns,
      groupByColumns: groupBy ? [groupBy] : undefined,
      distanceMetric: workspaceId === "gilead" ? "cosine" : undefined,

      // We generally z-score, assuming it's nearly always useful.
      // In the future, we could expose this as a user-configurable option, should
      // there be cases where we want differing feature magnitudes to be preserved.
      // TODO(zb): Make this user-configurable
      //  https://github.com/spring-discovery/spring-experiments/pull/7933
      zScore: workspaceId !== "gilead",
      normalizationConfig:
        normalizationColumns.length > 0
          ? {
              type: "stratified",
              stratificationColumns: normalizationColumns,
              targetFilter: null,
            }
          : { type: "none" },
    };
  }, [
    workspaceId,
    dataset,
    features,
    plates,
    columns,
    groupBy,
    normalizationColumns,
  ]);
  const shouldUmapBeSampledByDefault = useFeatureFlag(
    "umap-sampled-by-default",
  );
  const isUmapUserConfigSamplingAllowed = useFeatureFlag(
    "umap-user-config-sampling",
  );

  const useGPU = !!getDevOverride("hostUrl") || ENV === "production";

  const maxWellsBeforeBatch = useGPU
    ? MAX_WELLS_BEFORE_REQUIRING_BATCH_INFRA_WITH_GPUS
    : MAX_WELLS_BEFORE_REQUIRING_BATCH_INFRA;

  const isDatasetLarge = filteredMetadata.length > maxWellsBeforeBatch;

  const [isUserSamplingEnabled, setIsUserSamplingEnabled] = useState(
    shouldUmapBeSampledByDefault,
  );

  const isSamplingEnabled = isUmapUserConfigSamplingAllowed
    ? isUserSamplingEnabled
    : shouldUmapBeSampledByDefault;

  const isSampled = isSamplingEnabled && isDatasetLarge;

  const maybeSampledUmap = useMaybeSampledUmap(config, isSampled);
  const maybeSampledResult = maybeSampledUmap?.successful
    ? maybeSampledUmap.value
    : null;
  const umapResult = maybeSampledResult?.result;
  const data: Fetchable<UmapRow[]> = useQueryAsRecords<UmapRow>(
    umapResult?.successful ? umapResult.value.umapDB : null,
    sql`SELECT * FROM umap`,
  );

  const { setToast: _setToast, dismissToast } = useToastContext();
  const toastKey = maybeSampledResult?.isSampled
    ? `${dataset}-${features.join("-")}-${columns?.join(
        "-",
      )}-${filterSerialized}`
    : "";

  const setToast = useCallback(
    (toastKey: string, message: string) => {
      _setToast(toastKey, message, MULTI_FEATURE_TOAST_CONTAINER_ID);
    },
    [_setToast],
  );

  // TODO(davidsharff): if the user filter comes in beneath the threshold,
  // we'll briefly flash the toast and then remove it.
  useEffect(() => {
    if (!isUmapUserConfigSamplingAllowed && maybeSampledResult?.isSampled) {
      setToast(
        toastKey,
        "The view you've selected contains too much data to analyze at once and has been sampled; view the filter to review and adjust what data is shown.",
      );
    }

    return () => dismissToast(toastKey);
  }, [
    isUmapUserConfigSamplingAllowed,
    toastKey,
    maybeSampledResult?.isSampled,
    setToast,
    dismissToast,
  ]);

  const handleDownload = useCallback(() => {
    return handleDownloadData(dataset, data, filteredMetadata);
  }, [dataset, data, filteredMetadata]);

  useEffect(() => {
    // TODO(benkomalo): it'd be nice if we showed immediately and if user clicked it'd
    // just spin, but that requires a bit more machinery.
    if (!data?.successful) {
      setOnDownload(null);
    } else {
      setOnDownload(handleDownload);
    }
    return () => setOnDownload(null);
  }, [data?.successful, handleDownload, setOnDownload]);

  useEffect(() => {
    onSetFilterHeader(
      isUmapUserConfigSamplingAllowed && isDatasetLarge ? (
        <FilterSamplingToggle
          samplingEnabled={isUserSamplingEnabled}
          onSetSamplingEnabled={setIsUserSamplingEnabled}
        />
      ) : null,
    );
  }, [
    isUmapUserConfigSamplingAllowed,
    isDatasetLarge,
    isUserSamplingEnabled,
    setIsUserSamplingEnabled,
    onSetFilterHeader,
  ]);

  // Always clear on unmount.
  useEffect(() => {
    return () => onSetFilterHeader(null);
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, []);

  return (
    <UmapScatterPlotWithData
      dataset={dataset}
      filterSerialized={maybeSampledResult?.filterSerialized}
      plates={plates}
      features={features}
      umapResult={umapResult}
      data={data}
      estimatedFinishTime={maybeSampledResult?.estimatedFinishTime}
      metadata={maybeSampledResult?.metadata}
      coloringMetadataColumn={colorBy}
      setColoringMetadataColumn={setColorBy}
      groupByColumn={groupBy}
      selections={selections}
      setSelections={setSelections}
      onOpenFilterSelector={onOpenFilterSelector}
      ref={ref}
      normalizationColumns={normalizationColumns}
      onChangeNormalizationColumns={onChangeNormalizationColumns}
    />
  );
});

const UmapScatterPlotWithData = React.forwardRef<
  HTMLDivElement,
  {
    dataset: DatasetId;
    plates: string[];
    features: string[];
    umapResult?: Fetchable<UmapResult>;
    filterSerialized?: FilterSqlClause;
    data: Fetchable<UmapRow[]>;
    estimatedFinishTime?: Dayjs;
    metadata?: UntypedSampleMetadataRow[];
    coloringMetadataColumn: string | null;
    setColoringMetadataColumn: (column: string | null) => void;
    groupByColumn: string | null;
    selections: SelectionKind[];
    setSelections: (selections: SelectionKind[]) => void;
    onOpenFilterSelector: () => void;
    normalizationColumns: string[];
    onChangeNormalizationColumns: (columns: string[]) => void;
  }
>(function UmapScatterPlotWithData(
  {
    dataset,
    plates,
    features,
    umapResult,
    filterSerialized,
    data,
    estimatedFinishTime,
    metadata,
    coloringMetadataColumn,
    setColoringMetadataColumn,
    groupByColumn,
    selections,
    setSelections,
    onOpenFilterSelector,
    normalizationColumns,
    onChangeNormalizationColumns,
  },
  ref,
) {
  const [hoveredLegendKey, setHoveredLegendKey] = useState<string | null>(null);
  const [hoveredClusterLabel, setHoveredClusterLabel] = useState<number | null>(
    null,
  );

  const prevFilterSerialized = useRef(filterSerialized);

  useEffect(() => {
    // Whenever filters change, we will recalculate the UMAP
    // If the filter hasn't been initialized yet, these selections were added from URL params on
    // load and we shouldn't touch them
    if (
      prevFilterSerialized.current !== undefined &&
      filterSerialized !== prevFilterSerialized.current
    ) {
      prevFilterSerialized.current = filterSerialized;
      setSelections([]);
    }
  }, [filterSerialized, umapResult, setSelections]);

  const selectedPoints: Set<string> = useMemo(() => {
    if (metadata === undefined) {
      return new Set();
    }

    return selections.reduce(
      (selectedPoints: Set<string>, selection: SelectionKind) => {
        if (selection.kind === "scatterplot") {
          return setUnion(selectedPoints, selection.points);
          // TODO(you): Fix this no-unnecessary-condition rule violation
          // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
        } else if (selection.kind === "metadata") {
          const { key, value } = selection;

          const matchingKeys = metadata
            .filter((metadataRow) => String(metadataRow[key]) === value)
            .map((metadataRow) => metadataToKey(metadataRow));

          return setUnion(selectedPoints, new Set(matchingKeys));
        } else {
          return selectedPoints;
        }
      },
      new Set(),
    );
  }, [selections, metadata]);

  const handleDownloadSelectedPoints = useCallback(() => {
    if (metadata === undefined) {
      return new Promise<void>((resolve, reject) =>
        reject(new Error("Started download prior to data ready")),
      );
    }

    return handleDownloadData(dataset, data, metadata, {
      points: selectedPoints,
    });
  }, [dataset, data, metadata, selectedPoints]);

  const [hideUnselectedPoints, setHideUnselectedPoints] =
    useState<boolean>(false);

  useEffect(() => {
    if (hideUnselectedPoints && selectedPoints.size === 0) {
      setHideUnselectedPoints(false);
    }
  }, [hideUnselectedPoints, selectedPoints]);

  const [hoveredPointKey, setHoveredPointKey] = useState<string | null>(null);
  const [popperElement, setPopperElement] = useState<HTMLDivElement | null>(
    null,
  );
  const [popperArrowElement, setPopperArrowElement] =
    useState<HTMLDivElement | null>(null);
  const [hoverReference, setHoverReference] = useState<HTMLDivElement | null>(
    null,
  );

  const [sidebarHoverReference, setSidebarHoverReference] =
    useState<HTMLElement | null>(null);

  const handleSetPlotHoverRef = (elem: HTMLDivElement) => {
    setSidebarHoverReference(null);
    setHoverReference(elem);
  };

  const handleSetSidebarHoverRef = (elem: HTMLElement | null) => {
    setHoverReference(null);
    setSidebarHoverReference(elem);
  };

  const [plotWrapper, setPlotWrapper] = useState<HTMLDivElement | null>(null);
  const popper = usePopper(
    hoverReference || sidebarHoverReference,
    popperElement,
    {
      placement: "right",
      modifiers: [
        {
          name: "flip",
          options: {
            fallbackPlacements: ["left"],
          },
        },
        {
          name: "preventOverflow",
          options: {
            boundary: plotWrapper ?? undefined,
            padding: 4,
          },
        },
        {
          name: "arrow",
          options: {
            element: popperArrowElement,
            padding: 12,
          },
        },
      ],
    },
  );

  const handleSelectNewPoints = useCallback(
    (newPoints: Set<string>) => {
      if (newPoints.size === 0) {
        return;
      }

      const newScatterplotSelection: ScatterplotSelection = {
        kind: "scatterplot",
        points: newPoints,
      };
      setSelections([...selections, newScatterplotSelection]);
    },
    [selections, setSelections],
  );

  const handleSelectSinglePoint = useCallback(
    (newPointKey: string) => {
      const selectedMetadataRow = metadata?.find(
        (metadataRow) => metadataToKey(metadataRow) === newPointKey,
      );

      if (coloringMetadataColumn && selectedMetadataRow) {
        // If there is a coloring, selecting a single point selects all points
        // with the same metadata value

        const newMetatadataSelection: MetadataSelection = {
          kind: "metadata",
          key: coloringMetadataColumn,
          value: String(selectedMetadataRow[coloringMetadataColumn]),
        };
        setSelections([...selections, newMetatadataSelection]);
      } else {
        const newScatterplotSelection: ScatterplotSelection = {
          kind: "scatterplot",
          points: new Set([newPointKey]),
        };
        setSelections([...selections, newScatterplotSelection]);
      }
    },
    [selections, setSelections, coloringMetadataColumn, metadata],
  );

  const onToggleLegendKey = useCallback(
    (key: string) => {
      if (!coloringMetadataColumn) {
        return;
      }

      const selectionIndex = selections.findIndex(
        (selection) =>
          selection.kind === "metadata" &&
          selection.key === coloringMetadataColumn &&
          selection.value === key,
      );
      if (selectionIndex === -1) {
        const newKeySelection: MetadataSelection = {
          kind: "metadata",
          key: coloringMetadataColumn,
          value: key,
        };
        setSelections([...selections, newKeySelection]);
      } else {
        const newSelections = selections.slice();
        newSelections.splice(selectionIndex, 1);
        setSelections(newSelections);
      }
    },
    [selections, setSelections, coloringMetadataColumn],
  );

  const removeSelection = useCallback(
    (selectionIndex: number) => {
      const newSelections = selections.slice();
      newSelections.splice(selectionIndex, 1);
      setSelections(newSelections);
    },
    [selections, setSelections],
  );

  const addSelection = useCallback(
    (columnId: string, value: MetadataColumnValue) => {
      if (
        selections.some(
          (existing) =>
            existing.kind === "metadata" &&
            existing.key === columnId &&
            existing.value === String(value),
        )
      ) {
        return;
      }

      setSelections([
        ...selections,
        {
          kind: "metadata",
          key: columnId,
          value: String(value),
        },
      ]);
    },
    [selections, setSelections],
  );

  const dedupedMetadata = useMemo(() => {
    return getDedupedMetadata(metadata);
  }, [metadata]);
  const [groupedByMetadata, groups] = useMemo(() => {
    return getGroupedMetadata(dedupedMetadata, coloringMetadataColumn, data);
  }, [dedupedMetadata, coloringMetadataColumn, data]);
  const scatterPoints = useMemo((): ScatterPlotPoint[] | null => {
    if (!data?.successful) {
      return null;
    }

    // Sort of an edge case, but sometimes we delete metadata post-hoc for QC purposes,
    // after having generated features. In those cases, we want to make sure
    // we reconcile the data with the metadata and don't render points for which
    // we have no metadata on.
    return data.value
      .filter(({ plate, well, timepoint }) => {
        // TODO(you): Fix this no-unnecessary-condition rule violation
        return (
          // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
          dedupedMetadata[plate]?.[well]?.[timepoint ?? DEFAULT_TIMEPOINT] !==
          undefined
        );
      })
      .map((d) => ({
        key: metadataToKey(d),
        x: d["0"],
        y: d["1"],
        clusterLabel: d["cluster_label"],
      }));
  }, [data, dedupedMetadata]);

  const interestingMetadataColumns = useMemo(() => {
    return metadata !== undefined && metadata.length > 0
      ? inferInterestingColumns(metadata)
      : [];
  }, [metadata]);

  const [legend, legendSelections, visiblePointsByKey, metadataColors] =
    useMemo((): [
      Legend | undefined,
      MetadataSelection[],
      PointsByKey,
      MetadataColors,
    ] => {
      if (
        !data?.successful ||
        plates.length === 0 ||
        features.length === 0 ||
        metadata === undefined ||
        metadata.length === 0
      ) {
        return [undefined, [], {}, {}];
      }

      const visiblePointsByKey: PointsByKey = {};

      for (const d of data.value) {
        const maybeMetadata =
          dedupedMetadata[d.plate]?.[d.well]?.[
            d.timepoint ?? DEFAULT_TIMEPOINT
          ];
        if (maybeMetadata === undefined) {
          continue;
        }
        visiblePointsByKey[metadataToKey(d)] = {
          x: d["0"],
          y: d["1"],
          clusterLabel: d.cluster_label,
          metadata: maybeMetadata,
        };
      }

      if (coloringMetadataColumn) {
        const metadataValues =
          Object.keys(groupedByMetadata).sort(defaultComparator);

        const colorScheme = colorSchemeByWellMetadata(metadataValues);
        const metadataColors = Object.fromEntries(
          colorValuesByScheme(metadataValues, colorScheme),
        );

        const legendSelections: MetadataSelection[] = selections.filter(
          (selection) =>
            selection.kind === "metadata" &&
            selection.key === coloringMetadataColumn,
        ) as MetadataSelection[];

        const legend: DiscreteLegend = {
          type: "discrete",
          entries: groups.map(([metaKey]) => ({
            value: metaKey,
            key: metaKey,
            label: metaKey,
            color: metadataColors[metaKey],
            disabled:
              legendSelections.length > 0 &&
              !legendSelections.some(
                (selection) => selection.value === metaKey,
              ),
          })),
        };

        return [legend, legendSelections, visiblePointsByKey, metadataColors];
      } else {
        return [undefined, [], visiblePointsByKey, {}];
      }
    }, [
      data,
      plates.length,
      features.length,
      metadata,
      coloringMetadataColumn,
      dedupedMetadata,
      groupedByMetadata,
      groups,
      selections,
    ]);

  const scatterPointStates = useMemo((): Record<
    string,
    ScatterPlotPointState
  > => {
    if (!data?.successful) {
      return {};
    }

    if (coloringMetadataColumn) {
      const pointStates: Record<string, ScatterPlotPointState> = {};
      for (const [metaKey, dataPoints] of groups) {
        if (dataPoints === undefined) {
          continue;
        }
        for (const d of dataPoints) {
          pointStates[metadataToKey(d)] = {
            color: metadataColors[metaKey],
            highlighted: metaKey === hoveredLegendKey,
            hidden: hideUnselectedPoints
              ? !selectedPoints.has(metadataToKey(d))
              : false,
          };
        }
      }

      return pointStates;
    } else {
      const pointStates: Record<string, ScatterPlotPointState> = {};
      for (const d of data.value) {
        pointStates[metadataToKey(d)] = {
          color: "#666",
          highlighted: false,
          hidden: hideUnselectedPoints
            ? !selectedPoints.has(metadataToKey(d))
            : false,
        };
      }

      return pointStates;
    }
  }, [
    data,
    metadataColors,
    coloringMetadataColumn,
    groups,
    hideUnselectedPoints,
    selectedPoints,
    hoveredLegendKey,
  ]);

  const showClusters = useMemo(() => {
    return legendSelections.length === 0;
  }, [legendSelections.length]);

  const clusterColors: Map<number, string> = useMemo(() => {
    if (scatterPoints === null) {
      return new Map();
    }

    const clusterLabels = new Set<number>();
    for (const point of scatterPoints) {
      if (point.clusterLabel !== undefined) {
        clusterLabels.add(point.clusterLabel);
      }
    }

    const labelToColor: Map<number, string> = new Map();
    for (const label of clusterLabels) {
      labelToColor.set(
        label,
        colorForValue(
          clusterLabels.size === 1 ? 0.5 : label / (clusterLabels.size - 1),
          FULL_GRADIENT,
        ),
      );
    }

    return labelToColor;
  }, [scatterPoints]);

  // Don't slow down the UI while we're moving the mouse across points
  const setHoveredPointKeyDebounced = useDebouncedCallback(
    setHoveredPointKey,
    100,
  );
  const isUmapUserConfigSamplingEnabled = useFeatureFlag(
    "umap-user-config-sampling",
  );
  return (
    <div className="tw-w-full tw-h-full tw-flex" ref={ref}>
      <div className="tw-basis-96 tw-shrink-0 tw-border-r tw-flex tw-flex-col tw-overflow-hidden tw-bg-slate-50">
        <div className="tw-bg-slate-50 tw-p-sm tw-pl-md tw-relative">
          <div className="tw-inline-block tw-p-sm  tw-w-full">
            <LeftNavSectionTitle className="tw-mb-sm">
              View Options
            </LeftNavSectionTitle>
            <div className="tw-flex tw-flex-col tw-w-full tw-text-gray-500">
              Color By
              <MetadataSelector
                metadata={metadata ?? []}
                plates={plates}
                interestingMetadataColumns={interestingMetadataColumns}
                selectedMetadata={coloringMetadataColumn}
                onSelectMetadata={setColoringMetadataColumn}
              />
            </div>
          </div>
        </div>
        <NormalizationOptions
          metadataColumns={interestingMetadataColumns}
          normalizationColumns={normalizationColumns}
          onChangeNormalizationColumns={onChangeNormalizationColumns}
        />
        <SidebarSelectionView
          coloringMetadataColumn={coloringMetadataColumn}
          selections={selections}
          selectedPoints={selectedPoints}
          hideUnselectedPoints={hideUnselectedPoints}
          metadataColors={metadataColors}
          metadata={metadata ?? []}
          onRemoveSelection={removeSelection}
          onAddSelection={addSelection}
          setHideUnselectedPoints={setHideUnselectedPoints}
          onDownloadSelectedPoints={handleDownloadSelectedPoints}
        />

        <LabMate
          dataset={dataset}
          filterSerialized={filterSerialized}
          metadata={metadata ?? []}
          umapData={data?.successful ? data.value : undefined}
          scatterPoints={scatterPoints}
          clusterColors={clusterColors}
          selectedPoints={selectedPoints}
          plates={plates}
          features={features}
          coloringMetadataColumn={coloringMetadataColumn}
          metadataSelections={legendSelections}
          metadataColors={metadataColors}
          hoveredKey={hoveredLegendKey}
          hoveredCluster={hoveredClusterLabel}
          onSetHoverRef={handleSetSidebarHoverRef}
          onSelectPoints={handleSelectNewPoints}
          onHoverKey={setHoveredLegendKey}
          onToggleKey={onToggleLegendKey}
          onHoverCluster={setHoveredClusterLabel}
        >
          {(children: ReactElement) =>
            sidebarHoverReference && (
              <HoverPopup
                popperStyles={{
                  ...popper.styles,
                  popper: {
                    ...popper.styles.popper,
                    left: "-5px",
                    zIndex: 1, // Don't let the underlying points poke through
                  },
                }}
                popperAttributes={popper.attributes}
                popperState={popper.state}
                setPopperElement={setPopperElement}
                setPopperArrowElement={setPopperArrowElement}
              >
                {children}
              </HoverPopup>
            )
          }
        </LabMate>
      </div>
      <div ref={setPlotWrapper} className="tw-flex-1 tw-min-w-0 tw-relative">
        {plates.length > 0 && features.length > 0 && !scatterPoints && (
          <div className="tw-flex tw-items-center tw-justify-center tw-h-full">
            {umapResult && !umapResult.successful ? (
              <div className="tw-text-xl">
                Oops! Error loading feature data. Please try again later.
              </div>
            ) : (
              <PlotLoadingMessage
                className="tw-flex tw-flex-col tw-items-center tw-pb-4 tw-px-4"
                estimatedFinishTime={estimatedFinishTime}
                onOpenFilterSelector={onOpenFilterSelector}
              />
            )}
          </div>
        )}
        {scatterPoints && (
          <ScatterPlot
            data={scatterPoints}
            dataStates={scatterPointStates}
            pointSize={3.5}
            selectedPoints={selectedPoints}
            onSelectPoints={handleSelectNewPoints}
            onSelectPoint={handleSelectSinglePoint}
            clusterColors={clusterColors}
            showClusters={showClusters}
            hoverCluster={hoveredClusterLabel}
            onHoverPoint={setHoveredPointKeyDebounced}
            legend={
              coloringMetadataColumn && legend ? (
                <PlotLegend
                  legend={legend}
                  hoveredKey={hoveredLegendKey}
                  onHover={setHoveredLegendKey}
                  onClick={onToggleLegendKey}
                />
              ) : undefined
            }
          >
            {({
              pointToPx,
            }: {
              pointToPx: (x: number, y: number) => [number, number];
            }) => {
              if (hoveredPointKey) {
                const point = visiblePointsByKey[hoveredPointKey];
                if (!point) {
                  return;
                }

                const [pxX, pxY] = pointToPx(point.x, point.y);

                return (
                  <>
                    <div
                      className="tw-absolute"
                      style={{ left: pxX, top: pxY }}
                      ref={handleSetPlotHoverRef}
                    />
                    <PointHoverPopup
                      key={hoveredPointKey}
                      point={point}
                      coloringMetadataColumn={coloringMetadataColumn}
                      dataset={dataset}
                      interestingMetadataColumns={interestingMetadataColumns}
                      groupByColumn={groupByColumn}
                      metadataColors={metadataColors}
                      popperStyles={popper.styles}
                      popperAttributes={popper.attributes}
                      popperState={popper.state}
                      setPopperElement={setPopperElement}
                      setPopperArrowElement={setPopperArrowElement}
                      showCluster={showClusters}
                    />
                  </>
                );
              }
            }}
          </ScatterPlot>
        )}
        {isUmapUserConfigSamplingEnabled && metadata !== undefined ? (
          <button
            className={cx(
              "tw-absolute tw-top-4 tw-left-4",
              "tw-px-2 tw-py-1 tw-rounded-md tw-border",
              "tw-cursor-help",
              "tw-text-xs tw-text-gray-600 tw-bg-white",
            )}
            onClick={onOpenFilterSelector}
          >
            {/* The backends may have filtered out rows that were entirely NaNs so
                if the results are ready, show the length based on that, otherwise,
                fallback to the metadata which we'll have before making the fetch. */}
            Showing {scatterPoints ? scatterPoints.length : metadata.length}{" "}
            wells
          </button>
        ) : null}
      </div>
    </div>
  );
});

function FilterSamplingToggle({
  samplingEnabled,
  onSetSamplingEnabled,
}: {
  samplingEnabled: boolean;
  onSetSamplingEnabled: (enabled: boolean) => void;
}) {
  return (
    <>
      <label className="tw-text-sm tw-cursor-pointer tw-flex tw-items-center tw-justify-between tw-px-2 tw-pt-3">
        {samplingEnabled ? "Showing sample of wells" : "Showing all wells"}
        <RadixSwitch.Root
          className={cx(
            "tw-h-5 tw-w-9 tw-rounded-full",
            "tw-inline-flex tw-items-center",
            samplingEnabled ? "tw-bg-purple-500" : "tw-bg-gray-control-off",
          )}
          checked={!samplingEnabled}
          onClick={() => onSetSamplingEnabled(!samplingEnabled)}
        >
          <RadixSwitch.Thumb
            className={cx(
              "tw-inline-block",
              "tw-w-3 tw-h-3 tw-rounded-full",
              "tw-bg-white",
              "tw-transition-transform",
              samplingEnabled ? "tw-translate-x-5" : "tw-translate-x-1",
            )}
          />
        </RadixSwitch.Root>
      </label>
      <hr className="tw-mt-md" />
    </>
  );
}
