import Color from "color";
import KDBush from "kdbush";
import lodashMax from "lodash.max";
import lodashMin from "lodash.min";
import RBush from "rbush";
import {
  CANVAS_PADDING,
  CLOSE_DISTANCE,
  LEGEND_COLLAPSED_HEIGHT,
  LEGEND_MARGIN,
} from "./constants";
import {
  Bounds,
  DragPoint,
  GetClosestPointFromDifferentCluster,
  RenderOptions,
  RendererOptions,
  ScatterPlotPoint,
} from "./types";

export function sameRenderOptions<T extends RendererOptions>(
  a: T | null,
  b: T | null,
): boolean {
  if (a === b) {
    return true;
  }
  if (a === null || b === null) {
    return false;
  }

  const aEntries = Object.entries(a);
  const bEntries = Object.entries(b);

  return (
    aEntries.length === bEntries.length &&
    aEntries.every(([aKey, aValue], index) => {
      const [bKey, bValue] = bEntries[index];
      return aKey === bKey && aValue === bValue;
    })
  );
}

export function transparentify(color: string, amount: number): string {
  return Color(color)
    .fade(1 - amount)
    .string();
}

export function getSelectionBounds({
  dragStart,
  dragEnd,
}: Pick<RenderOptions, "dragEnd" | "dragStart">) {
  if (dragStart && dragEnd) {
    return {
      boundsLeft: Math.min(dragStart.x, dragEnd.x),
      boundsRight: Math.max(dragStart.x, dragEnd.x),
      boundsTop: Math.min(dragStart.y, dragEnd.y),
      boundsBottom: Math.max(dragStart.y, dragEnd.y),
    };
  } else {
    return null;
  }
}

export function enclosedPoints(
  options: Pick<
    RenderOptions,
    | "dragEnd"
    | "dragStart"
    | "points"
    | "pointsIndex"
    | "viewBounds"
    | "canvasHeight"
    | "canvasWidth"
  >,
): Set<string> {
  const { dragStart, dragEnd, points, pointsIndex } = options;

  if (!pointsIndex || !dragStart || !dragEnd) {
    return new Set();
  }

  const bounds = getSelectionBounds(options);
  if (!bounds) {
    return new Set();
  }

  const { boundsLeft, boundsRight, boundsTop, boundsBottom } = bounds;

  // Draw the analogous points bounding box to use with the index
  const bottomLeftPoint = toPointCoordinates(boundsLeft, boundsBottom, options);
  const topRightPoint = toPointCoordinates(boundsRight, boundsTop, options);

  // Find the points within the bounds using the tree and return the keys
  const pointsWithinBounds = pointsIndex.range(
    bottomLeftPoint[0],
    bottomLeftPoint[1],
    topRightPoint[0],
    topRightPoint[1],
  );
  return new Set(pointsWithinBounds.map((i) => points[i].key));
}

export function getInitialBounds({
  viewBounds: { minX, maxX, minY, maxY },
  canvasWidth,
  canvasHeight,
}: {
  viewBounds: Bounds;
  canvasWidth: number;
  canvasHeight: number;
}) {
  const centerX = (minX + maxX) / 2;
  const centerY = (minY + maxY) / 2;
  const paddedWidth =
    ((maxX - minX) * (canvasWidth + CANVAS_PADDING * 2)) / canvasWidth;
  const paddedHeight =
    ((maxY - minY) *
      (canvasHeight +
        CANVAS_PADDING * 2 +
        (LEGEND_COLLAPSED_HEIGHT + LEGEND_MARGIN) * 2)) /
    canvasHeight;

  // If we've got more horizontal space than vertical space, add some extra width
  // so we end up with a square
  if (
    canvasWidth >
    canvasHeight - (LEGEND_COLLAPSED_HEIGHT + LEGEND_MARGIN) * 2
  ) {
    const actualWidth =
      ((maxX - minX) *
        (canvasWidth +
          CANVAS_PADDING * 2 +
          (LEGEND_COLLAPSED_HEIGHT + LEGEND_MARGIN) * 2)) /
      canvasHeight;
    minY = centerY - paddedHeight / 2;
    maxY = centerY + paddedHeight / 2;
    minX = centerX - actualWidth / 2;
    maxX = centerX + actualWidth / 2;
  } else {
    const actualHeight =
      ((maxY - minY) * (canvasHeight + CANVAS_PADDING * 2)) / canvasWidth;
    minX = centerX - paddedWidth / 2;
    maxX = centerX + paddedWidth / 2;
    minY = centerY - actualHeight / 2;
    maxY = centerY + actualHeight / 2;
  }

  return { viewBounds: { minX, maxX, minY, maxY } };
}

