import { Box, useTheme } from '@northvolt/ui'
import * as vg from '@uwdata/vgplot'
import type { SelectableAttribute } from 'components/DataLoader/DataLoaderTypes'
import { useEffect, useRef, useState } from 'react'
import { colorSchemeCat } from 'routes/tools/plotter/PlotUtils'

interface PlotHistProps {
  attribute: SelectableAttribute
  colorAttr?: SelectableAttribute
  timeAttr?: SelectableAttribute
  LSL?: number
  USL?: number
  width: number
  height: number
  $filterParam: any
  setStats?: (stats: Stats[]) => void
  renderId?: number
}

const PlotHist: React.FC<PlotHistProps> = ({
  attribute,
  colorAttr,
  timeAttr,
  LSL,
  USL,
  width,
  height,
  $filterParam,
  setStats,
  renderId,
}) => {
  const [plot, setPlot] = useState<any>(null)

  const theme = useTheme()
  const themeGrey = theme.palette.grey[700]

  const plotRef = useRef<any>(null)

  useEffect(() => {
    if (attribute && width && height < Number.POSITIVE_INFINITY) {
      const updatePlot = async () => {
        const p = await getPlot(
          attribute,
          colorAttr,
          timeAttr,
          LSL,
          USL,
          $filterParam,
          width,
          height,
          themeGrey,
          setStats,
        )
        setPlot(p)
      }
      updatePlot()
    } else {
      setPlot(vg.vspace(0))
    }
  }, [
    attribute,
    colorAttr,
    timeAttr,
    LSL,
    USL,
    width,
    height,
    themeGrey,
    renderId,
  ])

  useEffect(() => {
    if (plotRef.current && plot) {
      plotRef.current.appendChild(plot)
      return () => {
        if (plotRef.current) {
          plotRef.current.removeChild(plot)
        }
      }
    }
  }, [plot, plotRef])

  return <Box ref={plotRef} />
}

export default PlotHist

function getPlot(
  attribute: SelectableAttribute,
  colorAttr: SelectableAttribute | undefined,
  timeAttr: SelectableAttribute | undefined,
  LSL: number | undefined,
  USL: number | undefined,
  $filterParam: any,
  width: number,
  height: number,
  themeGrey: string,
  setStats?: (stats: Stats[]) => void,
) {
  const bins = 50
  return getStatistics(attribute, colorAttr, timeAttr, LSL, USL).then(stats => {
    const gaussPlots = getGaussianLine(stats, bins, colorAttr)
    const normalizationConstant = getNormalizationConstantSql(stats, colorAttr)
    setStats?.(stats)
    const fillColor = colorAttr ? colorAttr?.unique_name : themeGrey
    return vg.plot(
      vg.rectY(vg.from('active_table', { filterBy: $filterParam }), {
        x: vg.bin(attribute?.unique_name, { steps: bins }),
        y2: vg.agg`count(*) / mean(${normalizationConstant})`,
        stroke: 'black',
        fill: fillColor,
        inset: 0.5,
        tip: true,
        opacity: colorAttr ? 0.4 : 1,
      }),
      ...(LSL
        ? [
            vg.ruleX([LSL], { stroke: 'red', strokeWidth: 2 }),
            vg.textX([LSL], {
              x: LSL,
              fill: 'red',
              lineAnchor: 'bottom',
              frameAnchor: 'top',
              dy: -2,
              fontSize: 16,
              text: (_p: any) => 'LSL',
            }),
          ]
        : []),
      ...(USL
        ? [
            vg.ruleX([USL], { stroke: 'red', strokeWidth: 2 }),
            vg.textX([USL], {
              x: USL,
              fill: 'red',
              lineAnchor: 'bottom',
              frameAnchor: 'top',
              dy: -2,
              fontSize: 16,
              text: (_p: any) => 'USL',
            }),
          ]
        : []),
      gaussPlots,
      vg.width(width),
      vg.height(height),
      vg.marginBottom(40),
      vg.colorScheme(colorSchemeCat),
      vg.colorDomain(vg.Fixed),
      vg.yLabel('↑ Frequency'),
      vg.xLabel(`${attribute.column_name} →`),
      ...(colorAttr ? [vg.colorLegend({ as: $filterParam, columns: 1 })] : []),
    )
  })
}

export interface Stats {
  color: string
  n: number
  minimum: number
  maximum: number
  std_dev: number
  mean: number
  median: number
  q1: number
  q3: number
  pp: number
  ppl: number
  ppu: number
  ppk: number
  cp: number
  cpu: number
  cpl: number
  cpk: number
}

