import cx from "classnames";
import KDBush from "kdbush";
import lodashMax from "lodash.max";
import lodashMin from "lodash.min";
import React, {
  ReactNode,
  useCallback,
  useEffect,
  useMemo,
  useRef,
  useState,
} from "react";
import { Disc, Minus, Plus } from "react-feather";
import { DeprecatedButton } from "../../Common/DeprecatedButton";
import { CLICK_TOLERANCE, LEGEND_MARGIN } from "./constants";
import { useRenderLoop } from "./hooks";
import { BackgroundRenderer } from "./renderBackground";
import { ClusterRenderer } from "./renderClusters";
import { OverlayRenderer } from "./renderOverlay";
import {
  Bounds,
  DragPoint,
  RenderOptions,
  Renderer,
  RendererOptions,
  ScatterPlotPoint,
  ScatterPlotPointState,
} from "./types";
import {
  createGetClosestPointFromDifferentCluster,
  createPointsIndex,
  enclosedPoints,
  getInitialBounds,
  getVisiblePointsForBounds,
  moveBounds,
  toPixel,
  zoomBounds,
} from "./util";

const LEFT_MOUSE_BUTTON = 0;
const RIGHT_MOUSE_BUTTON = 2;

interface Props {
  data: ScatterPlotPoint[];
  dataStates: Record<string, ScatterPlotPointState>;
  pointSize: number;
  selectedPoints: Set<string>;
  clusterColors: Map<number, string>;
  showClusters: boolean;
  hoverCluster: number | null;
  onHoverPoint?: (key: string | null) => void;
  onSelectPoints?: (keys: Set<string>) => void;
  onSelectPoint?: (key: string) => void;
  legend?: ReactNode;
  children: ({
    pointToPx,
  }: {
    pointToPx: (x: number, y: number) => [number, number];
  }) => ReactNode;
}

