import { AsyncDuckDB } from "@duckdb/duckdb-wasm";
import * as Tabs from "@radix-ui/react-tabs";
import cx from "classnames";
import { ReactNode, useCallback, useEffect, useMemo, useState } from "react";
import { ChevronDown, HelpCircle } from "react-feather";
import { useHistory } from "react-router-dom";
import { Button } from "src/Common/Button";
import { Select } from "src/Common/Select";
import { useActiveWorkspace } from "src/Workspace/hooks";
import VisualizationControls from "src/immunofluorescence/VisualizationControls";
import { Tooltip } from "@spring/ui/Tooltip";
import { FullScreenContainer } from "../Common/FullScreenContainer";
import Loader from "../Common/Loader";
import { FilterSqlClause } from "../Control/FilterSelector/types";
import { TabStrip } from "../Control/TabStrip";
import {
  SelectedEmbeddingInfo,
  SelectedMeasurementInfo,
} from "../FeatureSelector/CollapsedFeatureSelectorInfo";
import {
  groupFeaturesByPrefix,
  useFeatureSetsGrouped,
} from "../FeatureSelector/utils";
import ContributionsPreviewUrl from "../assets/screenshots/contributions-to-preds-preview.png";
import { DatasetId, MetadataColumnValue } from "../types";
import { inferInterestingColumnsDB } from "../util/dataset-util";
import { useAsyncValue } from "../util/hooks";
import {
  columnMatchesValueClause,
  queryDBAsRecords,
  sanitizedColumn,
  sql,
} from "../util/sql";
import ApplyModel from "./ApplyModel";
import ConfusionMatrix from "./ConfusionMatrix";
import { ImagesForPrediction } from "./ImagesForPrediction";
import { getSimplifiedSplitSummary } from "./SplitConfigurator";
import { useConfusionCounts, useModelTrainingResults } from "./hooks";
import { SavedModelListing } from "./types";
import {
  accuracyForCounts,
  findControlEntries,
  formatFilter,
  generateValueClauses,
  inferClassColumnAndValuesFromConfig,
  predictionsTableNameFor,
} from "./utils";

export default function ReviewResultsPage({
  dataset,
  model,
  metadata,
  onBack,
}: {
  dataset: DatasetId;
  model: SavedModelListing;
  metadata: AsyncDuckDB;
  onBack: () => void;
}) {
  const workspace = useActiveWorkspace();
  const history = useHistory();
  const results = useModelTrainingResults(
    workspace.id,
    dataset,
    model.id,
    metadata,
  );
  if (model.status === "incomplete") {
    throw new Error("AssertionError: Model is not yet complete");
  }

  // TODO(benkomalo): use "X mins ago" format.
  const lastTrained = useMemo(() => {
    if (model.status === "trained" || model.status === "applied") {
      return model.result.status === "success"
        ? model.result.lastTrained
        : undefined;
    } else {
      return undefined;
    }
  }, [model]);

  let contents: ReactNode;
  if (!results) {
    contents = (
      <>
        <hr />
        {model.status === "pending" ? (
          <WaitingForTrainingPage />
        ) : (
          <FullScreenContainer center>
            <Loader size={"large"} />
          </FullScreenContainer>
        )}
      </>
    );
  } else if (!results.successful) {
    contents = (
      <div className={"tw-flex tw-items-center tw-justify-center tw-h-[20rem]"}>
        Ooops. Something went wrong loading your data. Please refresh and try
        again.
      </div>
    );
  } else {
    contents = (
      <ReviewResultsPageInner
        metadata={metadata}
        dataset={dataset}
        model={model}
        results={results.value}
      />
    );
  }
  return (
    <div className={"tw-flex tw-flex-col"}>
      <div className={"tw-p-8"}>
        <button
          className={
            "tw-mb-8 tw-uppercase tw-text-sm tw-text-gray-500 tw-flex tw-items-center"
          }
          onClick={() => onBack()}
        >
          <ChevronDown className={"tw-rotate-90"} /> Back to all models
        </button>
        <div className={"tw-flex tw-items-start"}>
          <div className={"tw-flex-1"}>
            <div className={"tw-text-xl"}>{model.name}</div>
            <div className={"tw-text-gray-500 tw-mt-2"}>
              {lastTrained ? (
                <>Last trained: {lastTrained.toDateString()}</>
              ) : null}
            </div>
          </div>
          <Button
            variant={"secondary"}
            className={"tw-mr-4"}
            onClick={() => {
              history.push(
                `/workspace/${workspace.id}/e/${dataset}/sl/edit?from=${model.id}`,
              );
            }}
          >
            Clone model
          </Button>
          <ApplyModel dataset={dataset} model={model}>
            Apply
          </ApplyModel>
        </div>
      </div>

      {contents}
    </div>
  );
}

