import { AsyncDuckDB } from "@duckdb/duckdb-wasm";
import * as Dialog from "@radix-ui/react-dialog";
import cx from "classnames";
import {
  Dispatch,
  ReactNode,
  SetStateAction,
  useCallback,
  useEffect,
  useMemo,
  useState,
} from "react";
import { HelpCircle } from "react-feather";
import { Button } from "@spring/ui/Button";
import { Tooltip } from "@spring/ui/Tooltip";
import {
  colorSchemeByWellMetadata,
  colorValuesByScheme,
} from "../Control/ColorSchemeSelector";
import { FilterSqlClause } from "../Control/FilterSelector/types";
import { MetadataColumnValue } from "../types";
import { inferInterestingColumnsDB } from "../util/dataset-util";
import { useAsyncValue } from "../util/hooks";
import {
  ValidatedSQL,
  getTableColumns,
  queryDBAsRecords,
  sanitizedColumn,
  sanitizedTextValue,
  sql,
  sqlOr,
} from "../util/sql";
import { normalizeDomain } from "../util/vega-util";
import { DataClass, ModelTrainingConfig, SplitConfig } from "./types";
import {
  createExampleSplit,
  inferClassColumnAndValuesFromConfig,
  inferDefaultSplitColumn,
} from "./utils";

export default function SplitConfigurator({
  config,
  setConfig,
  metadata,
}: {
  // TODO(benkomalo): do a better job with typing here -- we know that the config
  // should be mostly filled out at this point (at least, classes are defined).
  config: Partial<ModelTrainingConfig>;
  setConfig: Dispatch<SetStateAction<Partial<ModelTrainingConfig>>>;
  metadata: AsyncDuckDB;
}) {
  const classColumn = useAsyncValue(async () => {
    const allColumns = await getTableColumns(metadata, "sample_metadata");
    const classColumn = await inferClassColumnAndValuesFromConfig(
      config.data?.classes ?? [],
      allColumns,
      metadata,
      config.data?.filter ?? "TRUE",
      true,
    );
    if (!classColumn) {
      throw new Error("Unable to infer class column from config");
    }
    return classColumn;
  }, [config.data?.classes, metadata]);

  const minCounts = useAsyncValue(
    () =>
      getMinColumnCountWithinClass(
        config.data?.classes ?? [],
        metadata,
        config.data?.filter ?? "TRUE",
      ),
    [config.data?.classes, metadata, config.data?.filter],
  );
  const filteredColumns = useMemo(() => {
    return minCounts
      ? Object.entries(minCounts)
          .filter(([, count]) => count > 1)
          .map(([column]) => column)
      : [];
  }, [minCounts]);

  const maxClassCountWithinColumn = useAsyncValue(async () => {
    // TODO(you): Fix this no-unnecessary-condition rule violation
    // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
    if (!filteredColumns) {
      return {};
    }

    return getMaxClassCountWithinColumnValue(
      config.data?.classes ?? [],
      metadata,
      filteredColumns,
    );
  }, [config.data?.classes, metadata, filteredColumns]);

  const potentiallyConfoundedColumns = useMemo(() => {
    if (!maxClassCountWithinColumn) {
      return [];
    }
    return filteredColumns.filter(
      (column) => maxClassCountWithinColumn[column] === 1,
    );
  }, [filteredColumns, maxClassCountWithinColumn]);

  useEffect(() => {
    if (
      config.split === undefined &&
      classColumn &&
      maxClassCountWithinColumn &&
      minCounts
    ) {
      setConfig((config) => {
        const idColumn = inferDefaultSplitColumn(
          maxClassCountWithinColumn,
          minCounts,
        );
        return {
          ...config,
          split: {
            type: "cross-validation",
            idColumn,
            numCVFolds: idColumn ? Math.min(5, minCounts[idColumn]!) : 5,
            stratifySamplingColumn: classColumn,
          },
        };
      });
    }
  }, [
    config.split,
    setConfig,
    classColumn,
    maxClassCountWithinColumn,
    minCounts,
  ]);

  const allDataLoaded =
    !!classColumn && !!minCounts && !!maxClassCountWithinColumn;

  return (
    <div className={"tw-flex"}>
      {/* TODO(you): Fix this no-unnecessary-condition rule violation */}
      {/* eslint-disable-next-line @typescript-eslint/no-unnecessary-condition */}
      {getSimplifiedSplitSummary(config?.split)}
      <Dialog.Root>
        <Dialog.Trigger
          className={
            "tw-text-purple tw-underline tw-ml-lg hover:tw-text-purple/80"
          }
        >
          Edit
        </Dialog.Trigger>

        <Dialog.Overlay
          className={cx(
            "tw-fixed tw-left-0 tw-top-0 tw-w-full tw-h-full tw-z-popup",
            "tw-flex tw-flex-col tw-justify-center tw-bg-gray-500/50",
          )}
        >
          <Dialog.Content
            className={cx(
              "tw-bg-white tw-shadow-xl tw-border tw-rounded",
              "tw-mx-auto tw-top-[50%] tw-w-[1080px] tw-h-[540px]",
              "tw-flex tw-flex-col",
            )}
          >
            <div className={"tw-flex tw-flex-1 tw-overflow-hidden"}>
              <div
                className={
                  "tw-flex-1 tw-border-r tw-p-lg tw-flex tw-flex-col tw-overflow-hidden"
                }
              >
                <div className={"tw-flex tw-items-center tw-mb-lg"}>
                  <div className={"tw-text-xl tw-text-purple"}>Split by</div>
                  <Tooltip
                    contents={
                      <div
                        className={
                          "tw-max-w-[400px] tw-text-slate-500 tw-text-sm"
                        }
                      >
                        <p>
                          Your data will be automatically partitioned into
                          training and test sets so that your model can be
                          evaluated on data it hasn't seen before.
                        </p>
                        <p className={"tw-mt-4"}>
                          Selecting how to partition the data is an important
                          way for you to build confidence in the results of your
                          model. Try to pick a split such that the model is
                          forced to generalize across dimensions that are
                          important to your experimental setup.
                        </p>

                        <p className={"tw-mt-4"}>
                          For example, if your data is generated over several
                          batches, you can try to see if your model will
                          generalize by splitting the data by experimental
                          batch.
                        </p>
                      </div>
                    }
                    side={"bottom"}
                    showArrow
                  >
                    <HelpCircle
                      className={"tw-ml-sm tw-text-slate-500 tw-w-4"}
                    />
                  </Tooltip>
                </div>
                {allDataLoaded && (
                  <SplitOptionsList
                    config={config}
                    setConfig={setConfig}
                    classColumn={classColumn}
                    minCounts={minCounts}
                    columns={filteredColumns}
                    potentiallyConfoundedColumns={potentiallyConfoundedColumns}
                  />
                )}
              </div>
              <div className={"tw-flex-1 tw-p-lg tw-overflow-y-auto"}>
                {allDataLoaded && (
                  <SplitPreview
                    config={config}
                    classColumn={classColumn}
                    metadata={metadata}
                  />
                )}
              </div>
            </div>
            <div className={"tw-flex tw-justify-end tw-p-md tw-border-t"}>
              <Dialog.Close>
                <Button variant={"primary"}>Apply</Button>
              </Dialog.Close>
            </div>
          </Dialog.Content>
        </Dialog.Overlay>
      </Dialog.Root>
    </div>
  );
}

