module Bestmove where

import Bitboard
import Control.Monad.State
import Control.Parallel.Strategies
import Data.Bits (popCount,testBit, shiftL, (.|.))
import Data.Function (on)
import Data.List (maximumBy, minimumBy)
import qualified Data.Map.Strict as Map
import Data.Ord (comparing)
import Debug.Trace (trace)
-- import Fen



debug :: Bool
debug = False 
-- debug = True

-- from and to are the shifts that we need to do 
-- #################################################################
-- to find the best move, do a piecewise minimax --> depth thread. Each subtree spawns a new process --> Each subtree returns found ones meaning that across subtrees we can do pruning
--   when (dMode == 1) (putStrLn "Engine recieved: " ++ show guiIn)
-- {-Pawn: 1, Knights: 3, Bishop: 3, Rook: 5, Queen: 9, King: 1000-}
-- Add a position mask into here at some point
-- Example piece-square tables for White (64 elements, indexed from a1=0 to h8=63).
-- Rank increases every 8 squares, with a1=0, b1=1,... h1=7; a2=8, etc.
-- https://www.chessprogramming.org/Simplified_Evaluation_Function
pawnTableWhite :: [Int]
pawnTableWhite =
  [ 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0 -- rank 1
  , 5
  , 5
  , 5
  , 5
  , 5
  , 5
  , 5
  , 5 -- rank 2
  , 1
  , 1
  , 2
  , 3
  , 3
  , 2
  , 1
  , 1 -- rank 3
  , 0
  , 0
  , 0
  , 2
  , 2
  , 0
  , 0
  , 0 -- rank 4
  , 0
  , 0
  , 0
  , -2
  , -2
  , 0
  , 0
  , 0 -- rank 5
  , 1
  , -1
  , -2
  , 0
  , 0
  , -2
  , -1
  , 1 -- rank 6
  , 1
  , 2
  , 2
  , -2
  , -2
  , 2
  , 2
  , 1 -- rank 7
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0 -- rank 8
  ]

knightTableWhite :: [Int]
knightTableWhite =
  [ -5
  , -4
  , -3
  , -3
  , -3
  , -3
  , -4
  , -5
  , -4
  , -2
  , 0
  , 0
  , 0
  , 0
  , -2
  , -4
  , -3
  , 0
  , 1
  , 1
  , 1
  , 1
  , 0
  , -3
  , -3
  , 0
  , 1
  , 2
  , 2
  , 1
  , 0
  , -3
  , -3
  , 0
  , 1
  , 2
  , 2
  , 1
  , 0
  , -3
  , -3
  , 0
  , 1
  , 1
  , 1
  , 1
  , 0
  , -3
  , -4
  , -2
  , 0
  , 0
  , 0
  , 0
  , -2
  , -4
  , -5
  , -4
  , -3
  , -3
  , -3
  , -3
  , -4
  , -5
  ]

bishopTableWhite :: [Int]
bishopTableWhite =
  [ -2
  , -1
  , -1
  , -1
  , -1
  , -1
  , -1
  , -2
  , -1
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , -1
  , -1
  , 0
  , 1
  , 1
  , 1
  , 1
  , 0
  , -1
  , -1
  , 0
  , 1
  , 2
  , 2
  , 1
  , 0
  , -1
  , -1
  , 0
  , 1
  , 2
  , 2
  , 1
  , 0
  , -1
  , -1
  , 0
  , 1
  , 1
  , 1
  , 1
  , 0
  , -1
  , -1
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , -1
  , -2
  , -1
  , -1
  , -1
  , -1
  , -1
  , -1
  , -2
  ]

rookTableWhite :: [Int]
rookTableWhite =
  [ 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , 5
  , 5
  , 5
  , 5
  , 5
  , 5
  , 5
  , 5
  , 0
  , 0
  , 0
  , 10
  , 10
  , 0
  , 0
  , 0
  ]

queenTableWhite :: [Int]
queenTableWhite =
  [ -2
  , -1
  , -1
  , -0
  , -0
  , -1
  , -1
  , -2
  , -1
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , -1
  , -1
  , 0
  , 1
  , 1
  , 1
  , 1
  , 0
  , -1
  , 0
  , 0
  , 1
  , 2
  , 2
  , 1
  , 0
  , 0
  , 0
  , 0
  , 1
  , 2
  , 2
  , 1
  , 0
  , 0
  , -1
  , 0
  , 1
  , 1
  , 1
  , 1
  , 0
  , -1
  , -1
  , 0
  , 0
  , 0
  , 0
  , 0
  , 0
  , -1
  , -2
  , -1
  , -1
  , -0
  , -0
  , -1
  , -1
  , -2
  ]

