import { AsyncDuckDB } from "@duckdb/duckdb-wasm";
import * as d3 from "d3";
import type { AccessToken } from "src/Auth0/accessToken";
import {
  colorSchemeByWellMetadata,
  colorValuesByScheme,
} from "src/Control/ColorSchemeSelector";
import fishersExactTest from "src/FeatureSetManagementPage/fishersExactTest";
import { findLeafNodes } from "src/PhenoFinder/treeUtils";
import {
  BaseClusterFeatureData,
  BaseClusterMetadata,
  ClusterFeatureData,
  ClusterFeatureDataStats,
  ClusterMetadata,
  DataNode,
  LeafDataNode,
  TreeModificationStep,
  VisualizationColumn,
} from "src/PhenoFinder/types";
import { DatasetId, WorkspaceId } from "src/types";
import { datasetApi } from "src/util/api-client";
import { defaultComparator } from "src/util/sorting";
import { queryDBAsRecords, sanitizedColumn, sql } from "src/util/sql";
import { uTest } from "statly";
import invariant from "tiny-invariant";
import { FilterSet } from "../Control/FilterSelector/types";
import { serializeToSqlClause } from "../Control/FilterSelector/utils";
import { cleanUpIndex } from "../PhenotypicLearner/util";
import {
  MAX_BAR_COUNT,
  MIN_METADATA_PERCENT,
  OTHER_COLUMN_NAME,
} from "./constants";

export async function requestInference(
  accessToken: AccessToken,
  description: string,
  defaultNamesMap: Map<string, string>,
  currentTree: DataNode,
  workspace: WorkspaceId,
  cutTreesPath: string,
  dataset: DatasetId,
) {
  await datasetApi({ accessToken, workspace, dataset })
    .route("tree/<cutTree>/publish", { cutTree: cutTreesPath })
    .post({
      labels: convertTreeToLabeledSetPayload(currentTree, defaultNamesMap),
      description,
    })
    .finish();
}

export async function requestClustering(
  accessToken: AccessToken,
  workspace: WorkspaceId,
  cutTreeName: string,
  dataset: DatasetId,
  stains: string[],
) {
  return datasetApi({ accessToken, workspace, dataset })
    .route("cut_trees/<cutTree>/create", { cutTree: cutTreeName })
    .post(undefined, {
      channels: stains,
    })
    .json();
}

// TODO(you): Fix this no-unused-exports rule violation
// ts-unused-exports:disable-next-line
export function convertTreeToLabeledSetPayload(
  currentTree: DataNode,
  defaultNamesMap: Map<string, string>,
) {
  const leafNodes = findLeafNodes(currentTree);
  return leafNodesToLabeledSet(leafNodes, defaultNamesMap);
}

function leafNodesToLabeledSet(
  leafNodes: LeafDataNode[],
  defaultNamesMap: Map<string, string>,
) {
  return leafNodes.map((node: LeafDataNode) => {
    return {
      name: node.displayName ?? defaultNamesMap.get(node.name),
      examples: node.cells.map(cleanUpIndex),
    };
  });
}

