/**
 * Component to render a correlations for arbitrary data.
 *
 * Given an array of entries containing a `value` property and a specified
 * column against which to compute correlations, renders both a scatterplot and
 * a line of best fit atop it.
 *
 * Note that this component will not react to changes in the input data after
 * the initial render.
 */
import { regressionLinear, regressionPoly } from "d3-regression";
import dl from "datalib";
import { useEffect, useMemo } from "react";
import { TopLevelSpec as VlSpec } from "vega-lite";
import VegaLite from "../Vega/VegaLite";
import { toMetadataColumns } from "../util/vega-util";
import { CheckOutcome, Datum, RegressionModel } from "./types";

type Props = {
  data: Datum[];

  metricName: string;
  size: number;

  // The column to correlate against, and the range of values in its domain.
  target: string;
  domain: string[];

  // Callback triggered when analysis completes, containing the outcome of the
  // analysis.
  onCheckComplete: (checkOutcome: CheckOutcome) => void;

  // The rotation angle of the x axis tick labels.
  xLabelAngle?: number;
  // Whether to show y axis labels. Defaults to true.
  showYLabels?: boolean;
  // Label for the x axis. Defaults to the value of the target prop.
  xLabel?: string;

  regressionModel: RegressionModel;
};

type RegressionResult = {
  R: number;
  points: Array<{ [target: string]: string | number; value: number }>;
};

function flattenData(data: Props["data"]) {
  return data.map((d) => {
    const { metadata, ...rest } = d;
    return { ...metadata, ...rest };
  });
}

function getBestFitPoints(
  target: string,
  domain: Array<string>,
  predict: (x: number) => number,
) {
  return domain.map((domainValue, i) => ({
    [target]: domainValue,
    value: predict(i),
  }));
}

/**
 * Compute a Summary based on the set of Props.
 */
function computeRegression({
  flatData,
  target,
  domain,
  regressionModel,
}: { flatData: ReturnType<typeof flattenData> } & Pick<
  Props,
  "target" | "domain" | "regressionModel"
>): RegressionResult {
  const groupby = Object.keys(flatData[0]).filter((k) => k !== "value");

  const aggregatedData = dl
    .groupby(groupby)
    .summarize([{ name: "value", ops: ["median"], as: ["aggregated"] }])
    .execute(flatData);

  switch (regressionModel) {
    case "linear": {
      const linearRegression = regressionLinear()
        .x((d: any) => domain.indexOf(d[target]))
        .y((d: any) => d.aggregated);
      const result = linearRegression(aggregatedData);
      return {
        points: getBestFitPoints(target, domain, result.predict),
        R: Math.sqrt(result.rSquared) * (result.a >= 0 ? 1 : -1),
      };
    }

    case "cubic": {
      const cubicRegression = regressionPoly()
        .x((d: any) => domain.indexOf(d[target]))
        .y((d: any) => d.aggregated)
        .order(3);
      const result = cubicRegression(aggregatedData);
      return {
        points: getBestFitPoints(target, domain, result.predict),
        R: Math.sqrt(result.rSquared),
      };
    }
  }
}
/**
 * Convert a RegressionResult to a CheckOutcome.
 */
function outcomeForRegressionResult(
  regressionResult: RegressionResult,
): CheckOutcome {
  return Math.abs(regressionResult.R) >= 0.6
    ? "error"
    : Math.abs(regressionResult.R) >= 0.3
      ? "warning"
      : null;
}

function CorrelationsRenderer({
  data,
  target,
  domain,
  size,
  regressionModel,
  onCheckComplete,
  xLabelAngle,
  showYLabels,
  xLabel,
}: Props) {
  // Flatten the metadata.
  const metadataColumns = useMemo(() => toMetadataColumns(data), [data]);
  const flatData = useMemo(() => flattenData(data), [data]);

  // Calculate regression
  const { regressionResult, bestFitData, outcome } = useMemo(() => {
    const result = computeRegression({
      flatData,
      target,
      domain,
      regressionModel,
    });
    const outcomeForResult = outcomeForRegressionResult(result);
    return {
      regressionResult: result,
      bestFitData: result.points,
      outcome: outcomeForResult,
    };
  }, [flatData, target, domain, regressionModel]);

  useEffect(() => {
    onCheckComplete(outcome);
  }, [onCheckComplete, outcome]);

  const spec: VlSpec = {
    $schema: "https://vega.github.io/schema/vega-lite/v4.json",
    title: `r = ${regressionResult.R.toPrecision(2)}`,
    layer: [
      // Render the aggregated scatterplot.
      {
        data: { values: flatData },
        mark: "point",
        encoding: {
          x: {
            field: target,
            type: "ordinal",
            scale: {
              domain: domain,
            },
            title: xLabel ?? target,
            axis: {
              labelAngle: xLabelAngle ?? 0,
            },
          },
          y: {
            field: "value",
            title: null,
            type: "quantitative",
            scale: { zero: false },
            axis: {
              labels: showYLabels ?? true,
              ticks: showYLabels ?? true,
            },
          },
          tooltip: [
            {
              field: "value",
              type: "quantitative",
            },
            ...metadataColumns.map(
              (column) =>
                ({
                  field: column,
                  type: "ordinal",
                }) as any,
            ),
          ],
        },
      }, // Line of best fit.
      {
        data: { values: bestFitData },
        mark: {
          type: "line",
          color: "red",
          interpolate: regressionModel === "linear" ? "linear" : "monotone",
        },
        encoding: {
          x: {
            field: target,
            type: "ordinal",
            scale: { domain },
          },
          y: {
            field: "value",
            type: "quantitative",
            scale: { zero: false },
          },
        },
      },
    ],
    autosize: {
      resize: true,
    },
    height: 240,
    width: 24 * size,
  };

  return (
    <div className="tw-flex tw-flex-col tw-gap-sm tw-text-center">
      {outcome === "error" ? (
        <div className={"tw-w-full tw-p-sm tw-rounded-lg tw-bg-red-error"}>
          Warning: <b>extreme</b> correlation detected
        </div>
      ) : (
        outcome === "warning" && (
          <div
            className={"tw-w-full tw-p-sm tw-rounded-lg tw-bg-yellow-warning"}
          >
            Warning: <i>high</i> correlation detected.
          </div>
        )
      )}
      <VegaLite spec={spec} renderer={"canvas"} />
    </div>
  );
}

export default CorrelationsRenderer;