kingTableWhite :: [Int]
kingTableWhite =
  [ -3
  , -4
  , -4
  , -5
  , -5
  , -4
  , -4
  , -3
  , -3
  , -4
  , -4
  , -5
  , -5
  , -4
  , -4
  , -3
  , -3
  , -4
  , -4
  , -5
  , -5
  , -4
  , -4
  , -3
  , -3
  , -4
  , -4
  , -5
  , -5
  , -4
  , -4
  , -3
  , -2
  , -3
  , -3
  , -4
  , -4
  , -3
  , -3
  , -2
  , -1
  , -2
  , -2
  , -2
  , -2
  , -2
  , -2
  , -1
  , 2
  , 2
  , 0
  , 0
  , 0
  , 0
  , 2
  , 2
  , 2
  , 3
  , 1
  , 0
  , 0
  , 1
  , 3
  , 2
  ]

indexForBlack :: Int -> Int
indexForBlack sq = 63 - sq

positionalScore :: Board -> Int
positionalScore board =
  sum (map (\sq -> pawnTableWhite !! sq) (bitboardSquares (pawnsWhite board)))
    + sum (map (\sq -> knightTableWhite !! sq) (bitboardSquares (knightsWhite board)))
    + sum (map (\sq -> bishopTableWhite !! sq) (bitboardSquares (bishopsWhite board)))
    + sum (map (\sq -> rookTableWhite !! sq) (bitboardSquares (rooksWhite board)))
    + sum (map (\sq -> queenTableWhite !! sq) (bitboardSquares (queensWhite board)))
    + sum (map (\sq -> kingTableWhite !! sq) (bitboardSquares (kingsWhite board)))
    - sum (map (\sq -> pawnTableWhite !! (indexForBlack sq)) (bitboardSquares (pawnsBlack board)))
    - sum (map (\sq -> knightTableWhite !! (indexForBlack sq)) (bitboardSquares (knightsBlack board)))
    - sum (map (\sq -> bishopTableWhite !! (indexForBlack sq)) (bitboardSquares (bishopsBlack board)))
    - sum (map (\sq -> rookTableWhite !! (indexForBlack sq)) (bitboardSquares (rooksBlack board)))
    - sum (map (\sq -> queenTableWhite !! (indexForBlack sq)) (bitboardSquares (queensBlack board)))
    - sum (map (\sq -> kingTableWhite !! (indexForBlack sq)) (bitboardSquares (kingsBlack board)))

-- Updated evaluateBoard with positional considerations
evaluateBoard :: Board -> Int -> Int
evaluateBoard board color=
  let materialScore =
        (popCount (pawnsWhite board) * 10)
          + (popCount (knightsWhite board) * 30)
          + (popCount (bishopsWhite board) * 30)
          + (popCount (rooksWhite board) * 50)
          + (popCount (queensWhite board) * 90)
          + (popCount (kingsWhite board) * 500)
          - (popCount (pawnsBlack board) * 10)
          - (popCount (knightsBlack board) * 30)
          - (popCount (bishopsBlack board) * 30)
          - (popCount (rooksBlack board) * 50)
          - (popCount (queensBlack board) * 90)
          - (popCount (kingsBlack board) * 500)
      posScore = positionalScore board
      -- posScore = 0
   in color * (materialScore + posScore)

applyMove2 :: Move -> Board -> Board
applyMove2 move board =
  let fromSq = moveFrom move
      toSq = moveTo move
      -- Identify the piece that is moving
      mbPiece = findPiece board fromSq
      -- Clear the from-square from all boards
      clearedBoard = clearSquare board fromSq
      -- Also clear the to-square in case of a capture
      capturedClearedBoard = clearSquare clearedBoard toSq
      pieceName =
        case mbPiece of
          Just (pn, _) -> pn
          Nothing -> error "No piece found on from-square (invalid move)"
      movedBoard = placePiece pieceName toSq capturedClearedBoard
   in movedBoard

