import * as Sentry from "@sentry/react";
import { useCallback, useMemo, useState } from "react";
import { useActiveWorkspace } from "src/Workspace/hooks";
import { useAccessToken } from "src/hooks/auth0";
import { DatasetId } from "src/types";
import { Tooltip } from "@spring/ui/Tooltip";
import { DeprecatedButton } from "../Common/DeprecatedButton";
import { PulseGuiderRoot } from "../Insights/PulseGuider";
import ToastContainer from "../Toast/ToastContainer";
import { useToastContext } from "../Toast/context";
import { extractPredictionData, predictUnlabeled } from "../hooks/examples";
import { useFeatureSetsForPlate } from "../hooks/features";
import { useModelTrainingAndInferenceAreAllowed } from "../util/users";
import { PredictionsMap, useLabeledSetContext } from "./Context";
import { FlaskTwinkle } from "./FlaskTwinkle";
import { REQUIRED_SAMPLES } from "./constants";
import {
  TrainLabeledSetResponse,
  filteredFeatures,
  saveLabeledSet,
  trainLabeledSet,
} from "./util";

const MODEL_TRAINING_TOAST_CONTAINER_ID = "modelTrainingToast";

export function TrainButton({
  dataset,
  plate,
}: {
  dataset: DatasetId;
  plate: string;
}) {
  const { state: labeledSetState, updateState } = useLabeledSetContext();
  const [training, setTraining] = useState(false);
  const accessToken = useAccessToken();
  const workspace = useActiveWorkspace();
  const { setToast, dismissToast } = useToastContext();

  const features = useFeatureSetsForPlate({ dataset, plate });
  // When we perform inference these will be predicted-upon
  const toPredict = useMemo(() => {
    return [
      ...labeledSetState.displayed,
      ...labeledSetState.queue,
      ...labeledSetState.skipped,
    ];
  }, [
    labeledSetState.displayed,
    labeledSetState.queue,
    labeledSetState.skipped,
  ]);
  const featureSets = useMemo(() => {
    if (features?.successful) {
      return filteredFeatures(features.value, labeledSetState.selectedStains);
    } else {
      return null;
    }
  }, [features, labeledSetState.selectedStains]);

  const isTrainingAllowed = useModelTrainingAndInferenceAreAllowed();
  const canTrain = useMemo(
    () =>
      featureSets &&
      labeledSetState.classifications.length >= 2 &&
      toPredict.length > 0 &&
      labeledSetState.classifications.every(
        (classification) => classification.examples.length >= REQUIRED_SAMPLES,
      ),
    [featureSets, labeledSetState.classifications, toPredict],
  );

  const handleReportBug = useCallback(
    (cb: () => void) => {
      dismissToast(labeledSetState.id);
      cb();
    },
    [dismissToast, labeledSetState],
  );

  const onTrainModel = useCallback(async () => {
    if (featureSets) {
      setTraining(true);
      await saveLabeledSet({
        accessToken,
        workspace,
        dataset,
        labeledSet: labeledSetState,
      });

      let response: TrainLabeledSetResponse;
      try {
        response = await trainLabeledSet({
          accessToken,
          workspace,
          dataset,
          labeledSetId: labeledSetState.id,
          featureSets,
        });
      } catch (e) {
        console.error(e);
        Sentry.captureException(e);
        if ((window as any).birdeatsbug.isBrowserSupported) {
          setToast(
            labeledSetState.id,
            <div>
              There was an error training your model. Please{" "}
              <button
                className="tw-text-purple-600"
                onClick={() =>
                  handleReportBug((window as any).birdeatsbug.trigger)
                }
              >
                report the bug
              </button>{" "}
              and we’ll look into it as soon as possible!
            </div>,
            MODEL_TRAINING_TOAST_CONTAINER_ID,
          );
        } else {
          setToast(
            labeledSetState.id,
            <div>
              There was an error training your model. Please reach out to an
              engineer for support!
            </div>,
            MODEL_TRAINING_TOAST_CONTAINER_ID,
          );
        }
        setTraining(false);
        return;
      }

      const accuracies = new Map(
        response.class_names.map((name, index) => [
          name,
          response.recalls[index],
        ]),
      );

      const predictionsResponse = await predictUnlabeled({
        dataset,
        workspace,
        accessToken,
        latestFeatureSets: featureSets,
        latestModelPath: response.path,
        toPredict: toPredict,
      });

      const predictionsMap: PredictionsMap | null = extractPredictionData(
        toPredict,
        predictionsResponse,
      );

      updateState((state) => ({
        ...state,
        latestModelPath: response.path,
        classifications: state.classifications.map((classification) => ({
          ...classification,
          accuracy: accuracies.get(classification.name),
        })),
        confusionMatrix: response.confusion_matrix,
        latestFeatureSets: featureSets,
        predictions: predictionsMap,
      }));
      setTraining(false);

      // Need to save again because the model path has changed.
      await saveLabeledSet({
        accessToken,
        workspace,
        dataset,
        labeledSet: {
          ...labeledSetState,
          latestModelPath: response.path,
          classifications: labeledSetState.classifications.map(
            (classification) => ({
              ...classification,
              accuracy: accuracies.get(classification.name),
            }),
          ),
        },
      });
    }
  }, [
    accessToken,
    dataset,
    featureSets,
    handleReportBug,
    labeledSetState,
    setToast,
    updateState,
    workspace,
    toPredict,
  ]);

  return (
    <>
      <PulseGuiderRoot
        guiderKey="phenosorter-train"
        position={{ corner: "top-left", offset: { x: -8, y: 12 } }}
        tooltipSide={"left"}
      >
        <DeprecatedButton
          variant="primary"
          disabled={!canTrain || training || !isTrainingAllowed}
          onClick={onTrainModel}
          className="tw-group tw-pr-[16px] tw-min-w-[163px]"
        >
          {training ? (
            <>
              Training
              <FlaskTwinkle className="tw-h-[24px] tw-ml-2" animate="always" />
            </>
          ) : isTrainingAllowed ? (
            <>Train Model</>
          ) : (
            <Tooltip
              contents={
                <div className={"tw-max-w-[440px]"}>
                  Training is disabled for this demo workspace. Please contact
                  support@springscience.com to get your own workspace and train
                  your own models.
                </div>
              }
              showArrow
              side={"left"}
            >
              Train Model
            </Tooltip>
          )}
        </DeprecatedButton>
      </PulseGuiderRoot>
      <div className="tw-absolute tw-inset-x-4 tw-top-32 tw-flex tw-justify-center">
        <ToastContainer
          id={MODEL_TRAINING_TOAST_CONTAINER_ID}
          position="custom"
          disableAutoDismiss
        />
      </div>
    </>
  );
}
