{-# LANGUAGE TupleSections #-}

module HuffmanTree where

import BitHelper (wordToBits)
import Data.ByteString.Lazy (ByteString, unpack)
import Data.Function (on)
import Data.List (insertBy, sort, sortBy)
import Data.Map (Map, fromList, fromListWith, toList)
import Data.Maybe (fromJust, isNothing)
import Data.Tuple (swap)
import Data.Word (Word8)

data HuffTree = HuffLeaf Word8 Int | HuffNode HuffTree HuffTree Int deriving (Show)

type EncDict = Map Word8 [Bool]

weight :: HuffTree -> Int
weight (HuffLeaf _ w) = w
weight (HuffNode _ _ w) = w

toFreqList :: ByteString -> [(Word8, Int)]
toFreqList bs = toList $ fromListWith (+) $ map (,1) $ unpack bs

mergeTrees :: HuffTree -> HuffTree -> HuffTree
mergeTrees f s = HuffNode f s (weight f + weight s)

construct :: [(Word8, Int)] -> HuffTree
construct ts = construct' $ map (uncurry HuffLeaf) (sortBy (compare `on` snd) ts)

construct' :: [HuffTree] -> HuffTree
construct' [] = error "empty huffman tree"
construct' [t] = t
construct' (f : s : xs) = construct' $ insertBy (compare `on` weight) (mergeTrees f s) xs

-- Given a huffman tree, build bl_count in RFC 1951
buildBitLen :: HuffTree -> [(Int, [Word8])]
buildBitLen tree = toList $ fromListWith (++) $ buildBitLen' tree 0

buildBitLen' :: HuffTree -> Int -> [(Int, [Word8])]
buildBitLen' (HuffNode a b _) blen = buildBitLen' a (blen + 1) ++ buildBitLen' b (blen + 1)
buildBitLen' (HuffLeaf c _) blen
  | blen <= 15 = [(blen, [c])]
  | otherwise = error "unsupported yet"

buildEncTree :: [(Int, [Word8])] -> EncDict
buildEncTree map = fromList $ buildEncTree' 1 0 map

{-
blen -> number of bits
start -> start point for these bits
prev -> number of elements in the previous bit width
map -> bl_count (bit len -> bits)
-}
buildEncTree' :: Int -> Int -> [(Int, [Word8])] -> [(Word8, [Bool])]
buildEncTree' 16 _ _ = []
buildEncTree' blen start map = curr ++ buildEncTree' (blen + 1) newStart map
  where
    entries = lookup blen map
    curr = maybe [] (\f -> genCodeForBitLen f blen start) entries
    count = maybe 0 length entries
    newStart = (start + count) * 2

genCodeForBitLen :: [Word8] -> Int -> Int -> [(Word8, [Bool])]
genCodeForBitLen words bitlen start = zip swords bits
  where
    swords = sort words
    nums = [start .. (start + length swords)]
    bits = map (wordToBits . (,fromIntegral bitlen)) nums