{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -Wall #-}
module IDAStarPara where

import System.Environment (getArgs)
import System.Exit (exitFailure)
import Control.Monad (forM_, when)
import Control.Concurrent (forkIO, newEmptyMVar, putMVar, takeMVar, MVar)
import Control.Monad (forM)

import RubikCube
import PDB
import IDAStar (heuristic, isGoal)
import Utils (parseLineOfMoves, logMsg, LogLevel(..))


-- | idaStar: Given a start state, heuristic, and moves, try to find a solution.
idaStar :: Cube                    -- ^ Start cube
        -> (Cube -> Int)           -- ^ Heuristic function
        -> [Move]                  -- ^ List of possible moves
        -> IO (Maybe [Move])       -- ^ Returns Maybe solution path
idaStar start h moves = do
  -- Compute initial threshold from heuristic of start
  let initialThreshold = h start
  -- Run iterative deepening
  iterativeDeepening start h moves initialThreshold

-- | iterativeDeepening: Perform IDA* by increasing threshold until solution found or no hope.
iterativeDeepening :: Cube -> (Cube -> Int) -> [Move] -> Int -> IO (Maybe [Move])
iterativeDeepening start h moves threshold = do
  -- Run depth-limited search
  result <- depthLimitedSearch start [] 0 threshold h moves
  case result of
    Left solution -> return (Just solution)
    Right nextThreshold ->
      if nextThreshold == maxBound
        then return Nothing
        else iterativeDeepening start h moves nextThreshold

-- | depthLimitedSearch: Search up to the given threshold.
-- Returns either Left solution or Right minimum f-cost that exceeded threshold.
depthLimitedSearch :: Cube -> [Move] -> Int -> Int -> (Cube -> Int) -> [Move] -> IO (Either [Move] Int)
depthLimitedSearch current path g threshold h moves = do
  let f = g + h current
  if f > threshold
    then return (Right f)
    else if isGoal current
      then return (Left (reverse path))
      else if g == 0
        -- Top-level: fork threads for each initial move on the root of the tree
        then do
          let initialSuccessors = [(m, applyMove m current) | m <- moves]
          -- Create MVars to hold results from each thread
          results <- forM initialSuccessors $ \_ -> newEmptyMVar

          -- Launch a thread for each initial successor
          forM_ (zip initialSuccessors results) $ \((m, succCube), resVar) -> do
            forkIO $ do
              r <- depthLimitedSearch succCube (m:path) (g+1) threshold h moves
              putMVar resVar r

          -- Collect results from all threads
          -- If any returns a solution, return that solution
          -- Otherwise, find the min cutoff
          gatherResults results maxBound Nothing
        else do
          -- Normal case (not top-level): proceed as before
          let expand [] bestCutoff = return (Right bestCutoff)
              expand (m:ms) bestCutoff = do
                let successor = applyMove m current
                res <- depthLimitedSearch successor (m:path) (g+1) threshold h moves
                case res of
                  Left solution -> return (Left solution)
                  Right cutoff  -> expand ms (min bestCutoff cutoff)

          expand moves maxBound


-- A helper function to gather results from multiple threads
gatherResults :: [MVar (Either [Move] Int)] 
              -> Int   -- bestCutoff so far
              -> Maybe [Move] -- solution found so far
              -> IO (Either [Move] Int)
gatherResults [] bestCutoff Nothing = return (Right bestCutoff)
gatherResults [] _ (Just sol)      = return (Left sol)
gatherResults (resVar:rest) bestCutoff maybeSol = do
  result <- takeMVar resVar
  case result of
    Left solution ->
      -- Found a solution, we can ignore the rest, but we must still read them
      -- to avoid leaving threads blocked. In a real application, you might use
      -- more sophisticated cancellation, but here we just gather them.
      gatherResults rest bestCutoff (Just solution)
    Right cutoff ->
      if maybeSol == Nothing
        then gatherResults rest (min bestCutoff cutoff) Nothing
        else gatherResults rest bestCutoff maybeSol


main :: IO ()
main = do
    args <- getArgs
    if length args < 3 || length args > 4
      then do
        putStrLn "Usage: ./IDAStarPara <pdb_file> <n> <scramble_file> [log_level]"
        putStrLn "log_level can be DEBUG, INFO, ERROR"
        exitFailure
      else return ()

    let pdbFile       = args !! 0
    let nStr          = args !! 1
    let scrambleFile  = args !! 2

    -- Parse n
    let n = read nStr :: Int
    if n /= 2 && n /= 3
      then do
        putStrLn "Error: n must be 2 or 3."
        exitFailure
      else return ()

    -- Parse optional log level
    let logLevel = if length args == 4
                     then read (args !! 3) :: LogLevel
                     else INFO  -- default to INFO if not provided

    -- Load the PDB
    pdb <- loadPDB pdbFile

    -- Initialize solved cube
    let solvedCube = initCube n

    -- Read scramble file
    contents <- readFile scrambleFile
    let linesOfMoves = lines contents -- each line is a scramble sequence

    forM_ (zip ([1..] :: [Int]) linesOfMoves) $ \(idx, line) -> do
      logMsg logLevel INFO $ "\nSolving cube #" ++ show idx

      let scrambleMoves = parseLineOfMoves line
      let scrambledCube = applyMoves scrambleMoves solvedCube

      logMsg logLevel DEBUG "Scrambled Cube:"
      when (logLevel <= DEBUG) $ printCube scrambledCube

      logMsg logLevel INFO "Starting IDA* search in parallel..."
      maybeSolution <- idaStar scrambledCube (heuristic pdb) allMoves
      case maybeSolution of
        Just solMoves -> do
          -- Apply the solution moves to scrambledCube and verify it's solved:
          let verifiedCube = applyMoves solMoves scrambledCube
          logMsg logLevel DEBUG "After applying solution moves, cube state:"
          when (logLevel <= DEBUG) $ printCube verifiedCube
          if isGoal verifiedCube
            then logMsg logLevel INFO $ "Solution for cube #" ++ show idx ++ ": "++ unwords (map show solMoves)
            else logMsg logLevel ERROR $ "Error: no solution found for cube #" ++ show idx ++ "."

        Nothing -> logMsg logLevel ERROR $ "Error: no solution found for cube #" ++ show idx ++ "."