import { AsyncDuckDB } from "@duckdb/duckdb-wasm";
import {
  ReactNode,
  createContext,
  useCallback,
  useContext,
  useEffect,
  useMemo,
  useRef,
  useState,
} from "react";
import { useActiveWorkspace } from "src/Workspace/hooks";
import { useDatasetApi } from "src/hooks/api";
import { useAccessToken } from "src/hooks/auth0";
import { DatasetId } from "src/types";
import { FullScreenContainer } from "../Common/FullScreenContainer";
import Loader from "../Common/Loader";
import { useDatasetSampleMetadataDB } from "../hooks/datasets";
import { useFetchClassificationModels } from "./hooks";
import { ModelListing, SavedModelListing } from "./types";
import { saveModel } from "./utils";

type SupervisedLearningContext = {
  models: SavedModelListing[];
  /** Create a new model and return a {@link SavedModelListing} with the reified ID. */
  createModel: (model: ModelListing) => Promise<SavedModelListing>;

  /**
   * Update model information, writing back metadata and config to the server.
   *
   * TODO(benkomalo): note that model _status_ is not written back to the server here
   * as it's mutated on the server only on server server events like initiating training
   * or training completion (and isn't initiated by client edits, since it doesn't
   * make sense for clients to edit those things directly). Updates to those
   * parameters via this method are only reflected in the local cache of model state.
   * It might be cleaner to just separate out the APIs to have an updateModelConfig
   * and only update statuses when coupled with calling submitTrainingJob?
   */
  updateModel: (model: SavedModelListing) => Promise<void>;
  deleteModel: (model: SavedModelListing) => void;
  /** Refresh the model state from the server. */
  refreshModels: () => void;
  metadata: AsyncDuckDB;
};

const Context = createContext<
  Record<string, never> | SupervisedLearningContext
>({});

/**
 * The context/store for the supervised learning pages.
 *
 * The context holds the list of models and caches them in-memory, but writes-through
 * to the server asynchronously on specific operations.
 */
export function SupervisedLearningContextProvider({
  dataset,
  children,
}: {
  dataset: DatasetId;
  children: ReactNode;
}) {
  const accessToken = useAccessToken();
  const workspace = useActiveWorkspace();
  const { fetchable: modelsFromServer, mutate: forceRefresh } =
    useFetchClassificationModels({ dataset }, "verbose");

  const metadataDB = useDatasetSampleMetadataDB({ dataset });
  const [models, setModels] = useState<{ [id: string]: SavedModelListing }>({});

  const hasInitialized = useRef(false);
  useEffect(() => {
    if (modelsFromServer?.successful && !hasInitialized.current) {
      hasInitialized.current = true;
      setModels(
        Object.fromEntries(modelsFromServer.value.map((m) => [m.id, m])),
      );
    }
  }, [models, setModels, modelsFromServer]);

  const refreshModels = useCallback(() => {
    hasInitialized.current = false;
    forceRefresh();
  }, [forceRefresh]);

  const updateModel = useCallback(
    async (model: SavedModelListing) => {
      const { id } = model;
      // Optimistically update update local version immediately...
      setModels((models) => ({
        ...models,
        [id]: model,
      }));

      // And fire off an async save the model to the server.
      // To be safe, we fire off another refresh request for this after we finish
      // saving it, because if there was a request to refresh all local models while
      // we were waiting for this save, that refresh would overwrite the updates
      // we just optimistically wrote above.
      await saveModel(accessToken, workspace.id, dataset, model).then(() =>
        refreshModels(),
      );
    },
    [accessToken, workspace.id, dataset, refreshModels],
  );

  const createModel = useCallback(
    async (model: ModelListing) => {
      const { id } = await saveModel(accessToken, workspace.id, dataset, model);
      setModels((models) => ({
        ...models,
        [id]: { id, ...model },
      }));
      return {
        ...model,
        id,
      };
    },
    [accessToken, workspace.id, dataset],
  );

  const api = useDatasetApi();

  const deleteModel = useCallback(
    (model: SavedModelListing) => {
      const { id } = model;
      // Optimistically update update local version immediately...
      setModels((models) => {
        const copy = { ...models };
        delete copy[id];
        return copy;
      });

      // Kick off the request to delete, but don't wait.
      api
        .route("models/<model>", { model: id })
        .delete()
        .finish()
        .then(() => refreshModels());
    },
    [api, refreshModels],
  );

  const modelsFlattened = useMemo(() => Object.values(models), [models]);

  if (!modelsFromServer || !metadataDB) {
    return (
      <FullScreenContainer center>
        <Loader />
      </FullScreenContainer>
    );
  } else if (!modelsFromServer.successful || !metadataDB.successful) {
    return (
      <FullScreenContainer center>
        Oops. There was an error in loading your data. Please refresh and try
        again.
      </FullScreenContainer>
    );
  }

  return (
    <Context.Provider
      value={{
        models: modelsFlattened,
        updateModel,
        createModel,
        deleteModel,
        refreshModels,
        metadata: metadataDB.value,
      }}
    >
      {children}
    </Context.Provider>
  );
}

export function useSupervisedLearningContext():
  | Record<string, never>
  | SupervisedLearningContext {
  return useContext(Context);
}

export function useSavedModel(
  modelId: string | undefined,
): SavedModelListing | undefined {
  const { models } = useSupervisedLearningContext();
  return useMemo(
    () => models.find((model) => model.id === modelId),
    [models, modelId],
  );
}
