module ParQueueProcessing
    ( tspSearch
    ) where

import AStarLib
import Structures
import qualified Data.HashMap.Strict as HashMap
import qualified Data.Set as Set
import qualified Data.PQueue.Min as PQ
import Data.Maybe (fromJust)
import Data.List
import Control.Parallel.Strategies

-- Assume 0 heuristic for simpler code.
-- Parallel and apply strategies (?)
-- Processing items off the priority queue in parallel. This should speed up better than the other attempts
-- since it will prioritize good Nodes.

tspSearch :: CityGraph -> Route
tspSearch cityGraph = Route (reverse (path bestNode))
  where
    cities = getCities cityGraph
    startCity = head cities  -- arbitrary
    goalState = Set.fromList cities
    initialNode = Node startCity [startCity] 0 0
    frontier = PQ.singleton initialNode
    visitedStates = HashMap.empty
    -- this is the max number of states to process in parallel, per iteration.
    k = 600
    -- maximum depth to explore for each node
    maxDepth = 1
    bestNode = parAstarSearch frontier visitedStates cityGraph goalState startCity k maxDepth


parAstarSearch :: PQ.MinQueue Node -> HashMap.HashMap (City, Set.Set City) Int -> CityGraph -> Set.Set City -> City -> Int -> Int -> Node
parAstarSearch frontier visitedStates cityGraph goalState startCity k maxDepth
  | PQ.null frontier = Node startCity [] 9999999999 9999999999 -- return a clearly bad solution instead of erroring.
  -- If the current front of the queue is the goal, we can exit.
  | isGoal currentNode goalState startCity = currentNode
  | otherwise = parAstarSearch updatedFrontier newVisitedStates cityGraph goalState startCity k maxDepth
  where
    -- only peek the currentNode for the purpose of checking exit condition.
    currentNode = fromJust (PQ.getMin frontier)
    -- take k items from the queue
    (nodeBatch, reducedFrontier)  = PQ.splitAt k frontier

    -- Filter out nodes that have been visited with a lower cost, or are goal states
    filterConditions = fmap (\n -> (n, isGoal n goalState startCity || not (hasVisitedBefore n visitedStates))) nodeBatch `using` parListChunk 10 rdeepseq
    filteredNodeBatch = [fst n | n <- filter snd filterConditions]

    expansions = fmap (\node -> depthLimitedExpansion node maxDepth cityGraph goalState) filteredNodeBatch `using` parListChunk 10 rdeepseq

    allNewNodes = PQ.unions expansions
    newVisitedStates = foldr (\n -> HashMap.insert (city n, Set.fromList (path n)) (gCost n)) visitedStates filteredNodeBatch
    updatedFrontier = PQ.union reducedFrontier allNewNodes

-- We consider a state visited if it's in the HashMap AND has a lower or equal cost than the current node.
hasVisitedBefore :: Node -> HashMap.HashMap (City, Set.Set City) Int -> Bool
hasVisitedBefore n visitedStates
  | HashMap.null visitedStates = False
  | otherwise = case HashMap.lookup element visitedStates of
                  Just cost -> gCost n >= cost
                  Nothing -> False
  where
    element = (city n, Set.fromList (path n))

-- depthLimitedExpansion should only generate successors, NOT explore them.
depthLimitedExpansion :: Node -> Int -> CityGraph -> Set.Set City -> PQ.MinQueue Node
depthLimitedExpansion startNode maxDepth cityGraph goalState = explore [startNode] 0
  where
    explore :: [Node] -> Int -> PQ.MinQueue Node
    explore [] _ = PQ.empty
    explore n_list k
      | k >= maxDepth = PQ.fromList n_list
      | otherwise =
          let goals = filter (\n -> isGoal n goalState 0) n_list
              successors = concat(fmap (\n -> expandNode n cityGraph goalState 0) n_list)
          in explore (successors ++ goals) (k+1)

-- This depth expansion is interesting, but too complicated to be correct. 
-- We want depth expansion to be generating children only, not marking nodes as 'explored'.
-- depthLimitedExpansion :: Node -> Int -> CityGraph -> Set.Set City -> PQ.MinQueue Node
-- depthLimitedExpansion startNode maxDepth cityGraph goalState = explore [(startNode, 0)] Set.empty PQ.empty
--   where
--     explore :: [(Node, Int)] -> Set.Set (City, [City]) -> PQ.MinQueue Node -> PQ.MinQueue Node
--     explore [] _ acc = acc
--     explore ((n, d):ns) visited acc
--       | d > maxDepth = acc
--       | (city n, path n) `Set.member` visited = explore ns visited acc
--       | otherwise =
--           let newVisited = Set.insert (city n, path n) visited
--               -- successors = PQ.toList (expandNode n cityGraph Set.empty 0)
--               successors = expandNode n cityGraph goalState 0
--               unvisitedSuccessors = [(s, d + 1) | s <- successors, (city s, path s) `Set.notMember` newVisited]
--               newAcc = PQ.insert n acc
--           in explore (ns ++ unvisitedSuccessors) newVisited newAcc


-- depthLimitedExpansion :: Node -> Int -> CityGraph -> Set.Set City -> PQ.MinQueue Node
-- depthLimitedExpansion startNode maxDepth cityGraph goalState = explore [(startNode, 0)] PQ.empty
--   where
--     explore :: [(Node, Int)] -> PQ.MinQueue Node -> PQ.MinQueue Node
--     explore [] acc = acc
--     explore ((n, d):ns) acc
--       | d >= maxDepth = acc
--       | otherwise =
--           let successors = expandNode n cityGraph goalState 0
--               newAcc = foldr PQ.insert acc successors
--           in explore (ns ++ [(s, d+1) | s<-successors]) newAcc

-- Checks if the current node is the goal state
-- isGoal :: Node -> Set.Set City -> City -> Bool
-- isGoal n goalState startCity = Set.fromList (path n) == goalState && city n == startCity
