{-# OPTIONS_GHC -Wall #-}
module PDB (cubeToKey, generatePDB, savePDB, loadPDB, word8ToColor, Word8Vector(..)) where

import qualified Data.Vector.Unboxed as V
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Sequence as Seq
import qualified Data.Binary as Bin
import Data.Sequence (Seq((:<|)), (|>))
import Data.Word (Word8)
import Data.List (foldl')

import RubikCube

-- Map each Color Char to a Word8:
colorToWord8 :: Color -> Word8
colorToWord8 c = case c of
    'W' -> 0
    'Y' -> 1
    'O' -> 2
    'R' -> 3
    'G' -> 4
    'B' -> 5
    _   -> error "Unknown color"

word8ToColor :: Word8 -> Color
word8ToColor w = case w of
    0 -> 'W'
    1 -> 'Y'
    2 -> 'O'
    3 -> 'R'
    4 -> 'G'
    5 -> 'B'
    _ -> error "Unknown Word8 value"

cubeToKey :: Cube -> Word8Vector
cubeToKey cube =
    Word8Vector $ V.fromList $ concatMap (map colorToWord8) [
        concat (up cube),
        concat (down cube),
        concat (left cube),
        concat (right cube),
        concat (front cube),
        concat (back cube)
    ]

-- Wrap the type (V.Vector Word8) to suppress a warning when making it a Binary instance
newtype Word8Vector = Word8Vector (V.Vector Word8)
    deriving (Eq, Show, Ord)

-- Define the Binary instance for the newtype
instance Bin.Binary Word8Vector where
    put (Word8Vector vec) = Bin.put (V.toList vec)
    get = Word8Vector . V.fromList <$> Bin.get

-- Pattern database is a map where the keys are states and the values are distances to solved states
type PDB = Map.Map Word8Vector Int

generatePDB :: Int        -- ^ Cube size (n)
             -> Int        -- ^ Depth limit
             -> PDB
generatePDB n depthLimit = bfs initialPDB initialVisited initialFrontier
  where
    -- Start from solved cube:
    solved = initCube n

    -- Convert solved cube to key:
    solvedKey = cubeToKey solved

    -- Initial PDB: solved state at depth 0
    initialPDB = Map.singleton solvedKey 0

    -- Initial visited set: includes solved state key
    initialVisited = Set.singleton solvedKey

    -- Frontier is a queue for BFS. It holds states to explore along with their depth.
    -- Seq.singleton creates a sequence with one element.
    -- Here, (solved, 0) is the starting point.
    initialFrontier = Seq.singleton (solved, 0)

    bfs :: PDB -> Set.Set Word8Vector -> Seq.Seq (Cube, Int) -> PDB
    bfs pdb visited frontier =
      case frontier of
        -- (state, depth) :<| rest -> pattern matches the first element of frontier
        (state, depth) :<| rest ->
          -- If current depth is already at depthLimit, do not expand further.
          if depth >= depthLimit
             then bfs pdb visited rest
             else
               -- Explore successors:
               let successors = [ (applyMove m state, depth + 1) | m <- allMoves ]
                   -- Filter successors not visited yet:
                   newStates = filter (\(s,_) -> Set.notMember (cubeToKey s) visited) successors
                   -- Insert new states into PDB and visited:
                   newPDB = foldl' (\acc (s, d) -> Map.insert (cubeToKey s) d acc) pdb newStates
                   newVisited = foldl' (\acc (s, _) -> Set.insert (cubeToKey s) acc) visited newStates
                   -- Add new states to the frontier:
                   newFrontier = foldl' (|>) rest newStates
               in bfs newPDB newVisited newFrontier

        -- If frontier is empty, we have explored all states within depth limit:
        Seq.Empty -> pdb


savePDB :: FilePath -> PDB -> IO ()
savePDB file pdb = Bin.encodeFile file pdb

loadPDB :: FilePath -> IO PDB
loadPDB file = Bin.decodeFile file