module Parallel
  ( runSimulationParallel
  , runSimulationRedBlack
  , simulationStepParallel
  , simulationStepRedBlack
  ) where

import Types
import Grid
import FireRules
import Serial (igniteInitialPoints)
import qualified Data.Vector as V
import System.Random (mkStdGen, RandomGen, split)
import Control.DeepSeq (force)
import Control.Parallel.Strategies

-- Strategy 1: Spatial Decomposition (Domain Decomposition)

-- | Run simulation with parallel spatial decomposition
runSimulationParallel :: SimConfig -> Grid -> (Grid, [SimStats])
runSimulationParallel config initialGrid =
  let gridWithFire = igniteInitialPoints (ignitionPoints config) initialGrid
      gen = mkStdGen 42
  in runStepsParallel config gridWithFire gen 0 []

-- | Run steps with parallel processing
runStepsParallel :: RandomGen g => SimConfig -> Grid -> g -> Int -> [SimStats] -> (Grid, [SimStats])
runStepsParallel config grid gen step stats
  | step >= maxTimeSteps config = (grid, reverse stats)
  | otherwise =
      let burning = countCellsByState grid Burning
          burned = countCellsByState grid Burned
          currentStats = SimStats step burning burned burned
          stats' = currentStats : stats
      in if burning == 0 && step > 0
         then (grid, reverse stats')
         else
           let (grid', gen') = simulationStepParallel config grid gen
           in runStepsParallel config grid' gen' (step + 1) stats'

-- | Parallel simulation step using spatial decomposition
simulationStepParallel :: RandomGen g => SimConfig -> Grid -> g -> (Grid, g)
simulationStepParallel config grid gen =
  let rows = V.length grid
      cols = if rows > 0 then V.length (grid V.! 0) else 0
      
      -- Process all rows in chunks
      allPositions = [(r, c) | r <- [0..rows-1], c <- [0..cols-1]]
      (newGrid, genFinal) = processPositionsParallel config grid gen allPositions
      
  in (force newGrid, genFinal)

-- | Process positions in parallel chunks
processPositionsParallel :: RandomGen g => SimConfig -> Grid -> g -> [(Int, Int)] -> (Grid, g)
processPositionsParallel config originalGrid gen positions =
  let numChunks = 8
      chunkSize = max 1 (length positions `div` numChunks)
      chunks = chunksOf chunkSize positions
      gens = splitGen gen (length chunks)
      
      -- Process each chunk in parallel
      results = parMap rdeepseq 
                  (\(chunk, g) -> processChunk config originalGrid chunk g) 
                  (zip chunks gens)
      
      -- Merge all results
      newGrid = foldl mergeChunkResult originalGrid results
      genFinal = last gens
  in (newGrid, genFinal)

-- | Process a chunk of positions
processChunk :: RandomGen g => SimConfig -> Grid -> [(Int, Int)] -> g -> [(Int, Int, Cell)]
processChunk config originalGrid positions gen =
  let (results, _) = foldlWithGen (processPosition config originalGrid) gen positions
  in [res | Just res <- results]

-- | Process a single position
processPosition :: RandomGen g => SimConfig -> Grid -> (Int, Int) -> g -> (Maybe (Int, Int, Cell), g)
processPosition config originalGrid pos gen =
  case getCell originalGrid pos of
    Nothing -> (Nothing, gen)
    Just cell ->
      let (maybeCell, gen') = processOneCell config originalGrid pos cell gen
      in case maybeCell of
        Nothing -> (Nothing, gen')
        Just newCell -> (Just (fst pos, snd pos, newCell), gen')

-- | Process a single cell and return updated cell if changed
processOneCell :: RandomGen g => SimConfig -> Grid -> (Int, Int) -> Cell -> g -> (Maybe Cell, g)
processOneCell config originalGrid pos cell gen =
  case cellState cell of
    Unburned ->
      if fuelLevel cell <= 0.0 || terrainType cell == Water
      then (Nothing, gen)
      else
        let neighbors = getNeighbors8 originalGrid pos
            prob = calculateIgnitionProbability config cell neighbors originalGrid
            (ignites, gen') = shouldIgnite config prob gen
        in if ignites
           then (Just (cell { cellState = Burning, burnSteps = 0 }), gen')
           else (Nothing, gen')
    Burning ->
      let maxBurnSteps = case terrainType cell of
            Forest    -> 3
            Grassland -> 2
            _         -> 1
          updatedCell = updateCellState cell maxBurnSteps
      in if updatedCell /= cell
         then (Just updatedCell, gen)
         else (Nothing, gen)
    Burned -> (Nothing, gen)

-- | Merge chunk results back into grid
mergeChunkResult :: Grid -> [(Int, Int, Cell)] -> Grid
mergeChunkResult grid updates =
  foldl (\g (r, c, cell) -> setCell g (r, c) cell) grid updates

-- | Helper: fold with generator threading
foldlWithGen :: (a -> g -> (b, g)) -> g -> [a] -> ([b], g)
foldlWithGen _ g [] = ([], g)
foldlWithGen f g (x:xs) =
  let (y, g') = f x g
      (ys, g'') = foldlWithGen f g' xs
  in (y:ys, g'')

-- | Split a random generator into multiple generators
splitGen :: RandomGen g => g -> Int -> [g]
splitGen gen n
  | n <= 0 = []
  | n == 1 = [gen]
  | otherwise =
      let (g1, g2) = split gen
      in g1 : splitGen g2 (n - 1)

-- | Split list into chunks
chunksOf :: Int -> [a] -> [[a]]
chunksOf _ [] = []
chunksOf n xs = take n xs : chunksOf n (drop n xs)

 -- Strategy 2: Red-Black Checkerboard
 
-- | Run simulation with red-black checkerboard parallelization
runSimulationRedBlack :: SimConfig -> Grid -> (Grid, [SimStats])
runSimulationRedBlack config initialGrid =
  let gridWithFire = igniteInitialPoints (ignitionPoints config) initialGrid
      gen = mkStdGen 42
  in runStepsRedBlack config gridWithFire gen 0 []

-- | Run steps with red-black processing
runStepsRedBlack :: RandomGen g => SimConfig -> Grid -> g -> Int -> [SimStats] -> (Grid, [SimStats])
runStepsRedBlack config grid gen step stats
  | step >= maxTimeSteps config = (grid, reverse stats)
  | otherwise =
      let burning = countCellsByState grid Burning
          burned = countCellsByState grid Burned
          currentStats = SimStats step burning burned burned
          stats' = currentStats : stats
      in if burning == 0 && step > 0
         then (grid, reverse stats')
         else
           let (grid', gen') = simulationStepRedBlack config grid gen
           in runStepsRedBlack config grid' gen' (step + 1) stats'

-- | Red-black parallel simulation step
simulationStepRedBlack :: RandomGen g => SimConfig -> Grid -> g -> (Grid, g)
simulationStepRedBlack config grid gen =
  let rows = V.length grid
      cols = if rows > 0 then V.length (grid V.! 0) else 0
      
      -- Get red and black cell positions
      redCells = [(r, c) | r <- [0..rows-1], c <- [0..cols-1], (r + c) `mod` 2 == 0]
      blackCells = [(r, c) | r <- [0..rows-1], c <- [0..cols-1], (r + c) `mod` 2 == 1]
      
      -- Split generator
      (genRed, genBlack) = split gen
      
      -- Process red cells in parallel
      (grid1, genRed') = processPositionsParallel config grid genRed redCells
      
      -- Process black cells in parallel
      (grid2, genBlack') = processPositionsParallel config grid1 genBlack blackCells
      
  in (force grid2, genBlack')