{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -Wall #-}
module IDAStar (heuristic, isGoal, idaStar) where

import qualified Data.Map.Strict as Map
import Data.Maybe (fromMaybe)

import RubikCube
import PDB

-- | heuristic: Given a partial PDB and a cube, return an estimate of steps to reach solved.
heuristic :: Map.Map Word8Vector Int -> Cube -> Int
heuristic pdb cube =
  -- Look up the cube in the PDB. If it’s not found, return a fallback value that does not overestimate, ensuring the heuristics remain admissible.
  -- e.g. if we’re using a PDB with a depth of 7, the fallback value should be 8, since all unseen states are at least that high.
  fromMaybe 8 (Map.lookup (cubeToKey cube) pdb)

-- | isGoal: Check if the cube is solved.
isGoal :: Cube -> Bool
isGoal cube = cube == initCube (length (front cube))

-- The IDA* algorithm searches for a path to the goal incrementally increasing a threshold based on f = g + h.
-- f(n) = g(n) + h(n):
-- g(n) = cost so far (depth)
-- h(n) = heuristic estimate to goal

-- | idaStar: Given a start state, heuristic, and moves, try to find a solution.
idaStar :: Cube                    -- ^ Start cube
        -> (Cube -> Int)           -- ^ Heuristic function
        -> [Move]                  -- ^ List of possible moves
        -> IO (Maybe [Move])       -- ^ Returns Maybe solution path
idaStar start h moves = do
  -- Compute initial threshold from heuristic of start
  let initialThreshold = h start
  -- Run iterative deepening
  iterativeDeepening start h moves initialThreshold


-- | iterativeDeepening: Start with empty visited map.
-- | iterativeDeepening: Perform iterative deepening with the given threshold.
iterativeDeepening :: Cube                    -- ^ Start cube
                   -> (Cube -> Int)           -- ^ Heuristic function
                   -> [Move]                  -- ^ List of possible moves
                   -> Int                     -- ^ Current threshold
                   -> IO (Maybe [Move])       -- ^ Returns Maybe solution path
iterativeDeepening start h moves threshold = do
  let initialVisited = Map.empty
  (res, _finalVisited) <- depthLimitedSearch start [] 0 threshold h moves initialVisited
  case res of
    Left solution -> return (Just solution)
    Right nextThreshold ->
      if nextThreshold == maxBound
        then return Nothing
        else iterativeDeepening start h moves nextThreshold


-- | depthLimitedSearch:
-- Returns (Either [Move] Int, Map Word8Vector Int) 
-- Either: Left solution or Right cutoff
-- The Map stores the minimal depth at which we've visited each state.
depthLimitedSearch :: Cube                  -- ^ Current state
                   -> [Move]                -- ^ Path (moves taken so far, reversed)
                   -> Int                   -- ^ g(n): current depth
                   -> Int                   -- ^ threshold
                   -> (Cube -> Int)         -- ^ heuristic
                   -> [Move]                -- ^ all moves
                   -> Map.Map Word8Vector Int  -- ^ visited map (state -> minimal depth)
                   -> IO (Either [Move] Int, Map.Map Word8Vector Int)
depthLimitedSearch current path g threshold h moves visited = do
  let key = cubeToKey current
  -- Check visited map
  case Map.lookup key visited of
    Just bestDepthSoFar ->
      -- This state was visited before at bestDepthSoFar
      -- If we are now at a deeper or equal depth, skip it.
      -- If we are now at a shallower depth, update and explore.
      if g >= bestDepthSoFar
        then return (Right maxBound, visited)  -- Skip exploring this state again
        else exploreState (Map.insert key g visited)  -- Update with shallower depth
    Nothing ->
      -- Not visited before, insert at current depth and explore
      exploreState (Map.insert key g visited)
  where
    exploreState updatedVisited = do
      let f = g + h current
      if f > threshold
        then return (Right f, updatedVisited)
        else if isGoal current
          then return (Left (reverse path), updatedVisited)
          else do
            let expand [] bestCutoff vMap = return (Right bestCutoff, vMap)
                expand (m:ms) bestCutoff vMap = do
                  let successor = applyMove m current
                  (res, vMap') <- depthLimitedSearch successor (m:path) (g+1) threshold h moves vMap
                  case res of
                    Left solution -> return (Left solution, vMap')
                    Right cutoff  -> expand ms (min bestCutoff cutoff) vMap'

            expand moves maxBound updatedVisited