import sampleSize from "lodash.samplesize";
import { useEffect, useMemo, useRef, useState } from "react";
import type { AccessToken } from "src/Auth0/accessToken";
import { examplesApi, handleAborted } from "src/util/api-client";
import { combine } from "@spring/core/result";
import { FilterSqlClause } from "../Control/FilterSelector/types";
import {
  LabeledSet,
  NeighborsMap,
  Prediction,
  PredictionsMap,
  useLabeledSetContext,
} from "../PhenotypicLearner/Context";
import { usePrefilterForStains } from "../PhenotypicLearner/ImageFilter";
import { NEIGHBOR_COUNT } from "../PhenotypicLearner/constants";
import {
  ModelMetrics,
  addIdToMetadata,
  cleanUpIndex,
  filteredFeatures,
  keyForSample,
  removeSamples,
  sortQueue,
} from "../PhenotypicLearner/util";
import { Workspace } from "../Workspace/types";
import { DisplayRange } from "../imaging/types";
import {
  CellSampleMetadata,
  DatasetId,
  LabeledCellSampleMetadata,
  NeighborData,
  PredictionData,
  UnlabeledCellSampleMetadata,
} from "../types";
import { queryDBAsRecords, sql } from "../util/sql";
import { makeWorkspaceApi, useDatasetApi, useExamplesApi } from "./api";
import { useDatasetSampleMetadataDB } from "./datasets";
import { useFeatureSetsForPlate } from "./features";

export type LabeledSetResponse = {
  stains: string[];
  displayName: string;
  classNames: string[]; // All classes which may or may not have corresopnding labels.
  labels: Omit<LabeledCellSampleMetadata, "id">[];
  latestModelPath: string;
  stainChannelIndices: Record<string, number>;
  stainDisplayRanges: Record<string, DisplayRange>;
  accuracies: number[];
  latestFeatureSets: string[];
};

type UseLabeledSetResult =
  | { loading: true }
  | { loading: false; value: LabeledSetResponse };

/**
 * Hook to fetch the latest labeled set by id
 */
export function useLabeledSet(id: string): UseLabeledSetResult {
  const [response, setResponse] = useState<UseLabeledSetResult>({
    loading: true,
  });

  const api = useExamplesApi();

  useEffect(() => {
    const request = api.route("<labeledSet>", { labeledSet: id }).get();

    request
      .json<LabeledSetResponse>()
      .then((value) => setResponse({ value, loading: false }))
      .catch(handleAborted);

    return () => request.abort();
  }, [api, id]);

  return response;
}

function extractNeighborData(
  samples: CellSampleMetadata[],
  items: (CellSampleMetadata & NeighborData)[],
): NeighborsMap {
  const rawNeighbors: (NeighborData & { id: string })[] = [];

  items.forEach(({ inDegreeScore, neighbors, ...metadata }) => {
    rawNeighbors.push({
      id: keyForSample(metadata),
      inDegreeScore,
      neighbors,
    });
  });

  const neighborsMap: NeighborsMap = new Map();

  rawNeighbors.forEach(({ inDegreeScore, neighbors, id }) => {
    neighborsMap.set(id, {
      inDegreeScore,
      neighbors: neighbors.map((index) => rawNeighbors[index].id),
    });
  });

  samples.forEach((sample) => {
    const id = keyForSample(sample);
    if (!neighborsMap.has(id)) {
      neighborsMap.set(id, {
        inDegreeScore: 0,
        neighbors: [],
      });
    }
  });

  return neighborsMap;
}

export function extractPredictionData(
  samples: CellSampleMetadata[],
  items: (CellSampleMetadata & PredictionData)[] | null,
): PredictionsMap | null {
  const predictionsMap: PredictionsMap = new Map();
  if (!items) {
    return null;
  }
  items.forEach(({ predictions, ...metadata }) => {
    const id = keyForSample(metadata);

    const predictionsEntry: Prediction = {
      predictedClass: "",
      predictions: new Map(),
    };
    // Set the predicted class and the scores for each class.
    let maxScore = 0.0;
    let predictedClass = "";
    Object.entries(predictions).forEach(([name, value]) => {
      if (value > maxScore) {
        maxScore = value;
        predictedClass = name;
      }
      predictionsEntry.predictions.set(name, value);
    });
    predictionsEntry.predictedClass = predictedClass;
    if (!predictionsMap.has(id)) {
      predictionsMap.set(id, predictionsEntry);
    }
  });

  return predictionsMap;
}

