module Trie (Trie(..), insertIntoTrie, populateTrie, searchTrie, fanTrie, fanTriePar) where

import qualified Data.List as List
import qualified Data.Map as Map
import Control.Monad
import Data.Maybe
import Control.Parallel.Strategies (parMap, runEval, using, rdeepseq, parList, rpar)


data Trie = TrieNode (Map.Map Char Trie) Int Bool | EmptyNode deriving (Show, Read, Eq)

insertIntoTrie :: String -> Int -> Trie -> Trie
insertIntoTrie _ _ EmptyNode = EmptyNode
insertIntoTrie [] _ t = t
insertIntoTrie [c] count (TrieNode children freq end) =
    case Map.lookup c children of
      Nothing -> TrieNode (Map.insert c (TrieNode Map.empty count True) children) freq end
      Just EmptyNode -> TrieNode (Map.insert c (TrieNode Map.empty count True) children) freq end
      Just (TrieNode m _ _) -> TrieNode (Map.insert c (TrieNode m count True) children) freq end

insertIntoTrie (c:cs) count (TrieNode children freq end) =
    case Map.lookup c children of
      Nothing -> TrieNode (Map.insert c (insertIntoTrie cs count (TrieNode Map.empty 0 False)) children) freq end
      Just v -> TrieNode (Map.insert c (insertIntoTrie cs count v) children) freq end


populateTrie :: Foldable t => p -> t (String, Int) -> Trie
populateTrie word_counts = foldl (\ x (w, f) -> insertIntoTrie w f x) (TrieNode Map.empty 0 False)


searchTrie :: String -> Trie -> String -> Trie
searchTrie _ EmptyNode _ = error "Bad Trie"
searchTrie [] trie ans = trie
searchTrie (c:cs) (TrieNode trie freq end) ans =
  case Map.lookup c trie of
    Nothing -> TrieNode trie freq end
    Just node -> searchTrie cs node (ans ++ [c])


fanTrie :: Trie -> String -> [(String, Int)]
fanTrie EmptyNode _ = []
fanTrie (TrieNode m freq end) word_so_far
  | end && not (null child_list) = (word_so_far, freq) : foldl1 (++) (map (\(c, node) -> fanTrie node (word_so_far ++ [c])) child_list)
  | end = [(word_so_far, freq)]
  | not (null child_list) = foldl1 (++) (map (\(c, node) -> fanTrie node (word_so_far ++ [c])) child_list)
  | otherwise = []
  where
    child_list = Map.toList m


fanTriePar :: Trie -> [Char] -> Int -> [(String, Int)]
fanTriePar EmptyNode _ _ = []
fanTriePar (TrieNode m freq end) word_so_far depth
  | end && not (null child_list) = do
    if depth > 0
      then 
        (word_so_far, freq) : foldl1 (++) (map (\(c, node) -> fanTriePar node (word_so_far ++ [c]) (depth - 1)) child_list `using` parList rdeepseq)
      else 
        (word_so_far, freq) : foldl1 (++) (map (\(c, node) -> fanTriePar node (word_so_far ++ [c]) (depth - 1)) child_list)
  | end = [(word_so_far, freq)]
  | not (null child_list) = do
    if depth > 0
      then
        foldl1 (++) (map (\(c, node) -> fanTriePar node (word_so_far ++ [c]) (depth - 1)) child_list `using` parList rdeepseq)
      else
        foldl1 (++) (map (\(c, node) -> fanTriePar node (word_so_far ++ [c]) (depth - 1)) child_list)
  | otherwise = []
  where
    child_list = Map.toList m