import cx from "classnames";
import * as d3 from "d3";
import { useCallback, useMemo } from "react";
import { CHART_HEIGHT_PX } from "src/PhenoFinder/constants";
import { ClusterMetadata } from "src/PhenoFinder/types";
import { usePhenoFinderContext } from "../Context";
import { MetadataBars } from "./MetadataBars";
import { useChartDimensions, useDataStacks } from "./hooks";

export function SingleMetadataChart({
  className,
  data,
  colorScale,
  shouldHighlightMeaningfulMetadata,
  numTotalRows,
  displayName,
}: {
  className?: string;
  data: ClusterMetadata[];
  colorScale: d3.ScaleOrdinal<string, string, string>;
  shouldHighlightMeaningfulMetadata: boolean;
  numTotalRows: number;
  displayName: string;
}) {
  const [state] = usePhenoFinderContext();

  const { selectedVisualizationColumn } = state;
  const stacks = useDataStacks(data);

  const [setContainerRef, chartDimensions] = useChartDimensions();
  const { width, boundedWidth, height, boundedHeight, margin } =
    chartDimensions;

  const xScale = useMemo(
    () => d3.scaleLinear().domain([0, 1]).range([0, boundedWidth]),
    [boundedWidth],
  );
  const columnDisplayName = selectedVisualizationColumn?.name ?? "";

  // Since we only ever show one chart, we can use a constant y scale
  const yScale = useCallback(() => 0, []);

  const axis = useMemo(() => {
    return (
      <>
        {xScale.ticks(5).map((value, i) => (
          <g key={i}>
            <text
              className="tw-fill-slate-800"
              x={xScale(value)}
              y={CHART_HEIGHT_PX + 16}
              textAnchor={i === 0 ? "left" : i === 5 ? "end" : "middle"}
              alignmentBaseline="central"
              fontSize={14}
            >
              {`${value * 100}%`}
            </text>
          </g>
        ))}

        <g>
          <text
            className="tw-fill-slate-800 tw-font-bold"
            x={boundedWidth / 2 + margin.left}
            y={CHART_HEIGHT_PX + 42}
            textAnchor="middle"
            alignmentBaseline="central"
            fontSize={14}
          >
            Percent of cells in {displayName} by "{columnDisplayName}"
          </text>
        </g>
      </>
    );
  }, [xScale, boundedWidth, margin, displayName, columnDisplayName]);

  return (
    <div
      className={cx(className, "tw-relative tw-w-full")}
      style={{ height: CHART_HEIGHT_PX * 2.5 }}
      ref={setContainerRef}
    >
      <div className="tw-absolute tw-w-full tw-h-full">
        <div
          className="tw-relative"
          style={{
            width: boundedWidth,
            height: boundedHeight,
            transform: `translateX(${margin.left}px) translateY(${margin.top}px)`,
          }}
        >
          <MetadataBars
            stacks={stacks}
            xScale={xScale}
            yScale={yScale}
            colorScale={colorScale}
            numTotalRows={numTotalRows}
            shouldHighlightMeaningfulMetadata={
              shouldHighlightMeaningfulMetadata
            }
          />
        </div>
      </div>

      <svg width={width} height={height}>
        <g
          width={boundedWidth}
          height={boundedHeight}
          transform={`translate(${[margin.left, margin.top].join(",")})`}
        >
          {axis}
        </g>
      </svg>
    </div>
  );
}
