{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE PackageImports #-}

module Lib
    (
        -- app
        gomokuMain
        -- testing
        , Element (Empty, Black, White)
        , Board
        , showBoard
        , getChildren
        , initializeBoard
        , move
        , isTerminal
        , heuristic
        , scoreLine2
        , scoreLine3
        , scoreLine4
        , scoreLine5
        , loopSerial
        , loopPar
    ) where

import Data.List (sortBy)
import Data.Maybe
import qualified Data.HashSet as HSet
import qualified Data.Matrix as M
import Control.Parallel.Strategies
import Control.DeepSeq
import System.Environment (getArgs)
import System.Exit (die)

addTuple :: (Int, Int) -> (Int, Int) -> (Int, Int)
addTuple (a, b) (c, d) = (a + c, b + d)

multTuple :: Int -> (Int, Int) -> (Int, Int)
multTuple s (a, b) = (a*s, b*s)

generateNeighbors :: HSet.HashSet (Int, Int) -> Int -> (Int, Int) -> HSet.HashSet (Int, Int)
generateNeighbors availableSpaces amount position = HSet.filter (`HSet.member` availableSpaces) possibleNeighbors
  where possibleNeighbors = HSet.fromList $ map (addTuple position) directions ++
                                            map (addTuple position . multTuple amount) directions
        directions = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, 1), (1, -1), (-1, -1), (1, 1)]

data Element = Empty | Black | White deriving (Enum, Show, Eq)

toElement :: Int -> Element
toElement i = toEnum i :: Element

type StonePosition = (Int, Int)
type StoneSet = HSet.HashSet StonePosition
type Matrix = M.Matrix Int

data Move = Move
    { moveColor :: Element
    , movePosition :: StonePosition
    }

data Board = Board
    { matrix :: Matrix
    , blackStones :: StoneSet
    , whiteStones :: StoneSet
    , stones :: StoneSet
    , mostRecentMove :: Move
    }

instance NFData Board where
  rnf b = b `seq` ()

showBoard :: Board -> Matrix
showBoard = matrix

isWithinBounds :: (Int, Int) -> Bool
isWithinBounds (a, b) = a >= 0 && a <= 8 && b >= 0 && b <= 8

isAvailable :: Board -> StonePosition -> Bool
isAvailable board position = (not $ HSet.member position $ stones board) && isWithinBounds position

move :: Board -> Element -> (Int, Int) -> Board
move board color pos@(x, y) = Board m' b' w' s' (Move color pos)
  where i = fromEnum color
        m = matrix board
        b = blackStones board
        w = whiteStones board
        s = stones board
        m' = M.setElem i (x+1, y+1) m
        s' = HSet.insert pos s
        b' = if color == Black then HSet.insert pos b else b
        w' = if color == White then HSet.insert pos w else w

initializeBoard :: (Int, Int) -> Board
initializeBoard = move (Board m b w s startMove) Black
  where b = HSet.fromList []
        w = HSet.fromList []
        s = HSet.fromList []
        m = M.fromList 15 15 (repeat 0)
        startMove = Move Empty (-1, -1)

getStoneChildren :: Board -> StonePosition -> HSet.HashSet (Int, Int)
getStoneChildren board position = HSet.filter (isAvailable board) $ generateNeighbors allSpaces 1 position
  where allSpaces = HSet.fromList allPositions
        allPositions = [(i, j) | i <- [0..8], j <- [0..8]]

childUnion :: [HSet.HashSet (Int, Int)] -> HSet.HashSet (Int, Int)
childUnion [] = HSet.fromList []
childUnion (x:xs) = foldr HSet.union x xs

getChildren :: Board -> Element -> [Board]
getChildren board color = map (move board color) newPositions
  where setList b = map (getStoneChildren b) $ HSet.toList $ stones b
        newPositions = HSet.toList $ childUnion $ setList board

get :: Matrix -> (Int, Int) -> Maybe Int
get m (x, y) = M.safeGet (x+1) (y+1) m

oppositeColor :: Element -> Element
oppositeColor color = if color == Black then White else Black

goInDirHelper :: Matrix -> [Int] -> (Int, Int) -> (Int, Int) -> Element -> [Int]
goInDirHelper m l pos dir color
    | stop = r : l
    | stopBorder = l
    | otherwise = goInDirHelper m (r : l) (addTuple pos dir) dir color
  where stop = r == fromEnum (oppositeColor color) || r == 0
        stopBorder = r == -1
        r = fromMaybe (-1) $ get m pos

goInDir :: M.Matrix Int -> (Int, Int) -> (Int, Int) -> Element -> [Element]
goInDir m pos dir color = map toElement $ init (goInDirHelper m [] pos (multTuple (-1) dir) color) ++ reverse (goInDirHelper m [] pos dir color)

scoreLine2 :: Element -> [Element] -> Int
scoreLine2 color line
    | length line == 3 = helper3 line
    | length line == 4 = helper4 line
    | otherwise = 0
  where helper3 l
            | l == [Empty, color, color] || l == [color, color, Empty] = 50
            | otherwise = 0

        helper4 l
            | l == [Empty, color, color, Empty] = 100
            | l == [Empty, color, color, oppositeColor color] ||
              l == [oppositeColor color, color, Empty] = 50
            | otherwise = 0

