{-
File: par.hs
Author: Max Levatich
-}

import Control.Parallel
import Control.Monad
import System.Environment

-- PRELIMINARIES:
-- `stack install parallel` (for access to Control.Parallel and `par`)
-- Install ThreadScope and its dependencies: https://wiki.haskell.org/ThreadScope
-- Test that ThreadScope works with `threadscope --test ch8`

-- COMPILE WITH:
-- `stack ghc -- -O2 -threaded -rtsopts par`
-- `-O2` is for optimization level, and is not needed in practice.

-- RUN WITH:
-- `./par 40 +RTS -N2`
-- N2 means 2 threads. Experiment with different N <= your machine's available cores
-- Add `-s` flag to print debug info to stdout
-- Add `-l` to create a `.eventlog` file for ThreadScope
-- Run ThreadScope with `./path/to/your/executable/threadscope par.eventlog`

-- A recursive computation of the nth fibonacci number.
-- This is extraordinarily inefficient (2^n operations instead of n operations for `fib n`),
-- But a good case study for basic parallelism.
-- We use `par` to tell GHC that f1 can be evaluated in parallel.
fib :: Integer -> Integer
fib 0 = 0
fib 1 = 1
fib n = par f1 (f1 + f2)
    where f1 = fib (n - 1)
          f2 = fib (n - 2)

-- To avoid creating millions of sparks, overflowing the spark pool and
-- doing wasted work creating and deleting them, this version of fib
-- only creates sparks up to a certain recursion depth. Achieves much
-- better speedup! 
fibd :: Int -> Integer -> Integer
fibd _ 0 = 0
fibd _ 1 = 1
fibd d n
    | d < target_d = f1 `par` f2 `pseq` (f1 + f2)
    | otherwise    = f1 + f2
    where f1 = fibd (d + 1) (n - 1)
          f2 = fibd (d + 1) (n - 2)

-- Target recursion depth for fibd. A high enough depth (>10 ish) is needed for the
-- sparks to be small enough pieces of work for GHC to effectively balance
-- the workload between threads.
target_d :: Int
target_d = 12

-- Invoke fib (or fibd with depth 0) with n as a command-line argument.
main :: IO ()
main = do
    a <- getArgs
    case a of
        [s] | n >= 0 -> print (fibd 0 n) where n = read s :: Integer
        _            -> putStrLn "one argument, must be integer >= 0"