import { AsyncDuckDB } from "@duckdb/duckdb-wasm";
import * as Sentry from "@sentry/react";
import type { AccessToken } from "src/Auth0/accessToken";
import { datasetApi } from "src/util/api-client";
import { omit } from "vega-lite";
import { FilterSqlClause } from "../Control/FilterSelector/types";
import { DatasetId, MetadataColumnValue, WorkspaceId } from "../types";
import {
  columnMatchesValueClause,
  queryDBAsRecords,
  sanitizedColumn,
  sql,
} from "../util/sql";
import {
  ConfusionCounts,
  DataClass,
  ModelListing,
  ModelTrainingConfig,
  SavedModelListing,
  SplitConfig,
} from "./types";

export function getDefaultModelTrainingConfig(
  dataset: DatasetId,
): Partial<ModelTrainingConfig> {
  return {
    data: {
      dataset,
      filter: "TRUE",
      classes: [],
    },
    featureInputs: [],
    model: {
      type: "xgb",
    },
    // These are a bit complex and rely on the shape of the dataset for a good default
    // so are filled in later by respective components.
    normalization: undefined,
    split: undefined,
  };
}

export async function queryCandidateValuesForColumn(
  metadata: AsyncDuckDB,
  column: string,
  prefilter: FilterSqlClause,
): Promise<MetadataColumnValue[]> {
  const sanitized = sanitizedColumn(column);
  const valuesWithCounts = await queryDBAsRecords<{
    value: MetadataColumnValue;
  }>(
    metadata,
    sql`SELECT ${sanitized} AS value, COUNT(1) AS n
        FROM sample_metadata
        WHERE ${prefilter}
        GROUP BY ${sanitized}
        ORDER BY n DESC, ${sanitized} ASC`,
  );
  return valuesWithCounts.map<MetadataColumnValue>((row) => row.value);
}

/**
 * Infer the class column and values from the given config, if possible.
 *
 * The {@link ModelTrainingConfig} and {@link DataClass} types are designed to be
 * flexible so allow arbitrary definitions of classes, so that the backend can support
 * progressively complex use cases in the future. However, we've designed a UI (for now)
 * that has the simplistic approach of having the user select a single column, and
 * different values of that column. Since we don't store that column explicitly
 * in the config, we have to derive it when restoring state.
 */
export async function inferClassColumnAndValuesFromConfig(
  classes: DataClass[],
  candidateColumns: string[],
  metadata: AsyncDuckDB,
  prefilter: FilterSqlClause,
  returnClassColumnOnly: false,
): Promise<[string, MetadataColumnValue[]] | undefined>;

export async function inferClassColumnAndValuesFromConfig(
  classes: DataClass[],
  candidateColumns: string[],
  metadata: AsyncDuckDB,
  prefilter: FilterSqlClause,
  returnClassColumnOnly: true,
): Promise<string | undefined>;

export async function inferClassColumnAndValuesFromConfig(
  classes: DataClass[],
  candidateColumns: string[],
  metadata: AsyncDuckDB,
  prefilter: FilterSqlClause,
  returnClassColumnOnly: boolean = false,
): Promise<[string, MetadataColumnValue[]] | string | undefined> {
  if (!classes.length) {
    return Promise.resolve(undefined);
  }

  for (const column of candidateColumns) {
    const re = new RegExp(`^${sanitizedColumn(column)} *`);
    if (classes.every(({ filter }) => !!filter.match(re))) {
      const values = await queryCandidateValuesForColumn(
        metadata,
        column,
        prefilter,
      );
      const selectedValues = values.filter((value) =>
        classes.some(
          ({ filter }) => filter === columnMatchesValueClause(column, value),
        ),
      );
      if (returnClassColumnOnly) {
        return column;
      }
      return [column, selectedValues];
    }
  }

  return Promise.resolve(undefined);
}

export async function queryPlates(
  metadata: AsyncDuckDB,
  prefilter: FilterSqlClause,
): Promise<string[]> {
  const plates = await queryDBAsRecords<{ plate: string }>(
    metadata,
    sql`SELECT DISTINCT plate
        FROM sample_metadata
        WHERE ${prefilter}
      ORDER BY plate ASC`,
  );
  return plates.map((plate) => plate.plate);
}

