-- Polynomial Multiplication using Recursive Fast Fourier Transform in Serialization

module FFT.FMultPoly
(
  mult_polys, 
  fft,
  ifft,
  convert,
  split,
  combine
) where

import Data.Complex

split :: [a] -> ([a], [a])
split l = split_helper l True
  where
    split_helper [] _ = ([], [])
    split_helper (x:xs) is_odd
      | is_odd = (x:o, e)
      | otherwise = (o, x:e)
      where (o, e) = split_helper xs (not is_odd)

combine :: [Complex Double] -> [Complex Double] -> Complex Double -> [Complex Double]
combine p1 p2 w = pf ++ ps
  where
    zip_p = zip3 [0..((length p1) - 1)] p1 p2
    pf = map (\(i, a, b) -> a + (w ** (fromIntegral i)) * b) zip_p
    ps = map (\(i, a, b) -> a - (w ** (fromIntegral i)) * b) zip_p

-- n is length of l and it must be power of 2;
-- w is nth root of unity: exp (2*pi*(0 :+ (-1))/n)
-- fft [1, 2, 3, 4, 5, 6, 7, 8] 8 (exp (-2*pi*(0:+1)/8))
fft :: [Complex Double] -> Int -> Complex Double -> [Complex Double]
fft l 1 _ = l
fft l n w = combine p1 p2 w
  where
    (l1, l2) = split l
    p1 = fft l1 (n `div` 2) (w ** 2)
    p2 = fft l2 (n `div` 2) (w ** 2)

ifft :: [Complex Double] -> Int -> Complex Double -> [Complex Double]
ifft l n w = map (\x -> x / (fromIntegral n)) (fft l n (1 / w))

convert :: [Double] -> Int -> [Complex Double]
convert x l = map (\f -> (f :+ 0)) (x ++ (take l (repeat 0)))

mult_polys :: [Double] -> [Double] -> [Double]
mult_polys x y =
  take (length_x + length_y - 1) $ map (\a -> realPart a) (ifft fft_r n w)
  where
    length_x = length x
    length_y = length y
    -- n > 2*l, n is power of 2
    n = 2 ^ (ceiling $ logBase 2 (fromIntegral (2 * (max length_x length_y))))
    w = exp (- 2 * pi * (0 :+ 1) / (fromIntegral n))
    fft_x = fft (convert x (n - length_x)) n w
    fft_y = fft (convert y (n - length_y)) n w
    fft_r = map (\(a, b) -> a * b) (zip fft_x fft_y)
