module Terrain where

import Data.List (find)
import Numeric.Noise.Perlin
import Codec.Picture
import Control.Parallel.Strategies (parMap, rdeepseq)

-- Heatmap rendering function
heatmap :: [(Double, PixelRGB8)] -> Double -> PixelRGB8
heatmap thresholds value = case find match thresholds of
  Just (_, colour) -> colour
  Nothing          -> PixelRGB8 0 0 0
  where
    match (threshold, _) = value > threshold

-- Generate Perlin noise
-- generateNoise :: Int -> Int -> Perlin -> [[Double]]
-- generateNoise width height noise = parMap rdeepseq (map noiseFn) coords
--   where
--     coords = [[(x, y) | x <- [0..width-1]] | y <- [0..height-1]]
--     noiseFn (x, y) = noiseValue noise (fromIntegral x, fromIntegral y, 0)
-- Generate Perlin noise

-- Splitting into chunks to parallelize even better?
generateNoise :: Int -> Int -> Perlin -> [[Double]]
generateNoise width height noise = parMap rdeepseq processChunk chunks
  where
    coords = [[(x, y) | x <- [0..width-1]] | y <- [0..height-1]]
    chunks = chunkList numChunks coords
    numChunks = 4 -- or any number depending on your parallelization needs
    processChunk = concatMap (map noiseFn)
    noiseFn (x, y) = noiseValue noise (fromIntegral x, fromIntegral y, 0)

-- Function to split a list into n chunks
chunkList :: Int -> [a] -> [[a]]
chunkList n xs = go n (length xs) xs
  where
    go _ _ [] = []
    go k l ys = let size = (l + k - 1) `div` k
                in take size ys : go (k-1) (l-size) (drop size ys)


-- Parallel version with convoluted smoothing
generateSmoothNoise :: Int -> Int -> Int -> Perlin -> [[Double]]
generateSmoothNoise width height iterations noise = iterateSmooth iterations initialNoise
  where
    initialNoise = generateNoiseSeq width height noise
    iterateSmooth 0 noiseData = noiseData
    iterateSmooth n noiseData = iterateSmooth (n - 1) (parMap rdeepseq (map (smoothNoise width height)) coords)
      where
        coords = [[(x, y) | x <- [0..width-1]] | y <- [0..height-1]]
        smoothNoise w h (x, y) = averageSurroundingNoise x y noiseData width height

averageSurroundingNoise :: Int -> Int -> [[Double]] -> Int -> Int -> Double
averageSurroundingNoise x y noiseData width height = let
    points = [(dx, dy) | dx <- [-1..1], dy <- [-1..1], inBounds (x + dx) (y + dy) width height]
    total = sum [noiseData !! (y + dy) !! (x + dx) | (dx, dy) <- points]
    avg = total / fromIntegral (length points)
  in avg

inBounds :: Int -> Int -> Int -> Int -> Bool
inBounds x y width height = x >= 0 && y >= 0 && x < width && y < height

generateSmoothNoiseSeq :: Int -> Int -> Int -> Perlin -> [[Double]]
generateSmoothNoiseSeq width height iterations noise = iterateSmooth iterations initialNoise
  where
    initialNoise = generateNoiseSeq width height noise
    iterateSmooth 0 noiseData = noiseData
    iterateSmooth n noiseData = iterateSmooth (n - 1) (map (map (smoothNoise width height)) coords)
      where
        coords = [[(x, y) | x <- [0..width-1]] | y <- [0..height-1]]
        smoothNoise w h (x, y) = averageSurroundingNoise x y noiseData width height


generateNoiseSeq :: Int -> Int -> Perlin -> [[Double]]
generateNoiseSeq width height noise = map (map noiseFn) coords
  where
    coords = [[(x, y) | x <- [0..width-1]] | y <- [0..height-1]]
    noiseFn (x, y) = noiseValue noise (fromIntegral x, fromIntegral y, 0)

-- Render heatmap from noise
renderHeatMap :: Int -> Int -> [[Double]] -> Image PixelRGB8
renderHeatMap width height noiseData = generateImage pixelRenderer width height
  where
    pixelRenderer x y = heatmap thresholds (noiseData !! y !! x)
    thresholds = [snow, mountains, forest, land, sand, shallowWater, depths]
    snow         = (0.85, PixelRGB8 255 255 255)
    mountains    = ( 0.5, PixelRGB8 200 200 200)
    forest       = ( 0.1, PixelRGB8 116 151  62)
    land         = ( 0, PixelRGB8 139 181  74)
    sand         = ( -0.1, PixelRGB8 227 221 188)
    shallowWater = ( -2, PixelRGB8 156 213 226)
    depths       = ( -25, PixelRGB8  74 138 125)
    -- Add other thresholds here

