/**
 * Component to render a categorical scatter plot for arbitrary data.
 *
 * Given an array of entries containing a `value` property and a specified
 * category column to use on the X-axis, renders scatterplot.
 *
 * Note that this component will not react to changes in the input data after
 * the initial render.
 */
import { StyleSheet, css } from "aphrodite";
import dl from "datalib";
import { Component } from "react";
import { Alert } from "src/Common/Alert";
import { TopLevelSpec as VlSpec } from "vega-lite";
import { UnitSpec } from "vega-lite/build/src/spec";
import { StandardType } from "vega-lite/build/src/type";
import { arrayEquals } from "@spring/core/utils";
import VegaLite from "../Vega/VegaLite";
import { MetadataColumnValue } from "../types";
import { toMetadataColumns } from "../util/vega-util";
import { CheckOutcome, Datum } from "./types";

type OutlierResult = {
  outlier: CheckOutcome;
  [key: string]: string | CheckOutcome | undefined;
};
type OutlierDetectionFunction = (
  byWellAggregatedData: {
    aggregated: number;
    [key: string]: string | number | undefined;
  },
  target: string,
) => OutlierResult[];

type Props = {
  data: Datum[];

  metricName: string;
  size: number;

  minHeight?: number;
  minWidth?: number;

  // The column to correlate against, and the range of values in its domain.
  target: string;
  domain: MetadataColumnValue[];

  markQuartiles: boolean;
  onCheckComplete: (checkOutcome: CheckOutcome) => void;
  // Metrics on which we want to flag outliers.
  outlierChecks: Array<{ metricName: string; target: string }>;
  outlierDetectionFunction: OutlierDetectionFunction;
};

type State = {
  errors: string[];
  warnings: string[];
};

// Check if our props match one of the metrics where we want to check outliers.
const shouldDoOutlierCheck = (props: Props) =>
  props.outlierChecks.some(
    (check) =>
      check.metricName === props.metricName && check.target === props.target,
  );

/**
 * Compute list of possible outliers grouped by a target metadata column.
 *
 * (At present we use this to flag outlier donors / animal ids.)
 *
 * We consider something to be a possible outlier if that donor's median is
 * more than 3 MADs away from the mean across all donors, and a likely outlier
 * if more than 5 MADs. While these are pretty stringent cutoffs, since we
 * usually run > 20 donors per experiment, that means that if we chose below 3
 * standard deviations, we'd always expect to see an "outlier."
 *
 * TODO(colin): using MADs here isn't ideal because it implicitly imposes the
 * prior that we think all donors should have exactly the same mean plating
 * density, which is true in the common case but not always, as well as what
 * the probability that something isn't clustered is. Instead, we should use a
 * Bayesian method that more explicitly models the experiment protocol and
 * sources of error to estimate the likelihood that a sample was incorrectly
 * plated, using a model and priors specific to each type of experiment.
 * TODO(colin): when switching to Bayesian methods, change the signature of
 * these outlier-detection functions to return probabilities, not
 * errors/warnings, and then impose standard probability cutoffs (this is not
 * easy to do with the MAD).
 */
const madOutliers: OutlierDetectionFunction = (
  byWellAggregatedData,
  target,
) => {
  const aggregatedByTargetField = dl
    .groupby(target)
    .summarize([
      {
        name: "aggregated",
        ops: ["median"],
        as: ["median"],
      },
    ])
    .execute(byWellAggregatedData);

  const overallAggregated = dl
    .groupby(() => 1)
    .summarize([
      {
        name: "median",
        ops: ["median"],
        as: ["median"],
      },
    ])
    .execute(aggregatedByTargetField);

  const { median } = overallAggregated[0];

  const absDevs = aggregatedByTargetField.map((d: any) => ({
    absDev: Math.abs(d.median - median),
  }));
  const { mad } = dl
    .groupby(() => 1)
    .summarize([{ name: "absDev", ops: ["median"], as: ["mad"] }])
    .execute(absDevs)[0];

  return aggregatedByTargetField.map((agg: any) => {
    // TODO(colin): we might want to try to take the within-donor variance into
    // account? That is, if we have one donor whose mean is a moderate outlier
    // but its variance is very small, that should probably be more worrying
    // than one that is a moderate outlier but more variable?
    const { median: singleMedian } = agg;
    const result: OutlierResult = {
      ...agg,
      outlier: null,
    };
    if (singleMedian > median + 3 * mad || singleMedian < median - 3 * mad) {
      result.outlier = "warning";
    }
    if (singleMedian > median + 5 * mad || singleMedian < median - 5 * mad) {
      result.outlier = "error";
    }
    return result;
  });
};

/**
 * Compute list of possible outliers grouped by a target metadata column.
 *
 * (At present we use this to flag outlier donors / animal ids.)
 *
 * The passed "outlierDetectionFunction" is used to define how we flag
 * outliers, which will eventually allows us to use different outlier models for
 * different experiment types / metrics.
 */
function computeOutlierStats({
  data,
  target,
  outlierDetectionFunction,
}: Props): OutlierResult[] {
  const flatData = data.map((d) => {
    const { metadata, ...rest } = d;
    return { ...metadata, ...rest };
  });
  const groupby = Object.keys(flatData[0]).filter((k) => k !== "value");

  const aggregatedData = dl
    .groupby(groupby)
    .summarize([{ name: "value", ops: ["median"], as: ["aggregated"] }])
    .execute(flatData);

  return outlierDetectionFunction(aggregatedData, target);
}

