import React from "react";
import {
  Bar,
  BarChart,
  Label,
  LabelList,
  Legend as ReChartsLegend,
  ResponsiveContainer,
  Tooltip,
  XAxis,
  YAxis,
} from "recharts";
import { Payload } from "recharts/types/component/DefaultLegendContent";

import { Legend } from "../../../Analysis/common/components/Legend";
import { getXAxisProps } from "../../../Analysis/common/utils/getXAxisProps";
import { axisLabelProps } from "../../../Analysis/common/utils/getYAxisLabel";
import { getYAxisProps } from "../../../Analysis/common/utils/getYAxisProps";
import { TEXT_OPACITY } from "../../../Analysis/common/utils/textOpacity";
import { ImageCaptureContextMenuContainer } from "../../../common/components/ImageCaptureContextMenuContainer";
import { formatTotal, GenericTooltip } from "../../../common/utils/barChartUtils/GenericTooltip";
import { getBarShape } from "../../../common/utils/barChartUtils/getBarShape";
import { sortData } from "../../../common/utils/barChartUtils/sortData";
import { LabelsPerCohortData } from "../utils/buildAnnotationsPerCohortData";

type TableDataType = Record<string, string | number>;

export function flattenData(data: LabelsPerCohortData[]): TableDataType[] {
  return data.map(({ label, cohorts }) => {
    const row: TableDataType = {};
    row["label"] = label;
    cohorts.forEach(({ name, studies }) => {
      row[name] = studies;
    });

    return row;
  });
}

function getLegendPayload(data: LabelsPerCohortData[]): Payload[] {
  const payload: Payload[] = [];

  const allCohorts = data.flatMap(({ cohorts }) => cohorts);

  const uniqueCohortNames = [...new Set(allCohorts.map(({ name }) => name))];

  for (const cohortName of uniqueCohortNames) {
    const cohort = allCohorts.find(({ name }) => name === cohortName);
    if (!cohort) {
      throw new Error(`Cohort name ${cohortName} not found in all cohorts`);
    }

    payload.push({
      type: "square",
      color: cohort.color,
      value: cohort.name,
    });
  }

  return payload;
}

interface AnnotationsPerCohortChartProps {
  data: LabelsPerCohortData[];
}

export function AnnotationsPerCohortChart({ data }: AnnotationsPerCohortChartProps): JSX.Element {
  const tableData = flattenData(data);

  const xAxisProps = getXAxisProps();
  const yAxisProps = getYAxisProps(true);

  const longestLabelLength = Math.max(...data.map(({ label }) => label.length));

  const legendPayload = getLegendPayload(data);
  const bars = legendPayload.map(({ value }) => value);

  // because the type of the data here cannot be enforced in typescript (see: https://stackoverflow.com/a/57371967) we have to add a filter to ensure that we are only looking at the keys that are numbers in order to compute the sum
  const getDataCount = (item: TableDataType) =>
    bars
      .map((key) => item[key])
      .filter((item): item is number => typeof item === "number")
      .reduce((partialSum, a) => partialSum + a, 0);

  const sortedData = sortData(tableData, getDataCount);

  const dense = data.length > 8;

  return (
    <ImageCaptureContextMenuContainer>
      {({ reference }) => (
        <ResponsiveContainer ref={reference}>
          <BarChart
            layout={"horizontal"}
            data={sortedData}
            maxBarSize={100}
            margin={{ top: 0, right: 0, left: 30, bottom: 15 }}
          >
            <ReChartsLegend
              align="center"
              verticalAlign="bottom"
              content={<Legend />}
              payload={legendPayload}
              wrapperStyle={{ bottom: 0 }}
            />
            <XAxis
              {...xAxisProps}
              type={"category"}
              dataKey={"label"}
              interval={0}
              angle={dense ? -90 : 0}
              textAnchor={dense ? "end" : "middle"}
              height={dense ? 20 + 6 * longestLabelLength : undefined}
            >
              <Label
                value="Label"
                offset={-5}
                position="insideBottom"
                fontFamily={"Inter"}
                fontStyle={"normal"}
                fontWeight={600}
                fontSize={"11px"}
                opacity={TEXT_OPACITY}
              />
            </XAxis>
            <YAxis
              type={"number"}
              {...yAxisProps}
              allowDecimals={false}
              domain={[0, (dataMax: number) => Math.ceil(dataMax * 1.05)]}
            >
              <Label {...axisLabelProps} position="left" value={"# of Studies"} offset={15} />
            </YAxis>
            <Tooltip
              content={<GenericTooltip summaryFormatter={formatTotal} />}
              isAnimationActive={false}
              cursor={false}
            />
            {bars.map((bar, index) => (
              <Bar
                isAnimationActive={false}
                dataKey={bar}
                stackId={"y"}
                key={bar}
                legendType={"square"}
                fill={legendPayload[index].color}
                shape={(props) => getBarShape(props, bars, bar)}
              >
                {index === bars.length - 1 && (
                  <LabelList
                    position="top"
                    valueAccessor={({ value }: { value: number[] }) => value[1] ?? "N/A"}
                  />
                )}
              </Bar>
            ))}
          </BarChart>
        </ResponsiveContainer>
      )}
    </ImageCaptureContextMenuContainer>
  );
}
