import cx from "classnames";
import * as d3 from "d3";
import {
  ReactNode,
  useCallback,
  useEffect,
  useMemo,
  useRef,
  useState,
} from "react";
import { Lock, Minus, Plus } from "react-feather";
import { IoSparkles } from "react-icons/io5";
import { Button } from "src/Common/Button";
import {
  CHART_AXIS_HEIGHT_PX,
  CLUSTER_CARD_HEIGHT_PX,
} from "src/PhenoFinder/constants";
import { findLeafNodes } from "src/PhenoFinder/treeUtils";
import {
  BranchDataNode,
  Clusters,
  DataNode,
  LeafDataNode,
  TreeModificationStep,
} from "src/PhenoFinder/types";
import { useUserPref } from "src/hooks/prefs";
import { DatasetId } from "src/types";
import { Tooltip } from "@spring/ui/Tooltip";
import { PopoverMessage } from "../../Common/PopoverMessage";
import { ImageSet } from "../../imaging/types";
import { usePhenoFinderContext } from "../Context";
import { ClusterCard } from "./ClusterCard";
import { useChartDimensions } from "./hooks";
import { getDefaultClusterDisplayName, getVisualizationHeight } from "./utils";

/**
 * Shows a tooltip over the first suggested split button, explaining how to use the dendrogram.
 */
function WithIntroTooltip({
  children,
  contents,
}: {
  children: ReactNode;
  contents: ReactNode;
}) {
  const [hasSeenPopover, setHasSeenPopover] = useUserPref(
    "hasSeenPhenofinderDendrogramPopover",
  );

  const [isOpen, setIsOpen] = useState<boolean>(!hasSeenPopover);

  useEffect(() => {
    if (!hasSeenPopover) {
      setHasSeenPopover(true);
    }
  }, [hasSeenPopover, setHasSeenPopover]);

  return (
    <PopoverMessage
      isOpen={isOpen}
      onOpenChange={() => setIsOpen(false)}
      contents={contents}
      sideOffset={0}
    >
      {children}
    </PopoverMessage>
  );
}

