module Matrix
    (
        ppmi
    ) where

import qualified Data.Map as M
import qualified Data.Vector.Unboxed as V
import qualified Control.Parallel.Strategies as ST

rowSum :: M.Map (Int, Int) Int -> Int -> V.Vector Int
rowSum m dim = V.accum (+) acc items
    where items = map (\(k, v) -> (fst k, v)) $ M.toList m
          acc   = V.generate dim (\i -> 0)

colSum :: M.Map (Int, Int) Int -> Int -> V.Vector Int
colSum m dim = V.accum (+) acc items
    where items = map (\(k, v) -> (snd k, v)) $ M.toList m
          acc   = V.generate dim (\i -> 0)

-- Note: Int potential overflow? yes.
-- However, wc -w brown.txt -> 1161192 << 2^64 so we are in the safe zone for now
matSum :: M.Map (Int, Int) Int -> Int
matSum = sum . M.elems

ppmi :: Int -> M.Map (Int, Int) Int -> Int -> M.Map (Int, Int) Double
ppmi depth m dim = _ppmiMatPar depth pwc pw pc
    where
        n = matSum m
        (pw, pc, pwc) = ST.runEval $ do
            rs' <- ST.rpar $ V.map (`floatDiv` n) $ rowSum m dim
            cs' <- ST.rpar $ V.map (`floatDiv` n) $ colSum m dim
            pwc <- ST.rpar $ M.map (`floatDiv` n) m
            return (cs', rs', pwc)

_ppmiMatPar :: Int -> M.Map (Int, Int) Double -> V.Vector Double -> V.Vector Double -> M.Map (Int, Int) Double
_ppmiMatPar 0 pwc pw pc = _ppmiMat pwc pw pc
_ppmiMatPar d pwc pw pc = M.unionWith (+) m1 m2
    where
        maxKey = fst . fst $ M.findMax pwc
        minKey = fst . fst $ M.findMin pwc
        mid = (minKey + maxKey) `div` 2
        (x,y) = M.split (mid, -1) pwc
        m1 = ST.runEval $ ST.rpar (_ppmiMatPar (d-1) x pw pc)
        m2 = ST.runEval $ ST.rpar (_ppmiMatPar (d-1) y pw pc)

_ppmiMat :: M.Map (Int, Int) Double -> V.Vector Double -> V.Vector Double -> M.Map (Int, Int) Double
_ppmiMat pwc pw pc = M.mapWithKey eval pwc
    where eval = _ppmiCell pw pc

_ppmiCell :: V.Vector Double -> V.Vector Double -> (Int, Int) -> Double -> Double
_ppmiCell pw pc (i, j) pwc | pwpc == 0 = 0 -- ppmi is +infinite
                           | pwc  == 0 = 0 -- ppmi is -infinite
                           | otherwise = max 0 (logBase 2 pmi)
    where pwi  = pw V.! i
          pcj  = pc V.! j
          pwpc = pwi * pcj
          pmi  = pwc / pwpc

floatDiv :: Int -> Int -> Double
floatDiv a b = (fromIntegral a) / (fromIntegral b)