export async function getMetadataForColumn(
  metadataDB: AsyncDuckDB,
  column: string,
  data: DataNode,
  filter: FilterSet,
): Promise<ClusterMetadata[]> {
  // TODO(michaelwiest): Make this slightly more efficient by only querying leaf nodes
  // that have changed. Probably would need to do that at the call site.
  const leafNodes: LeafDataNode[] = findLeafNodes(data);

  const metadata: BaseClusterMetadata[] = [];
  const filterSQL = serializeToSqlClause(filter);

  const metadataPromises = leafNodes.map((node) => {
    const query = sql`
      SELECT ${sanitizedColumn(column)} AS value, COUNT(*) AS count
      FROM single_cell_metadata
      WHERE cluster_${node.level} = ${node.name}
      AND ${filterSQL}
      GROUP BY 1
  `;

    return queryDBAsRecords<{ value: string | boolean | null; count: bigint }>(
      metadataDB,
      query,
    ).then((records) => {
      // If there are too many bars, combine the smallest ones into an "other" category
      const total = records.reduce((acc, row) => acc + Number(row.count), 0);
      const otherEntry = {
        clusterName: node.name,
        value: OTHER_COLUMN_NAME,
        count: 0,
        numEntries: 0,
      };
      for (const row of records) {
        if (
          (Number(row.count) / total) * 100 < MIN_METADATA_PERCENT &&
          records.length > MAX_BAR_COUNT
        ) {
          otherEntry.count += Number(row.count);
          otherEntry.numEntries += 1;
        } else {
          metadata.push({
            clusterName: node.name,
            value: row.value === null ? "<null>" : row.value.toString(),
            count: Number(row.count),
            numEntries: 1,
          });
        }
      }
      if (otherEntry.count > 0) {
        metadata.push(otherEntry);
      }
    });
  });

  await Promise.all(metadataPromises);
  return calculateMetadataStats(metadata);
}

// TODO(you): Fix this no-unused-exports rule violation
// ts-unused-exports:disable-next-line
export async function getFeatureData(
  featureDB: AsyncDuckDB,
  column: string,
  data: DataNode,
  filter: FilterSet,
): Promise<ClusterFeatureData[]> {
  const leafNodes: LeafDataNode[] = findLeafNodes(data);
  const features: BaseClusterFeatureData[] = [];
  const filterSQL = serializeToSqlClause(filter);
  const featuresPromises = leafNodes.map((node) => {
    const query = sql`
      SELECT ${column} AS value
      FROM single_cell_features
      WHERE cluster_${node.level} = ${node.name}
      AND ${filterSQL}
  `;
    return queryDBAsRecords<{ value: number | null }>(featureDB, query).then(
      (records) => {
        features.push({
          clusterName: node.name,
          values: records.map((row) => row.value ?? 0),
        });
      },
    );
  });

  await Promise.all(featuresPromises);
  return calculateFeatureStats(features);
}

export async function getChartData(
  metadataDB: AsyncDuckDB,
  featureDB: AsyncDuckDB | null,
  selectedVisualizationColumn: VisualizationColumn,
  data: DataNode,
  filter: FilterSet,
): Promise<ClusterFeatureData[] | ClusterMetadata[]> {
  // How can I assert that the type is never feature when the featuredb is null?
  if (selectedVisualizationColumn.type === "feature") {
    invariant(
      featureDB !== null,
      "featureDB can't be null if selecting a feature column",
    );
    return getFeatureData(
      featureDB,
      selectedVisualizationColumn.name,
      data,
      filter,
    );
  } else {
    return getMetadataForColumn(
      metadataDB,
      selectedVisualizationColumn.name,
      data,
      filter,
    );
  }
}

function calculateMetadataStats(
  metadata: BaseClusterMetadata[],
): ClusterMetadata[] {
  const totalCount = metadata.reduce((total, row) => total + row.count, 0);

  const totalCountByValue = metadata.reduce(
    (allCounts, row) => {
      allCounts[row.value] = (allCounts[row.value] || 0) + row.count;
      return allCounts;
    },
    {} as Record<string, number>,
  );

  const totalCountByCluster = metadata.reduce(
    (allCounts, row) => {
      allCounts[row.clusterName] =
        (allCounts[row.clusterName] || 0) + row.count;
      return allCounts;
    },
    {} as Record<string, number>,
  );

  const pValues: number[] = [];
  const oddsRatios: number[] = [];

  // Calculate significance for each cluster/value combination
  for (const row of metadata) {
    // Represent the contingency table
    // a: in this cluster, has this value
    const a = row.count;
    // b: not in this cluster, has this value
    const b = totalCountByValue[row.value] - row.count;
    // c: in this cluster, does not have this value
    const c = totalCountByCluster[row.clusterName] - row.count;
    // d: not in this cluster, does not have this value
    const d =
      totalCount -
      totalCountByCluster[row.clusterName] -
      totalCountByValue[row.value] +
      row.count; // Add back the count of this cluster/value since we double subtracted

    pValues.push(fishersExactTest(a, b, c, d));
    oddsRatios.push((a * d) / (b * c));
  }

  return metadata.map((row, i) => ({
    ...row,
    ratio: row.count / totalCountByCluster[row.clusterName],
    uncorrectedPValue: row.value !== OTHER_COLUMN_NAME ? pValues[i] : 1.0,
    oddsRatio: oddsRatios[i],
  }));
}