export function toPixel(
  pt: Pick<ScatterPlotPoint, "x" | "y">,
  options: Pick<RenderOptions, "canvasWidth" | "canvasHeight" | "viewBounds">,
): {
  x: number;
  y: number;
  // Modified values for x,y not affected by scroll position
  locX: number;
  locY: number;
  // Unique identifier for the x,y coordinates not affected by scroll position
  location: number;
} {
  const {
    canvasWidth,
    canvasHeight,
    viewBounds: { minX, maxX, minY, maxY },
  } = options;

  const x = Math.floor((canvasWidth * (pt.x - minX)) / (maxX - minX));
  const y = Math.floor(canvasHeight * (1 - (pt.y - minY) / (maxY - minY)));

  const locX = Math.floor((canvasWidth * pt.x) / (maxX - minX));
  const locY = Math.floor(canvasHeight * (1 - pt.y / (maxY - minY)));

  const location =
    Math.round(locY / CLOSE_DISTANCE) * Math.ceil(canvasWidth) +
    Math.round(locX / CLOSE_DISTANCE);

  return { x, y, location, locX, locY };
}

export function initCanvas(
  canvas: HTMLCanvasElement,
  {
    canvasWidth,
    canvasHeight,
  }: Pick<RenderOptions, "canvasWidth" | "canvasHeight">,
) {
  const ctx = canvas.getContext("2d");
  if (ctx) {
    if (canvas.width !== canvasWidth) {
      canvas.width = canvasWidth;
    }

    if (canvas.height !== canvasHeight) {
      canvas.height = canvasHeight;
    }

    ctx.clearRect(0, 0, canvas.width, canvas.height);
  }

  return ctx;
}

export function moveBounds(
  options: RenderOptions,
  delta: { deltaX: number; deltaY: number },
) {
  const { canvasWidth, canvasHeight } = options;
  let { minX, maxX, minY, maxY } = options.viewBounds;

  // deltaX and deltaY are provided as mouse pixel movements
  // Convert the values to the point coordinates in order to adjust the bounds
  const deltaX = delta.deltaX * ((maxX - minX) / canvasWidth);
  const deltaY = -delta.deltaY * ((maxY - minY) / canvasHeight);

  minX += deltaX;
  maxX += deltaX;
  minY += deltaY;
  maxY += deltaY;

  return { viewBounds: { minX, maxX, minY, maxY } };
}

export function zoomBounds(
  options: RenderOptions,
  delta: number,
  origin: DragPoint,
) {
  let { minX, maxX, minY, maxY } = options.viewBounds;

  const zoom = Math.pow(2, delta / 20);
  const [originX, originY] = toPointCoordinates(origin.x, origin.y, options);

  minX = originX + (minX - originX) * zoom;
  maxX = originX + (maxX - originX) * zoom;
  minY = originY + (minY - originY) * zoom;
  maxY = originY + (maxY - originY) * zoom;

  return { viewBounds: { minX, maxX, minY, maxY } };
}

// TODO(you): Fix this no-unused-exports rule violation
// ts-unused-exports:disable-next-line
export function toPointCoordinates(
  px: number,
  py: number,
  bounds: Pick<RenderOptions, "canvasWidth" | "canvasHeight" | "viewBounds">,
): [number, number] {
  const { canvasWidth, canvasHeight, viewBounds } = bounds;
  const { minX, maxX, minY, maxY } = viewBounds;
  return [
    minX + ((maxX - minX) * px) / canvasWidth,
    minY + (maxY - minY) * (1 - py / canvasHeight),
  ];
}

// Global mapping from color-scale to relevant canvas containing the prerendered point
const CIRCLE_CANVASES: Map<string, OffscreenCanvas> = new Map();