-- Place a piece at a given square on the board
placePiece :: String -> Int -> Board -> Board
placePiece pieceName sq board =
  let bit = 1 `shiftL` sq
   in case pieceName of
        "pawnsWhite" -> board {pawnsWhite = pawnsWhite board .|. bit}
        "pawnsBlack" -> board {pawnsBlack = pawnsBlack board .|. bit}
        "knightsWhite" -> board {knightsWhite = knightsWhite board .|. bit}
        "knightsBlack" -> board {knightsBlack = knightsBlack board .|. bit}
        "bishopsWhite" -> board {bishopsWhite = bishopsWhite board .|. bit}
        "bishopsBlack" -> board {bishopsBlack = bishopsBlack board .|. bit}
        "rooksWhite" -> board {rooksWhite = rooksWhite board .|. bit}
        "rooksBlack" -> board {rooksBlack = rooksBlack board .|. bit}
        "queensWhite" -> board {queensWhite = queensWhite board .|. bit}
        "queensBlack" -> board {queensBlack = queensBlack board .|. bit}
        "kingsWhite" -> board {kingsWhite = kingsWhite board .|. bit}
        "kingsBlack" -> board {kingsBlack = kingsBlack board .|. bit}
        _ -> error "Invalid piece type"

-- #######################################################################
-- Helper to determine which piece is at a given square (0-63).
-- This is how I encode the current position to compare with other positions (at the same depth)
getPieceChar :: Board -> Int -> Char
getPieceChar board sq
  | testBit (pawnsWhite board) sq = 'P'
  | testBit (pawnsBlack board) sq = 'p'
  | testBit (knightsWhite board) sq = 'N'
  | testBit (knightsBlack board) sq = 'n'
  | testBit (bishopsWhite board) sq = 'B'
  | testBit (bishopsBlack board) sq = 'b'
  | testBit (rooksWhite board) sq = 'R'
  | testBit (rooksBlack board) sq = 'r'
  | testBit (queensWhite board) sq = 'Q'
  | testBit (queensBlack board) sq = 'q'
  | testBit (kingsWhite board) sq = 'K'
  | testBit (kingsBlack board) sq = 'k'
  | otherwise = '.' -- empty square

-- Convert the board to a simplified FEN-like string.
boardToFEN :: Board -> String
boardToFEN board =
  let ranks = [7,6 .. 0] -- top rank = 7, down to 0
      fenRows = map (rankToFen board) ranks
   in foldr1 (\r acc -> r ++ "/" ++ acc) fenRows

rankToFen :: Board -> Int -> String
rankToFen board r =
  let start = r * 8
      end = start + 7
      chars = [getPieceChar board sq | sq <- [start .. end]]
   in compressEmpty chars

-- Compress consecutive empty squares ('.') into digits as per FEN
compressEmpty :: [Char] -> String
compressEmpty = go 0
  where
    go :: Int -> String -> String
    go count [] =
      if count > 0
        then show count
        else ""
    go count (c:cs)
      | c == '.' = go (count + 1) cs
      | otherwise =
        (if count > 0
           then show count
           else "")
          ++ [c]
          ++ go 0 cs


-- I define a cache Map datatype which I use to keep track of previous executied trees, which is much easier than recursion
-- Key: (FEN_string, color, depth) -> value: score
type Cache = Map.Map (String, Int, Int) Int

-- some helper funcs
mergeCaches :: Cache -> Cache -> Cache
mergeCaches = Map.union

getCachedScore :: Cache -> (String, Int, Int) -> Maybe Int
getCachedScore cache key = Map.lookup key cache

insertCacheScore :: Cache -> (String, Int, Int) -> Int -> Cache
insertCacheScore cache key score = Map.insert key score cache
-- #######################################################################
-- use generateAllMoves for each player at each turn to do the minimax 1 = white -1 = black


debugTrace :: String -> a -> a
debugTrace msg x =
  if debug
    then trace msg x
    else x

