import { useEffect } from 'react'
import { IResultData } from 'analyses/analysis.model'
import { contours, geoPath } from 'd3'
import { flatten } from 'lodash-es'
import { DrawingArea } from './useD3'
import { getColorFromMap, ColorMap } from './colormap'

const normalize = (min: number, max: number) => (val: number) => (val - min) / (max - min)

/**
 * Calculates height limits for each contour level.
 * Based on legacy code.
 */
const calculateThresholds = (data: number[], levels: number) => {
  const min = Math.min(...data)
  const max = Math.max(...data)
  const limit = (max - min) / (levels + 1)
  const limits = Array.from({ length: levels + 2 }, (_, i) => min + i * limit)
  return limits
}

/**
 * Returns a function that returns a color for each contour level.
 */
const createGetLevelColor = (thresholds: number[], colors: ColorMap) => {
  // Something like this was on legacy, not exactly sure about the point.
  // Mimics vmin parameter from https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.contour.html.
  const colorMin = normalize(thresholds[0], thresholds[thresholds.length - 1])((thresholds[0] + thresholds[1]) / 2)

  const normalizer = normalize(colorMin, thresholds[thresholds.length - 1])
  const normalThresholds = thresholds.map(normalizer)
  return (_: unknown, i: number) => getColorFromMap(colors, Math.max(0, normalThresholds[i]))
}

/**
 * Calculates contour plot based on data and level limits.
 */
const calculateData = (rawData: number[], thresholds: number[]) => {
  const dimension = Math.sqrt(rawData.length)
  const data = contours()
    .size([dimension, dimension])
    .thresholds(thresholds)(rawData)
    // Removes the outer contour which just draws a rectangle on the draw area.
    .map((item) => (item.value > 0.01 ? item : { ...item, coordinates: [] }))
  return data
}

export const useContour = (
  drawingArea: DrawingArea,
  contourLevels: number,
  initialScale: number,
  lineColor: string,
  colorMap: ColorMap,
  contour: IResultData
) => {
  useEffect(() => {
    if (!drawingArea) return
    const { g } = drawingArea
    g.selectAll('path').remove()

    const rawData = flatten([...contour.height_grid].reverse())
    const thresholds = calculateThresholds(rawData, contourLevels)
    const data = calculateData(rawData, thresholds)

    const getLevelColor = createGetLevelColor(thresholds, colorMap)

    g.selectAll('path')
      .data(data)
      .enter()
      .append('path')
      .attr('d', geoPath())
      .attr('fill', getLevelColor)
      .attr('stroke', lineColor)
      .attr('stroke-width', 0.5 / initialScale)
      .attr('stroke-linejoin', 'round')
  }, [drawingArea, contourLevels, contour.height_grid, initialScale, lineColor, colorMap])
}
