import {isNaN, isNumber, mean, meanBy, sum} from 'lodash'

import {levelPredictionMap, levelStrengthMap} from '../constants'
import {CsMaterialSample, StrengthLevel} from '../declarations/MaterialData'

import {filterSamplesByTimeRange, TimeRange} from '@predict/UtilsLib/dateTime'

export interface PredictionRecord {
  actual: number
  predicted: number
}

const pow2 = (x) => x * x

export const calcMeanAbsoluteError = (records: PredictionRecord[]): number | undefined => {
  const avg = mean(records.map((r) => Math.abs(r.predicted - r.actual)))
  if (isNaN(avg)) {
    return undefined
  }
  return avg
}

export const calcR2Score = (records: PredictionRecord[]): number | undefined => {
  if (records.length === 0) {
    return undefined
  }
  const mean = meanBy(records, (r) => r.actual)
  const ssRes = sum(records.map((r) => pow2(r.actual - r.predicted)))
  const ssTot = sum(records.map((r) => pow2(r.actual - mean)))
  if (ssTot > 0) {
    return 1 - ssRes / ssTot
  }
  return undefined
}

function calcAverage(
  samples: CsMaterialSample[],
  strengthLevel: StrengthLevel
): number | undefined {
  const predictedValues = samples.map((s) => s[levelPredictionMap[strengthLevel]]).filter(isNumber)
  const avg = mean(predictedValues)
  return isNaN(avg) ? undefined : avg
}

export interface PredictionStats {
  average?: number
  deviationFromTarget?: number
  meanAbsoluteError?: number
  r2?: number
}
export const calcCementStrengthStats = (
  samples: CsMaterialSample[],
  timeRange: TimeRange,
  strengthLevel: StrengthLevel,
  target?: number
): PredictionStats => {
  const samplesInRange = filterSamplesByTimeRange(samples, timeRange)
  const average = calcAverage(samplesInRange, strengthLevel)
  const deviationFromTarget = isNumber(target) && isNumber(average) ? average - target : undefined

  const records = samplesInRange
    .map((sample) => ({
      actual: sample[levelStrengthMap[strengthLevel]],
      predicted: sample[levelPredictionMap[strengthLevel]]
    }))
    .filter(
      (record): record is PredictionRecord => isNumber(record.actual) && isNumber(record.predicted)
    )
  if (records.length === 0) {
    return {
      average,
      deviationFromTarget
    }
  }
  const meanAbsoluteError = calcMeanAbsoluteError(records)
  const r2Value = calcR2Score(records)

  return {
    average,
    deviationFromTarget,
    meanAbsoluteError,
    r2: isNaN(r2Value) ? undefined : r2Value
  }
}
