-- ParHuff.hs
-- Malcolm Mashig (mjm2396)
-- Parallel Huffman Encoding/Decoding
-- COMS 4995: Parallel Functional Programming
-- Columbia University, Fall 2021
-- 12/22/2021
-- References:
--- http://lambduh.blogspot.com/2010/08/huffman-coding-in-haskell.html
--- https://stackoverflow.com/a/22134743/17467177
--- https://codereview.stackexchange.com/a/48564
--- https://en.wikipedia.org/wiki/Huffman_coding
--- http://www.gutenberg.org/files/100/100-0.txt

-------------------

import Data.Map
import Data.List
import Data.Maybe
import Data.Function
import Control.Parallel.Strategies(using, parList, rdeepseq)
import System.IO

-------------------

-- the list of character frequencies
type FreqList = [(Char, Int)]

-- function to count character frequency, storing the results in a map.
charFreq :: String -> FreqList
charFreq s = Data.Map.toList $ Prelude.foldl charFreq' Data.Map.empty s

charFreq' :: Map Char Int -> Char -> Map Char Int
charFreq' m c = Data.Map.insertWith (+) c 1 m

-------------------

data HuffTree = Branch { wt :: Int, l :: HuffTree, r :: HuffTree }
       | Leaf { symbol :: Char, wt :: Int }
             deriving (Eq)

instance Ord HuffTree where
        compare = compare `on` wt

mappend :: HuffTree -> HuffTree -> HuffTree
mappend x y = Branch (wt x + wt y) x y

-------------------

buildDecTree :: FreqList -> HuffTree
buildDecTree freqlist = (build . sort . Prelude.map (uncurry Leaf)) freqlist
            where
            build :: [ HuffTree ] -> HuffTree
            build [] = error "Empty Leaf list"
            build (t:[]) = t
            build (x:y:ts) = build $ Data.List.insert (Main.mappend x y) ts

-------------------

buildEncDict :: HuffTree -> [(Char, String)]
buildEncDict = buildEncDict' ""

buildEncDict' :: String -> HuffTree -> [(Char, String)]
buildEncDict' s (Leaf sy _) = [(sy,s)]
buildEncDict' s (Branch _ lt rt) = buildEncDict' (s ++ "0") lt
                                  ++ buildEncDict' (s ++ "1") rt
                                  
-------------------

encode :: [(Char, String)] -> String -> String
encode encDict s = encode' encDict s

encode' :: [(Char, String)] -> String -> String
encode' _ [] = []
encode' d (s:ss) = (fromJust $ Prelude.lookup s d) ++ encode' d ss

-------------------

decode :: HuffTree -> String -> String
decode t code = decode' t code
      where
      decode' (Branch _ lt rt) (c:cs) =
        case c of
        '0' -> decode' lt cs
        '1' -> decode' rt cs
        _ -> [] -- "otherwise" produces a warning about unused variable
      decode' (Leaf s _) cs = s : decode' t cs -- we need the whole tree again
      decode' _ [] = []

-------------------

writeParHuff :: [String] -> String -> IO ()
writeParHuff sa fname = Prelude.writeFile fname $ Data.List.unlines sa

-------------------

readParHuff :: FilePath -> IO [String]
readParHuff fname = fmap lines (readFile fname)

-------------------

getInds :: [String] -> [Int]
getInds content = Prelude.map read (Prelude.take ((read (Prelude.head content)) - 1) . Prelude.drop 1 $ content)

-------------------

freqListToChars :: FreqList -> String
freqListToChars [] = []
freqListToChars (h:t) = (fst h) : (freqListToChars t)

freqListToFreqs :: FreqList -> [Int]
freqListToFreqs [] = []
freqListToFreqs (h:t) = (snd h) : (freqListToFreqs t)

getNChars :: [String] -> Int
getNChars content = read (Prelude.head (Prelude.drop skip content))
    where
        skip = read (Prelude.head content)
        
getFreqs :: [String] -> Int -> [Int]
getFreqs content n_chars = Prelude.map read (Prelude.take n_chars (Prelude.drop (skip + 1) content))
    where
        skip = read (Prelude.head content)

