--  Polynomial Multiplication with Iterative Fast Fourier Transform (Cooley-Tukey Algorithm) in Serialization

module FFT.CTMultPoly
(
  mult_polys,
  iter_fft,
  inverse_iter_fft
) where

import Data.Complex
import FFT.FMultPoly(split, convert)

bit_reverse_l :: [a] -> [a]
bit_reverse_l [] = []
bit_reverse_l [x] = [x]
bit_reverse_l x =
  l_rev ++ r_rev
  where
    (left, right) = split x
    l_rev = bit_reverse_l left
    r_rev = bit_reverse_l right

split_l :: [a] -> Int -> ([a], [a])
split_l l interval = split_l_helper l interval 0
  where
    half_interval = interval `div` 2
    split_l_helper [] _ _ = ([], [])
    split_l_helper (x:xs) interval index
      | index `mod` interval < half_interval = (x:left, right)
      | otherwise = (left, x:right)
      where
        (left, right) = split_l_helper xs interval (index + 1)

combine_l :: [a] -> [a] -> Int -> [a]
combine_l left right interval = combine_l_helper left right interval 0
  where
    half_interval = interval `div` 2
    combine_l_helper [] [] _ _ = []
    combine_l_helper [] right _ _ = right
    combine_l_helper left [] _ _ = left
    combine_l_helper left@(l:ls) right@(r:rs) interval index
      | index `mod` interval < half_interval = l : (combine_l_helper ls right interval (index + 1))
      | otherwise = r : (combine_l_helper left rs interval (index + 1))

iter_fft :: [Complex Double] -> [Complex Double]
iter_fft l =
  foldl fold_butterfly rev_l [1..num_bits]
  where
    len = length l
    num_bits = ceiling $ logBase 2 (fromIntegral len)
    rev_l = bit_reverse_l l  -- O(nlogn)

    -- O(n)
    butterfly m n l =
      combine_l fft_j_l fft_k_l n
      where
        half_n = n `div` 2
        half_len = (length l) `div` 2
        (j_l, k_l) = split_l l n
        w_l = map (\i -> (exp (-2 * pi * (0:+1) * (fromIntegral ((i `mod` half_n) * m)) / (fromIntegral len)))) [0..(half_len - 1)]
        j_k_w_l = zip3 j_l k_l w_l
        fft_j_l = map (\(j, k, w) -> j + w * k) j_k_w_l
        fft_k_l = map (\(j, k, w) -> j - w * k) j_k_w_l

    fold_butterfly l iter = butterfly (2 ^ (num_bits - iter)) (2 ^ iter) l

inverse_iter_fft :: [Complex Double] -> [Complex Double]
inverse_iter_fft l =
  map (\x -> (conjugate x) / (fromIntegral len)) fft_con_l
  where
    len = length l
    con_l = map conjugate l
    fft_con_l = iter_fft con_l

mult_polys :: [Double] -> [Double] -> [Double]
mult_polys x y =
  take (length_x + length_y - 1) $ map (\a -> realPart a) (inverse_iter_fft fft_r)
  where
    length_x = length x
    length_y = length y
    n = 2 ^ (ceiling $ logBase 2 (fromIntegral (2 * (max length_x length_y))))
    fft_x = iter_fft (convert x (n - length_x))
    fft_y = iter_fft (convert y (n - length_y))
    fft_r = map (\(a, b) -> a * b) (zip fft_x fft_y)