type Tab = "summary" | "performance" | "contributions";
function ReviewResultsPageInner({
  metadata,
  dataset,
  model,
  results,
}: {
  metadata: AsyncDuckDB;
  dataset: DatasetId;
  model: SavedModelListing;
  results: AsyncDuckDB;
}) {
  const [tab, setTab] = useState<Tab>("performance");
  const [plateForControls, setPlateForControls] = useState<string | null>(null);
  const [showControls, setShowControls] = useState(false);

  const handleToggleVisualizationControls = useCallback(() => {
    setShowControls(!showControls);
  }, [showControls]);

  return (
    <Tabs.Root
      value={tab}
      onValueChange={(tab) => setTab(tab as Tab)}
      orientation="vertical"
    >
      <div className={"tw-inline-block tw-px-4"}>
        <TabStrip
          value={tab}
          options={[
            { value: "summary", label: "Summary" },
            { value: "performance", label: "Performance" },
            { value: "contributions", label: "Contributions to Predictions" },
          ]}
          customRenderFunc={(value, label, isSelected) => (
            <div
              className={cx(
                "tw-whitespace-nowrap tw-py-2 tw-px-4",
                isSelected && "tw-text-purple",
              )}
            >
              {label}
            </div>
          )}
          ariaLabel={"View mode"}
        />
      </div>

      <div
        className={
          "tw-border-t tw-min-h-[100vh] tw-bg-slate-100 tw-flex tw-flex-row"
        }
      >
        <div className="tw-flex-1 tw-p-8 tw-overflow-auto">
          <Tabs.Content
            value="summary"
            className={"tw-flex tw-flex-col tw-p-8 tw-bg-white tw-border"}
          >
            <ModelSummary metadata={metadata} dataset={dataset} model={model} />
          </Tabs.Content>

          <Tabs.Content value="performance" className={"tw-p-8 tw-bg-white"}>
            <div className={"tw-flex"}>
              <PerformanceResultsPage
                dataset={dataset}
                model={model}
                results={results}
                onSetPlateForVisualizationControls={setPlateForControls}
                onToggleVisualizationControls={
                  handleToggleVisualizationControls
                }
                showControls={showControls}
                plateForControls={plateForControls}
              />
            </div>
          </Tabs.Content>
          <Tabs.Content
            value="contributions"
            className={
              "tw-bg-white tw-flex tw-flex-col tw-items-center tw-p-lg"
            }
          >
            <div className={"tw-text-lg"}>
              Coming soon: tools to help better understand your model's
              predictions.
            </div>
            <div
              className={
                "tw-w-[640px] tw-rounded-xl tw-bg-gray-200 tw-p-lg tw-my-xl"
              }
            >
              <img
                src={ContributionsPreviewUrl}
                className={"tw-opacity-80 tw-w-full"}
              />
            </div>
          </Tabs.Content>
        </div>
      </div>
    </Tabs.Root>
  );
}

