import System.Environment (getArgs, getProgName)
import System.Exit (die)
import qualified Data.Map as Map
import qualified Data.Set as S
import Data.Maybe (catMaybes, listToMaybe)
import Control.Parallel.Strategies (using, parList, rdeepseq)


type Graph = Map.Map Int [Int] -- adjacency list repr
type GraphColoring = Map.Map Int Int -- node to color mapping. 0 is unassigned, 1 to k are the different colors

main :: IO ()
main = do
  args <- getArgs
  case args of
    [fileName, kStr, algo] -> runSolver fileName kStr algo
    _ -> getProgName >>= (\progName -> die ("Usage: " ++ progName ++ " <filename>" ++ " k" ++ " algo"))

-- general utils

loadGraph :: String -> Graph
loadGraph =
    Map.fromList . map processNode . lines
  where
    processNode :: String -> (Int, [Int])
    processNode nodeLine = 
      let (nodeIDStr, neighborsStr) = break (== ':') nodeLine
          neighborsWithoutColon = drop 1 neighborsStr 
          neighborList = words neighborsWithoutColon
          children = map read neighborList
          node = read nodeIDStr
      in (node, children)

checkColoring :: Graph -> GraphColoring -> Int -> Bool -- for each node, assert that the coloring for each of the neighbors is not the same
checkColoring graph coloring k =
    Map.foldrWithKey checkNode True coloring
    where
        checkNode :: Int -> Int -> Bool -> Bool
        checkNode currNode color acc = acc && (color >= 0 && color <= k) && (checkNeighbors currNode color)
          where
            checkNeighbors :: Int -> Int -> Bool
            checkNeighbors node col = (col == 0) || not (any (\n -> Map.findWithDefault 0 n coloring == col) (Map.findWithDefault [] node graph))

returnMaybeColoring :: Graph -> GraphColoring -> Int -> Maybe (GraphColoring)
returnMaybeColoring graph coloring k
    | checkColoring graph coloring k = Just coloring
    | otherwise = Nothing

initColoring :: Int -> GraphColoring
initColoring n = (Map.fromList [(i, 0) | i <- [0..(n - 1)]])

-- BRUTE FORCE

runBruteForce :: Graph -> Int -> GraphColoring -> Int -> Int -> Maybe (GraphColoring)
runBruteForce graph k col v c
    | v == Map.size graph = returnMaybeColoring graph col k
    | otherwise = listToMaybe (catMaybes [ runBruteForce graph k (Map.insert v c col) (v + 1) c_ | c_ <- [1..k] ])

bruteForceAlgorithm :: Graph -> Int -> Int -> Maybe (GraphColoring)
bruteForceAlgorithm graph k n = listToMaybe (catMaybes [ runBruteForce graph k (initColoring n) 0 c | c <- [1..k] ])

-- Pruning

runPruning :: Graph -> Int -> GraphColoring -> Int -> Int -> Maybe (GraphColoring)
runPruning graph k col v c
    | v == Map.size graph = returnMaybeColoring graph col k
    | checkColoring graph col k == False = Nothing
    | otherwise = listToMaybe (catMaybes [ runPruning graph k (Map.insert v c col) (v + 1) c_ | c_ <- [1..k] ])

pruningAlgorithm :: Graph -> Int -> Int -> Maybe (GraphColoring)
pruningAlgorithm graph k n = listToMaybe (catMaybes [ runPruning graph k (initColoring n) 0 c | c <- [1..k] ])


-- DSATUR

unassignedNodes :: GraphColoring -> [Int]
unassignedNodes coloring = Map.foldrWithKey (\k v acc -> if v == 0 then k : acc else acc) [] coloring

getSatur :: Graph -> GraphColoring -> Int -> Int
getSatur graph coloring node = 
    let neighbours = Map.findWithDefault [] node graph
        neighbourNonZeroLabels = (S.size . S.fromList . filter (/= 0)) [Map.findWithDefault 0 n coloring | n <- neighbours] 
    in neighbourNonZeroLabels

saturs :: Graph -> GraphColoring -> Map.Map Int Int
saturs graph coloring = 
    let unassigned = unassignedNodes coloring
        allSaturs = Map.fromList ([(v, getSatur graph coloring v) | v <- unassigned])
    in allSaturs

maxDSATUR :: Graph -> GraphColoring -> Int
maxDSATUR graph coloring = 
    let allSaturs = saturs graph coloring
        nextNode = fst $ Map.foldrWithKey (\k v (currK, currV) -> if v > currV then (k, v) else (currK, currV)) (-1, -1) allSaturs
    in nextNode

