{-# LANGUAGE DeriveAnyClass, DeriveGeneric #-}
{- Parallel Floyd-Warshall algorithm with 2-D block mapping in Haskell -}

module Parallel_fw_block (
        floyd_warshall_blocked
    ) where 

import Graph
import Control.Parallel
import Control.Parallel.Strategies
import Control.Monad.Par(runPar, get, spawnP)
import Data.List
import Debug.Trace
import GHC.Generics (Generic)

loops :: Int -> Int -> Int -> Int -> Int -> Int -> [Weight] -> [Weight] -> [Weight] -> [(Int, Weight)] -> Int -> [(Int, Weight)]
loops k n kth i j b l_a l_b input replaced c_index = 
    if i == b then replaced
    else 
        if j == b then loops k n kth (i + 1) 0 b l_a l_b input replaced c_index
        else 
            if element > sum1 then loops k n kth i (j+1) b l_a l_b input new_replaced c_index
            else 
                loops k n kth i (j+1) b l_a l_b input replaced c_index
                    where element = dataAt (c_index + (i*n + j)) input
                          sum1 = addWeights (dataAt (i*n + k) l_a) (dataAt (kth + j) l_b)
                          new_replaced = ((c_index + (i*n + j)), sum1):replaced
                    

floyd_warshall_in_place :: [Weight] -> [Weight] -> [Weight] -> Int -> Int -> Int -> Int -> [(Int, Weight)] -> [(Int, Weight)]
floyd_warshall_in_place l_a l_b input b n k c_index big_replaced = 
    if k == b then big_replaced
    else 
        floyd_warshall_in_place l_a l_b input b n (k+1) c_index new_big_replaced
            where 
                kth = k*n
                new_big_replaced = (loops k n kth 0 0 b l_a l_b input [] c_index)++big_replaced


inner_independent_phase :: Int -> Int -> Int -> Int -> Int -> [Weight] -> [(Int, Weight)] 
inner_independent_phase j i k b n input = 
    new_replaced
        where 
            l_a = drop (i*b*n + k*b) input
            l_b = drop (k*b*n + j*b) input
            new_replaced = (floyd_warshall_in_place l_a l_b input b n 0 (i*b*n + j*b) [])


independent_phase :: Int -> Int -> Int -> Int -> [Weight] -> [(Int, Weight)] -> [Weight]
independent_phase i k b n input replaced = 
    if i == (n `div` b) then res
    else 
        if i == k then independent_phase (i+1) k b n input replaced
        else 
            independent_phase (i+1) k b n new_input (new_replaced2++replaced)
                where 
                    l_a = drop (i*b*n + k*b) input
                    l_b = drop (k*b*n + k*b) input
                    new_replaced = floyd_warshall_in_place l_a l_b input b n 0 (i*b*n + k*b) []
                    new_input = replace_n_list new_replaced input
                    j_values = removeItem k [0..((n `div` b)-1)]
                    res = replace_n_list replaced input
                    new_replaced2 = runPar $ do 
                        m <-  mapM (\j -> spawnP (inner_independent_phase j i k b n new_input)) j_values
                        x <- mapM get m
                        return (concat x)

partially_dependent_phase :: Int -> [Weight] -> Int -> Int -> Int -> [(Int, Weight)]
partially_dependent_phase j input k n b = 
    new_replaced
        where 
            l_a = drop (k*b*n + k*b) input
            l_b = drop (k*b*n + j*b) input
            new_replaced = floyd_warshall_in_place l_a l_b input b n 0 (k*b*n + j*b) []

dependent_phase :: Int -> Int -> Int -> [Weight] -> [Weight]
dependent_phase k b n input = 
    if k == (n `div` b) then input
    else
        dependent_phase (k + 1) b n new_in_output
            where 
                l_a = drop (k*b*n + k*b) input
                l_b = drop (k*b*n + k*b) input
                big_replaced = floyd_warshall_in_place l_a l_b input b n 0 (k*b*n + k*b) []
                new_dep_output = replace_n_list big_replaced input
                j_values = removeItem k [0..((n `div` b)-1)]
                new_part_output = replace_n_list big_replaced2 new_dep_output
                new_in_output = independent_phase 0 k b n new_part_output []
                big_replaced2 = runPar $ do 
                    -- map (\j ...)  :: [Par (IVar [(,)])]
                    -- sequence (map ...) :: Par [IVar [(,)]]

                    m <-  mapM (\j -> spawnP (partially_dependent_phase j new_dep_output k n b)) j_values
                    -- m :: [IVar [(,)]]

                    x <- mapM get m -- : [[(,)]]   -- mapM :: (a -> m b) -> [a] -> m [b]
                    return (concat x)

floyd_warshall_blocked :: [Weight] -> Int -> Int -> [Weight]
floyd_warshall_blocked input n b = 
    dependent_phase 0 b n input 