scoreLine3 :: Element -> [Element] -> Int
scoreLine3 color line
    | length line == 4 = helper4 line
    | length line == 5 = helper5 line
    | otherwise = 0
  where helper4 l
            | l == [Empty, color, color, color] || l == [color, color, color, Empty] = 250
            | otherwise = 0

        helper5 l
            | l == [Empty, color, color, color, Empty] = 500
            | l == [Empty, color, color, color, oppositeColor color] ||
              l == [oppositeColor color, color, color, color, Empty] = 250
            | otherwise = 0

scoreLine4 :: Element -> [Element] -> Int
scoreLine4 color line
    | length line == 5 = helper5 line
    | length line == 6 = helper6 line
    | otherwise = 0
  where helper5 l
            | l == [Empty, color, color, color, color] || l == [color, color, color, color, Empty] = 500000
            | otherwise = 0

        helper6 l
            | l == [Empty, color, color, color, color, Empty] = 1000000
            | l == [Empty, color, color, color, color, oppositeColor color] ||
              l == [oppositeColor color, color, color, color, color, Empty] = 500000
            | otherwise = 0

scoreLine5 :: Element -> [Element] -> Int
scoreLine5 color line
    | length line >= 5 && length line <= 7 = helper line
    | otherwise = 0
  where helper [] = 0
        helper [_] = 0
        helper [_, _] = 0
        helper [_, _, _] = 0
        helper [_, _, _, _] = 0
        helper l@(a:b:c:d:e:_)
            | [a, b, c, d, e] == [color, color, color, color, color] = 10000000000
            | otherwise = scoreLine5 color (tail l)

halfDirections :: [(Int, Int)]
halfDirections = [(1, 0), (0, 1), (1, 1), (1, -1)]

reduce :: [Int] -> [Int] -> [Int] -> [Int] -> Int
reduce two three four five = ((sum two) `div` 2) + ((sum three) `div` 3) + ((sum four) `div` 4) + ((sum five) `div` 5)

heuristic :: Board -> Bool -> Int
heuristic board isSerial = 2*blackCount - whiteCount
  where m = matrix board
        colorLine (pos, dir) = goInDir m pos dir

        blackLines = [colorLine (pos, dir) Black | pos <- HSet.toList $ blackStones board, dir <- halfDirections]
        black2Serial = map (scoreLine2 Black) blackLines
        black3Serial = map (scoreLine3 Black) blackLines
        black4Serial = map (scoreLine4 Black) blackLines
        black5Serial = map (scoreLine5 Black) blackLines

        black2Par = parMap (rpar . force) (scoreLine2 Black) blackLines
        black3Par = parMap (rpar . force) (scoreLine3 Black) blackLines
        black4Par = parMap (rpar . force) (scoreLine4 Black) blackLines
        black5Par = parMap (rpar . force) (scoreLine5 Black) blackLines
        
        blackCount = if isSerial
                        then reduce black2Serial black3Serial black4Serial black5Serial
                        else reduce black2Par black3Par black4Par black5Par

        whiteLines = [colorLine (pos, dir) White | pos <- HSet.toList $ whiteStones board, dir <- halfDirections]

        white2Serial = map (scoreLine2 White) whiteLines
        white3Serial = map (scoreLine3 White) whiteLines
        white4Serial = map (scoreLine4 White) whiteLines
        white5Serial = map (scoreLine5 White) whiteLines

        white2Par = parMap (rpar . force) (scoreLine2 White) whiteLines
        white3Par = parMap (rpar . force) (scoreLine3 White) whiteLines
        white4Par = parMap (rpar . force) (scoreLine4 White) whiteLines
        white5Par = parMap (rpar . force) (scoreLine5 White) whiteLines

        whiteCount = if isSerial
                        then reduce white2Serial white3Serial white4Serial white5Serial
                        else reduce white2Par white3Par white4Par white5Par

isTerminal :: Board -> Bool
isTerminal board = elem 10000000000 $ map (scoreLine5 color) colorLines
  where m = matrix board
        r = mostRecentMove board
        (p, color) = (movePosition r, moveColor r)
        colorLine (pos, dir) = goInDir m pos dir
        colorLines = [colorLine (p, dir) color | dir <- halfDirections]

infinity :: Int
infinity = maxBound :: Int

-- Based on the "star lines" of http://www.cs.columbia.edu/~sedwards/classes/2021/4995-fall/reports/Gomokururu.pdf
recentMoveHeuristic :: Board -> Int
recentMoveHeuristic board = colorCount
  where m = matrix board
        r = mostRecentMove board
        (p, color) = (movePosition r, moveColor r)
        colorLine (pos, dir) = goInDir m pos dir
        colorLines = [colorLine (p, dir) color | dir <- halfDirections]
        color2 = map (scoreLine2 color) colorLines
        color3 = map (scoreLine3 color) colorLines
        color4 = map (scoreLine4 color) colorLines
        color5 = map (scoreLine5 color) colorLines
        colorCount = reduce color2 color3 color4 color5