function ModelSummary({
  metadata,
  dataset,
  model,
}: {
  metadata: AsyncDuckDB;
  dataset: DatasetId;
  model: SavedModelListing;
}) {
  const [classColumn, classValues] = useAsyncValue(async () => {
    const allColumns = await inferInterestingColumnsDB(metadata);
    const result = await inferClassColumnAndValuesFromConfig(
      model.config.data?.classes ?? [],
      allColumns,
      metadata,
      model.config.data?.filter ?? "TRUE",
      false,
    );
    return result ?? null;
  }, [model, metadata]) ?? [undefined, undefined];

  const selectedWells = useAsyncValue(async () => {
    if (classColumn) {
      return await findControlEntries(
        metadata,
        model.config.data?.filter ?? "TRUE",
        classColumn,
        classValues,
      );
    }
    return null;
  }, [classColumn, classValues]);

  const allFeatures = useFeatureSetsGrouped(dataset);
  const allEmbeddingPrefixes = useMemo(() => {
    if (!allFeatures || !allFeatures.successful) {
      return [];
    }
    return Object.keys(groupFeaturesByPrefix(allFeatures.value["embedding"]));
  }, [allFeatures]);

  if (model.status === "incomplete") {
    throw new Error("Summary page only relevant for complete and saved models");
  }

  const Row = useCallback(
    ({ label, children }: { label: string; children: ReactNode }) => (
      <div className={"tw-py-4 tw-flex"}>
        <div className={"tw-w-[320px]"}>{label}</div>
        <div className={"tw-flex-1"}>{children}</div>
      </div>
    ),
    [],
  );
  return (
    <div className={"tw-flex"}>
      <div className={"tw-flex tw-flex-col"}>
        <div className={"tw-uppercase tw-text-purple"}>Model Information</div>

        <div className={"tw-flex tw-flex-col"}>
          <Row key="filter" label={"Filter"}>
            {selectedWells && (
              <>
                Training on {selectedWells.length} wells
                <Tooltip
                  side="right"
                  showArrow
                  contents={
                    <>
                      <div className={"tw-font-semibold"}>User filters</div>
                      {formatFilter(model.config.data.filter)?.map((f) => {
                        return <div key={f}>{f}</div>;
                      })}
                      <div className={"tw-font-semibold tw-mt-4"}>
                        Class filters
                      </div>
                      {classColumn &&
                        formatFilter(
                          generateValueClauses(classColumn, classValues),
                        )?.map((f) => {
                          return <div key={f}>{f}</div>;
                        })}
                    </>
                  }
                >
                  <HelpCircle
                    id={"filter"}
                    className={
                      "tw-text-gray-400 tw-px-1 tw-no-underline hover:tw-text-gray-800 hover:tw-bg-gray-200"
                    }
                  ></HelpCircle>
                </Tooltip>
              </>
            )}
          </Row>
          <Row key="classes" label={"Classes"}>
            <>
              {model.config.data.classes.map(({ name }) => {
                // TODO(benkomalo): show well counts of each class.
                return <div key={name}>{name}</div>;
              })}
            </>
          </Row>

          <hr />

          <Row key={"features"} label={"Features"}>
            {model.config.featureInputs.map((selection) => {
              switch (selection.type) {
                case "embedding":
                  return (
                    <div
                      className={"tw-flex tw-flex-col"}
                      key={selection.names.join("-")}
                    >
                      {selection.names.map((name) => (
                        <div className={"tw-mb-sm"} key={name}>
                          <SelectedEmbeddingInfo
                            key={name}
                            selection={{
                              type: "embedding",
                              names: [name],
                            }}
                            allEmbeddingPrefixes={allEmbeddingPrefixes}
                          />
                        </div>
                      ))}
                    </div>
                  );
                case "numerical":
                case "prediction":
                  return (
                    <SelectedMeasurementInfo
                      key={selection.name}
                      selection={selection}
                    />
                  );
              }
            })}
          </Row>

          <hr />

          <Row key={"norm"} label={"Normalization"}>
            {/* TODO(benkomalo): flesh this out */}
            {model.config.normalization.type === "none"
              ? "No normalization"
              : "Z-score normalization"}
          </Row>

          <hr />

          <Row key={"other"} label={"Evaluation configuration"}>
            {getSimplifiedSplitSummary(model.config.split)}
          </Row>
        </div>
      </div>
    </div>
  );
}