export function inferDefaultSplitColumn(
  maxClassCountWithinColumn: { [column: string]: number },
  minCountWithinClass: { [column: string]: number },
): string | null {
  const candidates = Object.entries(minCountWithinClass)
    .filter(([, count]) => count > 1)
    .map(([column]) => column);

  if (candidates.length === 0) {
    return null;
  }

  const confounded = candidates.filter(
    (column) => maxClassCountWithinColumn[column] === 1,
  );

  const findInterestinglyNamedColumn = (columns: string[]) => {
    const interestingTerms = [
      "donor",
      "plate",
      "sample",
      "batch",
      "experiment",
    ];
    return (
      // Do two passes: one to find "exact" matches (or matches with an ID suffix)
      columns.find((column) =>
        interestingTerms.some(
          (term) =>
            column.search(new RegExp(`^${term}([_ ](id))?$`, "i")) !== -1,
        ),
      ) ??
      // And another to find fuzzier matches.
      columns.find((column) =>
        interestingTerms.some((term) => column.toLowerCase().includes(term)),
      )
    );
  };

  return (
    findInterestinglyNamedColumn(confounded) ??
    (confounded.length > 0
      ? confounded[0]
      : findInterestinglyNamedColumn(candidates) ?? candidates[0])
  );
}

/**
 * Given a split config, simulate a single fold/split.
 *
 * Note: making an actual sampling is rather complex, and under the hood we delegate to
 * https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html.
 * Replicating that logic here on the client is non-trivial and not worth it, but
 * it's important for us to visualize examples for illustrative purposes, so this
 * algorithm is an approximation of a split.
 */
export function createExampleSplit(
  split: SplitConfig,
  classMemberships: { [klass: string]: Set<MetadataColumnValue> },
): [Set<MetadataColumnValue>, Set<MetadataColumnValue>] {
  const trainingSet = new Set<MetadataColumnValue>();
  const testSet = new Set<MetadataColumnValue>();
  let trainingFraction: number;
  switch (split.type) {
    case "random":
      trainingFraction = split.trainFraction;
      break;
    case "cross-validation":
      trainingFraction = 1 - 1 / split.numCVFolds;
      break;
  }

  // The general strategy here is to go through each class in a stratified manner,
  // and then split within each class according to a target count based on our
  // trainFraction or number of folds (implying a fraction per fold).
  for (const members of Object.values(classMemberships)) {
    const targetTrainingCount = Math.max(
      1,
      Math.floor(members.size * trainingFraction),
    );

    // First, do a pass to see if there are any members already assigned to train or
    // test because of being assigned from a previous class.
    const previouslyAssigned = new Set();
    let includedInTrainingforClass = 0;
    for (const member of members) {
      if (trainingSet.has(member)) {
        previouslyAssigned.add(member);
        includedInTrainingforClass += 1;
      }
      if (testSet.has(member)) {
        previouslyAssigned.add(member);
      }
    }

    // Then, go through the unassigned members and add eagerly to training set until
    // target capacity.
    for (const member of members) {
      if (!previouslyAssigned.has(member)) {
        if (includedInTrainingforClass < targetTrainingCount) {
          trainingSet.add(member);
          includedInTrainingforClass += 1;
        } else {
          testSet.add(member);
        }
      }
    }
  }

  // Note: it's possible there are adversarial setups which result in non-ideal splits.
  // Notably: you could structure several classes with disjoint members, then have
  // some classes that are entirely repeats of existing members, in which case you
  // could result in no members being assigned to test in some classes. In practice
  // this is very, very unlikely, and fixing it would be a bit tricky. So just detect
  // it for now. We could re-implement scipy's logic if we _really_ wanted to at some
  // point.
  if (testSet.size === 0) {
    Sentry.captureMessage("No members assigned to test set in split");
  }
  return [trainingSet, testSet];
}

export function isConfigurationValid(
  config: Partial<ModelTrainingConfig>,
): config is ModelTrainingConfig {
  const { data, normalization, featureInputs, model, split } = config;
  if (!data || !normalization || !featureInputs || !model || !split) {
    return false;
  }

  const { classes } = data;
  if (classes.length < 2) {
    return false;
  }

  if (featureInputs.length === 0) {
    return false;
  }

  return true;
}

/**
 * Find control entries in metadata that correspond to user selection.
 *
 * If no selections are specified, returns all data for the given column.
 */
export async function findControlEntries(
  metadata: AsyncDuckDB,
  prefilter: FilterSqlClause,
  selectedColumn: string,
  selectedValues: MetadataColumnValue[] | undefined,
): Promise<{ plate: string; well: string; value: string }[]> {
  const valueClauses = generateValueClauses(selectedColumn, selectedValues);
  const match = await queryDBAsRecords<{
    plate: string;
    well: string;
    value: string;
  }>(
    metadata,
    sql`SELECT plate, well, ${sanitizedColumn(selectedColumn)} AS value
          FROM sample_metadata
          WHERE (${prefilter}) AND (${valueClauses})`,
  );
  return match;
}