export default function Dendrogram({
  dataset,
  imageSet,
  className,
  data,
  cropSize,
  onModifyTree,
  onHoverDataNode,
}: {
  dataset: DatasetId;
  imageSet: ImageSet | null;
  className: string;
  data: DataNode;
  onModifyTree: (
    modification: TreeModificationStep,
    shouldScrollToNewClusters?: boolean,
  ) => void;
  onHoverDataNode: (node: DataNode | null) => void;

  cropSize: number;
}) {
  const [state, dispatch] = usePhenoFinderContext();
  const {
    totalNumCells,
    maxTreeLevel,
    selectedClusterName,
    newClusterNames,
    shouldScrollToNewClusters,
    hoveredDataNode,
  } = state;

  const clusterRefs = useRef<Map<string, HTMLButtonElement> | null>();
  const [shouldHighlightNewClusters, setShouldHighlightNewClusters] =
    useState(false);

  const scrollToClusterName = useCallback(
    (clusterName) => {
      const map = getClusterNameToDomNodeMap();
      const domNode = map.get(clusterName);

      if (shouldScrollToNewClusters) {
        domNode?.scrollIntoView({
          behavior: "smooth",
          block: "center",
          inline: "nearest",
        });
      }
    },
    [shouldScrollToNewClusters],
  );

  useEffect(() => {
    if (newClusterNames.length === 0) {
      return;
    }

    // If there are multiple new clusters, a node has been split; scrolling to the first should
    // also show the second
    scrollToClusterName(newClusterNames[0]);
    setShouldHighlightNewClusters(true);

    setTimeout(() => {
      setShouldHighlightNewClusters(false);
    }, 2 * 1000); // 2 seconds
  }, [newClusterNames, scrollToClusterName]);

  const getClusterNameToDomNodeMap = () => {
    if (!clusterRefs.current) {
      clusterRefs.current = new Map();
    }
    return clusterRefs.current;
  };

  const handleSplitNode = useCallback(
    (node: LeafDataNode) => {
      onModifyTree({ type: "split", nodeName: node.name });
    },
    [onModifyTree],
  );

  const handleCombineAtNode = useCallback(
    (node: BranchDataNode) => {
      onModifyTree({ type: "combine", nodeName: node.name });
    },
    [onModifyTree],
  );

  const handleSelectCluster = useCallback(
    (clusterName: string) => {
      dispatch({
        type: "selectCluster",
        clusterName,
      });
    },
    [dispatch],
  );

  const hierarchy = useMemo(() => {
    return d3.hierarchy(data);
  }, [data]);

  const dimensionsConfig = useMemo(
    () => ({
      margin: {
        // 36px corresponds to the rubberband scrolling buffer in PhenoFinder component
        // 8px is padding
        top: 36 + 8,
        left: 48,
        // Leave some margin at the bottom for the bar chart axis, otherwise the last chart
        // will be covered. The bar chart y-positioning is determined by the dendrogram, so
        // we need to set this here.
        bottom: CHART_AXIS_HEIGHT_PX + 8,
      },
    }),
    [],
  );
  const [setContainerRef, chartDimensions] =
    useChartDimensions(dimensionsConfig);
  const { width, boundedWidth, margin } = chartDimensions;

  const [cardWidth, chartWidth] = useMemo(() => {
    const cardWidth = Math.round((boundedWidth * 3) / 4);
    return [cardWidth, boundedWidth - cardWidth - margin.left];
  }, [boundedWidth, margin]);

  const leafNodes = useMemo(() => {
    const leafNodes: LeafDataNode[] = findLeafNodes(data);
    return leafNodes;
  }, [data]);

  const [boundedHeight, height] = useMemo(() => {
    return getVisualizationHeight(leafNodes.length, margin);
  }, [leafNodes, margin]);

  // Draw dendrogram
  const [dendrogram, dendrogramLeaves, dendrogramInternalNodes] =
    useMemo(() => {
      const dendogramGenerator = d3
        .cluster<DataNode>()
        .separation(() => 1) // Uniform separation between clusters
        .size([boundedHeight, chartWidth]);

      const dendrogram = dendogramGenerator(hierarchy);

      const leaves = dendrogram
        .descendants()
        .filter((node) => node.data.type === "leaf")
        // Get all the leaves sorted from top to bottom for numbering
        .sort(
          (nodeA, nodeB) => nodeA.x - nodeB.x,
        ) as d3.HierarchyPointNode<LeafDataNode>[];

      const internalNodes = dendrogram
        .descendants()
        .filter(
          (node) => node.data.type !== "leaf",
        ) as d3.HierarchyPointNode<BranchDataNode>[];
      return [dendrogram, leaves, internalNodes];
    }, [hierarchy, boundedHeight, chartWidth]);

  const hoveredBranchD3Node = dendrogramInternalNodes.filter(
    (node) =>
      hoveredDataNode?.type === "branch" &&
      node.data.name === hoveredDataNode.name,
  )[0] as d3.HierarchyPointNode<BranchDataNode>;

  const horizontalLinkGenerator = d3.link(d3.curveStepBefore);
  const highlightedEdges = useMemo(() => {
    // If our node cannot split any further then we don't want to highlight any edges
    if (
      hoveredDataNode !== null &&
      maxTreeLevel !== null &&
      hoveredDataNode.level >= maxTreeLevel
    ) {
      return [];
    }
    // If we're thinking of splitting a cluster, just highlight the path to that cluster
    if (hoveredDataNode !== null && hoveredDataNode.type === "leaf") {
      return [hoveredDataNode.name];
    }

    // If we're thinking of combining down to an internal node, highlight the paths to all clusters
    // that are descendants of that node
    return (
      // TODO(you): Fix this no-unnecessary-condition rule violation
      // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
      hoveredBranchD3Node
        // TODO(you): Fix this no-unnecessary-condition rule violation
        // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
        ?.descendants()
        .map((node) => node.data.name)
        .filter((nodeName) => nodeName !== hoveredBranchD3Node.data.name) ?? []
    );
  }, [hoveredDataNode, hoveredBranchD3Node, maxTreeLevel]);

  const allEdges = dendrogram.descendants().map((node) => {
    if (!node.parent) {
      return null;
    }

    return (
      <path
        className={cx(
          // TODO(you): Fix this no-unnecessary-condition rule violation
          // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
          highlightedEdges?.includes(node.data.name)
            ? "tw-stroke-primary-500"
            : "tw-stroke-slate-300",
          "tw-stroke-2",
        )}
        key={"line" + node.data.name}
        fill="none"
        d={
          // Switch x/y into the generator to draw the dendogram horizontally
          horizontalLinkGenerator({
            source: [node.parent.y, node.parent.x],
            target: [node.y, node.x],
          }) ?? undefined
        }
      />
    );
  });

  useEffect(() => {
    dispatch({
      type: "updateClusters",
      clusters: dendrogramLeaves.map((node) => ({
        ...node.data,
        // These are swapped because the dendrogram is horizontal
        x: node.y,
        y: node.x,
      })) as Clusters,
    });
  }, [dendrogramLeaves, dispatch]);

  const clusterCards = dendrogramLeaves.map((node, i) => {
    // Wait until we're ready to draw them
    if (cardWidth === 0) {
      return;
    }

    const shouldHighlightCluster =
      highlightedEdges.includes(node.data.name) ||
      (shouldHighlightNewClusters && newClusterNames.includes(node.data.name));

    return (
      <ClusterCard
        key={node.data.name}
        ref={(domNode) => {
          const map = getClusterNameToDomNodeMap();
          if (domNode) {
            map.set(node.data.name, domNode);
          } else {
            map.delete(node.data.name);
          }
        }}
        id={node.data.name}
        name={node.data.displayName ?? getDefaultClusterDisplayName(i)}
        dataset={dataset}
        imageSet={imageSet}
        cells={node.data.cells}
        populationPercent={(
          (node.data.cells.length / totalNumCells!) *
          100
        ).toFixed(1)}
        cropSize={cropSize}
        style={{
          // Switch x/y as related to left/top because the dendogram is horizontal
          left: node.y,
          // Align the dendogram path to the middle of the cluster card
          top: node.x - CLUSTER_CARD_HEIGHT_PX / 2,
        }}
        width={cardWidth}
        onClick={() => handleSelectCluster(node.data.name)}
        highlighted={shouldHighlightCluster}
        selected={selectedClusterName === node.data.name}
      />
    );
  });

  const splitButtons = dendrogramLeaves.map((node) => {
    const nodeCanSplit =
      maxTreeLevel !== null && node.data.level < maxTreeLevel;
    const splittableNode = (
      <div
        key={node.data.name}
        className={cx("tw-absolute tw-group")}
        style={{
          left: node.y - 32,
          top: node.x - 12,
        }}
        onMouseOver={() => onHoverDataNode(node.data)}
        onMouseLeave={() => onHoverDataNode(null)}
      >
        {node.data.isNextSplit && hoveredDataNode?.name === node.data.name && (
          <IoSparkles className="tw-text-sm tw-mr-sm tw-text-yellow tw-group tw-absolute tw-right-4 tw-bottom-4" />
        )}
        <Button
          name="Split node"
          className={cx(
            "tw-text-xl tw-justify-center",
            node.data.name === hoveredDataNode?.name
              ? "tw-visible"
              : "tw-invisible group-hover:tw-visible",
          )}
          variant={"primary"}
          style={{
            height: 24,
            width: 24,
          }}
          noSpacing={true}
          onClick={() => {
            // It's possible this button will unmount and get into a weird state
            // To prevent that, unhover before splitting
            onHoverDataNode(null);
            handleSplitNode(node.data);
          }}
        >
          <Plus style={{ width: 18, height: 18 }} />
        </Button>
      </div>
    );
    return nodeCanSplit ? (
      node.data.isNextSplit ? (
        <WithIntroTooltip
          key={node.data.name}
          contents={
            <div className={cx("tw-max-w-[280px] tw-p-md tw-text-center")}>
              Click on the tree to split or merge clusters!
            </div>
          }
        >
          {splittableNode}
        </WithIntroTooltip>
      ) : (
        splittableNode
      )
    ) : (
      <div
        key={node.data.name}
        className={cx("tw-absolute tw-group")}
        style={{
          left: node.y - 32,
          top: node.x - 12,
        }}
        onMouseOver={() => onHoverDataNode(node.data)}
        onMouseLeave={() => onHoverDataNode(null)}
      >
        <Tooltip
          contents={
            <div className={"tw-p-xs"}>
              This cluster is too small to split further
            </div>
          }
          showArrow={true}
        >
          <div
            className={
              "tw-group tw-bg-gray-300 tw-invisible tw-rounded tw-border tw-border-gray-300 tw-outline-gray-300 tw-p-0.5 group-hover:tw-visible"
            }
          >
            <Lock
              style={{ width: 18, height: 18 }}
              className={cx(
                "tw-text-xl tw-p-0.5 tw-text-white tw-justify-center group-hover:tw-visible tw-border-gray-200 tw-invisible",
              )}
            />
          </div>
        </Tooltip>
      </div>
    );
  });

  const combineButtons = dendrogramInternalNodes.map((node) => {
    // Don't allow any actions on the root node
    if (!node.parent) {
      return;
    }
    return (
      <div
        key={node.data.name}
        className="tw-absolute tw-group"
        style={{
          left: node.y - 12,
          top: node.x - 12,
        }}
        onMouseOver={() => {
          onHoverDataNode(node.data);

          // If we mouse over, stop the new cluster highlighting early so that we don't have
          // conflicting highlights
          if (shouldHighlightNewClusters) {
            setShouldHighlightNewClusters(false);
          }
        }}
        onMouseLeave={() => {
          onHoverDataNode(null);
        }}
      >
        {node.data.isNextCombine &&
          hoveredDataNode?.name === node.data.name && (
            <IoSparkles className="tw-text-sm tw-mr-sm tw-text-yellow tw-group tw-absolute tw-right-4 tw-bottom-4" />
          )}
        <Button
          key={node.data.name}
          name="Combine at node"
          className={cx(
            "tw-text-xl tw-justify-center",
            node.data.name === hoveredDataNode?.name
              ? "tw-visible"
              : "tw-invisible group-hover:tw-visible",
          )}
          variant={"primary"}
          style={{
            height: 24,
            width: 24,
          }}
          noSpacing={true}
          onClick={() => {
            // It's possible this button will unmount and get into a weird state
            // To prevent that, unhover before combining
            onHoverDataNode(null);
            handleCombineAtNode(node.data);
          }}
        >
          <Minus size={16} />
        </Button>
      </div>
    );
  });

  return (
    <div className={cx(className, "tw-relative")} ref={setContainerRef}>
      <svg width={width} height={height}>
        <g
          width={chartWidth}
          height={boundedHeight}
          transform={`translate(${[margin.left, margin.top].join(",")})`}
        >
          {allEdges}
        </g>
      </svg>

      <div className="tw-absolute tw-w-full tw-h-full">
        <div
          className="tw-relative"
          style={{
            width: boundedWidth,
            height: boundedHeight,
            transform: `translateX(${margin.left}px) translateY(${margin.top}px)`,
          }}
        >
          {clusterCards}
          {splitButtons}
          {combineButtons}
        </div>
      </div>
    </div>
  );
}
