import cx from "classnames";
import * as d3 from "d3";
import { memo, useMemo } from "react";
import {
  CHART_AXIS_HEIGHT_PX,
  CHART_HEIGHT_PX,
  METADATA_LEGEND_WIDTH_PX,
} from "src/PhenoFinder/constants";
import { ClusterMetadata, Clusters } from "src/PhenoFinder/types";
import { MetadataBars } from "./MetadataBars";
import { useChartDimensions, useDataStacks } from "./hooks";
import { getVisualizationHeight } from "./utils";

function _MetadataCharts({
  className,
  clusters,
  data,
  colorScale,
  selectedMetadataColumn,
  shouldHighlightMeaningfulMetadata,
}: {
  className?: string;
  clusters: Clusters;
  data: ClusterMetadata[];
  colorScale: d3.ScaleOrdinal<string, string, string>;
  selectedMetadataColumn: string;
  shouldHighlightMeaningfulMetadata: boolean;
}) {
  const stacks = useDataStacks(data);
  const dimensionsConfig = useMemo(
    () => ({
      margin: {
        // 36px corresponds to the rubberband scrolling buffer in PhenoFinder component
        // 8px is padding
        top: 36 + 8,
        left: 24,
        right: METADATA_LEGEND_WIDTH_PX,
      },
    }),
    [],
  );
  const [setContainerRef, chartDimensions] =
    useChartDimensions(dimensionsConfig);
  const { width, boundedWidth, margin } = chartDimensions;
  const [boundedHeight, height] = useMemo(() => {
    return getVisualizationHeight(clusters.length, margin);
  }, [clusters, margin]);

  const xScale = useMemo(
    () => d3.scaleLinear().domain([0, 1]).range([0, boundedWidth]),
    [boundedWidth],
  );
  const yScale = useMemo(
    () =>
      d3
        .scaleOrdinal<number>()
        .domain(clusters.map((cluster) => cluster.name))
        .range(clusters.map((cluster) => cluster.y - CHART_HEIGHT_PX / 2)),
    [clusters],
  );

  const grid = useMemo(
    () =>
      xScale.ticks(5).map((value, i) => (
        <g key={i}>
          <line
            className="tw-stroke-slate-300"
            x1={xScale(value)}
            x2={xScale(value)}
            y1={0}
            y2={height + CHART_AXIS_HEIGHT_PX}
            opacity={0.8}
          />
        </g>
      )),
    [height, xScale],
  );

  const axis = useMemo(() => {
    return (
      <svg style={{ width: width }}>
        {xScale.ticks(5).map((value, i) => (
          <g key={i}>
            <text
              className="tw-fill-slate-500"
              // The axis isn't constrained to the bounded width, so drawn elements need to be
              // shifted over by the left margin
              x={xScale(value) + margin.left}
              y={8}
              textAnchor="middle"
              alignmentBaseline="central"
              fontSize={14}
            >
              {`${value * 100}%`}
            </text>
          </g>
        ))}

        <g>
          <text
            className="tw-fill-slate-700 tw-font-bold"
            x={boundedWidth / 2 + margin.left}
            y={36}
            textAnchor="middle"
            alignmentBaseline="central"
            fontSize={14}
          >
            Percent of cells in each cluster by "{selectedMetadataColumn}"
          </text>
        </g>
      </svg>
    );
  }, [xScale, width, boundedWidth, margin, selectedMetadataColumn]);

  return (
    <div
      className={cx(className, "tw-relative tw-w-full tw-h-full")}
      ref={setContainerRef}
    >
      {/* Draw the grid lines into the axis for rubberband scrolling continuity */}
      <svg width={width} height={height + CHART_AXIS_HEIGHT_PX}>
        <g transform={`translate(${[margin.left, 0].join(",")})`}>{grid}</g>
      </svg>

      <div className="tw-absolute tw-w-full tw-h-full">
        <div
          className="tw-relative"
          style={{
            width: boundedWidth,
            height: boundedHeight,
            transform: `translateX(${margin.left}px) translateY(${margin.top}px)`,
          }}
        >
          <MetadataBars
            stacks={stacks}
            xScale={xScale}
            yScale={yScale}
            colorScale={colorScale}
            numTotalRows={data.length}
            shouldHighlightMeaningfulMetadata={
              shouldHighlightMeaningfulMetadata
            }
          />
        </div>
      </div>

      {/* A colored div under the axis so that grid lines won't scroll below it */}
      <div
        className={cx("tw-fixed tw-w-full", "tw-bg-slate-100")}
        style={{
          width: boundedWidth + margin.left * 2,
          height: chartDimensions.height - height - margin.top,
          bottom: 0,
        }}
      />

      <div
        className={cx(
          "tw-fixed tw-w-full",
          "tw-pt-sm tw-rounded tw-shadow-sm",
          "tw-bg-white tw-border tw-border-slate-300",
        )}
        style={{
          width: boundedWidth + margin.left * 2,
          height: CHART_AXIS_HEIGHT_PX,
          // The axis sticks to the bottom of the page if there are enough charts that they scroll,
          // or render close to the bottom of the last chart if not
          bottom:
            chartDimensions.height -
              height -
              CHART_AXIS_HEIGHT_PX -
              // We need to account for the extra margin in the math, but we don't actually want
              // to render this much space between the last chart and the axis, hence the condition
              margin.top >
            0
              ? chartDimensions.height - height - CHART_AXIS_HEIGHT_PX - 8
              : 0,
        }}
      >
        {axis}
      </div>
    </div>
  );
}
export const MetadataCharts = memo(_MetadataCharts);