runDSATUR :: Graph -> Int -> GraphColoring -> Int -> Int -> Maybe (GraphColoring)
runDSATUR graph k col v c
    | v == -1 = returnMaybeColoring graph col k
    | checkColoring graph col k == False = Nothing
    | otherwise = listToMaybe (catMaybes [ runDSATUR graph k nextColoring (maxDSATUR graph nextColoring) c_ | c_ <- [1..k] ])
    where nextColoring = (Map.insert v c col)

dSATURAlgorithm :: Graph -> Int -> Int -> Maybe (GraphColoring)
dSATURAlgorithm graph k n = listToMaybe (catMaybes [ runDSATUR graph k (initColoring n) 0 c | c <- [1..k] ]) -- at the start it is fine to start with 0 as all nodes have 0 saturation

-- Parallel DSATUR

parallelSaturs :: Graph -> GraphColoring -> Map.Map Int Int
parallelSaturs graph coloring = 
    let unassigned = unassignedNodes coloring
        allSaturs = Map.fromList ([(v, getSatur graph coloring v) | v <- unassigned] `using` parList rdeepseq)
    in allSaturs


parallelMaxDSATUR :: Graph -> GraphColoring -> Int
parallelMaxDSATUR graph coloring = 
    let allSaturs = parallelSaturs graph coloring
        nextNode = fst $ Map.foldrWithKey (\k v (currK, currV) -> if v > currV then (k, v) else (currK, currV)) (-1, -1) allSaturs
    in nextNode

runParallelDSATUR :: Graph -> Int -> GraphColoring -> Int -> Int -> Int -> Maybe (GraphColoring)
runParallelDSATUR graph k col v c d
    | v == -1 = returnMaybeColoring graph col k
    | checkColoring graph col k == False = Nothing
    | d < 4 = listToMaybe (catMaybes ([ runParallelDSATUR graph k nextColoring (parallelMaxDSATUR graph nextColoring) c_ (d + 1) | c_ <- [1..k] ] `using` parList rdeepseq))
    | otherwise = listToMaybe (catMaybes [ runParallelDSATUR graph k nextColoring (parallelMaxDSATUR graph nextColoring) c_ (d + 1) | c_ <- [1..k] ])
    where nextColoring = (Map.insert v c col)

parallelDSATURAlgorithm :: Graph -> Int -> Int -> Maybe (GraphColoring)
parallelDSATURAlgorithm graph k n = listToMaybe (catMaybes ([ runParallelDSATUR graph k (initColoring n) 0 c 1 | c <- [1..k] ] `using` parList rdeepseq)) -- at the start it is fine to start with 0 as all nodes have 0 saturation


-- Parallel Pruning

runParallelPruning :: Graph -> Int -> GraphColoring -> Int -> Int -> Int -> Maybe (GraphColoring)
runParallelPruning graph k col v c d
    | v == Map.size graph = returnMaybeColoring graph col k
    | checkColoring graph col k == False = Nothing
    | d < 4 = listToMaybe (catMaybes ([ runParallelPruning graph k (Map.insert v c col) (v + 1) c_ (d + 1) | c_ <- [1..k] ] `using` parList rdeepseq))
    | otherwise = listToMaybe (catMaybes [ runParallelPruning graph k (Map.insert v c col) (v + 1) c_ (d + 1) | c_ <- [1..k] ])

parallelPruningAlgorithm :: Graph -> Int -> Int -> Maybe (GraphColoring)
parallelPruningAlgorithm graph k n = listToMaybe (catMaybes ([ runParallelPruning graph k (initColoring n) 0 c 1 | c <- [1..k] ] `using` parList rdeepseq))



-- runner

runAlgorithm :: Graph -> Int -> String -> Int -> Maybe (GraphColoring)
runAlgorithm graph k algo n
    | algo == "parallel_DSATUR" = parallelDSATURAlgorithm graph k n
    | algo == "parallel_pruning" = parallelPruningAlgorithm graph k n
    | algo == "DSATUR" = dSATURAlgorithm graph k n
    | algo == "pruning" = pruningAlgorithm graph k n
    | algo == "brute_force" = bruteForceAlgorithm graph k n
    | otherwise = error ("Unknown algorithm: " ++ algo)


runSolver :: String -> String -> String -> IO ()
runSolver fileName kStr algo = do
    inputData <- readFile fileName
    let graph = loadGraph inputData
        k = read kStr :: Int
        n = Map.size graph
        coloring = runAlgorithm graph k algo n
    --print graph
    print k
    print coloring

    --putStrLn inputData