module ParallelMCTS (rootNode, mctsAlgo, bestMove) where

import Board
import System.Random (randomRIO)
import Data.List (maximumBy)
import Control.Parallel.Strategies
import Control.Concurrent.Async (mapConcurrently)

data Node = Node {
    nodeId    :: Int,
    board     :: Board,
    color     :: Color,
    children  :: [Node],
    numVisits :: Int,
    numWins   :: Int 
} deriving (Show, Eq)

rootNode :: Board -> Color -> Node
rootNode b c = Node {
    nodeId    = 0,
    board     = b,
    color     = c,
    children  = [],
    numVisits = 0,
    numWins   = 0 }

otherColor :: Color -> Color
otherColor Red    = Yellow
otherColor Yellow = Red

expand :: Node -> Int -> ([Node], Int)
expand parent startId =
    let moves = availableCols (board parent)
        (cs, finalId) = foldl makeNode ([], startId) moves
    in (reverse cs, finalId)
    where
        makeNode (acc, currentId) move = 
            case placeTile (board parent) move (color parent) of
                Just newBoard -> 
                    let newNode = Node { 
                        nodeId    = currentId,
                        board     = newBoard,
                        color     = otherColor (color parent),
                        children  = [],
                        numVisits = 0,
                        numWins   = 0
                        }
                    in (newNode : acc, currentId + 1)
                Nothing -> (acc, currentId)  -- Skip invalid moves, don't increment ID

parallelPlay :: [(Board, Color)] -> IO [Color]
parallelPlay tasks = mapConcurrently (uncurry play) tasks

play :: Board -> Color -> IO Color
play b c = do
    case checkWin b of
        Just winner -> return winner
        Nothing ->
            if length (availableCols b) == 0
            then return c 
            else do
                let choices = availableCols b
                choice <- randomRIO (0, length choices - 1)
                let move = choices !! choice
                case placeTile b move c of
                    Just newBoard -> play newBoard (otherColor c)
                    Nothing -> play b c

mctsAlgo :: Node -> Int -> Int -> Int -> IO Node
mctsAlgo node _ _ 0 = return node
mctsAlgo node nextId plays n = do
    (newNode, newNextId) <- parallelOneMcts node nextId plays
    mctsAlgo newNode newNextId plays (n - 1)

score :: Int -> Node -> Double
score parentVisits node
    | numVisits node == 0 = (1 / 0) -- Set equal to infinity. Exlore unvisited!
    | otherwise = exploitation + exploration
    where
        exploitation = fromIntegral (numWins node) / fromIntegral (numVisits node)
        -- Chosing sqrt 2 as exploration parameter
        exploration = sqrt (2 * log (fromIntegral parentVisits) 
                                / fromIntegral (numVisits node))
    
selectNextNode :: Node -> Node
selectNextNode node = 
    case children node of
        [] -> node
        childList -> maximumBy compareScore childList
    where
        parentVisits = max 1 (numVisits node) -- prevents log(0) if numVisits was 0
        compareScore node1 node2 = compare (score parentVisits node1) (score parentVisits node2)

selectPath :: Node -> [Node]
selectPath node
    | null (children node) = [node]
    | otherwise = node : selectPath (selectNextNode node)

updateNode :: Color -> Node -> Node
updateNode winner node = node {
    numVisits = numVisits node + 1,
    numWins = numWins node + (if winner == (color node) then 1 else 0)
}

parallelOneMcts :: Node -> Int -> Int -> IO (Node, Int)
parallelOneMcts root nextId playCount = do
    let path = selectPath root
        leaf = last path
        (newChildren, finalId) = expand leaf nextId
    if null newChildren
    then do
        winners <- parallelPlay (replicate playCount (board leaf, color leaf))
        let updatedRoot = foldl (\tree w -> backpropagate w path) root winners
        return (updatedRoot, finalId)
    else do
        let childrenToAttach = take playCount newChildren
            leafWithChildren = leaf { children = childrenToAttach }
            fullPath = init path ++ [leafWithChildren]
            tasks = [ (board c, color c) | c <- childrenToAttach ]
        winners <- parallelPlay tasks
        let updatedTree =
                foldl (\tree w -> backpropagate w fullPath) root winners
        return (updatedTree, finalId)

backpropagate :: Color -> [Node] -> Node
backpropagate winner path =
    foldr updateNodeAlongPath (last path) (init path)
  where
    updateNodeAlongPath parent childWithUpdated =
        let updatedParent = parent {
                numVisits = numVisits parent + 1,
                numWins   = numWins parent + if winner == color parent then 1 else 0,
                children  = map (\c -> if nodeId c == nodeId childWithUpdated then childWithUpdated else c) (children parent)
            }
        in updatedParent

bestMove :: Node -> Maybe Node
bestMove root =
    case children root of
        [] -> Nothing
        childList -> Just $ maximumBy compareVisits childList
    where
        compareVisits node1 node2 = compare (numVisits node1) (numVisits node2)