import { getDeltaDateTimeInDays } from "../../../common/utils/dateFormatUtils/getDeltaDateTimeInDays";
import { SubjectType } from "../../common/types/SubjectType";
import { addMonthsToDate } from "../../common/utils/addMonthsToDate";
import { getBaselineFollowUp } from "../../common/utils/getBaselineFollowUp";
import { getMostRecentFollowUp } from "../../common/utils/getMostRecentFollowUp";

export const OBSERVED = "observed";
export const PREDICTED = "predicted";

export type DeathType = typeof OBSERVED | typeof PREDICTED;

export type RangeType = {
  min: number;
  max: number;
};

interface SubjectWithDeathDates extends SubjectType {
  deathDates: Record<DeathType, number | null>;
}

type SubjectsWithDeathDatesByArm = {
  [armNumber: number]: SubjectWithDeathDates[];
};

type DataPointType = {
  x: number;
  value: number;
  range: RangeType;
};

export type KaplanMeierDataPoint = {
  x: number;
  [arm: number]: Record<DeathType, DataPointType>;
};

export function generateKaplanMeierData(subjects: SubjectType[]): KaplanMeierDataPoint[] {
  const subjectsWithDeathDates = subjects.map(getSubjectWithDeathDates);
  const subjectsWithDeathDatesByArm =
    sortSubjectsWithDeathDatesByProjectArm(subjectsWithDeathDates);
  return computeKaplanMeierCurves(subjectsWithDeathDatesByArm);
}

function getSubjectWithDeathDates(subject: SubjectType): SubjectWithDeathDates {
  const observed = getSubjectObservedDeath(subject);
  const predicted = getSubjectPredictedDeath(subject);

  const deathDates = {
    observed,
    predicted,
  };

  return {
    ...subject,
    deathDates,
  };
}

function getSubjectObservedDeath(subject: SubjectType): number | null {
  const { deathDate, followUps } = subject;

  const baseLineFollowUp = getBaselineFollowUp(followUps);
  if (!deathDate || !baseLineFollowUp) {
    return null;
  }

  const { date: baseLineDate } = baseLineFollowUp;
  if (!baseLineDate) {
    return null;
  }

  return getDeltaDateTimeInDays(baseLineDate, deathDate);
}

function getSubjectPredictedDeath(subject: SubjectType): number | null {
  const { followUps } = subject;
  const mostRecentFollowUp = getMostRecentFollowUp(followUps);
  const baseLineFollowUp = getBaselineFollowUp(followUps);

  if (!mostRecentFollowUp || !baseLineFollowUp) {
    return null;
  }

  const { date: baseLineDate } = baseLineFollowUp;

  const { survivalPredictions, date: mostRecentFollowUpDate } = mostRecentFollowUp;

  const firstSurvivalPrediction = survivalPredictions.length > 0 ? survivalPredictions[0] : null;

  if (!baseLineDate || !mostRecentFollowUpDate || firstSurvivalPrediction === null) {
    return null;
  }

  const predictedDeathDate = addMonthsToDate(mostRecentFollowUpDate, firstSurvivalPrediction);

  return getDeltaDateTimeInDays(baseLineDate, predictedDeathDate);
}

function computeKaplanMeierCurves(subjects: SubjectsWithDeathDatesByArm): KaplanMeierDataPoint[] {
  //TODO this is super gross and can most definitely be refactored
  const deathsByArm: {
    [arm: number]: Record<DeathType, DataPointType[]>;
  } = {};
  for (const arm in subjects) {
    if (!Object.prototype.hasOwnProperty.call(subjects, arm)) {
      continue;
    }

    const armSubjects = subjects[arm];

    deathsByArm[arm] = {
      [PREDICTED]: getKaplanMeierCurve(armSubjects, PREDICTED),
      [OBSERVED]: getKaplanMeierCurve(armSubjects, OBSERVED),
    };
  }

  const allXValuesSet = new Set<number>();

  for (const thingsKey in deathsByArm) {
    if (!Object.prototype.hasOwnProperty.call(deathsByArm, thingsKey)) {
      continue;
    }

    const thing2 = deathsByArm[thingsKey];
    for (const thing2Key in thing2) {
      if (!Object.prototype.hasOwnProperty.call(thing2, thing2Key)) {
        continue;
      }

      const xValues = thing2[thing2Key as DeathType].map((t) => t.x);
      for (const xValue of xValues) {
        allXValuesSet.add(xValue);
      }
    }
  }

  const kaplanMeierDataPoints: KaplanMeierDataPoint[] = [];

  const allXValues = [0, ...allXValuesSet].sort((a, b) => a - b);

  for (let i = 0; i < allXValues.length; i++) {
    const x = allXValues[i];
    const kaplanMeierDataPoint: KaplanMeierDataPoint = { x };

    for (const arm in deathsByArm) {
      let predictedDataPoint: DataPointType | undefined;

      if (!Object.prototype.hasOwnProperty.call(deathsByArm, arm)) {
        continue;
      }

      const deathsByDeathType = deathsByArm[arm];
      for (const deathType in deathsByDeathType) {
        if (!Object.prototype.hasOwnProperty.call(deathsByDeathType, deathType)) {
          continue;
        }

        const deaths = deathsByDeathType[deathType as DeathType];
        for (let j = 0; j < deaths.length; j++) {
          const deathPoint = deaths[j];
          if (deathPoint.x === x) {
            predictedDataPoint = deathPoint;
            break;
          } else {
            //  if first point
            if (i === 0) {
              predictedDataPoint = {
                x,
                value: 1,
                range: {
                  min: 1,
                  max: 1,
                },
              };
              break;
            } else if (deathPoint.x > x) {
              break;
            }
          }
        }

        if (predictedDataPoint) {
          break;
        } else {
          const reversed = [...deaths].reverse();
          predictedDataPoint = reversed.find((death) => death.x < x);
          if (!predictedDataPoint) {
            predictedDataPoint = {
              x,
              value: 1,
              range: {
                min: 1,
                max: 1,
              },
            };
          }
          break;
        }
      }

      if (!predictedDataPoint) {
        throw new Error("KAPLAN MEIER CALCULATION: This should not happen");
      }

      kaplanMeierDataPoint[arm] = {
        [PREDICTED]: predictedDataPoint,
        [OBSERVED]: {
          x,
          value: 1,
          range: {
            min: 1,
            max: 1,
          },
        },
      };
    }

    kaplanMeierDataPoints.push(kaplanMeierDataPoint);
  }

  for (let i = 0; i < kaplanMeierDataPoints.length; i++) {
    const kaplanMeierDataPoint = kaplanMeierDataPoints[i];
    for (const armKey in kaplanMeierDataPoint) {
      if (!Object.prototype.hasOwnProperty.call(kaplanMeierDataPoint, armKey)) {
        continue;
      }

      const deaths = kaplanMeierDataPoint[armKey];
      for (const deathKey in deaths) {
        if (!Object.prototype.hasOwnProperty.call(deaths, deathKey)) {
          continue;
        }

        const deathType = deathKey as DeathType;

        if (i === kaplanMeierDataPoints.length - 1) {
          continue;
        }

        const { value } = deaths[deathType];
        const { value: nextValue } = kaplanMeierDataPoints[i + 1][armKey][deathType];
        if (value === 0 && nextValue === 0) {
          delete kaplanMeierDataPoints[i + 1][armKey][deathType];
        }
      }
    }
  }

  return kaplanMeierDataPoints;
}

