{-# LANGUAGE BangPatterns #-}

module Go5x5
  ( Go(..)
  , Move(..)
  , emptyGo
  ) where

import qualified Data.Vector as V
import qualified Data.Set as S
import Data.Maybe (catMaybes)
import Control.DeepSeq (NFData(..))
import MCTS

-- Board

data Cell = Empty | Black | White
  deriving (Eq, Show)

instance NFData Cell where
  rnf Empty = ()
  rnf Black = ()
  rnf White = ()

data Move
  = Place Int Int   -- (row, col)
  | Pass
  deriving (Eq, Ord, Show)

instance NFData Move where
  rnf (Place r c) = rnf r `seq` rnf c
  rnf Pass = ()

data Go = Go
  { board :: !(V.Vector Cell)  -- length 25
  , turn  :: !Player
  , passes :: !Int             -- consecutive passes
  } deriving (Eq, Show)

instance NFData Go where
  rnf (Go b t p) = rnf b `seq` rnf t `seq` rnf p

emptyGo :: Go
emptyGo = Go (V.replicate 25 Empty) P1 0


-- Game instance

instance Game Go Move where
  currentPlayer = turn

  legalMoves s =
    Pass :
    [ Place r c
    | r <- [0..4], c <- [0..4]
    , board s V.! idx r c == Empty
    , legalPlacement s r c
    ]

  applyMove s Pass =
    s { turn = other (turn s)
      , passes = passes s + 1
      }

  applyMove s (Place r c) =
    let p = turn s
        b0 = board s
        b1 = b0 V.// [(idx r c, stone p)]
        b2 = removeCaptured b1 (other p) r c
    in Go b2 (other p) 0

  terminal s
    | passes s >= 2 = Just (scorePosition s)
    | otherwise     = Nothing




idx :: Int -> Int -> Int
idx r c = r * 5 + c

stone :: Player -> Cell
stone P1 = Black
stone P2 = White

other :: Player -> Player
other P1 = P2
other P2 = P1



legalPlacement :: Go -> Int -> Int -> Bool
legalPlacement s r c =
  let p  = turn s
      b0 = board s
      b1 = b0 V.// [(idx r c, stone p)]
      b2 = removeCaptured b1 (other p) r c
  in hasLiberty b2 r c

removeCaptured :: V.Vector Cell -> Player -> Int -> Int -> V.Vector Cell
removeCaptured b p r c =
  foldl removeGroup b deadGroups
  where
    deadGroups =
      [ g
      | (r',c') <- neighbors r c
      , let g = groupAt b r' c'
      , not (S.null g)
      , stoneAt b r' c' == stone p
      , not (groupHasLiberty b g)
      ]

removeGroup :: V.Vector Cell -> S.Set Int -> V.Vector Cell
removeGroup b g =
  b V.// [ (i, Empty) | i <- S.toList g ]


-- Neighbors

neighbors :: Int -> Int -> [(Int,Int)]
neighbors r c =
  [ (r-1,c) | r > 0 ] ++
  [ (r+1,c) | r < 4 ] ++
  [ (r,c-1) | c > 0 ] ++
  [ (r,c+1) | c < 4 ]

stoneAt :: V.Vector Cell -> Int -> Int -> Cell
stoneAt b r c = b V.! idx r c

groupAt :: V.Vector Cell -> Int -> Int -> S.Set Int
groupAt b r c =
  case stoneAt b r c of
    Empty -> S.empty
    s     -> dfs S.empty [(r,c)] s
  where
    dfs seen [] _ = seen
    dfs seen ((x,y):qs) s
      | idx x y `S.member` seen = dfs seen qs s
      | stoneAt b x y /= s      = dfs seen qs s
      | otherwise =
          let seen' = S.insert (idx x y) seen
          in dfs seen' (neighbors x y ++ qs) s

groupHasLiberty :: V.Vector Cell -> S.Set Int -> Bool
groupHasLiberty b g =
  any isEmpty $
    [ stoneAt b r c
    | i <- S.toList g
    , let r = i `div` 5
    , let c = i `mod` 5
    , (r,c) <- neighbors r c
    ]
  where
    isEmpty Empty = True
    isEmpty _     = False

hasLiberty :: V.Vector Cell -> Int -> Int -> Bool
hasLiberty b r c =
  groupHasLiberty b (groupAt b r c)

-- Scoring

scorePosition :: Go -> Outcome
scorePosition s =
  case compare black white of
    GT -> Win P1
    LT -> Win P2
    EQ -> Draw
  where
    b = board s
    black = length [ () | x <- V.toList b, x == Black ]
    white = length [ () | x <- V.toList b, x == White ]
