{- In stack.yaml:

snapshot: lts-22.33
packages:
- .
extra-deps:
- repa-3.4.2.0

-}

{- In package.yaml:

name: repa-ex
dependencies:
- base >= 4.7 && < 5
ghc-options:
- -Wall
executables:
  repa-exe:
    main: Main.hs
    source-dirs: app
    ghc-options:
    - -threaded
    - -rtsopts
    - -with-rtsopts=-N
    - -XTypeOperators
    dependencies:
    - repa

-}

-- Then place this file in app/Main.hs, and run `stack build` or `stack ghci`

{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE CPP, BangPatterns #-}
{-# OPTIONS_GHC -Wall -fno-warn-name-shadowing #-}

import System.Environment
import Data.Array.Repa
import Data.Functor.Identity

-- Creating an unboxed one-dimensional 10 element array
-- Index with `!` (infix)
-- Get shape info with `extent`, `size`, `rank`
a :: Array U DIM1 Int
a = fromListUnboxed (Z :. 10) [1..10]

-- Creating an unboxed two-dimensional 5x6 array (matrix)
b :: Array U DIM2 Int
b = fromListUnboxed (Z :. 5 :. 6) [1..30]

-- Using `fromFunction` to create a Delayed array with the same dimensions as b.
c :: Array D DIM2 Int
c = fromFunction (extent b) (\(Z :. i :. j) -> i + j)

-- A function to map over a delayed array, preserving array fusion.
repaMap :: (Source r t, Shape sh) => (t -> a) -> Array r sh t -> Array D sh a
repaMap f arr = fromFunction (extent arr) (\i -> f (arr ! i))

-- To compute the results in a Delayed array, use `computeS` for serial:
-- computeS $ repaMap (+1) b :: Array U DIM2 Int
-- To compute the results in parallel on each element, use `computeP` from within a Monad:
-- runIdentity $ computeP $ repaMap (+1) b :: Array U DIM2 Int

-- Why Repa and Accelerate? Massively parallel computations demand basically zero overhead!
-- Eval and Par monads have overhead that prevents us from just creating a million threads.

---------------------------------------------------------------------------------------
-- Parallel Floyd-Warshall algorithm on a dense adjacency matrix with repa and computeP

type Weight = Int
type Graph r = Array r DIM2 Weight

shortestPaths :: Graph U -> Graph U
shortestPaths g0 = runIdentity $ go g0 0
  where
    Z :. _ :. n = extent g0
    go !g !k 
        | k == n    = return g
        | otherwise = do g' <- computeP (fromFunction (Z:.n:.n) sp)
                         go g' (k+1)
        where sp (Z:.i:.j) = min (g ! (Z:.i:.j)) (g ! (Z:.i:.k) + g ! (Z:.k:.j))

main :: IO ()
main = do
   [n] <- fmap (fmap read) getArgs
   let g = fromListUnboxed (Z:.n:.n) [1..n^(2::Int)] :: Graph U
   print (sumAllS (shortestPaths g))