function SplitOptionsList({
  config,
  setConfig,
  classColumn,
  minCounts,
  columns,
  potentiallyConfoundedColumns,
}: {
  config: Partial<ModelTrainingConfig>;
  setConfig: Dispatch<SetStateAction<Partial<ModelTrainingConfig>>>;
  classColumn: string;
  minCounts: { [column: string]: number };
  columns: string[];
  potentiallyConfoundedColumns: string[];
}) {
  const idColumn = useMemo(() => {
    return config.split?.idColumn ?? null;
  }, [config]);

  const handleSelectColumn = useCallback(
    (column) => {
      setConfig((config) => ({
        ...config,
        split: {
          type: "cross-validation",
          idColumn: column,
          numCVFolds: column ? Math.min(5, minCounts[column]!) : 5,
          stratifySamplingColumn: classColumn,
        },
      }));
    },
    [classColumn, minCounts, setConfig],
  );

  const OptionButton = useCallback(
    (
      props: React.ButtonHTMLAttributes<HTMLButtonElement> & {
        children: ReactNode;
        selected: boolean;
      },
    ) => {
      const { selected, className, children, ...rest } = props;
      return (
        <button
          className={cx(
            "tw-flex tw-border tw-rounded tw-p-sm hover:tw-border-purple tw-mb-1",
            "tw-text-left",
            selected && "tw-border-purple tw-text-purple",
          )}
          {...rest}
        >
          {props.children}
        </button>
      );
    },
    [],
  );

  return (
    <>
      <div className={"tw-flex tw-flex-col tw-flex-1 tw-overflow-y-auto"}>
        <OptionButton
          key={"random"}
          onClick={() => handleSelectColumn(null)}
          selected={!idColumn}
        >
          Random split
        </OptionButton>
        {potentiallyConfoundedColumns.map((column) => {
          return (
            <OptionButton
              key={column}
              onClick={() => handleSelectColumn(column)}
              selected={idColumn === column}
            >
              <>
                {column}
                <div className={"tw-ml-xs tw-text-red-500"}>*</div>
              </>
            </OptionButton>
          );
        })}
        {columns
          .filter((column) => !potentiallyConfoundedColumns.includes(column))
          .map((column) => {
            return (
              <OptionButton
                key={column}
                onClick={() => handleSelectColumn(column)}
                selected={idColumn === column}
              >
                {column}
              </OptionButton>
            );
          })}
      </div>
      {potentiallyConfoundedColumns.length > 0 && (
        <div
          className={
            "tw-text-sm tw-mt-md tw-text-slate-500 tw-whitespace-nowrap"
          }
        >
          <span className={"tw-text-red-500"}>*</span> indicates a field that is
          confounded by your target variable{" "}
          <span className={"tw-text-purple"}>{classColumn}</span>
        </div>
      )}
    </>
  );
}