export function useFilteredPlates(
  dataset: DatasetId,
  stains: string[],
  samplingFilter: FilterSqlClause | null,
): string[] | null {
  const [plates, setPlates] = useState<string[] | null>(null);
  const filterForResults = useRef<FilterSqlClause | null>(null);

  const fetch = combine({
    sampleMetadata: useDatasetSampleMetadataDB({ dataset }),
    prefilter: usePrefilterForStains(dataset, stains),
  });

  useEffect(() => {
    if (fetch?.successful && stains.length > 0) {
      const { sampleMetadata, prefilter } = fetch.value;
      queryDBAsRecords<{ plate: string }>(
        sampleMetadata,
        sql`SELECT DISTINCT(plate) FROM sample_metadata WHERE (${prefilter}) AND (${
          samplingFilter || "TRUE"
        })`,
      ).then((records) => {
        filterForResults.current = samplingFilter;
        setPlates(records.map(({ plate }) => plate));
      });
    }
    // HACK(benkomalo): specifying fetch as a dependency here results in infinite
    // updates, since it doesn't appear as if the results of combine() above is stable.
    // So instead, specify whether the fetch has returned, and the inputs to the fetch.
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [
    fetch?.successful,
    dataset,
    stains,
    samplingFilter,
    filterForResults,
    setPlates,
  ]);

  return filterForResults.current === samplingFilter ? plates : null;
}

function useRandomPlate({ plates }: { plates: [string, ...string[]] }): string {
  return useMemo((): string => sampleSize(plates, 1)[0], [plates]);
}

function cleanUpStateAfterLoad(
  state: LabeledSet,
  {
    samples,
    neighbors,
    predictions,
    smallN,
    smallSamplingFilter,
  }: {
    samples: UnlabeledCellSampleMetadata[];
    neighbors: NeighborsMap;
    predictions: PredictionsMap | null;
    smallN: boolean;
    smallSamplingFilter: FilterSqlClause | null;
  },
): LabeledSet {
  const alreadyLabeled = state.classifications.flatMap(
    (classification) => classification.examples,
  );
  // Prevent us from adding duplicates to the view.
  const samplesPruned = removeSamples(samples, [
    ...state.queue,
    ...state.displayed,
  ]);
  const queue = sortQueue(
    removeSamples([...state.queue, ...samplesPruned], alreadyLabeled),
    neighbors,
  );

  const skipped: UnlabeledCellSampleMetadata[] = [];
  return {
    ...state,
    queue,
    skipped,
    smallN,
    neighbors,
    predictions,
    outOfSamples: queue.length === 0,
    smallSamplingFilter,
  };
}

export async function predictUnlabeled({
  dataset,
  workspace,
  accessToken,
  latestFeatureSets,
  latestModelPath,
  toPredict,
}: {
  dataset: DatasetId;
  workspace: Workspace;
  accessToken: AccessToken;
  latestFeatureSets: string[];
  latestModelPath: string;
  toPredict: CellSampleMetadata[];
}): Promise<any | null> {
  if (
    latestFeatureSets.length === 0 ||
    latestModelPath === "" ||
    toPredict.length === 0
  ) {
    return null;
  }

  return examplesApi({ accessToken, workspace: workspace.id, dataset })
    .route("inference")
    .post(toPredict.map(cleanUpIndex), {
      modelPath: latestModelPath,
      featureSets: latestFeatureSets,
      plates: Array.from(new Set(toPredict.map(({ plate }) => plate))),
    })
    .json();
}

/**
 * Hook to refill the sample queue whenever it's running low on samples
 * @param config Configuration for the hook
 * @param config.dataset The dataset being used
 * @param config.threshold Will attempt to refill the queue if the number of samples
 * drops below this value
 * @param config.numSamples The number of samples to fetch when refilling the queue
 * @param config.samplingFilter A filtering clause used when requesting new samples
 * @param config.plates A list of plates to choose from when requesting new samples
 */
export function useReplenishQueue({
  dataset,
  workspace,
  accessToken,
  threshold,
  numSamples,
  samplingFilter,
  plates,
}: {
  dataset: DatasetId;
  workspace: Workspace;
  accessToken: AccessToken;
  threshold: number;
  numSamples: number;
  samplingFilter: FilterSqlClause | null;
  plates: [string, ...string[]];
}) {
  const { state, updateState } = useLabeledSetContext();
  const plateForFeatures = useRandomPlate({ plates });

  const needsMoreSamples = useMemo(
    () =>
      state.queue.length < threshold &&
      (!state.smallN || state.queue.length === 0),
    [state.queue.length, threshold, state.smallN],
  );
  // In theory this isn't needed as all plates should have all features.
  const features = useFeatureSetsForPlate({ dataset, plate: plateForFeatures });
  const featureSets = useMemo(() => {
    if (features?.successful) {
      return filteredFeatures(features.value, state.stains);
    } else {
      return null;
    }
  }, [features, state.stains]);

  const api = useDatasetApi();

  useEffect(() => {
    if (
      !needsMoreSamples ||
      !featureSets ||
      state.loadingError ||
      (state.smallN && state.smallSamplingFilter == state.samplingFilter)
    ) {
      return;
    }

    let cancelled = false;
    const existingClassifications = state.classifications.flatMap(
      (classification) => classification.examples,
    );
    const existingQueueDisplayed = [...state.displayed, ...state.queue].map(
      cleanUpIndex,
    );
    const plate = sampleSize(plates, 1)[0];

    let samplesResponse: CellSampleMetadata[] = [];
    let smallSamplingFilter: FilterSqlClause | null = null;
    const update = async () => {
      try {
        samplesResponse = await api
          .route("sample_examples")
          .post(existingClassifications.map(cleanUpIndex), {
            n: numSamples,
            samplingFilter: samplingFilter || "",
            plate,
          })
          .json();
        smallSamplingFilter =
          samplesResponse.length < threshold ? samplingFilter : null;
      } catch (e: any) {
        if (
          !(
            e.name === "FetchError" &&
            e.message === "Filtering resulted in empty Dataset."
          )
        ) {
          throw e;
        } else {
          smallSamplingFilter = samplingFilter;
        }
      }

      const cells: Omit<CellSampleMetadata, "type">[] = [
        ...samplesResponse,
        ...existingClassifications,
        ...existingQueueDisplayed,
      ]
        // The API only wants the required fields and nothing extra
        .map(cleanUpIndex);

      const [predictionsResponse, knnResponse] =
        samplesResponse.length > 0
          ? await Promise.all([
              predictUnlabeled({
                dataset,
                workspace,
                accessToken,
                latestFeatureSets: state.latestFeatureSets,
                latestModelPath: state.latestModelPath,
                toPredict: samplesResponse,
              }),
              api
                .route("knn")
                .post(cells, {
                  k: NEIGHBOR_COUNT,
                  samplingFilter: samplingFilter ?? "",
                  featureSets,
                })
                .json<(CellSampleMetadata & NeighborData)[]>(),
            ])
          : [null, []];
      const predictionsMap: PredictionsMap | null = extractPredictionData(
        samplesResponse,
        predictionsResponse,
      );

      const neighborsMap = extractNeighborData(samplesResponse, knnResponse);

      if (!cancelled) {
        updateState((state) =>
          cleanUpStateAfterLoad(state, {
            samples: sortQueue(
              samplesResponse.map(addIdToMetadata),
              neighborsMap,
            ),
            neighbors: neighborsMap,
            smallN: samplesResponse.length < numSamples,
            predictions: predictionsMap,
            smallSamplingFilter: smallSamplingFilter,
          }),
        );
      }
    };

    update().catch((ex) => {
      if (!cancelled) {
        updateState((state) => ({
          ...cleanUpStateAfterLoad(state, {
            samples: [],
            neighbors: new Map(),
            smallN: true,
            predictions: new Map(),
            smallSamplingFilter,
          }),
          loadingError: ex,
          outOfSamples: false,
        }));
      }
    });

    return () => {
      cancelled = true;
    };
  }, [
    accessToken,
    dataset,
    featureSets,
    needsMoreSamples,
    samplingFilter,
    state.classifications,
    state.queue,
    state.displayed,
    state.latestFeatureSets,
    state.latestModelPath,
    state.samplingFilter,
    state.smallSamplingFilter,
    state.smallN,
    numSamples,
    threshold,
    updateState,
    workspace,
    plates,
    state.loadingError,
    api,
  ]);
}

/**
 * Hook to refill the displayed samples whenever the list is empty
 * @param numSamples How many samples we should display when refilling the display list
 */
export function useRefillDisplayedWhenEmpty(numSamples: number): void {
  const { state, updateState } = useLabeledSetContext();

  useEffect(() => {
    if (state.displayed.length === 0 && state.queue.length > 0) {
      const displayed = state.queue.slice(0, numSamples);
      updateState((state) => ({
        ...state,
        displayed,
        queue: removeSamples(state.queue, displayed),
      }));
    }
  }, [
    numSamples,
    state.displayed.length,
    state.neighbors,
    state.queue,
    updateState,
  ]);
}

export function useOutOfSamples(): void {
  const { state, updateState } = useLabeledSetContext();
  useEffect(() => {
    if (
      state.queue.length === 0 &&
      !state.outOfSamples &&
      state.samplingFilter === state.smallSamplingFilter
    ) {
      updateState((state) => ({
        ...state,
        outOfSamples: state.queue.length === 0,
      }));
    }
  }, [
    state.queue,
    state.smallN,
    state.outOfSamples,
    state.samplingFilter,
    state.smallSamplingFilter,
    updateState,
  ]);
}

/**
 * Hook to fetch a given model's performance information.
 */
export const useModelMetrics = makeWorkspaceApi("examples/models")<
  ModelMetrics,
  { modelPath: string }
>(({ modelPath }) => ({ modelPath }));