export function drawCircle({
  ctx,
  x,
  y,
  color,
  pointSize,
  alpha = 1.0,
  scale = 1,
}: {
  ctx: CanvasRenderingContext2D;
  x: number;
  y: number;
  color: string;
  pointSize: number;
  alpha?: number;
  scale?: number;
}) {
  // Try to get the canvas from the cache
  let circleCanvas = CIRCLE_CANVASES.get(
    `${pointSize}-${scale}-${color}-${alpha}`,
  );

  // If it doesn't exist, draw the required circle and add it to the cache
  if (circleCanvas === undefined) {
    const circleRadius = pointSize * 3 * scale;
    circleCanvas = new OffscreenCanvas(circleRadius * 2, circleRadius * 2);
    const ctx = circleCanvas.getContext("2d");
    if (ctx) {
      ctx.fillStyle = alpha === 1.0 ? color : transparentify(color, alpha);
      ctx.beginPath();
      // 3x scale the pre-rendered point so that things won't look too fuzzy on retina
      ctx.arc(
        pointSize * 3 * scale,
        pointSize * 3 * scale,
        pointSize * 3 * scale,
        0,
        Math.PI * 2,
      );
      ctx.fill();

      CIRCLE_CANVASES.set(
        `${pointSize}-${scale}-${color}-${alpha}`,
        circleCanvas,
      );
    }
  }

  ctx.drawImage(
    circleCanvas,
    // The point image's center is offset from 0,0 by the pointSize * scale, so we need to shift
    // it back in order to align the center at x,y
    x - pointSize * scale,
    y - pointSize * scale,
    circleCanvas.width / 3,
    circleCanvas.height / 3,
  );
}

export function getVisiblePointsForBounds(
  selectedPoints: Set<string>,
  options: Pick<
    RenderOptions,
    | "points"
    | "pointStates"
    | "pointsIndex"
    | "viewBounds"
    | "canvasWidth"
    | "canvasHeight"
    | "pointSize"
  >,
): ScatterPlotPoint[] {
  const { points, pointStates, pointsIndex, viewBounds } = options;
  if (pointsIndex === null) {
    return [];
  }

  // Sort points top to bottom, left to right so that they appear cleaner on the plot
  const pointsWithinBounds = pointsIndex
    .range(viewBounds.minX, viewBounds.minY, viewBounds.maxX, viewBounds.maxY)
    .map((i) => points[i])
    .filter((pt) => !pointStates[pt.key].hidden)
    .sort((a, b) => b.y - a.y || a.x - b.x);

  // If there aren't too many points, the plot is performant enough that deduping
  // collisions isn't worth the computation time
  const visiblePoints =
    pointsWithinBounds.length > 5000
      ? getNonCollidingPoints(pointsWithinBounds, selectedPoints, options)
      : pointsWithinBounds;

  return visiblePoints;
}

function getNonCollidingPoints(
  points: ScatterPlotPoint[],
  selectedPoints: Set<string>,
  options: Pick<
    RenderOptions,
    "points" | "canvasWidth" | "canvasHeight" | "pointSize" | "viewBounds"
  >,
): ScatterPlotPoint[] {
  const pointPadding = Math.ceil(options.pointSize / 4);

  // Create a dynamic tree for spatial indexing of the points
  const pointsTree = new RBush();

  // Add points one at a time, skipping any points that collide with a point that we've
  // added previously
  const nonCollidingPoints: ScatterPlotPoint[] = [];
  for (const point of points) {
    const { locX, locY } = toPixel(point, options);
    const hitBox = {
      minX: Math.floor(locX - pointPadding),
      maxX: Math.ceil(locX + pointPadding),
      minY: Math.floor(locY - pointPadding),
      maxY: Math.ceil(locY + pointPadding),
    };

    // If the point is selected, we always show it
    if (selectedPoints.has(point.key)) {
      pointsTree.insert(hitBox);
      nonCollidingPoints.push(point);
      continue;
    }

    const hasCollidingNeighborRenderedOnPlot = pointsTree.collides(hitBox);
    if (!hasCollidingNeighborRenderedOnPlot) {
      pointsTree.insert(hitBox);
      nonCollidingPoints.push(point);
    }
  }

  return nonCollidingPoints;
}

export function createPointsIndex(
  points: { x: number; y: number }[],
  options: Pick<RenderOptions, "canvasWidth" | "canvasHeight" | "viewBounds">,
  { pixel = false }: { pixel?: boolean } = {},
): KDBush {
  const index = new KDBush(points.length);
  for (const pt of points) {
    if (pixel) {
      const { x: pxX, y: pxY } = toPixel(pt, options);
      index.add(pxX, pxY);
    } else {
      index.add(pt.x, pt.y);
    }
  }
  index.finish();
  return index;
}

