// This file contains components for a header container in the Phenotypic Leaner.
// At present it primarily stores a toggle for the imaging controls.
import cx from "classnames";
import sampleSize from "lodash.samplesize";
import { useCallback, useMemo } from "react";
import { ArrowRight, Check } from "react-feather";
import { Button } from "src/Common/Button";
import { DatasetId } from "src/types";
import { Title } from "@spring/ui/typography";
import { usePalettes } from "../hooks/immunofluorescence";
import { ApplyModel } from "./ApplyModel";
import { Classification, useLabeledSetContext } from "./Context";
import { LabeledSetSaveStatus } from "./LabeledSetSaveStatus";
import { TrainButton } from "./TrainButton";
import {
  APPLY_MODEL_REQUIRED_SAMPLES,
  CLASS_CARD_TOP_N,
  CLASS_CARD_TO_FIND_SIMILAR,
  REQUIRED_SAMPLES,
} from "./constants";
import { useLabeledSetStatus } from "./hooks";
import { LabeledSetStatus, NextStep } from "./types";
import { findSimilar, removeSamples, sortQueue } from "./util";

interface Props {
  className?: string;
  dataset: DatasetId;
}

export function Header({ className, dataset }: Props) {
  const { state: labeledSetState, updateState: updateLabeledSetState } =
    useLabeledSetContext();
  const { classifications, displayed, queue, skipped, neighbors, predictions } =
    labeledSetState;

  const palettes = usePalettes({ dataset });
  const plateWithStains = useMemo(() => {
    if (palettes?.successful) {
      return Object.entries(palettes.value).find(([, stains]) =>
        labeledSetState.stains.every((requiredStain) =>
          stains.includes(requiredStain),
        ),
      )?.[0];
    }
  }, [labeledSetState.stains, palettes]);

  const handleLabelMoreSamples = useCallback(() => {
    const classWithFewestSamples = classifications.reduce(
      (selected: Classification | null, classification) => {
        if (
          !selected ||
          classification.examples.length < selected.examples.length
        ) {
          return classification;
        } else {
          return selected;
        }
      },
      null,
    );

    if (!classWithFewestSamples) {
      return;
    }

    const sampled = sampleSize(
      classWithFewestSamples.examples.slice(0, CLASS_CARD_TOP_N),
      CLASS_CARD_TO_FIND_SIMILAR,
    );

    const candidates = [
      ...removeSamples(displayed, sampled),
      ...queue,
      ...skipped,
    ];
    const similarCells = findSimilar(
      sampled,
      candidates,
      neighbors,
      predictions,
      classWithFewestSamples.name,
    );

    updateLabeledSetState((labeledSetState) => {
      const newDisplayed = [...sampled, ...similarCells];

      const newQueue = sortQueue(
        removeSamples(
          removeSamples([...displayed, ...queue], newDisplayed),
          skipped,
        ),
        neighbors,
      );

      const newSkipped = removeSamples(skipped, newDisplayed);
      return {
        ...labeledSetState,
        displayed: newDisplayed,
        skipped: newSkipped,
        queue: newQueue,
      };
    });
  }, [
    classifications,
    displayed,
    neighbors,
    predictions,
    queue,
    skipped,
    updateLabeledSetState,
  ]);

  const status = useLabeledSetStatus({
    classifications: labeledSetState.classifications,
  });

  const allSteps = useMemo(
    (): Record<
      "labelMoreSamples" | "showSamplesWithHighUncertainty",
      NextStep
    > => ({
      labelMoreSamples: {
        caption: "Label more samples",
        onClick: handleLabelMoreSamples,
      },
      showSamplesWithHighUncertainty: {
        caption: "Show me cells that have a high uncertainty",
        disabled: !predictions,
        onClick: predictions
          ? async () => {
              const samplePriority = new Map(
                Array.from(predictions.entries())
                  .map(([key, prediction]) => ({
                    key,
                    prediction,
                    maxCertainty: Math.max(...prediction.predictions.values()),
                  }))
                  .sort((a, b) => a.maxCertainty - b.maxCertainty)
                  .map(({ key }, index) => [key, index]),
              );

              const sortedSamples = [...displayed, ...queue, ...skipped].sort(
                (a, b) => samplePriority.get(a.id)! - samplePriority.get(b.id)!,
              );

              updateLabeledSetState((state) => ({
                ...state,
                displayed: [],
                queue: sortedSamples,
              }));
            }
          : undefined,
      },
    }),
    [
      displayed,
      handleLabelMoreSamples,
      predictions,
      queue,
      skipped,
      updateLabeledSetState,
    ],
  );

  const step = useMemo((): {
    title: string;
    caption: string;
    nextSteps?: NextStep[];
  } => {
    switch (status) {
      case LabeledSetStatus.NeedsLabels: {
        return {
          title: "Step 1: Label images",
          caption: `In order to train the model, 
          you need to label at least ${REQUIRED_SAMPLES} images for each class.`,
        };
      }
      case LabeledSetStatus.InProgress: {
        return {
          title: "Step 2: Improve model",
          caption: `In order to apply the model,
        you need to label ${APPLY_MODEL_REQUIRED_SAMPLES} images per class.`,
          nextSteps: [
            allSteps.labelMoreSamples,
            allSteps.showSamplesWithHighUncertainty,
          ],
        };
      }
      case LabeledSetStatus.Ready:
        return {
          title: "Step 3: Apply model",
          caption: `Congratulations!
          You're ready to apply your model.
          You can either Apply your model now or continue to improve it.`,
          nextSteps: [
            allSteps.labelMoreSamples,
            allSteps.showSamplesWithHighUncertainty,
          ],
        };
    }
  }, [
    allSteps.labelMoreSamples,
    allSteps.showSamplesWithHighUncertainty,
    status,
  ]);

  return (
    <div
      className={cx(
        "tw-flex tw-border-b tw-flex-row tw-items-center tw-py-md tw-px-xl",
        "tw-bg-gray-50",
        "tw-min-h-[120px]",
        className,
      )}
    >
      <div className="tw-flex-auto tw-h-full tw-flex tw-flex-col tw-gap-sm">
        <Title>{step.title}</Title>
        <div className="tw-text-slate-500">{step.caption}</div>
        {step.nextSteps && step.nextSteps.length > 0 && (
          <div className="tw-flex tw-flex-col tw-gap-sm">
            <div className="tw-font-bold">Suggested next steps</div>
            <div className="tw-flex tw-flex-wrap tw-gap-sm">
              {step.nextSteps.map((nextStep) => (
                <Button
                  key={nextStep.caption}
                  icon={ArrowRight}
                  onClick={nextStep.onClick}
                  disabled={nextStep.disabled}
                >
                  {nextStep.caption}
                </Button>
              ))}
            </div>
          </div>
        )}
      </div>
      <LabeledSetSaveStatus className="tw-mr-2 tw-text-base" />
      <ApplyModel
        className="tw-mr-2 tw-min-w-[163px] tw-bg-white"
        dataset={dataset}
      >
        <Check size={18} className="tw-mr-2" />
        Apply Model
      </ApplyModel>
      {plateWithStains && (
        <TrainButton dataset={dataset} plate={plateWithStains} />
      )}
    </div>
  );
}
