import { AsyncDuckDB } from "@duckdb/duckdb-wasm";
import { useCallback, useEffect, useMemo } from "react";
import { makeDatasetApi } from "src/hooks/api";
import { DatasetId, WorkspaceId } from "src/types";
import { datasetApi } from "src/util/api-client";
import { Failure, Fetchable, Success } from "@spring/core/result";
import { FilterSqlClause } from "../Control/FilterSelector/types";
import { useAuthenticatedFetch } from "../hooks/fetch";
import { insertModelEvalTableDB, sql, useQueryAsRecords } from "../util/sql";
import { ConfusionCounts, SavedModelListing } from "./types";
import { predictionsTableNameFor } from "./utils";

/**
 * Fetch a list of models from the server.
 *
 * Returns the Fetchable result, as well as a function that can be imperatively called
 * to force a refresh.
 */
export const useFetchClassificationModels =
  makeDatasetApi("models")<SavedModelListing[]>();

export function useConfusionCounts(
  db: AsyncDuckDB,
  tableName: string,
  classes: string[],
  filter?: FilterSqlClause,
): Fetchable<ConfusionCounts> {
  const predictions = useQueryAsRecords(
    db,
    sql`
        SELECT * FROM ${tableName}
        INNER JOIN (
            SELECT *
            FROM sample_metadata
            WHERE ${filter || "TRUE"}
        ) sample_metadata
        ON ${tableName}.plate = sample_metadata.plate
            AND
           ${tableName}.well = sample_metadata.well
  `,
  );
  return useMemo(() => {
    if (!predictions) {
      return predictions;
    }
    if (!predictions.successful) {
      return Failure.of(predictions.error);
    }

    return predictions.map((rows) => {
      const counts = Array.from({ length: classes.length }, () =>
        Array.from({ length: classes.length }, () => 0),
      );

      for (const row of rows) {
        // The actual class label is a single column, label-encoded.
        const actualIx = classes.indexOf(row["_class"] as string);

        // The predictions are stored as individual columns per class; we have to
        // do an idxmax() to determine the actual class label.
        const predictions = classes.map((c) => row[c] as number);
        const maxPred = Math.max(...predictions);
        const predictedIx = predictions.indexOf(maxPred);
        counts[actualIx][predictedIx] += 1;
      }

      return counts;
    });
  }, [classes, predictions]);
}

/**
 * Fetch the results from a model training run.
 */
export function useModelTrainingResults(
  workspace: WorkspaceId,
  dataset: DatasetId,
  modelId: string,
  metadata: AsyncDuckDB,
): Fetchable<AsyncDuckDB> {
  const url = datasetApi({ workspace, dataset }).url("models/<model>/results", {
    model: modelId,
  });

  const tableName = predictionsTableNameFor(modelId);
  const transformResponse = useCallback(
    (blob, response) => {
      if (response.status === 202) {
        return Promise.resolve<"not-ready" | AsyncDuckDB>("not-ready");
      } else {
        // TODO(benkomalo): if there's a transition from "not ready" to "ready",
        // we should probably also trigger a model refresh.
        return insertModelEvalTableDB(
          metadata,
          blob,
          tableName,
          true,
        ) as Promise<"not-ready" | AsyncDuckDB>;
      }
    },
    [metadata, tableName],
  );

  const { data, error, mutate } = useAuthenticatedFetch(
    url,
    undefined,
    undefined,
    "blob",
    transformResponse,
  );

  useEffect(() => {
    if (data && data === "not-ready") {
      const timeout = setInterval(mutate, 5000);
      return () => clearInterval(timeout);
    }
  }, [data, error, mutate]);

  return useMemo(() => {
    if (error) {
      return Failure.of(error);
    }
    if (data && data !== "not-ready") {
      return Success.of<AsyncDuckDB>(data);
    }

    return undefined;
  }, [data, error]);
}