function SplitPreview({
  config,
  classColumn,
  metadata,
}: {
  config: Partial<ModelTrainingConfig>;
  classColumn: string;
  metadata: AsyncDuckDB;
}) {
  // TODO(benkomalo): tighten up these types. The config is a Partial because we're
  // in the model editor and that is flexible to incomplete configs, but we know if
  // we've gotten to this stage we should have "data" and "split" defined.
  // TODO(you): Fix this no-unnecessary-condition rule violation
  // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
  const idColumn = config?.split?.idColumn ?? null;
  // TODO(you): Fix this no-unnecessary-condition rule violation
  // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
  const classes = config?.data?.classes ?? [];

  // All unique values of the selected id column.
  const allValues = useAsyncValue(async () => {
    if (!idColumn) {
      return null;
    }
    const classProjection = getClassProjection(classes);
    const filter = config.data?.filter ?? "TRUE";
    const raw = await queryDBAsRecords<{
      [column: string]: MetadataColumnValue;
    }>(
      metadata,
      sql`
                SELECT 
                    DISTINCT ${sanitizedColumn(idColumn)}
                FROM sample_metadata
                WHERE ${classProjection} IS NOT NULL AND (${filter})
                ORDER BY ${sanitizedColumn(idColumn)}`,
    );

    return normalizeDomain(raw.map((row) => row[idColumn]));
  }, [metadata, idColumn]);
  const colorMap = useMemo(() => {
    if (!allValues) {
      return null;
    }
    const colorScheme = colorSchemeByWellMetadata(allValues);
    return colorValuesByScheme(allValues, colorScheme);
  }, [allValues]);
  const [hoveredValues, setHoveredValues] = useState<Set<MetadataColumnValue>>(
    new Set(),
  );

  // Fetch the grouped counts by the column, if one is set.
  // (null is used if no column is set, and undefined is used for async-not-ready).
  const counts:
    | {
        [classValue: string]: Map<MetadataColumnValue, number>;
      }
    | null
    | undefined = useAsyncValue(async () => {
    if (!idColumn) {
      return null;
    }

    const results: {
      [classValue: string]: Map<MetadataColumnValue, number>;
    } = {};

    const classProjection = getClassProjection(classes);
    const rawRecords = await queryDBAsRecords<{
      __class: string;
      count: number;
      [idColumn: string]: MetadataColumnValue;
    }>(
      metadata,
      sql`
          SELECT
              ${classProjection} AS __class,
              ${sanitizedColumn(idColumn)},
              COUNT(1) AS count
          FROM
              sample_metadata
          WHERE __class IS NOT NULL
          GROUP BY
              __class, ${sanitizedColumn(idColumn)}
        `,
    );

    for (const row of rawRecords) {
      const klass = row.__class;
      const idColumnValue = row[idColumn];
      // TODO(you): Fix this no-unnecessary-condition rule violation
      // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
      const countsInClass = (results[klass] ??= new Map());
      countsInClass.set(
        idColumnValue,
        (countsInClass.get(idColumnValue) ?? 0) + Number(row.count),
      );
    }
    return results;
  }, [metadata, idColumn, classColumn]);

  // Simulate an example split to make it salient what's being configured.
  const exampleSplit = useMemo(() => {
    if (!counts || !config.data || !config.split) {
      return null;
    }
    const memberships = Object.fromEntries(
      Object.entries(counts).map(([klass, counts]) => {
        return [klass, new Set(counts.keys())];
      }),
    );
    return createExampleSplit(config.split, memberships);
  }, [config, counts]);

  return (
    <div className={"tw-flex tw-flex-col"}>
      {idColumn && (
        <div className={"tw-mb-sm tw-text-sm tw-text-slate-500"}>
          Distribution of <span className={"tw-text-purple"}>{idColumn}</span>{" "}
          segmented by <span className={"tw-text-purple"}>{classColumn}</span>.
        </div>
      )}
      {allValues &&
        counts &&
        classes.map((klass) => {
          const data = Array.from(counts[klass.name].entries()).map(
            ([value, count]) => {
              return {
                count: Number(count),
                label: value,
                color: colorMap?.get(value) ?? "#000",
              };
            },
          );
          // TODO(benkomalo): super inefficient if there are many values.
          data.sort(
            (a, b) => allValues.indexOf(a.label) - allValues.indexOf(b.label),
          );
          return (
            <div className={"tw-flex"} key={klass.name}>
              <div className={"tw-w-[100px] tw-truncate"}>{klass.name}</div>
              <div className={"tw-flex-1"}>
                <SegmentedBar
                  data={data}
                  totalWidth={400}
                  hoveredValues={hoveredValues}
                />
              </div>
            </div>
          );
        })}
      {exampleSplit && colorMap && (
        <SplitLegend
          colorMap={colorMap}
          trainingSet={exampleSplit[0]}
          testSet={exampleSplit[1]}
          onValuesHovered={setHoveredValues}
        />
      )}
    </div>
  );
}

