import { AsyncDuckDB } from "@duckdb/duckdb-wasm";
import * as Sentry from "@sentry/react";
import cx from "classnames";
import pluralize from "pluralize";
import {
  Dispatch,
  Fragment,
  ReactNode,
  SetStateAction,
  useCallback,
  useEffect,
  useMemo,
  useState,
} from "react";
import { HelpCircle } from "react-feather";
import {
  generatePath,
  useHistory,
  useLocation,
  useRouteMatch,
} from "react-router-dom";
import { Button } from "src/Common/Button";
import { Select } from "src/Common/Select";
import { useActiveWorkspace } from "src/Workspace/hooks";
import { useAccessToken } from "src/hooks/auth0";
import { SpinningChevron } from "@spring/ui/Select/SpinningChevron";
import { Tooltip } from "@spring/ui/Tooltip/Tooltip";
import {
  SelectedEmbeddingInfo,
  SelectedMeasurementInfo,
} from "../FeatureSelector/CollapsedFeatureSelectorInfo";
import {
  groupFeaturesByPrefix,
  useFeatureSetsGrouped,
} from "../FeatureSelector/utils";
import Close from "../icons/Close.svg";
import Strut from "../imaging/Strut";
import { VisualizationContextProvider } from "../imaging/context";
import { DatasetId, MetadataColumnValue } from "../types";
import { inferInterestingColumnsDB } from "../util/dataset-util";
import { useAsyncValue } from "../util/hooks";
import { columnMatchesValueClause, getTableColumns } from "../util/sql";
import { useModelTrainingAndInferenceAreAllowed } from "../util/users";
import ClassesConfigurator from "./ClassesConfigurator";
import FeaturesConfigurator from "./FeaturesConfigurator";
import FilterConfigurator from "./FilterConfigurator";
import ReviewResultsPage from "./ReviewResultsPage";
import SplitConfigurator from "./SplitConfigurator";
import { useSavedModel, useSupervisedLearningContext } from "./context";
import { ModelTrainingConfig, SavedModelListing } from "./types";
import {
  findControlEntries,
  formatFilter,
  generateValueClauses,
  getDefaultModelTrainingConfig,
  getUnsavedClone,
  inferClassColumnAndValuesFromConfig,
  isConfigurationValid,
  submitTrainingJob,
} from "./utils";

type RouteParams = {
  modelId: string;
};

export default function ModelEditor({ dataset }: { dataset: DatasetId }) {
  const history = useHistory();
  const { pathname, search } = useLocation();
  const { params } = useRouteMatch<RouteParams>();
  const { metadata } = useSupervisedLearningContext();
  const { modelId } = params;
  const cloneParams = useMemo(() => {
    const parsed = new URLSearchParams(search);
    if (parsed.has("from")) {
      return {
        from: parsed.get("from")!,
        name: parsed.get("name"),
      };
    }
  }, [search]);

  // Editing can be in one of three states:
  // 1. No model at all (bare URL)
  // 2. Editing/viewing an existing saved model (modelId is set to a saved model)
  // 3. Cloning an existing saved model (search parameters specify cloneParams).
  const navigatedModel = useSavedModel(modelId);
  const sourceModel = useSavedModel(cloneParams?.from ?? undefined);
  const clonedModel = useMemo(() => {
    return sourceModel && cloneParams
      ? getUnsavedClone(sourceModel, cloneParams.name ?? "Model copy")
      : null;
  }, [cloneParams, sourceModel]);

  const handleBack = useCallback(() => {
    history.push(`${pathname.split("/edit")[0]}`);
  }, [history, pathname]);

  const trainingResultsReadyOrPending =
    !!navigatedModel &&
    (navigatedModel.status === "applied" ||
      navigatedModel.status === "pending" ||
      (navigatedModel.status === "trained" &&
        navigatedModel.result.status === "success"));

  return (
    <VisualizationContextProvider>
      {trainingResultsReadyOrPending ? (
        <ReviewResultsPage
          dataset={dataset}
          model={navigatedModel}
          metadata={metadata}
          onBack={handleBack}
        />
      ) : (
        <div
          className={cx(
            "tw-max-w-[calc(100vw-theme(spacing.16))]",
            "tw-h-[calc(100vh-theme(spacing.global-nav-height)-theme(spacing.16))] ",
            "tw-mx-auto tw-my-8",
            "tw-border tw-shadow-xl tw-overflow-y-auto",
          )}
        >
          <ModelEditorInner
            dataset={dataset}
            metadata={metadata}
            existingModelId={navigatedModel?.id}
            savedConfig={
              navigatedModel
                ? navigatedModel.config
                : clonedModel
                  ? clonedModel.config
                  : getDefaultModelTrainingConfig(dataset)
            }
            presetName={cloneParams?.name ?? undefined}
            onBack={handleBack}
          />
        </div>
      )}
    </VisualizationContextProvider>
  );
}

