module AI where

import           Board
import           Control.Parallel.Strategies
import           Data.List
import           Data.Maybe
import qualified Data.Set                    as Set
import           Data.Tree

minInt :: Int
minInt = -(2 ^ 29)

maxInt :: Int
maxInt = 2 ^ 29 - 1

generateMove :: Board -> Color -> Board
generateMove board color
  | isEmptyBoard board = addPointToBoard (Point color ((row board) `div` 2, (col board) `div` 2)) board
  -- | isEmptyBoard board = addPointToBoard (Point color (1,1)) board
  | otherwise = bestMove
  where
    neighbors = nextMoves board
    (Node node children) = buildTree color board neighbors
    minmax = parMap rdeepseq (minBeta color 3 minInt maxInt) children
    -- minmax = map (minBeta color 3 minInt maxInt) children
    index = fromJust $ elemIndex (maximum minmax) minmax
    (Node bestMove _) = children !! index


-- generate possible moves for the player
nextMoves :: Board -> [Point]
nextMoves board = Set.toList $ stepBoard board $ filterBoard board White ++ filterBoard board Black

stepBoard :: Board -> [Point] -> Set.Set Point
stepBoard _ [] = Set.empty
stepBoard board (point:rest) = Set.union (Set.fromList (stepFromPoint board point)) $ stepBoard board rest

stepFromPoint :: Board -> Point -> [Point]
stepFromPoint board (Point _ (x, y)) =
  [ Point Empty (x + xDir, y + yDir)
  | xDir <- [-1 .. 1]
  , yDir <- [-1 .. 1]
  , not (xDir == 0 && yDir == 0)
  , isValid (Point Empty (x + xDir, y + yDir)) board
  , isEmpty (Point Empty (x + xDir, y + yDir)) board
  ]

buildTree :: Color -> Board -> [Point] -> Tree Board
buildTree color board neighbors = Node board $ children neighbors
  where
    newNeighbors point =
      Set.toList $
      Set.union (Set.fromList (Data.List.delete point neighbors)) (Set.fromList (stepFromPoint board point))
    oppoColor = oppositeColor color
    children [] = []
    children (Point c (x, y):ns) =
      buildTree oppoColor (addPointToBoard (Point color (x,y)) board) (newNeighbors (Point c (x, y))) : children ns

maxAlpha :: Color -> Int -> Int -> Int -> Tree Board -> Int
maxAlpha _ _ alpha _ (Node _ []) = alpha
maxAlpha color level alpha beta (Node b (x:xs))
  | level == 0 = curScore
  | canFinish curScore = curScore
  | newAlpha >= beta = beta
  | otherwise = maxAlpha color level newAlpha beta (Node b xs)
  where
    curScore = scoreBoard b color
    canFinish score = score > 100000 || score < (-100000)
    newAlpha = maximum [alpha, minBeta color (level - 1) alpha beta x]

minBeta :: Color -> Int -> Int -> Int -> Tree Board -> Int
minBeta _ _ _ beta (Node _ []) = beta
minBeta color level alpha beta (Node b (x:xs))
  | level == 0 = curScore
  | canFinish curScore = curScore
  | alpha >= newBeta = alpha
  | otherwise = minBeta color level alpha newBeta (Node b xs)
  where
    curScore = scoreBoard b color
    canFinish score = score > 100000 || score < (-100000)
    newBeta = minimum [beta, maxAlpha color (level - 1) alpha beta x]

scoreBoard :: Board -> Color -> Int
scoreBoard board color = score (pointsOfColor color) - score (pointsOfColor $ oppositeColor color)
  where
    -- score points = sum $ map sumScores $ scoreDirections points
    score points = sum $ parMap rdeepseq sumScores $ scoreDirections points
    pointsOfColor = filterBoard board

sumScores :: [Int] -> Int
sumScores [] = 0
sumScores (x:xs)
  | x == 5 = 100000 + sumScores xs
  | x == 4 = 5000 + sumScores xs
  | x == 3 = 300 + sumScores xs
  | x == 2 = 10 + sumScores xs
  | otherwise = sumScores xs

scoreDirections :: [Point] -> [[Int]]
scoreDirections [] = [[0]]
scoreDirections ps@(point:rest) = parMap rdeepseq (scoreDirection point ps 0) [(xDir, yDir) | xDir <- [0 .. 1], yDir <- [-1 .. 1], not (xDir == 0 && yDir == (-1)), not (xDir == 0 && yDir == 0)]
-- scoreDirections ps@(point:rest) = map (scoreDirection point ps 0) [(xDir, yDir) | xDir <- [0 .. 1], yDir <- [-1 .. 1], not (xDir == 0 && yDir == (-1)), not (xDir == 0 && yDir == 0)]

scoreDirection :: Point -> [Point] -> Int -> (Int, Int) -> [Int]
scoreDirection _ [] cont (_, _) = [cont]
scoreDirection (Point c (x, y)) ps@(Point c1 (x1, y1):rest) cont (xDir, yDir)
  | Point c (x, y) `elem` ps = scoreDirection (Point c (x + xDir, y + yDir)) (Data.List.delete (Point c (x, y)) ps) (cont + 1) (xDir, yDir)
  | otherwise = cont : scoreDirection (Point c1 (x1, y1)) rest 1 (xDir, yDir)