// TODO(benkomalo): if there are many, many values, this becomes slow/unwieldly.
// Do the same thing PhenoFinder does by collapsing and creating an "Other" group if
// that's the case. (Bonus: refactor to a general component?)
function SegmentedBar({
  totalWidth,
  data,
  hoveredValues,
}: {
  totalWidth: number;
  data: {
    count: number;
    label: MetadataColumnValue;
    color: string;
  }[];
  hoveredValues: Set<MetadataColumnValue>;
}) {
  const anyValuesHovered = hoveredValues.size > 0;
  const totalCount = useMemo(
    () => data.map(({ count }) => count).reduce((a, b) => a + b, 0),
    [data],
  );
  let currentLeft = 0;
  return (
    <div
      className={"tw-relative tw-h-[32px] tw-border tw-overflow-hidden"}
      style={{
        width: totalWidth,
      }}
    >
      {data.map((datum) => {
        const width = (totalWidth * datum.count) / totalCount;
        const node = (
          <div
            key={String(datum.label)}
            title={String(datum.label)}
            className={"tw-absolute"}
            style={{
              left: currentLeft,
              width: width,
              top: 0,
              bottom: 0,
              backgroundColor: datum.color,
              opacity:
                anyValuesHovered && !hoveredValues.has(datum.label)
                  ? 0.2
                  : undefined,
            }}
          />
        );
        currentLeft += width;
        return node;
      })}
    </div>
  );
}

function SplitLegend({
  colorMap,
  trainingSet,
  testSet,
  onValuesHovered,
}: {
  colorMap: Map<MetadataColumnValue, string>;
  trainingSet: Set<MetadataColumnValue>;
  testSet: Set<MetadataColumnValue>;
  onValuesHovered: (values: Set<MetadataColumnValue>) => void;
}) {
  return (
    <div className={"tw-mt-lg"} onMouseLeave={() => onValuesHovered(new Set())}>
      <div className={"tw-text-sm tw-text-slate-500 tw-flex tw-items-center"}>
        Example configuration:
        <Tooltip
          contents={
            <div className={"tw-max-w-[360px] tw-text-slate-500 tw-text-sm"}>
              Visualization for illustrative purposes only. <br />
              Actual configuration during training may differ slightly.
            </div>
          }
          side={"right"}
          showArrow
        >
          <HelpCircle className={"tw-ml-sm tw-text-slate-500 tw-w-4"} />
        </Tooltip>
      </div>
      <div className={"tw-flex"}>
        <Legend
          key="train"
          colorMap={colorMap}
          header={"Train"}
          values={trainingSet}
          className={"tw-border-r"}
          onMouseEnter={() => onValuesHovered(trainingSet)}
        />
        <Legend
          key="test"
          colorMap={colorMap}
          header={"Test"}
          values={testSet}
          className={"tw-pl-lg"}
          onMouseEnter={() => onValuesHovered(testSet)}
        />
      </div>
    </div>
  );
}