simpleMinimax :: Int -> Board -> Int -> Int
simpleMinimax color board depth = debugTrace msg result
  where
    moves =
      if color == 1
        then (generateAllMovesWhite board)
        else (generateAllMovesBlack board)
    msg = "Minimax called at depth: " ++ show depth ++ ", color: " ++ show color ++ ", moves: " ++ show (length moves)
    result
      | depth == 0 || null moves =
        let score = evaluateBoard board color
         in debugTrace ("Terminal node reached with score: " ++ show score) score
      | color == 1 =
        let scores = [simpleMinimax (-color) (applyMove2 m board) (depth - 1) | m <- moves]
            best = maximum scores
         in debugTrace ("Maximizer choosing best score: " ++ show best) best
      | otherwise =
        let scores = [simpleMinimax (-color) (applyMove2 m board) (depth - 1) | m <- moves]
            best = minimum scores
         in debugTrace ("Minimizer choosing best score: " ++ show best) best

chooseBestMove :: Int -> Board -> Int -> Move
chooseBestMove color board depth =
  let moves =
        if color == 1
          then generateAllMovesWhite board
          else generateAllMovesBlack board
      scores = [(simpleMinimax (-color) (applyMove2 m board) (depth - 1), m) | m <- moves]
   in if null scores
        then error "No moves available"
        else if color == 1
               then snd (maximumBy (compare `on` fst) scores)
               else snd (minimumBy (compare `on` fst) scores)


simpleMinimaxCache :: Int -> Board -> Int -> Int -> Int -> State Cache Int
simpleMinimaxCache color board depth alpha beta = do
  let fen = boardToFEN board
      key = (fen, color, depth)
  cache <- get
  case Map.lookup key cache of
    Just score -> return score
    Nothing -> do
      let moves =
            if color == 1
              then generateAllMovesWhite board
              else generateAllMovesBlack board
      if depth == 0 || null moves
        then do
          let score = evaluateBoard board color -- I need to flip this because I messed up the order of calculations -> it doesn't really matter because the flip is consistent
          debugTrace
            ("Terminal node: "
               ++ show key
               ++ ", Score: "
               ++ show score
               ++ ", Cache: "
               ++ show (Map.lookup key cache))
            $ return ()
          modify (Map.insert key score)
          return score
        else do
          let helper =
                if color == 1
                  then simpleHelperMaxC
                  else simpleHelperMinC
          score <- helper moves board depth alpha beta
          modify (Map.insert key score)
          return score

-- Helper for Maximizing Player
simpleHelperMaxC ::
     [Move] -> Board -> Int -> Int -> Int -> State Cache Int
simpleHelperMaxC [] _ _ alpha _ = return alpha
simpleHelperMaxC (m:ms) board depth alpha beta = do
  let newBoard = applyMove2 m board
  score <- simpleMinimaxCache (-1) newBoard (depth - 1) alpha beta
  let newAlpha = max alpha score
  if newAlpha >= beta
    then return newAlpha  -- Prune
    else do
      restScore <- simpleHelperMaxC ms board depth newAlpha beta
      return (max score restScore)

-- Helper for Minimizing Player
simpleHelperMinC ::
     [Move] -> Board -> Int -> Int -> Int -> State Cache Int
simpleHelperMinC [] _ _ _ beta = return beta
simpleHelperMinC (m:ms) board depth alpha beta = do
  let newBoard = applyMove2 m board
  score <- simpleMinimaxCache 1 newBoard (depth - 1) alpha beta
  let newBeta = min beta score
  if newBeta <= alpha
    then return newBeta  -- Prune
    else do
      restScore <- simpleHelperMinC ms board depth alpha newBeta
      return (min score restScore)


chooseBestMoveCache :: Int -> Board -> Int -> State Cache (Maybe Move)
chooseBestMoveCache color board depth = do
  let moves = 
        if color == 1
          then generateAllMovesWhite board
          else generateAllMovesBlack board
  if null moves 
    then return Nothing 
    else do
      evaluatedMoves <- mapM (\m -> do
                                let newBoard = applyMove2 m board
                                score <- simpleMinimaxCache (-color) newBoard (depth - 1) (-1000000) 1000000
                                return (m, score)
                            ) moves
      
      let bestMove = if color == 1
                      then Just $ fst $ maximumBy (comparing snd) evaluatedMoves
                      else Just $ fst $ minimumBy (comparing snd) evaluatedMoves
      return bestMove


