import { Box, useTheme } from '@northvolt/ui'
import * as vg from '@uwdata/vgplot'
import type { SelectableAttribute } from 'components/DataLoader/DataLoaderTypes'
import { isCategorical, isNumeric } from 'components/Utils'
import type React from 'react'
import { useEffect, useRef, useState } from 'react'
import { colorScheme, colorSchemeCat } from './PlotUtils'

interface Plot2DProps {
  attributeX?: SelectableAttribute
  attributeY?: SelectableAttribute
  attributeColor?: SelectableAttribute
  attributeSize?: SelectableAttribute
  attributeSymbol?: SelectableAttribute
  width: number
  height: number
  plotTypes: string[]
  plotTool: string
  lockAxis: boolean
  $filterParam: any
  $panZoomX: any
  $panZoomY: any
  renderId?: number
}

const Plot2D: React.FC<Plot2DProps> = props => {
  const [plot, setPlot] = useState<any>(null)
  const theme = useTheme()
  const themeGrey = theme.palette.grey[700]

  const plotRef = useRef<any>(null)

  useEffect(() => {
    if (props.width && props.height < Number.POSITIVE_INFINITY) {
      if (props.attributeX && props.attributeY) {
        props.$filterParam.update(props.$filterParam.clauses)
        if (props.plotTypes.includes('Scatter')) {
          const p = getScatterPlot(props)
          setPlot(p)
        } else if (props.plotTypes.includes('Hexgrid')) {
          const p = getHexPlot(props)
          setPlot(p)
        } else {
          console.log('Unknown plot type')
        }
      } else {
        setPlot(vg.hspace(props.width))
      }
    } else {
      setPlot(vg.hspace(props.width))
    }
  }, [
    props.attributeX,
    props.attributeY,
    props.attributeColor,
    props.attributeSize,
    props.attributeSymbol,
    props.width,
    props.height,
    props.plotTool,
    props.lockAxis,
    props.plotTypes,
    props.renderId,
    themeGrey,
  ])

  useEffect(() => {
    if (plotRef.current && plot) {
      plotRef.current.appendChild(plot)
      return () => {
        if (plotRef.current) {
          plotRef.current.removeChild(plot)
        }
      }
    }
  }, [plot, plotRef])

  return <Box ref={plotRef} sx={{ width: '100%' }} />
}

export default Plot2D

function getScatterPlot(props: Plot2DProps) {
  const color = props.attributeColor
    ? props.attributeColor.unique_name
    : 'black'
  const size = props.attributeSize ? props.attributeSize.unique_name : 3
  const symbol = props.attributeSymbol
    ? props.attributeSymbol.unique_name
    : 'circle'
  return vg.plot(
    vg.dot(
      vg.from(
        'active_table',
        ...(props.lockAxis ? [] : [{ filterBy: props.$filterParam }]),
      ),
      {
        x: props.attributeX?.unique_name,
        y: props.attributeY?.unique_name,
        stroke: color,
        r: size,
        symbol: symbol,
        tip: true,
      },
    ),
    props.plotTool === 'panzoom'
      ? vg.panZoom({ x: props.$panZoomX, y: props.$panZoomY })
      : vg.intervalXY({ as: props.$filterParam }),
    vg.width(props.width),
    vg.height(props.height),
    vg.grid(true),
    vg.marginBottom(40),
    ...(props.lockAxis
      ? [
          vg.xDomain(vg.Fixed),
          vg.yDomain(vg.Fixed),
          vg.highlight({ by: props.$filterParam, opacity: 0.05 }),
        ]
      : []),
    vg.colorDomain(vg.Fixed),
    vg.colorScheme(
      isCategorical(props.attributeColor) ? colorSchemeCat : colorScheme,
    ),
    vg.rDomain(vg.Fixed),
    vg.rRange([2, 8]),
    isNumeric(props.attributeSize) ? vg.rScale('log') : vg.rScale('ordinal'),
    isNumeric(props.attributeColor)
      ? vg.colorScale('log')
      : vg.colorScale('ordinal'),
    vg.colorClamp(true),
    vg.xLabel(`${props.attributeX?.column_name} →`),
    vg.yLabel(`↑ ${props.attributeY?.column_name}`),
    vg.colorLabel(props.attributeColor ? props.attributeColor.column_name : ''),
  )
}

function getHexPlot(props: Plot2DProps) {
  // const tool = props.plotTool === "panzoom" ? vg.panZoom({ x: props.$panZoomX, y: props.$panZoomY }) : vg.intervalXY({ as: props.$filterParam });
  const color = props.attributeColor
    ? vg.median(props.attributeColor.unique_name)
    : 'black'
  return vg.plot(
    vg.hexbin(vg.from('active_table', { filterBy: props.$filterParam }), {
      x: props.attributeX?.unique_name,
      y: props.attributeY?.unique_name,
      stroke: color,
      fill: color,
      fillOpacity: vg.count(),
      binWidth: 5,
      r: 5,
      tip: true,
    }),
    props.plotTool === 'panzoom'
      ? vg.panZoom({ x: props.$panZoomX, y: props.$panZoomY })
      : vg.intervalXY({ as: props.$filterParam }),
    vg.width(props.width),
    vg.height(props.height),
    vg.grid(true),
    vg.marginBottom(40),
    ...(props.lockAxis ? [vg.xDomain(vg.Fixed), vg.yDomain(vg.Fixed)] : []),
    vg.colorDomain(vg.Fixed),
    vg.colorScheme(colorScheme),
    isNumeric(props.attributeColor)
      ? vg.colorScale('log')
      : vg.colorScale('ordinal'),
    vg.opacityScale('log'),
    vg.colorClamp(true),
    vg.xLabel(`${props.attributeX?.column_name} →`),
    vg.yLabel(`↑ ${props.attributeY?.column_name}`),
    vg.colorLabel(props.attributeColor ? props.attributeColor.column_name : ''),
    vg.opacityLabel('count'),
  )
}