enum EditorFlowStep {
  FilterDataset = "filter-dataset",
  SelectClasses = "select-classes",
  SelectFeatures = "select-features",
  Finalize = "finalize",
}

/**
 * Walk the user through a wizard to configure the training parameters of a model.
 */
function ModelEditorInner({
  dataset,
  metadata,
  savedConfig,
  existingModelId,
  presetName,
  onBack,
}: {
  dataset: DatasetId;
  metadata: AsyncDuckDB;
  savedConfig: Partial<ModelTrainingConfig> | null;
  existingModelId?: string;
  presetName?: string;
  onBack: () => void;
}) {
  const [config, setConfig] = useState<Partial<ModelTrainingConfig>>(
    savedConfig || {},
  );
  const [state, setState] = useState<
    [step: EditorFlowStep, advancable: boolean]
  >([EditorFlowStep.FilterDataset, false]);
  const [step, advancable] = state;
  const currentStepI = Object.values(EditorFlowStep).indexOf(step);

  const goBackOneStep = useCallback(() => {
    const prevStepI = currentStepI - 1;
    const prevStep = Object.values(EditorFlowStep)[prevStepI];
    setState([prevStep, true]);
  }, [currentStepI]);
  const advance = useCallback(() => {
    const nextStepI = currentStepI + 1;
    const nextStep = Object.values(EditorFlowStep)[nextStepI];
    setState([nextStep, false]);
  }, [currentStepI]);
  const handleAdvancableChanged = useCallback((advancable: boolean) => {
    setState(([step]) => [step, advancable]);
  }, []);

  // TODO(benkomalo): wrap the onBack handler with a check to see if any state has
  // been changed. If it has, then prompt the user to confirm they want to leave.
  // Alternatively, we should just auto-save the state and save it as an "incomplete
  // model" in the list.

  return (
    <div className={"tw-flex tw-flex-col tw-h-full"}>
      <div className={"tw-flex tw-items-center tw-border-b"}>
        <div
          className={
            "tw-border-r tw-w-16 tw-h-12 tw-flex tw-items-center tw-justify-center"
          }
        >
          <button onClick={onBack}>
            <Close
              className={"tw-w-4 tw-h-4 tw-text-gray-500"}
              aria-label={"Back to all models"}
            />
          </button>
        </div>
        <div className={"tw-flex-1 tw-px-4"}>
          Create supervised learning model
        </div>
        <div className={"tw-flex tw-items-center tw-px-8"}>
          {Object.entries(EditorFlowStep).map(([, stepI], i) => {
            return (
              <Fragment key={stepI}>
                {i > 0 && (
                  <div
                    className={cx(
                      "tw-w-12 -tw-mx-4 tw-h-[2px]",
                      i <= currentStepI ? "tw-bg-purple" : "tw-bg-gray-300",
                    )}
                  />
                )}
                <button
                  className={"tw-p-4"}
                  onClick={() => setState([stepI, false])}
                  disabled={i >= currentStepI}
                >
                  <div
                    className={cx(
                      "tw-inline-flex tw-justify-center tw-items-center",
                      "tw-w-8 tw-h-8 tw-border-2 tw-rounded-full",
                      i < currentStepI &&
                        "tw-bg-purple tw-border-purple tw-text-white",
                      i == currentStepI && "tw-border-purple",
                      i > currentStepI && "tw-border-gray-300",
                    )}
                  >
                    {i < currentStepI ? <>✓</> : <>{i + 1}</>}
                  </div>
                </button>
              </Fragment>
            );
          })}
        </div>
      </div>

      <div className={"tw-flex-1 tw-relative tw-overflow-y-hidden"}>
        {step === EditorFlowStep.FilterDataset && (
          <FilterConfigurator
            dataset={dataset}
            metadata={metadata}
            config={config}
            setConfig={setConfig}
            onReadyToAdvanceChanged={handleAdvancableChanged}
          />
        )}
        {step === EditorFlowStep.SelectClasses && (
          <ClassesConfigurator
            dataset={dataset}
            metadata={metadata}
            config={config}
            setConfig={setConfig}
            onReadyToAdvanceChanged={handleAdvancableChanged}
          />
        )}
        {step === EditorFlowStep.SelectFeatures && (
          <FeaturesConfigurator
            dataset={dataset}
            config={config}
            setConfig={setConfig}
            onReadyToAdvanceChanged={handleAdvancableChanged}
          />
        )}
        {step === EditorFlowStep.Finalize && (
          <FinalizeAndSubmit
            dataset={dataset}
            metadata={metadata}
            existingModelId={existingModelId}
            config={config}
            setConfig={setConfig}
            presetName={presetName}
            onGoBack={goBackOneStep}
          />
        )}
      </div>

      {step !== EditorFlowStep.Finalize && (
        <div className={"tw-p-8 tw-flex tw-justify-end"}>
          {currentStepI > 0 && (
            <Button
              variant={"secondary"}
              onClick={goBackOneStep}
              disableTracking={true}
            >
              Prev
            </Button>
          )}
          <Strut size={8} />
          {currentStepI < Object.values(EditorFlowStep).length - 1 && (
            <Button
              variant={"primary"}
              onClick={advance}
              disabled={!advancable}
              disableTracking={true}
            >
              Next
            </Button>
          )}
        </div>
      )}
    </div>
  );
}

