module ParallelDFS (parallelDFS) where

import Data.Set (Set, empty, insert, member, toList, union, unions, fromList)
import Data.List ((\\))
import DFSCommon (getNeighbors)
import Control.Monad.Par

type Node = (Int, Int)

-- Given an ancestor path, find the next node that has an unvisited neighbor.
-- Return this node and the new ancestor.
nextNode :: [[Char]] -> Set Node -> [Node] -> [Node]
nextNode _ _ [] = []
nextNode maze visited (parent:ancestors) =
  let neighbors = (getNeighbors maze parent \\ (toList visited)) 
  in if null neighbors
     then nextNode maze visited ancestors
     else (neighbors !! 0):parent:ancestors

-- Search a path of depth at most n.
-- Return: (found, new visited nodes, ancestors)
step :: [[Char]] -> Node -> Set Node -> Int -> [Node] -> (Bool, Set Node, [Node])
step _ _ _ _ [] = (False, empty, []) -- This happens when a thread has searched a whole maze
step maze goal visited depth (start:ancestors)
  | start == goal = (True, visited, start:ancestors)
  | depth == 0 = (False, visited, start:ancestors)
  | start `member` visited = step maze goal visited (depth - 1) $ nextNode maze visited ancestors
  | otherwise =
      let newVisited = insert start visited
          newAncestors = nextNode maze newVisited (start:ancestors)
      in step maze goal newVisited (depth - 1) newAncestors

-- Update the next search node for thread i if it has an empty context
updateContext :: [[Char]] -> [[Node]] -> Set Node -> Set Node -> Int -> [Node]
updateContext maze contexts visited assigned thread
  | not (null (contexts !! thread)) = contexts !! thread
  | length allAvailable <= thread = []
  | otherwise = allAvailable !! thread
  where excludedNodes = union visited assigned 
        available _ [] = [[]]
        available m (node:rest)
          | node `member` assigned = available m rest -- To make sure each thread searches a different path
          | otherwise = case (getNeighbors m node) \\ toList excludedNodes of
              [] -> available m rest
              neighbor:_ -> (neighbor:node:rest):(available m rest)
        allAvailable = concat [filter (not . null) (available maze context) | context <- contexts]

-- The parallel part of the algorithm.
parallelStep :: [[Char]] -> [[Node]] -> Node -> Set Node -> Int -> [(Bool, Set Node, [Node])]
parallelStep maze contexts goal visited maxDepth =
  runPar $ parMap (\context -> step maze goal visited maxDepth context) contexts

parallelDFS :: [[Char]] -> [[Node]] -> Node -> Set Node -> Int -> Int -> IO [Node]
parallelDFS maze contexts goal visited threads maxDepth = do
  let results = parallelStep maze contexts goal visited maxDepth
  let newVisited = union visited $ unions [v | (_, v, _) <- results]
  let newContexts = filter (not . null) [context | (_, _, context) <- results]
  let assigned = fromList [c !! 0 | c <- newContexts]
  let updatedContexts = filter (not . null) $ parallelUpdate maze newContexts newVisited assigned threads
  case filter (\(found, _, _) -> found) results of
    [] -> do
      -- If no solution is found, recursively call parallelDFS with updated contexts
      if null updatedContexts
        then return [] -- No more nodes to explore
        else parallelDFS maze updatedContexts goal newVisited threads maxDepth
    (_, _, ancestors):_ -> return ancestors

parallelUpdate :: [[Char]] -> [[Node]] -> Set Node -> Set Node -> Int -> [[Node]]
parallelUpdate maze contexts visited assigned threads =
  let filledContexts = [[] | _ <- [0 .. (threads - length contexts - 1)]] ++ contexts
  in runPar $ parMap (\i -> updateContext maze filledContexts visited assigned i) [0 .. (threads - 1)]