function getStatistics(
  attribute: SelectableAttribute,
  colorAttr: SelectableAttribute | undefined,
  timeAttr: SelectableAttribute | undefined,
  LSL: number | undefined,
  USL: number | undefined,
): Promise<Stats[]> {
  const groupCol = colorAttr ? `"${colorAttr?.unique_name}"` : "'all'"
  let cpk_equations = 'NULL as pp, NULL as ppl, NULL as ppu, NULL as ppk'

  if (LSL !== undefined && USL !== undefined) {
    cpk_equations = `
          (${USL} - ${LSL}) / (6 * stddev_pop("${attribute?.unique_name}")) as pp,
          (${USL} - avg("${attribute?.unique_name}")) / (3 * stddev_pop("${attribute?.unique_name}")) as ppl,
          (avg("${attribute?.unique_name}") - ${LSL}) / (3 * stddev_pop("${attribute?.unique_name}")) as ppu,
          LEAST(((${USL} - avg("${attribute?.unique_name}")) / (3 * stddev_pop("${attribute?.unique_name}"))), ((avg("${attribute?.unique_name}") - ${LSL}) / (3 * stddev_pop("${attribute?.unique_name}")))) as ppk
    `
  } else if (LSL !== undefined && USL === undefined) {
    cpk_equations = `
      NULL as pp,
      (avg("${attribute?.unique_name}") - ${LSL}) / (3 * stddev_pop("${attribute?.unique_name}")) as ppl,
      NULL as ppu,
      (avg("${attribute?.unique_name}") - ${LSL}) / (3 * stddev_pop("${attribute?.unique_name}")) as ppk
  `
  } else if (LSL === undefined && USL !== undefined) {
    cpk_equations = `
      NULL as pp,
      NULL as ppl,
      (${USL} - avg("${attribute?.unique_name}")) / (3 * stddev_pop("${attribute?.unique_name}")) as ppu,
      (${USL} - avg("${attribute?.unique_name}")) / (3 * stddev_pop("${attribute?.unique_name}")) as ppk
    `
  }
  let capa_equations = `
    ${groupCol}::VARCHAR AS color,
    ("${timeAttr?.unique_name}" AT TIME ZONE 'Europe/Stockholm' - INTERVAL 7 HOUR)::DATE AS date,
    ${cpk_equations}
  `
  if (timeAttr === undefined) {
    capa_equations = `
      ${groupCol}::VARCHAR AS color,
      NULL as date,
      NULL as pp,
      NULL as ppl,
      NULL as ppu,
      NULL as ppk
    `
  }

  const q = `
  WITH capa_1 AS (
    SELECT
      ${capa_equations}
    FROM active_table
    GROUP BY 1, 2
  ),
  capa_2 AS (
    SELECT
      color,
      MAX(pp) FILTER (WHERE NOT isinf(pp)) as cp,
      MAX(ppl) FILTER (WHERE NOT isinf(ppl)) as cpl,
      MAX(ppu) FILTER (WHERE NOT isinf(ppu)) as cpu,
      MAX(ppk) FILTER (WHERE NOT isinf(ppk)) as cpk
    FROM capa_1
    GROUP BY 1
  ),
  perf AS (
    SELECT
      ${groupCol}::VARCHAR AS color,
      COUNT("${attribute?.unique_name}") AS n,
      MIN("${attribute?.unique_name}") AS minimum,
      MAX("${attribute?.unique_name}") AS maximum,
      stddev_pop("${attribute?.unique_name}") AS std_dev,
      AVG("${attribute?.unique_name}") as mean,
      MEDIAN("${attribute?.unique_name}") as median,
      quantile_cont(${attribute?.unique_name}, 0.25) as q1,
      quantile_cont(${attribute?.unique_name}, 0.75) as q3,
      ${cpk_equations}
    FROM active_table
    GROUP BY 1
  )

  SELECT
    perf.*,
    capa_2.cp,
    capa_2.cpl,
    capa_2.cpu,
    capa_2.cpk
  FROM perf
  LEFT JOIN capa_2 ON perf.color = capa_2.color;
  `
  return vg.coordinator().query(q, { type: 'json' })
}

function getNormalizationConstantSql(
  stats: Stats[],
  colorAttr?: SelectableAttribute,
) {
  if (colorAttr === undefined) {
    return stats[0].n
  } else {
    const cases = stats
      .map(
        group =>
          `WHEN ${colorAttr?.unique_name} = '${group.color}' THEN ${group.n} `,
      )
      .join('\n')
    return `CASE ${cases} ELSE 1.0 END`
  }
}

type ResultsType = { x: number; y: number; [key: string]: any }

function getGaussianLine(
  stats: Stats[],
  bins: number,
  colorAttr?: SelectableAttribute,
) {
  const numSteps = 100
  const data = Array.from({ length: numSteps }, (_, stepIndex) => {
    return stats.map(group => {
      const binWidth = (group.maximum - group.minimum) / bins
      const mean = group.mean
      const std_dev = group.std_dev
      const xMin = mean - 3 * std_dev
      const xMax = mean + 3 * std_dev
      const x = xMin + (xMax - xMin) * (stepIndex / (numSteps - 1))
      const y =
        (1 / (std_dev * Math.sqrt(2 * Math.PI))) *
        Math.exp(-((x - mean) ** 2) / (2 * std_dev ** 2)) *
        binWidth *
        1.5
      const results: ResultsType = { x, y }
      if (colorAttr !== undefined) {
        results[colorAttr.unique_name] = group.color
      }
      return results
    })
  }).flat()
  return vg.line(data, {
    x: 'x',
    y: 'y',
    stroke: colorAttr ? colorAttr?.unique_name : 'black',
  })
}