function InlineNormalizationConfigurator({
  config,
  setConfig,
  metadata,
}: {
  config: Partial<ModelTrainingConfig>;
  setConfig: Dispatch<SetStateAction<Partial<ModelTrainingConfig>>>;
  metadata: AsyncDuckDB;
}) {
  // The logic here is complicated by the fact that the underlying types support
  // arbitrary filters for the negative control (for future proofing), yet we want
  // to (for now) stick with a simple UI for selecting a target.
  const [classColumn, classValues] = useAsyncValue(async () => {
    const allColumns = await getTableColumns(metadata, "sample_metadata");
    const [classColumn, classValues] =
      (await inferClassColumnAndValuesFromConfig(
        config.data?.classes ?? [],
        allColumns,
        metadata,
        config.data?.filter ?? "TRUE",
        false,
      )) ?? [undefined, undefined];
    return [classColumn, classValues];
  }, [metadata, config]) ?? [undefined, undefined];

  const normalization = config.normalization;
  useEffect(() => {
    // Set a default normalization if one is not set.
    // Note: this should only happen for the creation of new configs. Cloning
    // a config should carry over the normalization from the source config, and
    // this will be a noop (and reminder: no normalization would result in
    // `normalization` being defined, but with type "none").
    if (!normalization && classColumn && classValues) {
      setConfig((config) => ({
        ...config,
        normalization: {
          type: "stratified",
          stratificationColumns: ["plate"],
          targetFilter: columnMatchesValueClause(classColumn, classValues[0]),
        },
      }));
    }
  }, [normalization, classColumn, classValues, setConfig]);

  // Map back the actual normalization config to a simple negative control value.
  const negativeControlValue: MetadataColumnValue | undefined = useMemo(() => {
    if (!normalization) {
      return undefined;
    }
    if (!classColumn || !classValues) {
      return undefined;
    }

    switch (normalization.type) {
      case "none":
        return undefined;
      case "stratified":
        if (normalization.targetFilter === null) {
          return undefined;
        } else {
          return classValues.find(
            (value) =>
              columnMatchesValueClause(classColumn, value) ===
              normalization.targetFilter,
          );
        }
    }
  }, [normalization, classColumn, classValues]);

  if (!classColumn || !classValues) {
    return null;
  }

  return (
    <div className={"tw-flex tw-flex-col"}>
      <label>
        Configure normalization
        <Select
          name="Configure normalization"
          items={[
            {
              value: "none",
              text: "None",
            },
            {
              value: "z-score-plate",
              text: "Z-score by plate",
            },
          ]}
          value={
            normalization?.type === "stratified" ? "z-score-plate" : "none"
          }
          onChange={(value) => {
            if (value === "none") {
              setConfig((config) => ({
                ...config,
                normalization: {
                  type: "none",
                },
              }));
            } else {
              setConfig((config) => ({
                ...config,
                normalization: {
                  type: "stratified",
                  stratificationColumns: ["plate"],
                  targetFilter: columnMatchesValueClause(
                    classColumn,
                    classValues[0],
                  ),
                },
              }));
            }
          }}
        />
      </label>

      {normalization?.type === "stratified" && (
        <div className={"tw-mt-4"}>
          <label>
            Select your negative controls to fit normalization to:
            <Select
              name="Select negative controls"
              items={[
                { value: "__all_data__", text: "Use all data on plate" },
                ...classValues.map((value) => ({
                  value,
                  text: value === null ? "<null>" : String(value),
                })),
              ]}
              value={
                negativeControlValue === undefined
                  ? "__all_data__"
                  : negativeControlValue
              }
              onChange={(value) => {
                setConfig((config) => ({
                  ...config,
                  normalization: {
                    ...normalization,
                    targetFilter:
                      value === "__all_data__"
                        ? null
                        : columnMatchesValueClause(classColumn, value),
                  },
                }));
              }}
            />
          </label>
        </div>
      )}
    </div>
  );
}

