module Main where

import Control.Parallel.Strategies
import Data.Array
import Data.List
import System.Directory(doesFileExist)
import System.Environment(getArgs, getProgName)
import System.Exit(die)
import Numeric.Transform.Fourier.FFT(rfft, irfft)
import Utils

-- Integer to Float Array
type IFArray = Array Int Float

{-
Efficient algorithm for computing sliding dot product using FFT
-}

slidingDotProduct :: [Float] -> [Float] -> [Float]

slidingDotProduct q t = elems (ixmap (m-2, n-2) succ qt) where 
        qt = irfft (arrZipWith (*) q_raf t_af)
        q_raf = rfft (listArray (0, 2*n-1) q_ra)
        t_af = rfft (listArray (0,2*n-1) t_a)
        t_a = t ++ (take n (repeat 0))
        q_ra = (reverse q) ++ (take (2 * n - m) (repeat 0))
        n = length t 
        m = length q
{-
Welford's online algorithm adapted to compute the rolling standard deviations and means
-}

welford :: Int -> Int -> IFArray -> ([Float], [Float])

welford k w a
    | k < w = let fw = fromIntegral w in
            let t = take w (elems a) in
            let mean = (sum t)/fw in 
            ([mean], [(sum (map (\x -> x*x) t))/fw - (mean ** 2)])
    | otherwise = ((newMean : prevMeans), (newVar : prevVars)) where 
            (prevMean : _) = prevMeans
            (prevVar : _) = prevVars
            newMean = prevMean + (a!k - a!(k-w))/(fromIntegral w)
            newVar = prevVar + (a!k - a!(k-w)) * (a!k - newMean + a!(k-w) - prevMean)/(fromIntegral w)
            (prevMeans, prevVars) = welford (k-1) w a
{-
Computes rolling standard deviations and means
-}

computeMeanStd :: IFArray -> Int -> (IFArray, IFArray)

computeMeanStd a m = (toZeroIndexedArray (reverse rMean), toZeroIndexedArray(map sqrt (reverse rVar))) where
    (rMean, rVar) = welford end m a
    (_, end) = bounds a
{-
Preprocess all necessary auxillary time series needed for STOMP
-}

preprocessT :: IFArray -> Int -> (IFArray, IFArray, IFArray)

preprocessT t m = (means, stdsInv, meansDec) where
    stdsInv = fmap (\x -> 1/x) stds
    (means, stds) = computeMeanStd t m
    (meansDec, _) = computeMeanStd t (m-1)


{-
Compute distance from variances, means and dot product for MASS
-}

distanceComp :: Float -> Float -> Float -> (Float, Float, Float) -> Float

distanceComp meanQ varQ m (meanT, varT, qt) = sqrt (abs (2 * m * (1-(qt - m * meanQ * meanT)/(m * varT * varQ) ))) 

{-
Mueen's Algorithm for Similarity Search
Uses the FFT sliding dot product algorithm which is then used to compute the sliding distance
-}

mass :: [Float] -> [Float] -> [Float] -> [Float] -> [Float]