export default class CategoricalScatterPlotRenderer extends Component<
  Props,
  State
> {
  static defaultProps = {
    markQuartiles: false,
    outlierDetectionFunction: madOutliers,
  };

  constructor(props: Props) {
    super(props);
    if (shouldDoOutlierCheck(props)) {
      const outlierStats = computeOutlierStats(props);
      const outliers = outlierStats.filter((s) => s.outlier !== null);
      this.state = {
        // TODO(colin): fix these any casts.
        errors: outliers
          .filter((o) => o.outlier === "error")
          .map((o) => o[props.target]) as any,
        warnings: outliers
          .filter((o) => o.outlier === "warning")
          .map((o) => o[props.target]) as any,
      };
    } else {
      this.state = { errors: [], warnings: [] };
    }
  }

  overallCheckStatus(): CheckOutcome {
    if (this.state.errors.length > 0) {
      return "error";
    }
    if (this.state.warnings.length > 0) {
      return "warning";
    }
    return null;
  }

  componentDidMount() {
    const { onCheckComplete } = this.props;
    onCheckComplete(this.overallCheckStatus());
  }

  componentDidUpdate() {
    const { onCheckComplete, target } = this.props;
    if (shouldDoOutlierCheck(this.props)) {
      const outlierStats = computeOutlierStats(this.props);
      const outliers = outlierStats.filter((s) => s.outlier !== null);
      this.setState(
        {
          // TODO(colin): fix these any casts.
          errors: outliers
            .filter((o) => o.outlier === "error")
            .map((o) => o[target]) as any,
          warnings: outliers
            .filter((o) => o.outlier === "warning")
            .map((o) => o[target]) as any,
        },
        () => onCheckComplete(this.overallCheckStatus()),
      );
    }
  }

  shouldComponentUpdate(nextProps: Props) {
    return (
      // As a performance optimization, we only update when the non-data
      // properties change. In practice, the data shouldn't change under our
      // feet without at least one of these other properties changing; and
      // (1) rendering on every component update is expensive, as is (2) doing
      // a deep comparison over a (possibly large) data object.
      this.props.target !== nextProps.target ||
      this.props.size !== nextProps.size ||
      !arrayEquals(this.props.domain, nextProps.domain) ||
      !arrayEquals(
        toMetadataColumns(this.props.data),
        toMetadataColumns(nextProps.data),
      )
    );
  }

  render() {
    const { data, target, domain, size, markQuartiles, minHeight, minWidth } =
      this.props;

    const { errors, warnings } = this.state;

    // Flatten the metadata.
    const metadataColumns = toMetadataColumns(data);
    const flatData = data.map((d) => {
      const { metadata, ...rest } = d;
      return { ...metadata, ...rest };
    });

    const spec: VlSpec = {
      $schema: "https://vega.github.io/schema/vega-lite/v4.json",
      title: `Grouped by ${target}`,
      data: { values: flatData },
      layer: [
        // Render the aggregated scatterplot.
        {
          mark: "point",
          encoding: {
            x: {
              field: target,
              type: "ordinal",
              scale: {
                domain: domain,
              },
            },
            y: {
              field: "value",
              title: null,
              type: "quantitative",
              scale: { zero: false },
            },
            color: {
              field: target,
              type: "ordinal",
              scale: { scheme: "category20" },
              legend: null,
            },
            tooltip: [
              {
                field: "value",
                type: "quantitative",
              },
              ...metadataColumns.map((column) => ({
                field: column,
                type: "ordinal" as StandardType,
              })),
            ],
          },
        },
        ...(markQuartiles
          ? ([
              {
                mark: "rule",
                encoding: {
                  y: {
                    aggregate: "mean",
                    field: "value",
                    type: "quantitative",
                  },
                  color: { value: "blue" },
                  size: { value: 1 },
                },
              },
              {
                mark: "rule",
                encoding: {
                  y: {
                    // TODO(charlie): Use 10th percentile
                    aggregate: "q1",
                    field: "value",
                    type: "quantitative",
                  },
                  // TODO(charlie): Dash.
                  color: { value: "red" },
                  size: { value: 1 },
                },
              },
              {
                mark: "rule",
                encoding: {
                  y: {
                    // TODO(charlie): Use 90th percentile
                    aggregate: "q3",
                    field: "value",
                    type: "quantitative",
                  },
                  // TODO(charlie): Dash.
                  color: { value: "red" },
                  size: { value: 1 },
                },
              },
            ] as UnitSpec[])
          : []),
      ],

      autosize: {
        resize: true,
      },

      height: Math.max(240, minHeight || -1),
      width: Math.max(domain.length * size, minWidth || -1),
    };

    return (
      <div>
        <div className={css(styles.outlierWrapper)}>
          {errors.length !== 0 ? (
            <Alert color={"danger"}>
              Likely outliers for {target}: {errors.join(", ")}
            </Alert>
          ) : null}
          {warnings.length !== 0 ? (
            <Alert color={"warning"}>
              Possible outliers for {target}: {warnings.join(", ")}
            </Alert>
          ) : null}
        </div>
        <VegaLite spec={spec} renderer={"canvas"} />
      </div>
    );
  }
}

const styles = StyleSheet.create({
  outlierWrapper: {
    paddingLeft: 15,
    paddingRight: 15,
  },
});