function Section({
  title,
  defaultOpened = false,
  children,
}: {
  title: string;
  defaultOpened?: boolean;
  children: ReactNode;
}) {
  const [isOpen, setIsOpen] = useState(defaultOpened);

  // TODO(benkomalo): sort of a hack, but we still render the component but just make it
  // hidden because some components might have side effects (e.g. the normalization
  // section actually infers the default target and sets that in the config so it
  // needs to be mounted).
  return (
    <div className={"tw-my-4"}>
      <button
        className={"tw-flex tw-text-purple"}
        onClick={() => setIsOpen(!isOpen)}
      >
        <span className={"tw-w-6"}>
          <SpinningChevron variant="DownOpenRightClose" isOpen={isOpen} />
        </span>
        <span className={"tw-px-2"}>{title}</span>
      </button>
      <div
        className={cx("tw-p-2 tw-ml-6", isOpen ? "tw-visible" : "tw-hidden")}
      >
        {children}
      </div>
    </div>
  );
}

function FinalizeAndSubmit({
  dataset,
  metadata,
  existingModelId,
  config,
  setConfig,
  presetName,
  onGoBack,
}: {
  dataset: DatasetId;
  metadata: AsyncDuckDB;
  existingModelId?: string;
  config: Partial<ModelTrainingConfig>;
  setConfig: Dispatch<SetStateAction<Partial<ModelTrainingConfig>>>;
  presetName?: string;
  onGoBack: () => void;
}) {
  const workspace = useActiveWorkspace();
  const isTrainingAllowed = useModelTrainingAndInferenceAreAllowed();
  const accessToken = useAccessToken();
  const history = useHistory();
  const { path } = useRouteMatch();
  const { params: pathParams } = useRouteMatch<RouteParams>();
  const existingModel = useSavedModel(existingModelId);
  const [name, setName] = useState(presetName ?? existingModel?.name ?? "");
  const [description, setDescription] = useState(
    existingModel?.description ?? "",
  );
  const isReadyToSubmit = isConfigurationValid(config);
  const [submitState, setSubmitState] = useState<
    | { type: "idle" }
    | { type: "submitted" }
    | { type: "error"; message: string }
  >({ type: "idle" });
  const { refreshModels } = useSupervisedLearningContext();
  const allFeatures = useFeatureSetsGrouped(dataset);
  const allEmbeddingPrefixes = useMemo(() => {
    if (!allFeatures || !allFeatures.successful) {
      return [];
    }
    return Object.keys(groupFeaturesByPrefix(allFeatures.value["embedding"]));
  }, [allFeatures]);

  const { deleteModel, createModel } = useSupervisedLearningContext();
  const saveAndSubmit = useCallback(async () => {
    if (!isConfigurationValid(config)) {
      return;
    }
    setSubmitState({ type: "submitted" });

    let model: SavedModelListing;
    if (existingModel) {
      // Here we're re-trying the training of an existing model (that failed training).
      // Instead of doing a mutation on that existing model, we create a new one that
      // is a fork of it, and delete the old one in the background transparently.
      // To the user, this is the same as editing the model (and really, to the backends
      // it's more or less the same because every mutation soft-deletes the existing
      // version for archival purposes), though the ID we're working with changes.
      model = await createModel({
        name: name.trim() || "Unnamed model",
        description: description.trim(),
        config,
        status: "pending",
      });
      await deleteModel(existingModel);
    } else {
      try {
        model = await createModel({
          name: name.trim() || "Unnamed model",
          description: description.trim(),
          config,
          status: "pending",
        });
      } catch (e) {
        Sentry.captureException(e);
        setSubmitState({ type: "error", message: String(e) });
        return;
      }
    }

    // Submit the training job but don't wait for it.
    submitTrainingJob(workspace.id, dataset, accessToken, model).then(
      (result) => {
        refreshModels();
        if (result.error) {
          setSubmitState({ type: "error", message: result.error });
        }
      },
    );

    if (pathParams.modelId) {
      // In the case where we're re-training from a previous model, update our URL
      // to the new model ID.
      history.replace(generatePath(path, { ...pathParams, modelId: model.id }));
    } else {
      // Otherwise, the existing path won't have the modelID in it; tack it on.
      history.replace(`${generatePath(path, pathParams)}/${model.id}`);
    }
  }, [
    workspace.id,
    dataset,
    accessToken,
    config,
    history,
    name,
    description,
    existingModel,
    deleteModel,
    createModel,
    refreshModels,
    path,
    pathParams,
  ]);

  const Row = useCallback(
    ({ label, children }: { label: string; children: ReactNode }) => (
      <div className={"tw-py-2 tw-flex"}>
        <div className={"tw-w-[320px]"}>{label}</div>
        <div className={"tw-flex-1"}>{children}</div>
      </div>
    ),
    [],
  );

  const [classColumn, classValues] = useAsyncValue(async () => {
    const allColumns = await inferInterestingColumnsDB(metadata);
    const result = await inferClassColumnAndValuesFromConfig(
      // TODO(you): Fix this no-unnecessary-condition rule violation
      // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
      config?.data?.classes ?? [],
      allColumns,
      metadata,
      config.data?.filter ?? "TRUE",
      false,
    );
    return result ?? null;
  }, [config, metadata]) ?? [undefined, undefined];

  const selectedWells = useAsyncValue(async () => {
    if (classColumn) {
      return await findControlEntries(
        metadata,
        config.data?.filter ?? "TRUE",
        classColumn,
        classValues,
      );
    }
    return null;
  }, [classColumn, classValues]);

  const wellCountPerClass = useAsyncValue(async () => {
    if (classColumn) {
      const selectedWells = await findControlEntries(
        metadata,
        config.data?.filter ?? "TRUE",
        classColumn,
        classValues,
      );
      const counts = new Map();
      for (const { value } of selectedWells) {
        counts.set(value, (counts.get(value) ?? 0) + 1);
      }
      return counts;
    } else {
      return null;
    }
  }, [classColumn, classValues, metadata]);

  return (
    <div className={"tw-flex tw-flex-col tw-p-8 tw-h-full"}>
      <div className={"tw-text-xl"}>Review and confirm</div>

      <div className={"tw-my-4 tw-flex-1 tw-overflow-y-auto"}>
        <Section title={"General"} defaultOpened>
          <>
            <Row key="model-name" label={"Model name"}>
              <input
                className={"tw-border tw-rounded tw-p-2 tw-w-full"}
                placeholder={"Model name"}
                value={name}
                onChange={(e) => setName(e.target.value)}
                autoFocus
              />
            </Row>
            <Row key="desc" label={"Description"}>
              <textarea
                className={"tw-border tw-rounded tw-p-2 tw-w-full"}
                placeholder={
                  "(Optional) Add a description to describe relevant details"
                }
                value={description}
                onChange={(e) => setDescription(e.target.value)}
              />
            </Row>
          </>
        </Section>
        <hr />
        <Section title={"Inputs"} defaultOpened>
          <>
            <Row key="filter" label={"Filter"}>
              {selectedWells && (
                <>
                  Training on {selectedWells.length} wells
                  <Tooltip
                    side="right"
                    showArrow
                    contents={
                      <>
                        <div className={"tw-font-semibold"}>User filters</div>
                        {formatFilter(config.data?.filter ?? "TRUE")?.map(
                          (f) => {
                            return <div key={f}>{f}</div>;
                          },
                        )}
                        <div className={"tw-font-semibold tw-mt-4"}>
                          Class filters
                        </div>
                        {formatFilter(
                          generateValueClauses(
                            classColumn ?? "",
                            classValues ?? [],
                          ),
                        )?.map((f) => {
                          return <div key={f}>{f}</div>;
                        })}
                      </>
                    }
                  >
                    <HelpCircle
                      id={"filter"}
                      className={
                        "tw-text-gray-400 tw-px-1 tw-no-underline hover:tw-text-gray-800 hover:tw-bg-gray-200"
                      }
                    ></HelpCircle>
                  </Tooltip>
                </>
              )}
            </Row>
            <Row key="classes" label={"Classes"}>
              <>
                {classValues?.map((value) => {
                  return (
                    <div key={String(value)}>
                      {value === null ? "<null>" : String(value)}
                      {wellCountPerClass && wellCountPerClass.has(value) ? (
                        <p className={"tw-text-xs tw-inline"}>
                          {" "}
                          (
                          {pluralize(
                            "well",
                            wellCountPerClass.get(value),
                            true,
                          )}
                          )
                        </p>
                      ) : null}
                    </div>
                  );
                })}
              </>
            </Row>
            <Row key="features" label={"Features"}>
              {config.featureInputs?.map((selection) => {
                switch (selection.type) {
                  case "embedding":
                    return (
                      <div
                        className={"tw-flex tw-flex-col"}
                        key={selection.names.join("-")}
                      >
                        {selection.names.map((name) => (
                          <div className={"tw-mb-sm"} key={name}>
                            <SelectedEmbeddingInfo
                              key={name}
                              selection={{
                                type: "embedding",
                                names: [name],
                              }}
                              allEmbeddingPrefixes={allEmbeddingPrefixes}
                            />
                          </div>
                        ))}
                      </div>
                    );
                  case "numerical":
                  case "prediction":
                    return (
                      <SelectedMeasurementInfo
                        key={selection.name}
                        selection={selection}
                      />
                    );
                }
              })}
            </Row>
          </>
        </Section>
        <hr />
        <Section title={"Other settings"} defaultOpened>
          <Row key="normalization" label={"Normalization"}>
            {/* Note: mounting this component has a side effect of choosing a default
                normalization. */}
            <InlineNormalizationConfigurator
              config={config}
              setConfig={setConfig}
              metadata={metadata}
            />
          </Row>
          <Row key="split" label={"Evaluation configuration"}>
            {/* Note: mounting this component has a side effect of choosing a default
                split. */}
            <SplitConfigurator
              config={config}
              setConfig={setConfig}
              metadata={metadata}
            />
          </Row>
        </Section>
      </div>

      <div className={"tw-flex tw-items-center tw-justify-end"}>
        <Button
          key={"prev"}
          variant={"secondary"}
          onClick={onGoBack}
          disableTracking={true}
        >
          Prev
        </Button>
        <Strut size={8} />
        <Button
          key={"submit"}
          name="Finalize and train"
          variant={submitState.type === "error" ? "danger" : "primary"}
          onClick={() => saveAndSubmit()}
          disabled={
            !isReadyToSubmit ||
            submitState.type === "submitted" ||
            !isTrainingAllowed
          }
        >
          {!isTrainingAllowed ? (
            <Tooltip
              contents={
                <div className={"tw-max-w-[440px]"}>
                  Training is disabled for this demo workspace. Contact
                  support@springscience.com to get your own workspace and train
                  your own models.
                </div>
              }
              showArrow
              side={"top"}
            >
              Finalize and train
            </Tooltip>
          ) : submitState.type === "error" ? (
            <Tooltip
              contents={
                <div className={"tw-max-w-[440px]"}>
                  Oops. There was an error in submitting your training job.
                  Please try again in a few minutes.
                </div>
              }
              showArrow
              side={"top"}
            >
              Finalize and train
            </Tooltip>
          ) : (
            <>Finalize and train</>
          )}
        </Button>
      </div>
    </div>
  );
}