getChars :: [String] -> Int -> String
getChars content n_chars = concat (init (Prelude.drop (skip + 1 + n_chars) content))
    where
        skip = read (Prelude.head content)
        
constructFreqList :: String -> [Int] -> FreqList
constructFreqList [] _ = []
constructFreqList _ [] = []
constructFreqList (h1:t1) (h2:t2) = (h1, h2) : (constructFreqList t1 t2)

-------------------

getBits :: [String] -> String
getBits content = last content

-------------------

batchBits :: String -> [Int] -> [String]
batchBits bits inds = go bits (Prelude.zipWith subtract (0:inds) inds) where
    go [] _      = []
    go xs (i:is) = let (a, b) = Prelude.splitAt i xs in a : go b is
    go xs _      = [xs]
    
-------------------

splitEvery :: Int -> [a] -> [[a]]
splitEvery _ [] = []
splitEvery n xs = as : splitEvery n bs
  where (as,bs) = Prelude.splitAt n xs
  
-------------------

batchInput :: String -> Int -> [String]
batchInput input n_batches = splitEvery (batch_size input n_batches) input
    where
        batch_size i n =  (div (length i) n) + 1

-------------------

main :: IO ()
main = do
    -- params ---------
    let parallel = True
    let n_batches = 4
    let filepath = "shakespeare.txt" -- (large-scale), other option is "test.txt" (small-scale)
    -------------------
    test <- readFile filepath
    let freqlist = charFreq test
    let tree = buildDecTree freqlist
    let encDict = buildEncDict tree
    if parallel
    then do
        let inputBatches = batchInput test n_batches
        let encodings = Prelude.map (encode encDict) inputBatches `using` parList rdeepseq
        let encoded = concat encodings
        let lengths = Prelude.map length encodings
        let inds = Prelude.map show (Prelude.take (n_batches - 1) (scanl1 (+) lengths))
        let chars = freqListToChars freqlist
        let freqs = Prelude.map show (freqListToFreqs freqlist)
        let out = [(show n_batches)] ++ inds ++ [show (length chars)] ++ freqs ++ [chars] ++ [encoded]
        writeParHuff out "test-ParHuff.txt"
        dec_instr <- readParHuff "test-ParHuff.txt"
        let decInds = getInds dec_instr
        let decNChars = getNChars dec_instr
        let decFreqs = getFreqs dec_instr decNChars
        let decChars = getChars dec_instr decNChars
        let decTree = buildDecTree (constructFreqList decChars (getFreqs dec_instr (length decChars)))
        let decBits = getBits dec_instr
        let decBatches = batchBits decBits decInds
        let decodings = Prelude.map (decode decTree) decBatches `using` parList rdeepseq
        let decoded = concat decodings
        original_size <- withFile "shakespeare.txt" ReadMode hFileSize
        if (decoded == test)
        then putStrLn "Compression was lossless"
        else error "Compression was NOT lossless"
        putStrLn "COMPRESSION RATIO:"
        putStrLn (show (((fromIntegral (length encoded)) / (fromIntegral ((fromInteger original_size) * 8))) * 100) ++ "%")
    else do
        let inputBatches = batchInput test 1
        let encoded = encode encDict (head inputBatches)
        let encodings = [encoded]
        let chars = freqListToChars freqlist
        let freqs = Prelude.map show (freqListToFreqs freqlist)
        let out = [(show 1)] ++ [show (length chars)] ++ freqs ++ [chars] ++ [encoded]
        writeParHuff out "test-ParHuff.txt"
        dec_instr <- readParHuff "test-ParHuff.txt"
        let decNChars = getNChars dec_instr
        let decFreqs = getFreqs dec_instr decNChars
        let decChars = getChars dec_instr decNChars
        let decTree = buildDecTree (constructFreqList decChars (getFreqs dec_instr (length decChars)))
        let decBits = getBits dec_instr
        let decoded = decode decTree decBits
        original_size <- withFile "shakespeare.txt" ReadMode hFileSize
        if (decoded == test)
        then putStrLn "Compression was lossless"
        else error "Compression was NOT lossless"
        putStrLn "COMPRESSION RATIO:"
        putStrLn (show (((fromIntegral (length encoded)) / (fromIntegral ((fromInteger original_size) * 8))) * 100) ++ "%")