function PerformanceResultsPage({
  dataset,
  model,
  results,
  onSetPlateForVisualizationControls,
  onToggleVisualizationControls,
  showControls,
  plateForControls,
}: {
  dataset: DatasetId;
  model: SavedModelListing;
  results: AsyncDuckDB;
  onSetPlateForVisualizationControls: (plate: string) => void;
  onToggleVisualizationControls: () => void;
  showControls: boolean;
  plateForControls: string | null;
}) {
  if (model.status === "incomplete") {
    throw new Error("AssertionError: Model is not yet complete");
  }

  const predictionsTableName = predictionsTableNameFor(model.id);
  const config = model.config;
  const classes = config["data"]["classes"].map(({ name }) => name);
  const segmentColumns = useAsyncValue(async () => {
    const columns = await inferInterestingColumnsDB(results);
    const targetColumn = await inferClassColumnAndValuesFromConfig(
      config.data.classes,
      columns,
      results,
      config.data.filter,
      true,
    );
    return columns.filter((c) => c !== targetColumn);
  }, [config, results]);

  const [segmentByColumn, _setSegmentByColumn] = useState<string | undefined>(
    undefined,
  );
  const [valuesOfSegment, setValuesOfSegment] = useState<
    MetadataColumnValue[] | undefined
  >(undefined);
  const [ignoredValuesOfSegment, setIgnoredValuesOfSegment] = useState<
    MetadataColumnValue[] | undefined
  >(undefined);

  // Cell selection for previewing well images.
  const [selectedCell, setSelectedCell] = useState<{
    predicted: string;
    actual: string;
  } | null>(null);

  const setSegmentByColumn = useCallback(
    (column: string | undefined) => {
      _setSegmentByColumn(column);

      if (column) {
        queryDBAsRecords<{ value: MetadataColumnValue; has_preds: boolean }>(
          results,
          sql`
            SELECT
                sample_metadata.${sanitizedColumn(column)} AS value,
                MAX(CASE WHEN _class IS NULL THEN 0 ELSE 1 END) = 1 AS has_preds
            FROM sample_metadata
            LEFT JOIN ${predictionsTableName}
                ON sample_metadata.plate = ${predictionsTableName}.plate AND
                    sample_metadata.well = ${predictionsTableName}.well
            GROUP BY sample_metadata.${sanitizedColumn(column)}
            ORDER BY sample_metadata.${sanitizedColumn(column)}`,
        ).then((records) => {
          const emptySegments = records
            .filter(({ has_preds }) => !has_preds)
            .map((r) => r.value);
          const segments = records
            .filter(({ has_preds }) => has_preds)
            .map((r) => r.value);
          setValuesOfSegment(segments);
          setIgnoredValuesOfSegment(emptySegments);
        });
      }
    },
    [predictionsTableName, _setSegmentByColumn, setValuesOfSegment, results],
  );

  useEffect(() => {
    queryDBAsRecords<{ plate: string }>(
      results,
      sql`SELECT plate FROM sample_metadata ORDER BY plate LIMIT 1`,
    ).then((records) => {
      if (records.length > 0) {
        onSetPlateForVisualizationControls(records[0].plate);
      }
    });
  }, [onSetPlateForVisualizationControls, results]);

  return (
    <div className={cx("tw-flex tw-flex-col tw-w-full")}>
      <div className={"tw-flex tw-gap-xl"}>
        <LabelledConfusionMatrix
          db={results}
          classes={classes}
          tableName={predictionsTableName}
          selectedCell={selectedCell}
          onSelectCell={setSelectedCell}
        />
        {selectedCell && (
          <ImagesForPrediction
            className="tw-border tw-flex-1 tw-overflow-hidden tw-bg-white"
            dataset={dataset}
            predicted={selectedCell.predicted}
            actual={selectedCell.actual}
            classes={classes}
            db={results}
            tableName={predictionsTableName}
            onToggleVisualizationControls={onToggleVisualizationControls}
          />
        )}
        <div
          className={cx("tw-p-lg tw-w-[300px]", !showControls && "tw-hidden")}
        >
          <VisualizationControls dataset={dataset} plate={plateForControls} />
        </div>
      </div>

      {segmentColumns && (
        <label className={"tw-flex tw-items-center tw-mt-8"}>
          <span className={"tw-mr-4"}>Segment results by:</span>
          <Select
            name="Segment results by"
            className={"tw-w-[400px]"}
            items={segmentColumns}
            value={segmentByColumn}
            onChange={setSegmentByColumn}
          />
        </label>
      )}

      {ignoredValuesOfSegment && (
        <div>
          {ignoredValuesOfSegment.length >= 1 ? (
            <div className={"tw-text-left tw-mt-4 tw-text-gray-500"}>
              The following values are not displayed as they do not exist in the
              test set: {ignoredValuesOfSegment.join(",")}
            </div>
          ) : null}
        </div>
      )}

      {valuesOfSegment && (
        <div className={"tw-grid tw-grid-cols-2 tw-gap-4 tw-mt-4"}>
          {valuesOfSegment.map((value) => {
            const filter = columnMatchesValueClause(segmentByColumn!, value);
            return (
              <div className={"tw-border tw-w-full"} key={String(value)}>
                <div
                  className={
                    "tw-text-center tw-m-2 tw-text-lg tw-text-gray-500"
                  }
                >
                  {segmentByColumn}:{" "}
                  <span className={"tw-font-semibold"}>
                    {" "}
                    {value === null ? "<null>" : String(value)}
                  </span>
                </div>
                <div className={"tw-p-md"}>
                  <LabelledConfusionMatrix
                    db={results}
                    tableName={predictionsTableName}
                    classes={classes}
                    filter={filter}
                  />
                </div>
              </div>
            );
          })}
        </div>
      )}
    </div>
  );
}

