module FFT.ParCTMultPoly
(
  be_fft,
  splitChunks,
  mult_polys
) where

import Data.Complex
import Control.Parallel(pseq)
import Control.Parallel.Strategies(rseq, rpar, runEval, parListChunk, parMap, rdeepseq, NFData, Eval)
import Control.DeepSeq(force)

import FFT.FMultPoly(split, convert)

bit_reverse_l :: NFData a => [a] -> Int -> [a]
bit_reverse_l [] _ = []
bit_reverse_l [x] _ = [x]
bit_reverse_l x 0 = (bit_reverse_l left 0) ++ (bit_reverse_l right 0)
  where (left, right) = split x
bit_reverse_l x d = runEval $ do
  l_rev <- rpar (force (bit_reverse_l left (d - 1)))
  r_rev <- rpar (force (bit_reverse_l right (d - 1)))
  _ <- rseq l_rev
  _ <- rseq r_rev
  return $ l_rev ++ r_rev
  where
    (left, right) = split x

split_l :: [a] -> Int -> ([a], [a])
split_l l interval = split_l_helper l interval 0
  where
    half_interval = interval `div` 2
    split_l_helper [] _ _ = ([], [])
    split_l_helper (x:xs) interval index
      | index `mod` interval < half_interval = (x:left, right)
      | otherwise = (left, x:right)
      where
        (left, right) = split_l_helper xs interval (index + 1)

combine_l :: [a] -> [a] -> Int -> [a]
combine_l left right interval = combine_l_helper left right interval half_interval 0
  where
    half_interval = interval `div` 2
    combine_l_helper [] [] _ _ _ = []
    combine_l_helper [] right _ _ _ = right
    combine_l_helper left [] _ _ _ = left
    combine_l_helper left@(l:ls) right@(r:rs) interval half_interval index
      | index < half_interval = l : (combine_l_helper ls right interval half_interval (index + 1))
      | index == interval - 1 = r : (combine_l_helper left rs interval half_interval 0)
      | otherwise = r : (combine_l_helper left rs interval half_interval (index + 1))

splitChunks :: Int -> [a] -> [[a]]
splitChunks _ [] = []
splitChunks n l = f : splitChunks n s
  where
    (f, s) = splitAt n l

butterfly :: Int -> Int -> [Complex Double] -> Int -> [Complex Double]
butterfly m n l len = combine_l fft_j_l fft_k_l n
  where
    half_n = n `div` 2
    half_len = (length l) `div` 2
    (j_l, k_l) = split_l l n
    w_l = map (\i -> (exp (-2 * pi * (0:+1) * (fromIntegral ((i `mod` half_n) * m)) / (fromIntegral len)))) [0..(half_len - 1)]
    j_k_w_l = zip3 j_l k_l w_l
    fft_j_l = map (\(j, k, w) -> j + w * k) j_k_w_l
    fft_k_l = map (\(j, k, w) -> j - w * k) j_k_w_l

par_butterfly :: Int -> Int -> [[Complex Double]] -> Int -> Int -> [[Complex Double]]
par_butterfly iter_start iter_end par_lists num_bits len =
  parMap rdeepseq par_butterfly_func par_lists
  where
    fold_butterfly l iter = butterfly (2 ^ (num_bits - iter)) (2 ^ iter) l len
    par_butterfly_func l = foldl fold_butterfly l [iter_start..iter_end]

butterfly_interact :: Int -> Int -> [Complex Double] -> Int -> Int -> [Complex Double]
butterfly_interact m n l par_size len = runEval $ do
  fft_j_l <- parListChunk par_size rdeepseq (map (\(j, k, w) -> j + w * k) j_k_w_l)  -- not balanced in the end
  fft_k_l <- parListChunk par_size rdeepseq (map (\(j, k, w) -> j - w * k) j_k_w_l)
  _ <- rseq fft_j_l
  _ <- rseq fft_k_l
  return $ combine_l fft_j_l fft_k_l n
  where
    half_n = n `div` 2
    half_len = (length l) `div` 2
    (j_l, k_l) = split_l l n
    w_l = map (\i -> (exp (-2 * pi * (0:+1) * (fromIntegral ((i `mod` half_n) * m)) / (fromIntegral len)))) [0..(half_len - 1)]
    j_k_w_l = zip3 j_l k_l w_l

be_fft  :: [Complex Double] -> Int -> [Complex Double]
be_fft l num_c = end_list `pseq` interact_list  -- TODO
  where
    len = length l
    num_bits = ceiling $ logBase 2 (fromIntegral len)
    -- take min to prevent num_c > length of list
    num_c_bits = ceiling $ logBase 2 (fromIntegral (min num_c len))
    c_partition_size = 2 ^ (num_bits - num_c_bits)

    fold_butterfly_interact l iter = butterfly_interact (2 ^ (num_bits - iter)) (2 ^ iter) l (c_partition_size `div` 2) len

    rev_l = bit_reverse_l l num_c -- O(nlogn)

    split_lists = splitChunks c_partition_size rev_l
    end_list = concat (par_butterfly 1 (num_bits - num_c_bits) split_lists num_bits len)
    interact_list = foldl fold_butterfly_interact end_list [(num_bits - num_c_bits + 1)..num_bits]

inverse_be_fft :: [Complex Double] -> Int -> [Complex Double]
inverse_be_fft l num_c =
  map (\x -> (conjugate x) / (fromIntegral len)) fft_con_l
  where
    len = length l
    con_l = map conjugate l
    fft_con_l = be_fft con_l num_c

mult_polys :: Int -> [Double] -> [Double] -> [Double]
mult_polys num_c x y =
  take (length_x + length_y - 1) $ map (\a -> realPart a) (inverse_be_fft fft_r num_c)
  where
    -- num_c = 4
    length_x = length x
    length_y = length y
    num_bits = ceiling $ logBase 2 (fromIntegral (2 * (max length_x length_y)))
    n = 2 ^ num_bits
    fft_x = be_fft (convert x (n - length_x)) num_c
    fft_y = be_fft (convert y (n - length_y)) num_c
    fft_r = map (\(a, b) -> a * b) (zip fft_x fft_y)