orderMoves :: Bool -> [Board] -> [Board]
orderMoves isSerial moves = result
  where hmoves = zip heuristics moves
        sortedMoves = sortBy compareHeuristic hmoves
        compareHeuristic (ha, _) (hb, _)
            | ha > hb = LT
            | otherwise = GT
        extractMoves (_, m) = m
        heuristics = if isSerial
                        then map recentMoveHeuristic moves
                        else parMap (rpar . force) recentMoveHeuristic moves
        result = if isSerial
                    then map extractMoves sortedMoves
                    else parMap (rpar . force) extractMoves sortedMoves

minimax :: Board -> Int -> Int -> Int -> Element -> Bool -> (Int, Board)
minimax board depth alpha beta color isSerial
    | depth == 0 || isTerminal board = (h, board)
    | color == Black = playBlack (-infinity) board alpha beta children
    | otherwise = playWhite infinity board alpha beta children
  where children = orderMoves isSerial $ getChildren board color
        h = heuristic board isSerial

        playBlack maxValue maxChild _ _ [] = (maxValue, maxChild)
        playBlack maxValue maxChild a b (c:cs) =
            let (pvalue, _) = minimax c (depth-1) a b White isSerial
                comparison = pvalue > maxValue
                (maxValue', maxChild') = if comparison then (pvalue, c) else (maxValue, maxChild)
                a' = max a maxValue'
            in if maxValue >= b
                then (maxValue', maxChild') -- break loop
                else playBlack maxValue' maxChild' a' b cs -- continue loop
        
        playWhite minValue minChild _ _ [] = (minValue, minChild)
        playWhite minValue minChild a b (c:cs) =
            let (pvalue, _) = minimax c (depth-1) a b Black isSerial
                comparison = pvalue < minValue
                (minValue', minChild') = if comparison then (pvalue, c) else (minValue, minChild)
                b' = min b minValue'
            in if minValue <= a
                then (minValue', minChild') -- break loop
                else playWhite minValue' minChild' a b' cs -- continue loop

chooseMove :: Element -> [(Int, Board)] -> (Int, Board)
chooseMove color moves = if color == Black then last sortedMoves else head sortedMoves
  where sortedMoves = sortBy compareHeuristic moves
        compareHeuristic (ha, _) (hb, _)
            | ha > hb = GT
            | otherwise = LT

parmapMinimax :: Int -> Board -> Element -> [(Int, Board)]
parmapMinimax depth board color
    | depth == 0 = parMap (rpar . force) play children
    -- playP was used during debugging, but I found that partial parallelization beyond one level didn't help
    | otherwise = parMap (rpar . force) playP children
  where children = getChildren board color
        play child = (fst $ minimax child 4 (-infinity) infinity (oppositeColor color) True, child)
        -- playP was used during debugging, but I found that partial parallelization beyond one level didn't help
        playP child = (fst $ chooseMove color $ parmapMinimax (depth-1) child $ oppositeColor color, child)

mapMinimax :: Board -> Element -> [(Int, Board)]
mapMinimax board color = map play children
  where children = getChildren board color
        play child = (fst $ minimax child 4 (-infinity) infinity (oppositeColor color) True, child)

loopNoMap :: Board -> Element -> Int -> [Board] -> [Board]
loopNoMap board color n boards
    | n == 0 = reverse boards
    | otherwise = loopNoMap next (oppositeColor color) (n-1) (next : boards)
  where next = snd $ minimax board 5 (-infinity) infinity color True

loopSerial :: Board -> Element -> Int -> [Board] -> [Board]
loopSerial board color n boards
    | n == 0 = reverse boards
    | otherwise = loopSerial next (oppositeColor color) (n-1) (next : boards)
  where next = snd $ chooseMove color $ mapMinimax board color

loopPar :: Board -> Element -> Int -> [Board] -> [Board]
loopPar board color n boards
    | n == 0 = reverse boards
    | otherwise = loopPar next (oppositeColor color) (n-1) (next : boards)
  where next = snd $ chooseMove color $ parmapMinimax 0 board color

gomokuMain :: IO ()
gomokuMain = do
    putStrLn "BEGIN GAME"

    let startStone = (7, 7)
    let board = initializeBoard startStone

    args <- getArgs
    if length args /= 1
        then do die $ "Usage: stack exec gomokuku-exe <argument>\n<argument> may be serial, parallel, or no-map"
    else if head args == "serial"
        then do
            putStrLn "SERIAL"
            let solutions = loopSerial board White 10 []
            mapM_ putStrLn $ map (show . (`heuristic` True)) solutions
            mapM_ print $ map showBoard solutions
    else if head args == "parallel"
        then do
            putStrLn "PARALLEL"
            let solutions = loopPar board White 10 []
            mapM_ putStrLn $ map (show . (`heuristic` True)) solutions
            mapM_ print $ map showBoard solutions
    else do
        putStrLn "NO MAP"
        let solutions = loopNoMap board White 10 []
        mapM_ putStrLn $ map (show . (`heuristic` True)) solutions
        mapM_ print $ map showBoard solutions