function calculateFeatureStats(
  featureData: BaseClusterFeatureData[],
): ClusterFeatureData[] {
  // This does a one vs rest comparison for each cluster using a Mann-Whitney U test.
  const pValues: number[] = [];

  // Compare the values for a given cluster to all other values.
  for (const row of featureData) {
    const controlValues: number[] = featureData
      .filter((featRows) => featRows.clusterName !== row.clusterName)
      .map((featRows) => featRows.values)
      .flat();
    pValues.push(uTest(row.values, controlValues).p);
  }

  return featureData.map((row, i) => ({
    ...row,
    ...getFeatureSummaryStats(row),
    uncorrectedPValue: pValues[i],
  }));
}

// TODO(you): Fix this no-unused-exports rule violation
// ts-unused-exports:disable-next-line
export function getFeatureSummaryStats(
  featureData: BaseClusterFeatureData,
): Omit<ClusterFeatureDataStats, "uncorrectedPValue"> {
  const sortedData = {
    ...featureData,
    values: featureData.values.sort(),
  };

  const q1 = d3.quantile(sortedData.values, 0.25);
  const q3 = d3.quantile(sortedData.values, 0.75);
  const median = d3.median(sortedData.values);
  const mean = d3.mean(sortedData.values);

  invariant(q1);
  invariant(q3);
  invariant(median);
  invariant(mean);

  return {
    q1,
    median,
    mean,
    q3,
    min: q1 - 1.5 * (q3 - q1),
    max: q3 + 1.5 * (q3 - q1),
  };
}

export function cleanUserInputString(str: string | undefined | null) {
  return (str ?? "").trim().length > 0 ? str!.trim() : undefined;
}

export function sortStringsOtherLast(a: string, b: string): number {
  // Sorts strings alphabetically, but puts special "<other>" string last
  return (
    Number(a === OTHER_COLUMN_NAME) - Number(b === OTHER_COLUMN_NAME) ||
    defaultComparator(a, b)
  );
}

export function consolidateDisplayNameModifications(
  modifications: TreeModificationStep[],
): TreeModificationStep[] {
  const latestModification = modifications[modifications.length - 1];
  let minIndex = modifications.length;
  for (let i = modifications.length - 1; i >= 0; i--) {
    if (
      modifications[i].type === "changeDisplayName" &&
      modifications[i].nodeName == latestModification.nodeName
    ) {
      minIndex = i;
    } else {
      break;
    }
  }
  const newModifications = modifications.slice(0, minIndex);
  newModifications.push(latestModification);
  return newModifications;
}

export function getColorScaleForMetadata(data: ClusterMetadata[]) {
  const metadataValues = d3.union(data.map((d) => d.value));
  const valuesArray = [...metadataValues.keys()].sort(sortStringsOtherLast);

  const colorMapping = colorValuesByScheme(
    valuesArray,
    colorSchemeByWellMetadata(valuesArray),
  );
  // Insert the "other" color at the end.
  const colorMappingValues = [...colorMapping.values()];
  if (colorMapping.has(OTHER_COLUMN_NAME)) {
    colorMappingValues.splice(-1, 1, "#ccc");
  }
  return d3
    .scaleOrdinal<string>()
    .domain(valuesArray)
    .range(colorMappingValues)
    .unknown("#ccc");
}
