{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FunctionalDependencies #-}

module MCTS
  ( Player(..)
  , Outcome(..)
  , Game(..)
  , mctsBestMoveParallel
  , chunk
  ) where

import Control.DeepSeq (NFData(..))
import Control.Parallel.Strategies
import Data.List (maximumBy)
import Data.Ord (comparing)
import qualified Data.Map.Strict as M
import System.Random

-- Game Interface

data Player = P1 | P2 deriving (Eq, Ord, Show)

instance NFData Player where
  rnf P1 = ()
  rnf P2 = ()

data Outcome
  = Win Player
  | Draw
  deriving (Eq, Show)

class (Eq s, Show s, Eq m, Ord m, Show m) => Game s m | s -> m where
  currentPlayer :: s -> Player
  legalMoves    :: s -> [m]
  applyMove     :: s -> m -> s
  terminal      :: s -> Maybe Outcome

-- Tree

data Node s m = Node
  { st          :: !s
  , untried     :: ![m]
  , kids        :: !(M.Map m (Node s m))
  , visits      :: !Int
  , winsForP1   :: !Double  -- store wins from P1
  } deriving (Show)

instance (NFData s, NFData m) => NFData (Node s m) where
  rnf (Node a b c d e) = rnf a `seq` rnf b `seq` rnf c `seq` rnf d `seq` rnf e

newNode :: Game s m => s -> Node s m
newNode s = Node
  { st        = s
  , untried   = legalMoves s
  , kids      = M.empty
  , visits    = 0
  , winsForP1 = 0
  }


uctC :: Double
uctC = 1.41421356237  -- sqrt(2)

childUctScore :: Player -> Int -> Node s m -> Double
childUctScore playerToMove parentVisits child =
  let v  = fromIntegral (visits child) :: Double
      wP1 = winsForP1 child
      exploit =
        case playerToMove of
          P1 -> wP1 / v
          P2 -> (v - wP1) / v
      explore = uctC * sqrt (log (fromIntegral parentVisits) / v)
  in exploit + explore

selectChild :: Game s m => Node s m -> (m, Node s m)
selectChild node =
  let p = currentPlayer (st node)
      pv = max 1 (visits node)
  in maximumBy (comparing (childUctScore p pv . snd)) (M.toList (kids node))

-- Simulation

rollout :: Game s m => s -> StdGen -> (Outcome, StdGen)
rollout s0 g0 =
  case terminal s0 of
    Just out -> (out, g0)
    Nothing  ->
      let ms = legalMoves s0
      in if null ms
         then (Draw, g0)
         else
           let (i, g1) = randomR (0, length ms - 1) g0
               m = ms !! i
           in rollout (applyMove s0 m) g1

outcomeToP1Score :: Outcome -> Double
outcomeToP1Score (Win P1) = 1
outcomeToP1Score (Win P2) = 0
outcomeToP1Score Draw     = 0.5

-- Full Iteration

mctsIter :: forall s m. Game s m => Node s m -> StdGen -> (Node s m, StdGen)
mctsIter root g0 =
  let (root', out, g1) = go root g0
      score = outcomeToP1Score out
      root'' = root' { visits = visits root' + 1
                     , winsForP1 = winsForP1 root' + score
                     }
  in (root'', g1)
  where
    go :: Node s m -> StdGen -> (Node s m, Outcome, StdGen)
    go node g =
      case terminal (st node) of
        Just out -> (node, out, g)
        Nothing ->
          -- Expand
          case untried node of
            (m:ms) ->
              let s' = applyMove (st node) m
                  child0 = newNode s'
                  nodeExpanded = node { untried = ms
                                      , kids = M.insert m child0 (kids node)
                                      }
                  (out, g') = rollout s' g
                  score = outcomeToP1Score out
                  -- Backprop
                  child1 = child0 { visits = visits child0 + 1
                                  , winsForP1 = winsForP1 child0 + score
                                  }
                  node' = nodeExpanded
                          { kids = M.insert m child1 (kids nodeExpanded)
                          , visits = visits nodeExpanded + 1
                          , winsForP1 = winsForP1 nodeExpanded + score
                          }
              in (node', out, g')
            [] ->
              -- Select
              let (m, child) = selectChild node
                  (child', out, g') = go child g
                  score = outcomeToP1Score out
                  node' = node { kids = M.insert m child' (kids node)
                               , visits = visits node + 1
                               , winsForP1 = winsForP1 node + score
                               }
              in (node', out, g')

-- Run iterations

runMcts :: Game s m => Int -> s -> StdGen -> Node s m
runMcts iters s g0 = loop iters (newNode s) g0
  where
    loop 0 !t !_ = t
    loop k !t !g =
      let (t', g') = mctsIter t g
      in loop (k-1) t' g'


-- Results Merge

type Stats = (Double, Int) -- (winsForP1, visits)

rootStats :: Game s m => Node s m -> M.Map m Stats
rootStats t =
  M.map (\c -> (winsForP1 c, visits c)) (kids t)

mergeStats :: Ord m => M.Map m Stats -> M.Map m Stats -> M.Map m Stats
mergeStats = M.unionWith (\(w1,v1) (w2,v2) -> (w1+w2, v1+v2))

bestMoveFromStats :: Game s m => s -> M.Map m Stats -> m
bestMoveFromStats s mp
  | M.null mp = error "No legal moves."
  | otherwise =
      let p = currentPlayer s
          score (w, v) =
            let vv = max 1 v
                winRateP1 = w / fromIntegral vv
            in case p of
                 P1 -> winRateP1
                 P2 -> 1 - winRateP1
      in fst $ maximumBy (comparing (score . snd)) (M.toList mp)

-- Parallel MCTS

mctsBestMoveParallel :: (Game s m, NFData s, NFData m)
                     => Int -> Int -> Int -> s -> IO m
mctsBestMoveParallel workers chunkSize itersPerWorker s0 = do
  base <- newStdGen

  let gens    = take workers (splitGens base)
      batches = chunk chunkSize gens

      -- Each spark manage multiple workers
      partialStats =
        withStrategy (parList rdeepseq)
          [ foldl mergeStats M.empty
              [ runMctsRootStats itersPerWorker s0 g
              | g <- gs
              ]
          | gs <- batches
          ]

      finalStats = foldl mergeStats M.empty partialStats

  pure (bestMoveFromStats s0 finalStats)




runMctsRootStats :: Game s m => Int -> s -> StdGen -> M.Map m Stats
runMctsRootStats iters s g = rootStats (runMcts iters s g)

splitGens :: StdGen -> [StdGen]
splitGens g = let (g1, g2) = splitGen g in g1 : splitGens g2

chunk :: Int -> [a] -> [[a]]
chunk _ [] = []
chunk k xs =
  let (a,b) = splitAt k xs
  in a : chunk k b