interface Bucket {
  clusterLabels: Set<number>;
  points: ScatterPlotPoint[];
  bucketX: number;
  bucketY: number;
  bounds: { minX: number; maxX: number; minY: number; maxY: number };
}

interface BucketWithDistance {
  bucket: Bucket;
  minDistance2: number;
}

// Take a record of the current nearest point / distance and update it with
// any closer point
function updateNearestPointInBucket(
  point: ScatterPlotPoint,
  bucket: Bucket,
  current: { nearest: ScatterPlotPoint | null; nearestDist2: number } = {
    nearest: null,
    nearestDist2: Infinity,
  },
): { nearest: ScatterPlotPoint | null; nearestDist2: number } {
  const pointCluster = point.clusterLabel;
  if (pointCluster === undefined) {
    return current;
  }

  for (const otherPoint of bucket.points) {
    if (otherPoint.clusterLabel === pointCluster) {
      continue;
    }
    const distX = otherPoint.x - point.x;
    const distY = otherPoint.y - point.y;
    const dist2 = distX * distX + distY * distY;
    if (dist2 < current.nearestDist2) {
      current.nearestDist2 = dist2;
      current.nearest = otherPoint;
    }
  }

  return current;
}

function minDistanceToBucket2(
  point: ScatterPlotPoint,
  { bounds: { minX, maxX, minY, maxY } }: Bucket,
) {
  return euclideanDist2(
    point.x,
    point.y,
    point.x < minX ? minX : point.x > maxX ? maxX : point.x,
    point.y < minY ? minY : point.y > maxY ? maxY : point.y,
  );
}

// Expand the bounds to include a given point
function expandBounds(bounds: Bucket["bounds"], point: ScatterPlotPoint) {
  bounds.maxX = Math.max(bounds.maxX, point.x);
  bounds.minX = Math.min(bounds.minX, point.x);
  bounds.maxY = Math.max(bounds.maxY, point.y);
  bounds.minY = Math.min(bounds.minY, point.y);
}

function getBucketKey(bucketX: number, bucketY: number) {
  return [bucketX, bucketY].join(":");
}

// Divide a bunch of points into buckets (where each bucket is a cell in a two
// dimensional grid)
function getBucketing(points: ScatterPlotPoint[]) {
  const xValues = points.map(({ x }) => x);
  const yValues = points.map(({ y }) => y);
  const minX = lodashMin(xValues) ?? Infinity;
  const maxX = lodashMax(xValues) ?? -Infinity;
  const minY = lodashMin(yValues) ?? Infinity;
  const maxY = lodashMax(yValues) ?? -Infinity;

  // Figure out how many rows/columns to divide the points into
  const TARGET_POINTS_PER_BUCKET = 100;
  const bucketsPerDimension = Math.max(
    4,
    Math.min(
      32,
      Math.ceil(Math.sqrt(points.length / TARGET_POINTS_PER_BUCKET)),
    ),
  );
  const columnWidth = (maxX - minX) / bucketsPerDimension;
  const rowHeight = (maxY - minY) / bucketsPerDimension;

  const getBucketValues = (
    point: ScatterPlotPoint,
  ): { bucketX: number; bucketY: number } => {
    return {
      bucketX: Math.floor(
        Math.min((point.x - minX) / columnWidth, bucketsPerDimension - 1),
      ),
      bucketY: Math.floor(
        Math.min((point.y - minY) / rowHeight, bucketsPerDimension - 1),
      ),
    };
  };

  const bucketsByKey: Map<string, Bucket> = new Map();

  for (const point of points) {
    if (point.clusterLabel === undefined) {
      continue;
    }

    const { bucketX, bucketY } = getBucketValues(point);
    // Get the key for the bucket this point is in
    const bucketKey = getBucketKey(bucketX, bucketY);

    const existing = bucketsByKey.get(bucketKey);

    // Add the point to an existing bucket, or else add a new bucket for the point
    if (existing) {
      existing.points.push(point);
      existing.clusterLabels.add(point.clusterLabel);
      expandBounds(existing.bounds, point);
    } else {
      bucketsByKey.set(bucketKey, {
        points: [point],
        clusterLabels: new Set([point.clusterLabel]),
        bucketX,
        bucketY,
        bounds: {
          minX: point.x,
          maxX: point.x,
          minY: point.y,
          maxY: point.y,
        },
      });
    }
  }

  return {
    bucketsByKey,
    getBucketValues,
    minSideLength: Math.min(columnWidth, rowHeight),
  };
}