function getKaplanMeierCurve(subjects: SubjectWithDeathDates[], key: DeathType): DataPointType[] {
  const events = subjects
    .map((subject) => {
      const { deathDates } = subject;
      return {
        time: deathDates[key],
        status: 1,
      };
    })
    .filter((event) => event.time !== null) as {
    time: number;
    status: number;
  }[];

  const sortedEvents = events.slice();
  sortedEvents.sort((a, b) => (a.time > b.time ? 1 : b.time > a.time ? -1 : 0));

  const time: number[] = sortedEvents.map((subject) => subject.time);
  const status: number[] = sortedEvents.map((subject) => subject.status);

  const uniqueTime: number[] = [...new Set(time)];

  const nRisk = [];
  const nEvent = [];
  let runningAlive = status.length;

  let runningCensoredOut = 0;
  // let runningDead = 0;

  for (const i of uniqueTime) {
    // const currentRisks = [];
    let deathCount = 0;
    let aliveCount = 0;

    for (let k = 0; k < time.length; k++) {
      if (time[k] === i) {
        // currentRisks.push(status[k]);

        if (status[k] === 1) {
          deathCount += 1;
        } else {
          aliveCount += 1;
        }
      }
    }
    // runningDead += deathCount;
    runningCensoredOut += aliveCount;
    nRisk.push(runningAlive - runningCensoredOut / 2);
    nEvent.push(deathCount);

    runningAlive -= deathCount;
  }

  const se = [];
  const survP = [];
  const upperCi = [];
  const lowerCi = [];
  let csEvent = 0;
  let csSe: number | null = 0;

  for (let j = 0; j < nEvent.length; j++) {
    csEvent += nEvent[j];
    if (nRisk[j] - nEvent[j] === 0) {
      csSe = null;
    } else if (csSe !== null) {
      csSe += nEvent[j] / (nRisk[j] * (nRisk[j] - nEvent[j]));
    }
    if (csSe !== null) {
      se.push(Math.sqrt(csSe));
      survP.push((nRisk[0] - csEvent) / nRisk[0]);
      upperCi.push(Math.min(1, survP[j] + 1.96 * se[j] * survP[j]));
      lowerCi.push(Math.max(0, survP[j] - 1.96 * se[j] * survP[j]));
    }
  }

  const all: DataPointType[] = [];
  for (let i = 0; i < uniqueTime.length; i++) {
    if (i < survP.length - 1 && i < lowerCi.length - 1 && i < upperCi.length - 1) {
      all.push({
        value: survP[i],
        range: {
          min: lowerCi[i],
          max: upperCi[i],
        },
        x: uniqueTime[i],
      });
    }
  }

  const lastPoint = all.length > 0 ? all[all.length - 1] : null;
  if (lastPoint) {
    all.push({
      value: 0,
      range: {
        ...lastPoint.range,
      },
      x: lastPoint.x + 1,
    });
  }
  return all;
}

type SubjectsWithDeathDatesByArmType = {
  [armNumber: number]: SubjectWithDeathDates[];
};

//TODO this can be generic and re-use the `sortSubjectsByTrialArm` function
export function sortSubjectsWithDeathDatesByProjectArm(
  subjects: SubjectWithDeathDates[]
): SubjectsWithDeathDatesByArmType {
  const subjectsByArm: SubjectsWithDeathDatesByArmType = {};

  for (const subject of subjects) {
    const {
      projectArm: { number },
    } = subject;

    if (!subjectsByArm[number]) {
      subjectsByArm[number] = [];
    }

    subjectsByArm[number].push(subject);
  }

  return subjectsByArm;
}
