{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Eta reduce" #-}
module AStarLib
    ( createNode,
        expandNode,
        isGoal,
        getCities,
        createReturnNode,
        heuristic,
    ) where

import Structures
import qualified Data.Set as Set
import qualified Data.Map as Map
import qualified Data.List as List


isGoal :: Node -> Set.Set City -> City -> Bool
isGoal node goalState startCity =
    Set.fromList (path node) == goalState && city node == startCity && length (path node) > 1

getCities :: CityGraph -> [City]
getCities (CityGraph graph) = Map.keys graph

expandNode :: Node -> CityGraph -> Set.Set City -> City -> [Node]
expandNode node cityGraph goalState startCity = successors
  where
    (CityGraph graph) = cityGraph
    currentCity = city node
    nodePath = path node
    visitedCities = Set.fromList nodePath
    possibleEdges = Map.findWithDefault [] currentCity graph
    unvisitedEdges = filter (\(Edge dest _) -> dest `Set.notMember` visitedCities) possibleEdges
    successors = [createNode node edge goalState cityGraph | edge <- unvisitedEdges] ++ returnToStartNode
    returnToStartNode = if Set.size visitedCities == Set.size goalState && currentCity /= startCity
                        then createReturnNode node cityGraph startCity
                        else []


-- Create a new node given an edge
createNode :: Node -> Edge -> Set.Set City -> CityGraph -> Node
createNode node (Edge dest dist) goalState cityGraph = Node {
    city = dest,
    path = dest : path node,
    gCost = gCost node + dist,
    fCost = gCost node + dist + heuristic dest (dest : path node) goalState cityGraph
}


-- Create a node that returns to the start city
createReturnNode :: Node -> CityGraph -> City -> [Node]
createReturnNode node cityGraph startCity =
    case edgeToStart of
        (Edge _ dist):_ -> [Node {
                                city = startCity,
                                path = startCity : path node,
                                gCost = gCost node + dist,
                                fCost = gCost node + dist
                            }]
        _ -> []
    where
        (CityGraph graph) = cityGraph
        currentCity = city node
        possibleEdges = Map.findWithDefault [] currentCity graph
        edgeToStart = filter (\(Edge dest _) -> dest == startCity) possibleEdges

-- Heuristic function using MST over unvisited cities
heuristic :: City -> [City] -> Set.Set City -> CityGraph -> Distance
-- heuristic _ _ _ _ = 0
heuristic currentCity currentPath goalState cityGraph = mstCost relevantCities cityGraph
  where
    visitedCities = Set.fromList currentPath
    unvisitedCities = Set.difference goalState visitedCities
    relevantCities = Set.insert currentCity unvisitedCities

mstCost :: Set.Set City -> CityGraph -> Distance
mstCost cities (CityGraph graph) = totalCost
  where
    (totalCost, _) = kruskal sortedEdges initialParentMap 0
    initialParentMap = Map.fromList [(city', city') | city' <- Set.toList cities]
    sortedEdges = List.sortBy (\(_, _, d1) (_, _, d2) -> compare d1 d2) edges
    edges = getEdgesBetween cities graph


type Edge3 = (City, City, Distance)

-- Get all edges between the given cities
getEdgesBetween :: Set.Set City -> Map.Map City [Edge] -> [Edge3]
getEdgesBetween cities graph = edges
  where
    edges = concatMap collectEdges (Set.toList cities)
    collectEdges u = [(u, v, d) | Edge v d <- Map.findWithDefault [] u graph,
                                  v `Set.member` cities,
                                  u <= v]  -- Avoid duplicates

kruskal :: [Edge3] -> Map.Map City City -> Distance -> (Distance, Map.Map City City)
kruskal [] parentMap accCost = (accCost, parentMap)
kruskal ((u,v,d):es) parentMap accCost =
    let rootU = find u parentMap
        rootV = find v parentMap
    in if rootU /= rootV
       then let updatedParentMap = union rootU rootV parentMap
                newAccCost = accCost + d
            in kruskal es updatedParentMap newAccCost
       else kruskal es parentMap accCost

-- Find in "Union-Find"
find :: City -> Map.Map City City -> City
find city' parentMap =
    let parent = Map.findWithDefault city' city' parentMap
    in if city' == parent
       then city'
       else find parent parentMap

-- Union in "Union-Find"
union :: City -> City -> Map.Map City City -> Map.Map City City
union = Map.insert 