-- Minimax with Alpha-Beta Pruning and Caching
simpleMinimaxAB :: Int -> Board -> Int -> Int -> Int -> Cache -> (Int, Cache)
simpleMinimaxAB color board depth alpha beta cache =
  let fen = boardToFEN board
      key = (fen, color, depth)
  in case getCachedScore cache key of
       Just score -> (score, cache)
       Nothing ->
         let moves =
               if color == 1
                 then generateAllMovesWhite board
                 else generateAllMovesBlack board
         in if depth == 0 || null moves
              then
                let score = evaluateBoard board color -- when the bottom depth is reached evaluate the value of the node and return it as the score value
                    newCache = insertCacheScore cache key score
                in (score, newCache)
                -- in debugTrace
                --       ("Terminal node: "
                --         ++ show key
                --         ++ ", Score: "
                --         ++ show score
                --         ++ ", Cache: "
                --         ++ show (Map.lookup key cache)) (score, newCache)
              else
                let helper =
                      if color == 1
                        then simpleHelperMax
                        else simpleHelperMin
                    (score, updatedCache) = helper moves board depth alpha beta cache
                in (score, insertCacheScore updatedCache key score)


simpleHelperMax ::
     [Move] -> Board -> Int -> Int -> Int -> Cache -> (Int, Cache)
simpleHelperMax [] _ _ alpha _ cache = (alpha, cache)
simpleHelperMax (m:ms) board depth alpha beta cache =
  let newBoard = applyMove2 m board
      (score, newCache) = simpleMinimaxAB (-1) newBoard (depth - 1) alpha beta cache
      newAlpha = max alpha score
  in if newAlpha >= beta
       then (newAlpha, newCache)  -- Prune
       else
         let (restScore, restCache) = simpleHelperMax ms board depth newAlpha beta newCache
         in (max score restScore, restCache)

simpleHelperMin ::
     [Move] -> Board -> Int -> Int -> Int -> Cache -> (Int, Cache)
simpleHelperMin [] _ _ _ beta cache = (beta, cache)
simpleHelperMin (m:ms) board depth alpha beta cache =
  let newBoard = applyMove2 m board
      (score, newCache) = simpleMinimaxAB 1 newBoard (depth - 1) alpha beta cache
      newBeta = min beta score
  in if newBeta <= alpha
       then (newBeta, newCache)  -- Prune
       else
         let (restScore, restCache) = simpleHelperMin ms board depth alpha newBeta newCache
         in (min score restScore, restCache)


-- Choose the best move using minimax with caching and parallel processing
-- does the first layer of the minimax here
chooseBestMovePar :: Int -> Board -> Int -> State Cache (Maybe Move)
chooseBestMovePar color board depth = do
  let moves =
        if color == 1
          then generateAllMovesWhite board
          else generateAllMovesBlack board
  if null moves
    then do
      trace "No moves available for the current player." $
        return Nothing
    else do
      evaluatedScores <- evaluateMovesInParallel color board depth moves

      let bestMove =
            if color == 1
              then Just $ snd $ maximumBy (comparing fst) (zip evaluatedScores moves)
              else Just $ snd $ minimumBy (comparing fst) (zip evaluatedScores moves)

      return bestMove

evaluateMovesInParallel :: Int -> Board -> Int -> [Move] -> State Cache [Int]
evaluateMovesInParallel color board depth moves = do
  let evaluateMove :: Move -> Int
      evaluateMove move =
        let newBoard = applyMove2 move board
            (score, _) = simpleMinimaxAB (-color) newBoard (depth - 1) (minBound :: Int) (maxBound :: Int) Map.empty
        in score
  let moveEvaluations = parMap rpar evaluateMove moves `using` parList rseq
  return moveEvaluations


-- Define how to show a Move
showMove2 :: Maybe Move -> Int -> String
showMove2 Nothing _ = "No move"
showMove2 (Just m) color =
  "Move from " ++ showSquare (moveFrom m) ++ " to " ++ showSquare (moveTo m) ++ " for color " ++ show color

extractMove :: Maybe Move -> Move
extractMove Nothing = Move {moveFrom = 0, moveTo = 0, promotion = "", isCapture = 0, isCastling = 0, isEnPassant = 0}
extractMove (Just move) = move

showSquare :: Int -> String
showSquare sq =
  let file = ['a' .. 'h'] !! (sq `mod` 8)
      rank = show ((sq `div` 8) + 1)
   in [file] ++ rank