function Legend({
  colorMap,
  header,
  values,
  className,
  onMouseEnter,
}: {
  colorMap: Map<MetadataColumnValue, string>;
  header: string;
  values: Set<MetadataColumnValue>;
  className?: string;
  onMouseEnter: () => void;
}) {
  const sortedMembers = useMemo(() => {
    return Array.from(colorMap.keys()).filter((value) => values.has(value));
  }, [colorMap, values]);
  return (
    <div
      className={cx("tw-flex-1 tw-py-md tw-flex tw-flex-col", className)}
      onMouseEnter={onMouseEnter}
    >
      <div className={"tw-text-slate-500"}>{header}</div>
      {sortedMembers.map((value) => {
        return (
          <div
            className={"tw-flex tw-items-center tw-text-sm"}
            key={String(value)}
          >
            <div
              className={"tw-w-[16px] tw-h-[16px] tw-rounded"}
              style={{ background: colorMap.get(value) }}
            />
            <div className={"tw-flex-1 tw-ml-sm"}>{String(value)}</div>
          </div>
        );
      })}
    </div>
  );
}

export function getSimplifiedSplitSummary(
  split: SplitConfig | undefined,
): ReactNode {
  if (split?.idColumn) {
    return (
      <div>
        <div>
          Split by:{" "}
          <span className={"tw-font-bold tw-font-mono tw-text-sm"}>
            {split.idColumn}
          </span>
        </div>
      </div>
    );
  } else {
    return "Randomly split";
  }
}

/**
 * Determine the minimum unique values of a given column within a class.
 *
 * For example, if we're trying to predict a compound condition, and there's only
 * a single donor_id in the experiment, then every class (i.e. every compound) will have
 * a single value for donor_id (so it's not an appropriate split column).
 */
async function getMinColumnCountWithinClass(
  classes: DataClass[],
  metadata: AsyncDuckDB,
  filter: FilterSqlClause,
) {
  // TODO(you): Fix this no-unnecessary-condition rule violation
  // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
  if (!classes) {
    return {};
  }
  const prefilter = sqlOr(classes.map(({ filter }) => sql`(${filter})`));
  const allColumns = await inferInterestingColumnsDB(metadata, prefilter);

  const classProjection = getClassProjection(classes);
  const columnsProjection = allColumns
    .map(
      (c) =>
        sql`COUNT(DISTINCT ${sanitizedColumn(c)}) AS ${sanitizedColumn(c)}`,
    )
    .join(",");

  const countsPerClass = await queryDBAsRecords<{
    __class: string;
    [column: string]: number | string;
  }>(
    metadata,
    sql`SELECT
          ${classProjection} AS __class,
          ${columnsProjection}
      FROM sample_metadata
      WHERE __class IS NOT NULL AND (${filter})
      GROUP BY __class`,
  );

  const minCountWithinAClass: { [klass: string]: number } = {};
  for (const column of allColumns) {
    minCountWithinAClass[column] = Math.min.apply(
      null,
      countsPerClass.map((counts) => Number(counts[column])),
    );
  }

  return minCountWithinAClass;
}

/**
 * Determine the maximum number of unique class values per unique column value.
 *
 * For example: if our target class is "gender", then the maximum number of unique
 * "gender" values for the column "donor_id" will be 1 (i.e. every donor belongs
 * to at most a single class).
 *
 * On the other hand: if our target class is "compound", and we have a "concentration"
 * column, any given concentration may belong to many compounds.
 */
async function getMaxClassCountWithinColumnValue(
  classes: DataClass[],
  metadata: AsyncDuckDB,
  columns: string[],
) {
  const classProjection = getClassProjection(classes);

  const counts: { [column: string]: number } = {};
  for (const column of columns) {
    const columnCounts = await queryDBAsRecords<{
      __class: string;
      [column: string]: number | string;
    }>(
      metadata,
      sql`SELECT
        ${sanitizedColumn(column)},
        COUNT(DISTINCT __class) AS count
      FROM (
        SELECT
            ${sanitizedColumn(column)},
            ${classProjection} AS __class
        FROM sample_metadata
        WHERE __class IS NOT NULL
      )
      GROUP BY ${sanitizedColumn(column)}`,
    );
    counts[column] = Math.max.apply(
      null,
      columnCounts.map((r) => Number(r.count)),
    );
  }

  return counts;
}

function getClassProjection(classes: DataClass[]): ValidatedSQL {
  return sql`
  CASE
    ${classes
      .map(
        ({ filter, name }) =>
          sql`WHEN (${filter}) THEN ${sanitizedTextValue(name)}`,
      )
      .join("\n")}
    ELSE NULL
  END`;
}
