/**
 * Component to render plate-wise correlations for one or more features.
 *
 * Includes a heatmap, along with row- and column-wise correlations.
 */
import { useCallback, useMemo, useState } from "react";
import "react-select-plus/dist/react-select-plus.css";
import { DatasetId } from "src/types";
import { typedValues } from "@spring/core/utils";
import { ColorScheme } from "../Control/ColorSchemeSelector";
import CheckOutcomesStatusBar from "./CheckOutcomesStatusBar";
import FeatureCorrelationView from "./FeatureCorrelationView";
import LabeledFeatureView from "./LabeledFeatureView";
import type { CheckOutcome, FieldFeature, RegressionModel } from "./types";
import { featureImportanceComparator } from "./util";

function flattenChecks(checkOutcomes: {
  [key: string]: {
    [key: string]: CheckOutcome | null;
  };
}): CheckOutcome[] {
  return typedValues(checkOutcomes).map(typedValues).flat();
}

export default function FeatureCorrelationsView({
  colorScheme,
  dataset,
  featureSet,
  featureSetColumns,
  features,
  metadataColumns,
  plate,
  stratifyColumn,
  regressionModel,
}: {
  colorScheme: ColorScheme | null;
  dataset: DatasetId;
  featureSet: string;
  featureSetColumns: string[];
  features: FieldFeature[];
  metadataColumns: string[];
  plate: string;
  stratifyColumn: string | null;
  regressionModel: RegressionModel;
}) {
  // Outcomes of constituent QC checks, indexed by FeatureSetColumn, and then
  // by an opaque, unique key (e.g., "rowCorrelations").
  const [checkOutcomes, setCheckOutcomes] = useState<{
    [key: string]: {
      [key: string]: CheckOutcome | null;
    };
  }>({});

  const handleClearChecks = useCallback(
    (featureSetColumn: string) =>
      setCheckOutcomes((checkOutcomes) =>
        Object.keys(checkOutcomes)
          .filter((key) => key !== featureSetColumn)
          .reduce((obj, key) => {
            obj[key] = checkOutcomes[key];
            return obj;
          }, {} as any),
      ),
    [],
  );

  const handleCheckCompleteForFeatureSetColumn = useMemo(
    () =>
      Object.fromEntries(
        featureSetColumns.map((featureSetColumn) => [
          featureSetColumn,
          (context: string, outcome: CheckOutcome) => {
            setCheckOutcomes((checkOutcomes) => ({
              ...checkOutcomes,
              [featureSetColumn]: {
                ...checkOutcomes[featureSetColumn],
                [context]: outcome,
              },
            }));
          },
        ]),
      ),
    [featureSetColumns],
  );

  return (
    <div className={"tw-flex tw-flex-col tw-gap-lg tw-min-w-[1440px]"}>
      <CheckOutcomesStatusBar checkOutcomes={flattenChecks(checkOutcomes)} />

      <div className={"tw-z-0 tw-w-full tw-flex tw-flex-col tw-gap-lg"}>
        {colorScheme &&
          featureSetColumns
            .sort(featureImportanceComparator)
            .map((featureSetColumn) => (
              <LabeledFeatureView
                key={featureSetColumn}
                featureSetColumn={featureSetColumn}
              >
                <FeatureCorrelationView
                  dataset={dataset}
                  plate={plate}
                  featureSet={featureSet}
                  featureSetColumn={featureSetColumn}
                  metadataColumns={metadataColumns}
                  features={features}
                  colorScheme={colorScheme}
                  stratifyOn={stratifyColumn}
                  regressionModel={regressionModel}
                  onCheckComplete={
                    handleCheckCompleteForFeatureSetColumn[featureSetColumn]
                  }
                  onClearChecks={handleClearChecks}
                />
              </LabeledFeatureView>
            ))}
      </div>
    </div>
  );
}