mass t meanT varT q = map (distanceComp meanQ varQ m') (zip3 meanT varT qt) `using` parBuffer 100 rdeepseq where
    meanQ = meanQ' ! 0
    varQ = varQ' ! 0
    (meanQ', varQ') = computeMeanStd (toZeroIndexedArray q) m
    qt = slidingDotProduct q t
    m' = fromIntegral m
    m = length q

{-
Computes every entry of the correlation matrix on a given diagonal
-}

computePearsons :: IFArray -> IFArray -> IFArray -> IFArray -> IFArray -> IFArray -> IFArray -> IFArray -> Int -> Int -> [Int] -> (Float, [Float])

computePearsons tA meansA stdsInvA meansDecA tB meansB stdsInvB meansDecB m k [] = 
    (cov, [cov * stdsInvB!(i + k) * stdsInvA!i]) where
        cov = sum (zipWith (*) aVec bVec ) * 1/(fromIntegral m)
        aVec = map (subtract (meansB!(i + k))) (subArrayList (i+k) (i+k+m-1) tB)
        bVec = map (subtract (meansA!i)) (subArrayList i (i+m-1) tA) 
        i = if k >= 0 then 0 else -k

computePearsons tA meansA stdsInvA meansDecA tB meansB stdsInvB meansDecB m k (i : indices) =
    (newCov, ((newCov * stdsInvB!(i + k) * stdsInvA!(i)) : pearsons) ) where
        newCov = cov + (c * ((termA * termB) -  (termC*termD))) 
        termA = (tB!(j+ m) - meansDecB!(i + k))
        termB =  ((tA!((i + m) - 1)) - (meansDecA!i))
        termC = tB!j - (meansDecB!(i + k))
        termD =  (tA!(i - 1) - meansDecA!i)
        j = (i+k)-1
        (cov, pearsons) = computePearsons tA meansA stdsInvA meansDecA tB meansB stdsInvB meansDecB m k indices
        c = (fromIntegral (m-1))/(fromIntegral (m*m))

{-
Computes every entry of the correlation matrix on a given diagonal with appropriate padding
This ensures that all the diagonals are aligned vertically
-}

computeFullPearsons :: IFArray -> IFArray -> IFArray -> IFArray -> IFArray -> IFArray -> IFArray -> IFArray -> Int -> Int -> [Float]
computeFullPearsons tA meansA stdsInvA meansDecA tB meansB stdsInvB meansDecB m k = 
    take (length tA - m + 1) ((take (-k) (repeat (-1))) ++ pearsons ++ (repeat (-1))) where
        pearsons = reverse rPearsons
        (cov, rPearsons) = computePearsons tA meansA stdsInvA meansDecA tB meansB stdsInvB meansDecB m k [upperI, (upperI-1)..lowerI] `using` rdeepseq
        upperI = (min ((length tA) - m) ((length tB) - m - k))
        lowerI = (max 0 (-k)) + 1

{-
Folds all the diagonals together to get the maximum shadow of the pearson correlation matrix
-}

minDiag :: IFArray -> IFArray -> IFArray -> IFArray -> IFArray -> IFArray -> IFArray -> IFArray -> Int -> [Int] -> [Float]

minDiag tA meansA stdsInvA meansDecA tB meansB stdsInvB meansDecB m diags = maxPearsons reducedPearsonMatrix where
    reducedPearsonMatrix = map maxPearsons (chunk 100 pearsonMatrix) `using` parBuffer 100 rdeepseq
    pearsonMatrix = map (computeFullPearsons tA meansA stdsInvA meansDecA tB meansB stdsInvB meansDecB m) diags
    maxPearsons = foldl (zipWith max) floorPearsons
    floorPearsons = take (length tA - m + 1) (repeat (-1))

{-
Scalable Time Series Anytime Matrix Profile
A relatively inefficient algorithm for computing the Matrix Profile
-}

stamp ::  IFArray -> IFArray -> Int -> [Float]

stamp tA tB m = res  where
    res =  map (minimum . (mass (elems tB) (elems meansB) (elems stdsB))) rollingA
    rollingA = rollingWindow m (elems tA)
    (meansB, stdsB) = computeMeanStd tB m

{-
Scalable Time Series Ordered-search Matrix Profile (further optimized)
A much more efficient algorithm for coputign the matrix profile
-}

stompOpt :: IFArray -> IFArray -> Int -> [Float]

stompOpt tA tB m = map (\x -> sqrt (abs (2 * (fromIntegral m) * (1-x)))) pearsons where
    pearsons = minDiag tA meansA stdsInvA meansDecA tB meansB stdsInvB meansDecB m [(-((length tA) - m + 1) + 1)..((length tB) - m )]
    (meansA, stdsInvA, meansDecA) = preprocessT tA m `using` parTuple3 rdeepseq rdeepseq rdeepseq
    (meansB, stdsInvB, meansDecB) = preprocessT tB m `using` parTuple3 rdeepseq rdeepseq rdeepseq

{-
Function for computing the MPdist which takes a matrix profile algorithm as a parameter
-}

mpdist :: (IFArray -> IFArray -> Int -> [Float]) -> Int -> Float -> [Float] -> [Float] -> Float

mpdist mpFunc m p tAList tBList = getKPartition k pABBA where
    k =  min (ceiling (p * (fromIntegral (nA + nB)))) (nA - m + 1 + nB - m)
    pABBA = pAB ++ pBA
    pAB = mpFunc tA tB m
    pBA = mpFunc tB tA m
    tA = toZeroIndexedArray tAList
    tB = toZeroIndexedArray tBList
    nA = length tAList
    nB = length tBList

{-
Function for dealing with some command line input options unrelated to file IO
-}

mpdistCmd :: String -> Int -> Float -> [Float] -> [Float] -> String

mpdistCmd algorithm m p tAList tBList = show $ mpdist mpFunc m p tAList tBList where
    mpFunc = case algorithm of
                  "stamp" -> stamp
                  "stompopt" -> stompOpt
                  otherwise -> error "specified algorithm not found: should be either stamp or stompopt" 

-- The main function

main :: IO ()
main = do 
        args <- getArgs
        (algo, filename1, filename2, m, p) <- case args of
         [algo, f1, f2, m, p] -> return (algo, f1, f2, m, p)
         _ -> do 
            pn <- getProgName
            die $ "Usage: " ++ pn ++ " <stompopt/stamp> <time-series-1-filename> <time-series-2-filename> <window-size> <threshold>"
        exist <- doesFileExist filename1
        case exist of
            False -> do pn <- getProgName
                        die $ pn ++ ": " ++ filename1 ++": openFile: does not exist (No such file or directory)"
            True -> putStrLn ("Found " ++ filename1)
        exist <- doesFileExist filename2
        case exist of
            False -> do pn <- getProgName
                        die $ pn ++ ": " ++ filename2 ++": openFile: does not exist (No such file or directory)"
            True -> putStrLn ("Found " ++ filename2)
        do tA <- getFloats filename1
           do tB <- getFloats filename2 
              putStrLn (mpdistCmd algo (read m ::Int) (readFloat p) tA tB) 
