-- Polynomial Multiplication using FFT in Parallelization with Par

module FFT.ParFMultPoly
(
  fft,
  mult_polys
) where

import Data.Complex
import Control.Parallel.Strategies(rseq, rpar, runEval)
import Control.DeepSeq(force, NFData)

import FFT.FMultPoly(split, convert)
import qualified FFT.FMultPoly(fft)

-- without paring combine, at the end there is a long period that only one core works
-- unable to do the same thing on split
combine :: [Complex Double] -> [Complex Double] -> Complex Double -> [Complex Double]
combine p1 p2 w = runEval $ do
  pf <- rpar (force (map (\(i, a, b) -> a + (w ** (fromIntegral i)) * b) zip_p))
  ps <- rpar (force (map (\(i, a, b) -> a - (w ** (fromIntegral i)) * b) zip_p))
  _ <- rseq pf
  _ <- rseq ps
  return $ pf ++ ps
  where
    zip_p = zip3 [0..((length p1) - 1)] p1 p2

-- parallel fft with depth
fft :: [Complex Double] -> Int -> Complex Double -> Int -> [Complex Double]
fft l 1 _ _ = l
fft l n w 0 = FFT.FMultPoly.fft l n w
fft l n w d = runEval $ do
  p1 <- rpar (force (fft l1 (n `div` 2) (w ** 2) (d - 1)))
  p2 <- rpar (force (fft l2 (n `div` 2) (w ** 2) (d - 1)))
  _ <- rseq p1
  _ <- rseq p2
  return $ combine p1 p2 w
  where
    (l1, l2) = split l

-- fft :: [Complex Double] -> Int -> Complex Double -> [Complex Double]
-- fft l n w = fft_d l n w 12 -- depth = 12

ifft :: [Complex Double] -> Int -> Complex Double -> Int -> [Complex Double]
ifft l n w d = map (\x -> x / (fromIntegral n)) (fft l n (1 / w) d)

mult_polys :: Int -> [Double] -> [Double] -> [Double]
mult_polys depth x y =
  take (length_x + length_y - 1) $ map (\a -> realPart a) (ifft fft_r n w depth)
  where
    length_x = length x
    length_y = length y
    -- n > 2*l, n is power of 2
    n = 2 ^ (ceiling $ logBase 2 (fromIntegral (2 * (max length_x length_y))))
    w = exp (- 2 * pi * (0 :+ 1) / (fromIntegral n))
    fft_x = fft (convert x (n - length_x)) n w depth
    fft_y = fft (convert y (n - length_y)) n w depth
    fft_r = map (\(a, b) -> a * b) (zip fft_x fft_y)