{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
-- | Core Blokus game representation and rules for a 4-player MaxN search.
--
-- This module models the 20x20 Blokus board, the 21 polyomino pieces per
-- player (mono–pentominoes), legal move generation under Blokus rules, and
-- a heuristic evaluation function suitable for MaxN-based game tree search.
module Blokus where

import Prelude hiding (foldr)
import qualified Prelude         as Pre
import qualified Data.List       as L
import qualified Data.Array      as A
import qualified Data.Map.Strict as M
import qualified Data.Set        as S
import qualified Data.Vector     as V

import GHC.Generics (Generic)
import Control.DeepSeq (NFData)

------------------------------------------------------------
-- Core Types
------------------------------------------------------------

data Player = P1 | P2 | P3 | P4
    deriving (Eq, Ord, Show, Enum, Bounded)

-- | Board coordinates, with (0,0) in one corner and (19,19) in the opposite corner.
type Coord = (Int, Int)

-- | Contents of a single board cell: empty or occupied by a player.
data Cell = Empty | Occupied Player
    deriving (Eq, Show)

-- | The 20x20 Blokus board indexed by coordinates.
type Board = A.Array Coord Cell

-- | Identifier for each of the 21 Blokus polyomino pieces.
data PieceId
    = Mono
    | Domino
    | TrominoI | TrominoV
    | TetrominoI | TetrominoL | TetrominoT | TetrominoS | TetrominoSquare
    | PentominoI | PentominoL | PentominoS | PentominoP | PentominoU
    | PentominoF | PentominoT | PentominoV | PentominoM | PentominoZ
    | PentominoY | PentominoX
    deriving (Eq, Ord, Show, Enum, Bounded)

-- | A polyomino shape as a list of squares in local coordinates relative to an origin.
-- The coordinates are normalized (e.g. via 'translateToOrigin') when generating orientations.
type Shape = [Coord]

-- | Index into a piece's vector of distinct orientations.
-- Each piece may have between 1 and 8 normalized orientations.
newtype Orientation = Orientation Int
    deriving (Eq, Ord, Show)

-- | A Blokus piece: its identity and all distinct oriented shapes.
data Piece = Piece
    { pieceId      :: PieceId        -- ^ Logical identity (e.g. 'TetrominoL').
    , orientations :: V.Vector Shape -- ^ Distinct normalized orientations for this piece.
    }

-- | Pretty-print a piece by listing all its orientations with indices.
instance Show Piece where
    show (Piece pid orients) =
        let header = show pid ++ ":\n"
            body   =
                unlines
                    [ "  " ++ show i ++ ": " ++ show shape
                    | (i, shape) <- zip [0..] (V.toList orients)
                    ]
        in header ++ body

-- | A concrete move in the game: a player places one of their pieces on the board.
data Move = Move
    { mvPlayer      :: Player      -- ^ Player making the move.
    , mvPieceId     :: PieceId     -- ^ Which piece is placed.
    , mvOrientation :: Orientation -- ^ Which orientation of the piece.
    , mvAnchorCoord :: Coord       -- ^ Board coordinate for the piece's local origin.
    }
    deriving (Eq, Show)

-- | Full game state at a node in the MaxN search tree.
data GameState = GameState
    { gsBoard           :: Board                    -- ^ Current board occupancy.
    , gsPlayerToMove    :: Player                   -- ^ Player whose turn it is.
    , gsRemainingPieces :: M.Map Player (S.Set PieceId)
      -- ^ Remaining (unused) pieces per player.
    , gsCorners         :: M.Map Player (S.Set Coord)
      -- ^ Candidate corner cells per player used as anchors for move generation.
      -- May include off-board coordinates for first-move encoding; legality is enforced by 'isValidMove'.
    , gsPassed          :: S.Set Player
      -- ^ Players who passed on their most recent turn; when all 4 have passed, the game is terminal.
    }
    deriving (Eq, Show)

------------------------------------------------------------
-- Piece base shapes and geometry
------------------------------------------------------------

-- | Canonical base shapes for each piece in local coordinates.
-- Each shape is a connected polyomino; rotations and reflections are
-- generated from these bases by 'orientationsOf'.
baseShapes :: M.Map PieceId Shape
baseShapes = M.fromList
    [ (Mono, [(0,0)])

    , (Domino, [(0,0), (0,1)])

    , (TrominoI, [(0,0), (0,1), (0,2)])
    , (TrominoV, [(0,0), (1,0), (0,1)])

    , (TetrominoI,      [(0,0), (0,1), (0,2), (0,3)])
    , (TetrominoL,      [(0,0), (1,0), (1,1), (1,2)])
    , (TetrominoT,      [(0,0), (0,1), (0,2), (1,1)])
    , (TetrominoS,      [(0,0), (0,1), (1,1), (1,2)])
    , (TetrominoSquare, [(0,0), (0,1), (1,0), (1,1)])

    , (PentominoI, [(0,0), (0,1), (0,2), (0,3), (0,4)])
    , (PentominoL, [(0,0), (1,0), (1,1), (1,2), (1,3)])
    , (PentominoS, [(0,0), (0,1), (1,1), (1,2), (1,3)])
    , (PentominoP, [(0,0), (0,1), (0,2), (1,1), (1,2)])
    , (PentominoU, [(0,0), (0,1), (1,0), (2,0), (2,1)])
    , (PentominoF, [(0,0), (0,1), (0,2), (0,3), (1,2)])
    , (PentominoT, [(0,0), (0,1), (0,2), (1,1), (2,1)])
    , (PentominoV, [(0,0), (1,0), (2,0), (0,1), (0,2)])
    , (PentominoM, [(0,0), (0,1), (1,1), (1,2), (2,2)])
    , (PentominoZ, [(0,0), (1,0), (1,1), (1,2), (2,2)])
    , (PentominoY, [(0,0), (1,0), (1,1), (1,2), (2,1)])
    , (PentominoX, [(0,0), (0,1), (1,1), (-1,1), (0,2)])
    ]

-- | Mirror a shape across the Y-axis in local coordinates.
mirrorShapeY :: Shape -> Shape
mirrorShapeY s = L.map mirrorY s
  where
    mirrorY :: Coord -> Coord
    mirrorY (0,y) = (0,y)
    mirrorY (x,y) = (-x,y)

-- | Mirror a shape across the X-axis in local coordinates.
mirrorShapeX :: Shape -> Shape
mirrorShapeX s = L.map mirrorX s
  where
    mirrorX :: Coord -> Coord
    mirrorX (x,0) = (x,0)
    mirrorX (x,y) = (x,-y)

-- | Rotate a shape by 90 degrees clockwise in local coordinates.
rotateShape90 :: Shape -> Shape
rotateShape90 s = L.map rotate90 s
  where
    rotate90 :: Coord -> Coord
    rotate90 (x,y) = (y,-x)

-- | Translate a shape so that its minimum x and y coordinates become (0,0),
-- and sort the coordinate list to obtain a canonical representation.
-- This is used to normalize rotated/reflected shapes so duplicates
-- can be removed.
translateToOrigin :: Shape -> Shape
translateToOrigin [] = []
translateToOrigin ps =
    let (minx, miny) = Pre.foldr
            (\(x,y) (mx,my) -> (min x mx, min y my))
            (head ps)
            ps
    in L.sort [ (x - minx, y - miny) | (x,y) <- ps ]

-- | Compute all distinct orientations (rotations and reflections) of a base shape,
-- normalized to a common origin. The resulting vector is indexed by 'Orientation'.
orientationsOf :: Shape -> V.Vector Shape
orientationsOf base = V.fromList $ L.nub
    [ base
    , translateToOrigin $ mirrorShapeY base
    , translateToOrigin $ mirrorShapeX base
    , translateToOrigin $ mirrorShapeX $ mirrorShapeY base
    , translateToOrigin $ rotateShape90 base
    , translateToOrigin $ rotateShape90 $ mirrorShapeY base
    , translateToOrigin $ rotateShape90 $ mirrorShapeX base
    , translateToOrigin $ rotateShape90 $ mirrorShapeX $ mirrorShapeY base
    ]

-- | All pieces in the game keyed by 'PieceId', with their precomputed orientations.
allPieces :: M.Map PieceId Piece
allPieces =
    M.fromList
        [ (pid, Piece pid (orientationsOf shape))
        | (pid, shape) <- M.toList baseShapes
        ]

-- | Lookup a 'Piece' by its 'PieceId'.
getPieceById :: PieceId -> Piece
getPieceById pid = allPieces M.! pid

-- | Compute the board coordinates occupied by a piece for a given orientation
-- and anchor coordinate. The shape's local coordinates are shifted by the anchor.
getPieceCoords :: Piece -> Orientation -> Coord -> [Coord]
getPieceCoords p o (ax, ay) =
    [ (x + ax, y + ay) | (x,y) <- baseShape ]
  where
    Orientation oi = o
    baseShape      = orientations p V.! oi

------------------------------------------------------------
-- Board / shape neighborhood helpers
------------------------------------------------------------

-- | Check whether a board coordinate lies within the 20x20 board.
isInBounds :: Coord -> Bool
isInBounds (x,y) = x >= 0 && x <= 19 && y >= 0 && y <= 19

-- | Diagonal (corner) neighbors of a coordinate on the grid.
findCoordCorners :: Coord -> [Coord]
findCoordCorners (x,y) = [(x+1,y+1), (x+1,y-1), (x-1,y-1), (x-1,y+1)]

-- | Orthogonal (edge) neighbors of a coordinate on the grid.
findAdjacentCoords :: Coord -> [Coord]
findAdjacentCoords (x,y) = [(x+1,y), (x,y+1), (x-1,y), (x,y-1)]

-- | Set of all orthogonally adjacent cells to a shape, excluding the shape's own cells.
findShapeAdjacent :: Shape -> S.Set Coord
findShapeAdjacent s =
    S.fromList (L.concatMap findAdjacentCoords s) `S.difference` S.fromList s

-- | Set of all diagonally adjacent (corner) cells to a shape, excluding the shape's own cells.
findShapeCorners :: Shape -> S.Set Coord
findShapeCorners s =
    S.fromList (L.concatMap findCoordCorners s) `S.difference` S.fromList s

------------------------------------------------------------
-- Game initialisation
------------------------------------------------------------

-- | The initial empty 20x20 board with all cells set to 'Empty'.
initBoard :: Board
initBoard =
    A.array ((0,0), (19,19))
        [ ((i,j), Empty) | i <- [0..19], j <- [0..19] ]

-- | Initial mapping from players to the full set of their 21 pieces.
initPieces :: M.Map Player (S.Set PieceId)
initPieces =
    M.fromList [ (p, S.fromList $ M.keys baseShapes) | p <- [P1, P2, P3, P4] ]

-- | Initial candidate corners per player.
--
-- Each starting corner is encoded as an off-board diagonal neighbor of
-- the true board corner so that the usual corner-touch rules and 'isInBounds'
-- together enforce \"must cover the corner\" for the first move.
initCorners :: M.Map Player (S.Set Coord)
initCorners =
    -- clockwise turn order
    M.fromList
        [ (P1, S.singleton (0,0))   -- enforces coverage of (0,0)
        , (P2, S.singleton (0,19))   -- enforces coverage of (0,19)
        , (P3, S.singleton (19,19))   -- enforces coverage of (19,19)
        , (P4, S.singleton (19,0))   -- enforces coverage of (19,0)
        ]

-- | Initial game state: empty board, all pieces available, starting
-- candidate corners, and P1 to move with nobody having passed yet.
initGameState :: GameState
initGameState = GameState
    { gsBoard           = initBoard
    , gsPlayerToMove    = P1
    , gsRemainingPieces = initPieces
    , gsCorners         = initCorners
    , gsPassed          = S.empty
    }

------------------------------------------------------------
-- Corners and move legality
------------------------------------------------------------

-- | Safe cell lookup: if the coordinate is out of bounds, throw an error.
cellAt :: Board -> Coord -> Cell
cellAt board c
    | isInBounds c = board A.! c
    | otherwise    = error ("cellAt: out of bounds " ++ show c)

-- | Check if a board coordinate is empty. Throws if the coordinate is out of bounds.
isCoordEmpty :: Board -> Coord -> Bool
isCoordEmpty board c =
    case cellAt board c of
        Empty      -> True
        Occupied _ -> False

-- | Check whether a candidate placement of a piece is a legal Blokus move
-- for the given player on the given board.
--
-- The rules enforced are:
--
-- * All piece cells lie within the 20x20 board.
-- * All piece cells are empty on the board.
-- * No piece cell is edge-adjacent to another piece of the same player.
-- * At least one piece cell is corner-adjacent to another piece of the same player
--   (or to the encoded off-board starting corner for the first move).
-- | Check whether a candidate placement of a piece is a legal Blokus move
-- for the given player on the given board.
isValidMove :: Board -> Player -> [Coord] -> Bool
isValidMove board p coords =
       all isInBounds coords
    && all (isCoordEmpty board) coords
    && noEdgeTouchSameColor
    && hasCornerTouchSameColor
  where
    noEdgeTouchSameColor :: Bool
    noEdgeTouchSameColor =
        all (\c -> not (isInBounds c && cellAt board c == Occupied p))
            (findShapeAdjacent coords)

    hasCornerTouchSameColor :: Bool
    hasCornerTouchSameColor =
        any (\c -> isInBounds c && cellAt board c == Occupied p)
            (findShapeCorners coords)

-- | First-move legality: must stay in bounds, land on empty cells,
-- and cover the required corner for the current player.
isValidFirstMove :: Board -> Player -> [Coord] -> Bool
isValidFirstMove board p coords =
       all isInBounds coords
    && all (isCoordEmpty board) coords
    && coversCorner
  where
    requiredCorner :: Player -> Coord
    requiredCorner P1 = (0,0)
    requiredCorner P2 = (0,19)
    requiredCorner P3 = (19,19)
    requiredCorner P4 = (19,0)

    coversCorner = requiredCorner p `elem` coords

-- | Given a normalized shape and a candidate corner coordinate, generate all
-- possible board placements by aligning each local cell of the shape to that corner.
--
-- This is used in conjunction with 'gsCorners' to enumerate candidate moves
-- around each corner frontier.
candidatePlacements :: Shape -> Coord -> [(Coord, [Coord])]
candidatePlacements shape corner =
    [ ((cx, cy), [ (cx + dx, cy + dy) | (dx,dy) <- shape ])
    | (sx, sy) <- shape
    , let (cx, cy) = (fst corner - sx, snd corner - sy)
    ]

validMoves :: GameState -> [Move]
validMoves gs
    | isFirstMove gs = validFirstMoves gs
    | otherwise      = validNonFirstMoves gs

isFirstMove :: GameState -> Bool
isFirstMove gs =
    case M.lookup p (gsRemainingPieces gs) of
        Just ps -> S.size ps == M.size baseShapes
        Nothing -> False
    where
        p = gsPlayerToMove gs

-- | First-move variant that uses 'isValidFirstMove'.
validFirstMoves :: GameState -> [Move]
validFirstMoves = validMovesWith isValidFirstMove

-- | Normal (non-first) move generation using full Blokus rules.
validNonFirstMoves :: GameState -> [Move]
validNonFirstMoves = validMovesWith isValidMove


validMovesWith
    :: (Board -> Player -> [Coord] -> Bool)
    -> GameState
    -> [Move]
validMovesWith isValid gs =
    [ Move p pid (Orientation oi) anchor
    | let p         = gsPlayerToMove gs
    , let board     = gsBoard gs
    , let remPieces = gsRemainingPieces gs M.! p
    , pid <- S.toList remPieces
    , let piece     = getPieceById pid
    , (oi, shape) <- zip [0..] (V.toList (orientations piece))
    , corner <- S.toList (gsCorners gs M.! p)
    , (anchor, coords) <- candidatePlacements shape corner
    , isValid board p coords
    ]

------------------------------------------------------------
-- Game state updates, and checks
------------------------------------------------------------

-- | Update a player's corner frontier after a move:
--
-- * Remove any previously stored corners that are now occupied by the new piece.
-- * Add new corner cells induced by the newly placed piece.
--
-- Corners are not filtered by 'isInBounds' so off-board corners can be used
-- to encode first-move corner rules; 'isValidMove' enforces bounds on piece cells.
updateCorners :: GameState -> Move -> M.Map Player (S.Set Coord)
updateCorners gs m =
    M.insert p newCorners (gsCorners gs)
  where
    p          = mvPlayer m
    oldCorners = M.findWithDefault S.empty p (gsCorners gs)

    piece      = getPieceById (mvPieceId m)
    coords     = getPieceCoords piece (mvOrientation m) (mvAnchorCoord m)

    oldCorners'     = oldCorners S.\\ S.fromList coords
    newPieceCorners = findShapeCorners coords

    newCorners = S.filter isInBounds $ oldCorners' `S.union` newPieceCorners

-- | Apply a legal move to the game state:
--
-- * Fill the corresponding board cells for this player's piece.
-- * Advance 'gsPlayerToMove' to the next player.
-- * Remove the used piece from that player's remaining set.
-- * Update the player's corner frontier.
-- * Clear any \"passed\" flag for this player in 'gsPassed'.

updateBoard :: Board -> Move -> Board
updateBoard b mv =
    let coords = getPieceCoords (getPieceById (mvPieceId mv))
                                (mvOrientation mv)
                                (mvAnchorCoord mv)
    in if all isInBounds coords
       then b A.// [ (c, Occupied (mvPlayer mv)) | c <- coords ]
       else error ("updateBoard: out of bounds coords " ++ show coords
                   ++ " for move " ++ show mv)

applyValidMove :: GameState -> Move -> GameState
applyValidMove gs m =
    GameState
        { gsBoard           = updateBoard (gsBoard gs) m
        , gsPlayerToMove    = nextPlayer (mvPlayer m)
        , gsRemainingPieces = updateRemaining (gsRemainingPieces gs) m
        , gsCorners         = updateCorners gs m
        , gsPassed          = S.delete (mvPlayer m) (gsPassed gs)
        }
    where
        updateRemaining
            :: M.Map Player (S.Set PieceId)
            -> Move
            -> M.Map Player (S.Set PieceId)
        updateRemaining remain mv =
            M.adjust (S.delete (mvPieceId mv)) (mvPlayer mv) remain

-- | Compute the next player in turn order.
nextPlayer :: Player -> Player
nextPlayer P1 = P2
nextPlayer P2 = P3
nextPlayer P3 = P4
nextPlayer P4 = P1

-- | Represent a pass: the current player cannot move, so they are added to
-- 'gsPassed' and the turn advances to the next player.
passTurn :: GameState -> GameState
passTurn gs =
    let playerPass = gsPlayerToMove gs
    in gs { gsPlayerToMove = nextPlayer playerPass
          , gsPassed       = S.insert playerPass (gsPassed gs)
          }

-- | Terminal state predicate: the game ends when every player passed on
-- their most recent turn, i.e. no one can play any remaining pieces.
isTerminal :: GameState -> Bool
isTerminal gs = gsPassed gs == S.fromList [P1, P2, P3, P4]

------------------------------------------------------------
-- Scoring / evaluation
------------------------------------------------------------

-- | 4-tuple of heuristic scores, one component per player.
-- Higher is better for the corresponding player.
data Scores = Scores Int Int Int Int
    deriving (Generic, NFData)



instance Show Scores where
    show (Scores s1 s2 s3 s4) =
        "P1=" ++ show s1 ++
        " P2=" ++ show s2 ++
        " P3=" ++ show s3 ++
        " P4=" ++ show s4

-- | Project the score for a given player from a 'Scores' vector.
scoreOfPlayer :: Player -> Scores -> Int
scoreOfPlayer P1 (Scores s _ _ _) = s
scoreOfPlayer P2 (Scores _ s _ _) = s
scoreOfPlayer P3 (Scores _ _ s _) = s
scoreOfPlayer P4 (Scores _ _ _ s) = s

-- | Count how many board squares are occupied by the given player
-- (material / tiles placed).
tilesPlaced :: Player -> GameState -> Int
tilesPlaced p gs =
    length [ ()
           | ((_,_), Occupied q) <- A.assocs (gsBoard gs)
           , q == p
           ]

-- | Measure the number of candidate corners available to the given player,
-- as a proxy for mobility and future move options.
cornersCount :: Player -> GameState -> Int
cornersCount p gs = S.size (gsCorners gs M.! p)

-- | Sum of the squares in all remaining pieces for the given player.
-- Players with many large pieces left are penalized, since these are
-- harder to place as the board fills up.
remainingPieces :: Player -> GameState -> Int
remainingPieces p gs =
    sum [ pieceSize pid
        | pid <- S.toList (gsRemainingPieces gs M.! p)
        ]
  where
    pieceSize :: PieceId -> Int
    pieceSize pid = length (baseShapes M.! pid)

-- | Heuristic evaluation of a game state as a 4-tuple of scores, combining:
--
-- * tiles placed (material),
-- * corner frontier size (mobility),
-- * and remaining piece area (penalized).
--
-- The weights (10, 3, 1) are a reasonable starting point inspired by
-- prior Blokus solver work, and can be tuned experimentally.
evaluate :: GameState -> Scores
evaluate gs = Scores (scoreFor P1) (scoreFor P2) (scoreFor P3) (scoreFor P4)
  where
    scoreFor p =
          10 * tilesPlaced p gs
        +  3 * cornersCount p gs
        -  1 * remainingPieces p gs