export function generateValueClauses(
  selectedColumn: string,
  selectedValues: MetadataColumnValue[] | undefined,
): string {
  const valueClauses = selectedValues
    ? selectedValues
        .map((value) => `(${columnMatchesValueClause(selectedColumn, value)})`)
        .join(" OR ")
    : "TRUE";
  return `(${valueClauses})`;
}

export function accuracyForCounts(counts: ConfusionCounts): number {
  let correct = 0;
  let incorrect = 0;
  counts.forEach((row, i) => {
    row.forEach((count, j) => {
      if (i === j) {
        correct += count;
      } else {
        incorrect += count;
      }
    });
  });

  return correct / (correct + incorrect);
}

function isModelSaved(
  model: ModelListing | SavedModelListing,
): model is SavedModelListing {
  // TODO(you): Fix this no-unnecessary-condition rule violation
  // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
  return (model as SavedModelListing).id !== undefined;
}

export async function saveModel(
  accessToken: AccessToken,
  workspace: WorkspaceId,
  dataset: DatasetId,
  model: ModelListing | SavedModelListing,
): Promise<{ id: string }> {
  const api = datasetApi({ accessToken, workspace, dataset });

  const body = {
    metadata: {
      name: model.name,
      description: model.description,
    },
    config: model.config,
  };

  const response = isModelSaved(model)
    ? api.route("models/<model>", { model: model.id }).put(body)
    : api.route("models").post(body);

  return response.json<{ id: string }>();
}

/**
 * Asynchronously submits a training job to the backend.
 */
export async function submitTrainingJob(
  workspace: WorkspaceId,
  dataset: DatasetId,
  accessToken: AccessToken,
  model: SavedModelListing,
): Promise<{ error?: string; result: string | null }> {
  if (!isConfigurationValid(model.config)) {
    throw new Error("Invalid model configuration");
  }

  return datasetApi({ accessToken, workspace, dataset })
    .route("models/<model>/train", { model: model.id })
    .post()
    .json<{ error?: string; result: string | null }>();
}

/**
 * Asynchronously submits an inference job to the backend.
 */
export async function submitInferenceJob(
  workspace: WorkspaceId,
  dataset: DatasetId,
  accessToken: AccessToken,
  model: SavedModelListing,
  name: string,
  description: string,
): Promise<"OK"> {
  if (!isConfigurationValid(model.config)) {
    throw new Error("Invalid model configuration");
  }

  await datasetApi({ accessToken, workspace, dataset })
    .route("models/<model>/apply", { model: model.id })
    .post({
      metadata: {
        name,
        description,
      },
    });

  return "OK";
}

export function getUnsavedClone(
  model: SavedModelListing,
  name: string,
): ModelListing {
  const withoutId = omit(model, ["id"]) as ModelListing;
  return {
    ...withoutId,
    name,
    status: "incomplete",
  };
}

const modelIds: string[] = [];

/**
 * Generate a unique name for a predictions table for a model.
 *
 * We use the same AsyncDuckDB instance to store sample metadata and model results
 * (predictions) and so as users navigate to diff models, we can potentially
 * add different tables. This ensures they have unique names.
 */
export function predictionsTableNameFor(modelId: string): string {
  // Since model names aren't necessarily safe table names, we just alias to ordinal
  // values with a prefix.
  if (!modelIds.includes(modelId)) {
    modelIds.push(modelId);
  }
  const index = modelIds.indexOf(modelId);
  return `pred_${index}`;
}

/**
 * Formats SQL query clauses into a multiline more human readable format
 * TODO(Hosny): Rewrite when we have a "parse to a FilterSet" utility.
 */
export function formatFilter(filter: string | undefined): string[] | undefined {
  // No filter
  if (filter && filter == "TRUE") {
    return ["No filters applied."];
  }
  // Single filter clause, return as is
  if (filter && !filter.includes("OR") && !filter.includes("AND")) {
    return [filter];
  }
  // Multi filter clause, format it
  if (filter && (filter.includes("OR") || filter.includes("AND"))) {
    return filter
      .slice(1, -1)
      .split(/(AND|OR)/g)
      .map((f) =>
        ["AND", "OR"].includes(f) ? f.trim() : f.trim().slice(1, -1),
      );
  }
  return undefined;
}