function LabelledConfusionMatrix({
  db,
  tableName,
  classes,
  filter,
  selectedCell,
  onSelectCell,
  className,
}: {
  db: AsyncDuckDB;
  tableName: string;
  classes: string[];
  filter?: FilterSqlClause;
  useAsSegment?: boolean;
  selectedCell?: { predicted: string; actual: string } | null;
  onSelectCell?: (selection: { predicted: string; actual: string }) => void;
  className?: string;
}) {
  const countsForMatrix = useConfusionCounts(db, tableName, classes, filter);
  const accuracy = countsForMatrix?.map(accuracyForCounts).orElse(() => 0) ?? 0;

  if (!countsForMatrix?.successful) {
    // Error is likely a SQL error that's unrecoverable.
    return null;
  }

  const cellSize = classes.length < 5 ? 80 : 40;

  return (
    <div className={cx("tw-bg-white", className)}>
      <div className={"tw-text-lg tw-text-center tw-mb-2"}>
        Accuracy:{" "}
        <span className={"tw-text-purple"}>
          {(accuracy * 100).toPrecision(3)}%
        </span>
      </div>
      <ConfusionMatrix
        classes={classes}
        counts={countsForMatrix.value}
        includeRowSummaries
        cellSize={cellSize}
        onSelectCell={onSelectCell}
        selectedCell={selectedCell}
      />
    </div>
  );
}

function WaitingForTrainingPage() {
  return (
    <div
      className={
        "tw-flex tw-flex-col tw-items-center tw-justify-center tw-py-32"
      }
    >
      <div>
        <Loader size={"large"} />
      </div>
      <div className={"tw-text-gray-500"}>
        Training your model. This may take a few minutes.
      </div>
      <div className={"tw-text-gray-500 tw-text-xs"}>
        Our machines are hard at work learning your phenotypes. You can navigate
        away from this page.
      </div>
    </div>
  );
}