function findNearest(
  point: ScatterPlotPoint,
  {
    getBucketValues,
    bucketsByKey,
    minSideLength,
  }: ReturnType<typeof getBucketing>,
  threshold: number,
): ScatterPlotPoint | null {
  const pointCluster = point.clusterLabel;

  // If this point isn't part of a cluster, or we don't have any buckets, don't worry
  // about calculating nearest
  if (pointCluster === undefined || bucketsByKey.size === 0) {
    return null;
  }

  const { bucketX, bucketY } = getBucketValues(point);

  const expandSize = Math.ceil(threshold / minSideLength);

  const buckets: Bucket[] = [];

  for (let x = bucketX - expandSize; x <= bucketX + expandSize; x += 1) {
    for (let y = bucketY - expandSize; y <= bucketY + expandSize; y += 1) {
      const bucketKey = getBucketKey(x, y);
      const bucket = bucketsByKey.get(bucketKey);
      if (bucket) {
        buckets.push(bucket);
      }
    }
  }

  // For all the buckets, figure out which ones have points from other clusters, and
  // the closest a point in the bucket could possibly be
  const bucketsWithSizes: BucketWithDistance[] = buckets
    .filter(
      (bucket) =>
        bucket.clusterLabels.size > 1 ||
        !bucket.clusterLabels.has(pointCluster),
    )
    .map(
      (bucket): BucketWithDistance => ({
        bucket,
        minDistance2: minDistanceToBucket2(point, bucket),
      }),
    )
    .filter(({ minDistance2 }) => minDistance2 <= threshold * threshold);

  if (bucketsWithSizes.length === 0) {
    return null;
  }

  // Figure out which bucket has the potential to have the closest point
  const { bucket: nearestBucket } = bucketsWithSizes.reduce(
    (current: BucketWithDistance | undefined, entry) =>
      !current || entry.minDistance2 < current.minDistance2 ? entry : current,
  );

  // Figure out the closest point in the closest bucket
  let currentBest = updateNearestPointInBucket(point, nearestBucket);

  // See if any of the other buckets could possibly contain a point closer
  // This should usually be a small number of buckets, i.e. less than 5 so the
  // sort should be quick
  const sortedBuckets = bucketsWithSizes
    .filter(
      (entry) =>
        entry.minDistance2 <= currentBest.nearestDist2 &&
        entry.bucket !== nearestBucket,
    )
    .sort((a, b) => a.minDistance2 - b.minDistance2);

  // Check the remaining buckets to see if any happen to have a closer point in them
  for (const { minDistance2, bucket } of sortedBuckets) {
    if (minDistance2 > currentBest.nearestDist2) {
      break;
    }

    currentBest = updateNearestPointInBucket(point, bucket, currentBest);
  }

  return currentBest.nearest;
}

// Returns a function that can calculate the nearest point that's part of a different
// cluster
export function createGetClosestPointFromDifferentCluster(
  points: ScatterPlotPoint[],
): GetClosestPointFromDifferentCluster {
  let bucketing: ReturnType<typeof getBucketing> | null;

  const cache: Map<
    string,
    {
      point: ScatterPlotPoint;
      nearest: ScatterPlotPoint | null;
      threshold: number;
    }
  > = new Map(
    points.map((point) => [point.key, { point, nearest: null, threshold: 0 }]),
  );

  const getClosestPointFromDifferentCluster = (
    pointKey: string,
    threshold: number,
  ) => {
    const cacheEntry = cache.get(pointKey);
    // This shouldn't happen, but if we don't know about the point we can't get the closest
    if (!cacheEntry) {
      return null;
    }

    if (cacheEntry.nearest) {
      // We know the nearest point
      return cacheEntry.nearest;
    } else if (cacheEntry.threshold >= threshold) {
      // We don't know the nearest point, but we know it's not within the threshold
      return null;
    } else {
      // Update the threshold that we've checked out to, and find the nearest point
      cacheEntry.threshold = threshold;
      return (cacheEntry.nearest = findNearest(
        cacheEntry.point,
        (bucketing ??= getBucketing(points)),
        threshold,
      ));
    }
  };

  return getClosestPointFromDifferentCluster;
}

function euclideanDist2(x1: number, y1: number, x2: number, y2: number) {
  return Math.pow(x2 - x1, 2) + Math.pow(y2 - y1, 2);
}
