-- --------------------------------------------------------------------------------------------------------
import Data.List as List
import System.IO
import Control.Parallel(par)
import Control.Parallel.Strategies(withStrategy, parList, rseq)
import System.Environment


-- --------------------------------------------------------------------------------------------------------
-- Union Find Implementation
-- https://github.com/jgrant27/jngmisc/blob/master/haskell/quick_union.hs
import Data.Sequence as Seq
data DisjointSet = DisjointSet
     { count :: Int, ids :: (Seq Int), sizes :: (Seq Int) }
     deriving (Read,  Show)
-- Return id of root object
findRoot :: DisjointSet -> Int -> Int
findRoot set p | p == parent = p
               | otherwise   = findRoot set parent
               where
                parent = index (ids set) (p - 1)
-- Are objects P and Q connected ?
connected :: DisjointSet -> Int -> Int -> Bool
connected set p q = findRoot set q `par` (findRoot set p) == (findRoot set q)
-- Replace sets containing P and Q with their union
quickUnion :: DisjointSet -> Int -> Int -> DisjointSet
quickUnion set p q | i == j = set
                   | otherwise = DisjointSet cnt rids rsizes
                     where
                        (i, j)   = findRoot set q `par` (findRoot set p, findRoot set q)
                        (i1, j1) = (index (sizes set) (i - 1), index (sizes set) (j - 1))
                        (cnt, psmaller, size) = (count set - 1, i1 < j1, i1 + j1)
                        -- Always make smaller root point to the larger one
                        (rids, rsizes) = if psmaller
                                         then (update (i - 1) j (ids set), update (j - 1) size (sizes set))
                                         else (update (j - 1) i (ids set), update (i - 1) size (sizes set))
-- --------------------------------------------------------------------------------------------------------




-- --------------------------------------------------------------------------------------------------------
getEdge :: Handle -> IO (Int, Int, Int)
getEdge fileHandle = 
      do
            line <- hGetLine fileHandle
            let [u, v, w] = words line
            return ((read u :: Int), (read v :: Int), (read w :: Int))


readGraphFile :: String -> (Handle -> IO a) -> IO [a]
readGraphFile fileName handleFunction = do
      fileHandle <- openFile fileName ReadMode
      res <- readLine fileHandle handleFunction []
      return res

readLine :: Handle -> (Handle -> IO a) -> [a] -> IO [a]
readLine fileHandle handleFunction cur = do
      end <- hIsEOF fileHandle
      if end
        then return cur
        else do
            p <- handleFunction fileHandle
            let cur' = p:cur
            readLine fileHandle handleFunction cur'


readGraph :: String -> IO [(Int, Int, Int)]
readGraph fileName = readGraphFile fileName getEdge


-- --------------------------------------------------------------------------------------------------------

diff_edge :: DisjointSet -> (Int, Int, Int) -> Bool
diff_edge m (u, v, _) = not $ connected m (u+1) (v+1)

compareEdges :: (Int, Int, Int) -> (Int, Int, Int) -> Ordering
compareEdges (u1, v1, w1) (u2, v2, w2) = compare (w1, u1, v1) (w2, u2, v2)

compare_pivot :: Int -> (Int, Int, Int) -> Bool 
compare_pivot p (_, _, w) = (w <= p)


-- --------------------------------------------------------------------------------------------------------
-- Parallel Stuff
-- --------------------------------------------------------------------------------------------------------


sortEdges :: Int -> [(Int, Int, Int)] -> [(Int, Int, Int)]
sortEdges n (x:xs)
  | n > 0 = b `par` a ++ x:b
  | otherwise = a ++ x:b
  where
    a = sortEdges (n-1) $ parFilter (\h -> compareEdges h x /= GT) xs
    b = sortEdges (n-1) $ parFilter (\h -> compareEdges h x == GT) xs
sortEdges _ [] = []


parFilter :: (a -> Bool) -> [a] -> [a]
parFilter p = withStrategy (parList rseq) . List.filter p


filter' :: [(Int, Int, Int)] -> DisjointSet -> [(Int, Int, Int)]
filter' e m = parFilter (diff_edge m) e

partitionEdges :: [(Int, Int, Int)] -> Int -> ([(Int, Int, Int)], [(Int, Int, Int)])
partitionEdges e pivot = r `par` (l, r)
  where
    l = parFilter (compare_pivot pivot) e
    r = parFilter (not . (compare_pivot pivot)) e



-- --------------------------------------------------------------------------------------------------------
-- Core Algorithm
-- --------------------------------------------------------------------------------------------------------

kruskal :: [(Int, Int, Int)] -> [(Int, Int, Int)] -> DisjointSet -> ([(Int, Int, Int)], DisjointSet)
kruskal [] t m = (t, m)
kruskal e t m
  | diff_edge m (u, v, w) = kruskal (tail e) ((u,v,w):t) (quickUnion m (u+1) (v+1))
  | otherwise = kruskal (tail e) t m
  where
    (u, v, w) = head e


filterKruskal :: [(Int, Int, Int)] -> DisjointSet -> Int -> Int -> Int -> Int -> ([(Int, Int, Int)], DisjointSet)
filterKruskal e m depth maxDepth threshold sortDepth
  | (depth > maxDepth) || (List.length e < threshold) = kruskal (sortEdges sortDepth e) [] m
  | otherwise =  (res_l ++ res_r, new_m_r)  -- `S.using` (S.parTuple2 S.rseq S.rseq)
    where
      (_, _, pivot) = head e
      (e_l, e_r) = partitionEdges e pivot
      (res_l, new_m_l) = filterKruskal e_l m (depth+1) maxDepth threshold sortDepth
      e_r' = filter' e_r new_m_l
      (res_r, new_m_r) = filterKruskal e_r' new_m_l (depth+1) maxDepth threshold sortDepth



main :: IO()
main = do
  -- let edges = [(0, 1, 10), (0, 2, 6),(0, 3, 5),(1, 3, 15), (2, 3, 4)]
  -- let edges = [(0,1,7),(1,2,8),(0,3,5),(1,3,9),(1,4,7),(2,4,5),(3,4,15),(3,5,6),(4,5,8),(4,6,9),(5,6,11)]
  args <- getArgs
  let [fileName, sortDepth] = args
  edges <- readGraph fileName
  let maxDepth = 100
  let threshold = 1000
  let n = maximum [b+1 | (_,b,_) <- edges]
  let uf = DisjointSet n (Seq.fromList [1..n]) (Seq.replicate n 1)
  let (res, _) = filterKruskal edges uf 0 maxDepth threshold (read sortDepth :: Int)
  print ("Minimum Weight: " ++ (show (sum [c | (_,_,c) <- res])))