export function ScatterPlot({
  children,
  data,
  dataStates,
  legend,
  clusterColors,
  showClusters,
  onHoverPoint,
  onSelectPoint,
  onSelectPoints,
  pointSize,
  selectedPoints,
  hoverCluster,
}: Props) {
  const refContainer = useRef<HTMLDivElement | null>(null);
  const refClusterBackground = useRef<HTMLCanvasElement>(null);
  const refBackground = useRef<HTMLCanvasElement>(null);
  const refOverlay = useRef<HTMLCanvasElement>(null);
  const refResizeObserver = useRef<ResizeObserver | null>(null);
  const refSelectedPoints = useRef<Set<string>>(selectedPoints);
  const refClusterHulls = useRef<Map<number, ScatterPlotPoint[]> | null>(null);
  const [isMouseOverPlot, setIsMouseOverPlot] = useState(false);
  const [currentDraggingButton, setCurrentDraggingButton] = useState<
    number | null
  >(null);
  const [defaultBounds, setDefaultBounds] = useState<Bounds | null>(null);

  useEffect(() => {
    refSelectedPoints.current = selectedPoints;
  }, [selectedPoints]);

  // Fast spatial index for pixel coordinates of visible points, used for finding the hovered point
  // Resets to null whenever bounds/visible points change, and is instantiated on first hover
  const visiblePointsIndex = useRef<KDBush | null>(null);
  const resetVisiblePointsIndex = useCallback(() => {
    visiblePointsIndex.current = null;
  }, []);

  // Render the different parts of the scatter plot as different layers
  // which we stack on top of each other. This lets us quickly render
  // an "overlay" (with highlights, hovers, selection box) without having
  // to re-render every single other point.
  const renderers = useMemo<Renderer<RenderOptions, RendererOptions>[]>(
    () => [
      ...(showClusters ? [ClusterRenderer(refClusterBackground)] : []),
      BackgroundRenderer(refBackground),
      OverlayRenderer(refOverlay, refSelectedPoints, refClusterHulls),
    ],
    [showClusters],
  );

  const [renderOptions, updateRenderOptions] = useRenderLoop<RenderOptions>(
    {
      hoverPoint: null,
      hoverCluster: null,
      clusterColors,
      points: [],
      pointStates: {},
      visiblePoints: [],
      pointsIndex: null,
      getClosestPointFromDifferentCluster: () => null,
      pointSize,
      pointToPx(x, y) {
        return [x, y];
      },
      dragStart: null,
      dragEnd: null,
      canvasWidth: 0,
      canvasHeight: 0,
      viewBounds: {
        minX: Infinity,
        maxX: -Infinity,
        minY: Infinity,
        maxY: -Infinity,
      },
      initializedData: false,
      initializedCanvas: false,
    },
    renderers,
    useCallback(
      (options, update) => {
        const afterUpdate: RenderOptions = { ...options, ...update };
        if (
          !(options.initializedData && options.initializedCanvas) &&
          afterUpdate.initializedData &&
          afterUpdate.initializedCanvas
        ) {
          const initialBounds = getInitialBounds(afterUpdate);
          resetVisiblePointsIndex();
          return {
            ...initialBounds,
            visiblePoints: getVisiblePointsForBounds(
              refSelectedPoints.current,
              {
                ...afterUpdate,
                ...initialBounds,
              },
            ),
          };
        } else {
          return {};
        }
      },
      [resetVisiblePointsIndex],
    ),
  );

  useEffect(() => {
    updateRenderOptions({ hoverCluster });
  }, [hoverCluster, updateRenderOptions]);

  const pointToPx = useCallback(
    (x: number, y: number): [number, number] => {
      const {
        viewBounds: { minX, maxX, minY, maxY },
        canvasWidth,
        canvasHeight,
      } = renderOptions.current;

      return [
        (canvasWidth * (x - minX)) / (maxX - minX),
        canvasHeight * (1 - (y - minY) / (maxY - minY)),
      ];
    },
    [renderOptions],
  );

  const updateHoverPoint = useCallback(
    (hoverPoint: ScatterPlotPoint | null) => {
      if (hoverPoint !== renderOptions.current.hoverPoint) {
        updateRenderOptions({ hoverPoint });
        onHoverPoint?.(hoverPoint?.key ?? null);
      }
    },
    [onHoverPoint, renderOptions, updateRenderOptions],
  );

  useEffect(() => {
    const xValues = data.map(({ x }) => x);
    const yValues = data.map(({ y }) => y);
    const minX = lodashMin(xValues) ?? Infinity;
    const maxX = lodashMax(xValues) ?? -Infinity;
    const minY = lodashMin(yValues) ?? Infinity;
    const maxY = lodashMax(yValues) ?? -Infinity;

    const bounds: Bounds = { minX, maxX, minY, maxY };
    let { points, pointsIndex, getClosestPointFromDifferentCluster } =
      renderOptions.current;

    // These computed variables never change for a given dataset
    if (!renderOptions.current.initializedData) {
      setDefaultBounds(bounds);

      points = data;
      // Initialize a KD-tree as a fast static spatial index for the points
      pointsIndex = createPointsIndex(points, renderOptions.current);
      getClosestPointFromDifferentCluster =
        createGetClosestPointFromDifferentCluster(points);
    }

    const currentCanvasBounds = {
      canvasWidth: renderOptions.current.canvasWidth,
      canvasHeight: renderOptions.current.canvasHeight,
      viewBounds: renderOptions.current.initializedData
        ? renderOptions.current.viewBounds
        : bounds,
    };

    // Get visible points based on the canvas bounds
    const pointStates = dataStates;
    const visiblePoints = getVisiblePointsForBounds(refSelectedPoints.current, {
      points,
      pointStates,
      pointsIndex,
      pointSize,
      ...currentCanvasBounds,
    });
    resetVisiblePointsIndex();

    updateRenderOptions({
      points,
      pointStates,
      visiblePoints,
      getClosestPointFromDifferentCluster,
      pointsIndex,
      pointToPx,
      ...(!renderOptions.current.initializedData
        ? {
            initializedData: true,
            viewBounds: bounds,
          }
        : {}),
    });
    updateHoverPoint(null);
  }, [
    data,
    dataStates,
    pointSize,
    pointToPx,
    renderOptions,
    resetVisiblePointsIndex,
    updateHoverPoint,
    updateRenderOptions,
  ]);

  const handleResetBounds = useCallback(() => {
    if (defaultBounds === null) {
      return;
    }

    const newBounds = getInitialBounds({
      viewBounds: defaultBounds,
      canvasWidth: renderOptions.current.canvasWidth,
      canvasHeight: renderOptions.current.canvasHeight,
    });
    resetVisiblePointsIndex();

    updateRenderOptions({
      ...newBounds,
      visiblePoints: getVisiblePointsForBounds(refSelectedPoints.current, {
        ...renderOptions.current,
        ...newBounds,
      }),
    });
  }, [
    updateRenderOptions,
    defaultBounds,
    renderOptions,
    resetVisiblePointsIndex,
  ]);

  const onContainer = useCallback(
    (el: HTMLDivElement) => {
      refContainer.current = el;

      // TODO(you): Fix this no-unnecessary-condition rule violation
      // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
      if (!el) {
        refResizeObserver.current?.disconnect();
        refResizeObserver.current = null;
        return;
      }

      refResizeObserver.current = new ResizeObserver(() => {
        updateRenderOptions({
          canvasWidth: el.offsetWidth,
          canvasHeight: el.offsetHeight,
          initializedCanvas: true,
        });
      });

      refResizeObserver.current.observe(el);
    },
    [updateRenderOptions],
  );

  const zoomPlot = useCallback(
    (delta: number, origin?: DragPoint) => {
      const options = renderOptions.current;
      // TODO(you): Fix this no-unnecessary-condition rule violation
      // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
      if (!options) {
        return;
      }

      // If no origin is provided, use the center of the plot
      const zoomPoint = origin ?? {
        x: options.canvasWidth / 2,
        y: options.canvasHeight / 2,
      };

      const newBounds = zoomBounds(options, delta, zoomPoint);
      resetVisiblePointsIndex();
      updateRenderOptions({
        ...newBounds,
        visiblePoints: getVisiblePointsForBounds(refSelectedPoints.current, {
          ...renderOptions.current,
          ...newBounds,
        }),
      });

      // If there's a hovered point, clear it while zooming
      if (renderOptions.current.hoverPoint !== null) {
        updateHoverPoint(null);
      }
    },
    [
      renderOptions,
      updateRenderOptions,
      updateHoverPoint,
      resetVisiblePointsIndex,
    ],
  );

  const panPlot = useCallback(
    (delta: { deltaX: number; deltaY: number }) => {
      const options = renderOptions.current;
      // TODO(you): Fix this no-unnecessary-condition rule violation
      // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
      if (!options) {
        return;
      }

      const newBounds = moveBounds(options, delta);
      resetVisiblePointsIndex();
      updateRenderOptions({
        ...newBounds,
        visiblePoints: getVisiblePointsForBounds(refSelectedPoints.current, {
          ...renderOptions.current,
          ...newBounds,
        }),
      });

      // If there's a hovered point, clear it while panning
      if (renderOptions.current.hoverPoint !== null) {
        updateHoverPoint(null);
      }
    },
    [
      renderOptions,
      updateRenderOptions,
      updateHoverPoint,
      resetVisiblePointsIndex,
    ],
  );

  const onCancelDrag = useCallback<
    React.MouseEventHandler<HTMLDivElement>
  >(() => {
    updateRenderOptions({
      dragStart: null,
      dragEnd: null,
    });
    setCurrentDraggingButton(null);
  }, [updateRenderOptions]);

  useEffect(() => {
    const onWheel = (e: WheelEvent) => {
      if (isMouseOverPlot) {
        e.preventDefault();
      }
    };

    document.addEventListener("wheel", onWheel, { passive: false });

    return () => {
      document.removeEventListener("wheel", onWheel);
    };
  }, [isMouseOverPlot]);

  const handleWheel = useCallback<React.WheelEventHandler<HTMLDivElement>>(
    (e) => {
      const options = renderOptions.current;
      // TODO(you): Fix this no-unnecessary-condition rule violation
      // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
      if (!options) {
        return;
      }

      if (e.ctrlKey) {
        zoomPlot(e.deltaY, {
          x: e.nativeEvent.offsetX,
          y: e.nativeEvent.offsetY,
        });
      } else {
        panPlot(e);
      }
    },
    [renderOptions, panPlot, zoomPlot],
  );

  const handleContextMenu = useCallback<React.MouseEventHandler<HTMLElement>>(
    (e) => {
      // Don't open the context menu when right-clicking on the plot
      // Right click and drag is used for panning the plot instead
      if (isMouseOverPlot) {
        e.preventDefault();
      }
    },
    [isMouseOverPlot],
  );

  const handleMouseDown = useCallback<React.MouseEventHandler<HTMLDivElement>>(
    (e) => {
      // If a different button was already held down, cancel it to prevent getting
      // into a weird state
      if (currentDraggingButton !== null) {
        onCancelDrag(e);
      }

      // If currently showing a hover popup, close it to prevent getting into a weird state
      if (
        e.button === RIGHT_MOUSE_BUTTON &&
        renderOptions.current.hoverPoint !== null
      ) {
        updateHoverPoint(null);
      }

      if (e.button !== RIGHT_MOUSE_BUTTON) {
        // Right click and drag pans the plot, but otherwise we render a selection box
        updateRenderOptions({
          dragStart: { x: e.nativeEvent.offsetX, y: e.nativeEvent.offsetY },
          dragEnd: null,
        });
      }

      setCurrentDraggingButton(e.button);
    },
    [
      currentDraggingButton,
      onCancelDrag,
      renderOptions,
      updateHoverPoint,
      updateRenderOptions,
    ],
  );

  const handleMouseMove = useCallback<React.MouseEventHandler<HTMLElement>>(
    (e) => {
      let closest: ScatterPlotPoint | null = null;

      // It's possible a user may load the page with their mouse over the plot to start, in which
      // case it won't be captured by onMouseEnter – set it when they move their mouse over the plot
      // instead
      if (!isMouseOverPlot) {
        setIsMouseOverPlot(true);
      }

      const options = renderOptions.current;
      const { pointSize, visiblePoints, dragStart } = options;

      // Dragging with right mouse button should pan the plot
      if (currentDraggingButton === RIGHT_MOUSE_BUTTON) {
        panPlot({
          deltaX: -e.movementX,
          deltaY: -e.movementY,
        });
        return;
      }

      if (dragStart) {
        updateRenderOptions({
          dragEnd: { x: e.nativeEvent.offsetX, y: e.nativeEvent.offsetY },
        });
      } else {
        const [pxX, pxY] = [e.nativeEvent.offsetX, e.nativeEvent.offsetY];

        // Instantiate the points index if needed (it resets whenever visible points change)
        visiblePointsIndex.current ??= createPointsIndex(
          visiblePoints,
          renderOptions.current,
          { pixel: true },
        );

        const pointsWithinBounds =
          // TODO(you): Fix this no-unnecessary-condition rule violation
          // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
          visiblePointsIndex.current
            .within(pxX, pxY, pointSize * 1.2)
            .map((i) => visiblePoints[i]) ?? [];

        let minDist2 = Infinity;
        for (const pt of pointsWithinBounds) {
          const { x, y } = toPixel(pt, renderOptions.current);
          const dist2 = Math.pow(x - pxX, 2) + Math.pow(y - pxY, 2);
          if (dist2 < minDist2) {
            minDist2 = dist2;
            closest = pt;
          }
        }

        if (closest !== renderOptions.current.hoverPoint) {
          updateHoverPoint(closest);
        }
      }
    },
    [
      currentDraggingButton,
      isMouseOverPlot,
      renderOptions,
      panPlot,
      updateHoverPoint,
      updateRenderOptions,
    ],
  );

  useEffect(() => {
    const onKeyDown = (e: KeyboardEvent) => {
      if ((e.target as Element).tagName !== "BODY") {
        return;
      }

      switch (e.key) {
        case "ArrowUp":
          panPlot({ deltaX: 0, deltaY: -50 });
          break;
        case "ArrowDown":
          panPlot({ deltaX: 0, deltaY: 50 });
          break;
        case "ArrowLeft":
          panPlot({ deltaX: -50, deltaY: 0 });
          break;
        case "ArrowRight":
          panPlot({ deltaX: 50, deltaY: 0 });
          break;
        case "=":
          zoomPlot(-10);
          break;
        case "-":
          zoomPlot(10);
          break;
      }
    };

    document.addEventListener("keydown", onKeyDown);

    return () => {
      document.removeEventListener("keydown", onKeyDown);
    };
  }, [isMouseOverPlot, panPlot, zoomPlot]);

  const handleMouseUp = useCallback<React.MouseEventHandler<HTMLDivElement>>(
    (e) => {
      const { dragStart, dragEnd, hoverPoint } = renderOptions.current;

      if (currentDraggingButton === LEFT_MOUSE_BUTTON && dragStart && dragEnd) {
        if (
          hoverPoint &&
          Math.sqrt(
            Math.pow(dragStart.x - dragEnd.x, 2) +
              Math.pow(dragStart.y - dragEnd.y, 2),
          ) < CLICK_TOLERANCE
        ) {
          onSelectPoint?.(hoverPoint.key);
        } else {
          onSelectPoints?.(enclosedPoints(renderOptions.current));
        }
      } else if (hoverPoint) {
        onSelectPoint?.(hoverPoint.key);
      }
      onCancelDrag(e);
    },
    [
      currentDraggingButton,
      onCancelDrag,
      onSelectPoint,
      onSelectPoints,
      renderOptions,
    ],
  );

  const handleMouseLeave = useCallback<React.MouseEventHandler<HTMLDivElement>>(
    (e) => {
      setIsMouseOverPlot(false);
      onCancelDrag(e);
    },
    [onCancelDrag, setIsMouseOverPlot],
  );

  const handleMouseLeaveContainer = useCallback<
    React.MouseEventHandler<HTMLDivElement>
  >(() => {
    if (renderOptions.current.hoverPoint !== null) {
      updateHoverPoint(null);
    }
  }, [renderOptions, updateHoverPoint]);

  const anyHighlights = useMemo(
    () => data.some((pt) => dataStates[pt.key].highlighted),
    [data, dataStates],
  );

  return (
    <>
      <div className="tw-absolute tw-right-4 tw-h-full tw-flex tw-flex-col tw-justify-center tw-items-center">
        <DeprecatedButton
          onClick={() => zoomPlot(-10)}
          className="tw-z-10 tw-bg-white"
        >
          <Plus size={16} />
        </DeprecatedButton>
        <DeprecatedButton
          onClick={handleResetBounds}
          className="tw-z-10 tw-bg-white tw-text-gray-300"
        >
          <Disc size={16} />
        </DeprecatedButton>
        <DeprecatedButton
          onClick={() => zoomPlot(10)}
          className="tw-z-10 tw-bg-white"
        >
          <Minus size={16} />
        </DeprecatedButton>
      </div>
      <div
        className={cx(
          "tw-w-full tw-h-full tw-min-h-[32px] tw-relative tw-cursor-[crosshair]",
        )}
        onMouseLeave={handleMouseLeaveContainer}
        ref={onContainer}
      >
        <canvas
          className={cx(
            "tw-absolute tw-inset-0 tw-opacity-20",
            !showClusters && "tw-hidden",
          )}
          ref={refClusterBackground}
        />
        <canvas
          className={cx(
            "tw-absolute tw-inset-0",
            selectedPoints.size > 0
              ? "tw-opacity-20"
              : anyHighlights && "tw-opacity-30",
          )}
          ref={refBackground}
        />
        <canvas className="tw-absolute tw-inset-0" ref={refOverlay} />
        {legend && (
          <div
            className={cx(
              "tw-absolute tw-z-10",
              currentDraggingButton !== null && "tw-pointer-events-none",
            )}
            style={{ top: LEGEND_MARGIN, right: LEGEND_MARGIN }}
          >
            {legend}
          </div>
        )}
        {/* A div that sits on top of all of the canvases, to handle mouse events */}
        <div
          className="tw-absolute tw-inset-0"
          onWheel={handleWheel}
          onMouseMove={handleMouseMove}
          onMouseDown={handleMouseDown}
          onMouseUp={handleMouseUp}
          onMouseLeave={handleMouseLeave}
          onContextMenu={handleContextMenu}
        ></div>
        {/* TODO(you): Fix this no-unnecessary-condition rule violation */}
        {/* eslint-disable-next-line @typescript-eslint/no-unnecessary-condition */}
        {children?.({
          pointToPx: renderOptions.current.pointToPx,
        })}
      </div>
    </>
  );
}
