module Data.Vector.Parallel where

import           RIO
import qualified RIO.List                  as L
import           RIO.List.Partial          (tail)
import qualified RIO.Vector                as V

import           Control.Monad.Par

-- | Splits vector into list of chunks.
-- Chunk order is reversed
--
-- concatV . chunksOf == id
chunksOf :: Int -> Vector a -> [Vector a]
chunksOf n v = vs
  where
    vs = go v []
    go v' xs
      | null v' = xs
      | otherwise =
        let (c, t) = V.splitAt n v'
         in go t (c : xs)

-- | Concatenates chunks
--
-- concatV . chunksOf == id
concatV :: [Vector a] -> Vector a
concatV = go V.empty
  where
    go v' [] = v'
    go v' (v:vs) =
      let v'' = v <> v'
       in go v'' vs

-- | Parallel version of V.map
parMapV :: NFData b => (a -> b) -> Vector a -> Vector b
parMapV f va = concatV . runPar $ f' `mapM` chunks >>= traverse get
  where
    chunks = chunksOf 4096 va
    f' v = spawnP $ f <$> v

-- | Parallel version of V.imap
parIMapV :: NFData b => (Int -> a -> b) -> Vector a -> Vector b
parIMapV f va = concatV . runPar $ zipWithM f' acc chunks >>= traverse get
    where
        chunks = chunksOf 4096 va
        lengths = V.length <$> chunks
        acc = tail $ L.scanr (+) 0 lengths
        f' i0 as = spawnP $ V.imap (\i a -> f (i+i0) a) as
