import System.Exit (die)
import System.Environment (getArgs, getProgName)
import Data.Char (isAlpha, isSpace, toLower)
import System.IO
import qualified Data.Map.Strict as Map
import Control.Parallel.Strategies (using, parList, parMap, rdeepseq)
import GHC.Conc (numCapabilities)
import Control.DeepSeq (deepseq)

-- removes non-alphabetic characters and converts to lower case
cleanAndSplit :: String -> [String]
cleanAndSplit s = words $ map toLower $ filter (\x -> isAlpha x || isSpace x) s

-- Map and reduce stage of MapReduce
doMapReduce :: String -> Map.Map String Int
doMapReduce = Map.fromListWith (+) . flip zip (repeat 1) . cleanAndSplit

-- Split a list into non-overlapping lists of given size
chunk :: Int -> [a] -> [[a]]
chunk _ [] = []
chunk n xs = let (ys, zs) = splitAt n xs in ys : chunk n zs

-- Find most similar word to search word
getClosestWord :: Map.Map String Int -> String -> (String, Int)
getClosestWord wordFreq target = (closestWord, resultInt $ Map.lookup closestWord wordFreq)
    where 
    closestWord = getClosestWord' (Map.keys wordFreq) target

-- getClosestWord helper function
getClosestWord' :: [String] -> String -> String
getClosestWord' listWords target' = snd $ minimum $ parMap rdeepseq (findMinDist startMin startWord target') chunks
  where
    chunks = chunk (length listWords `div` numCapabilities) listWords
    startMin = -1
    startWord = ""

-- Find minimum levenshtein distance and corresponding word
findMinDist :: Int -> String -> String -> [String] -> (Int, String)
findMinDist minDist minWord _ [] = (minDist, minWord)
findMinDist minDist minWord searchWord (x:xs) 
  | minDist == -1 = findMinDist dist word searchWord xs
  | otherwise     = findMinDist (min minDist dist) findMinWord searchWord xs
    where
      distPair = calcLevenshteinDist searchWord x
      dist = fst distPair
      word = snd distPair
      findMinWord | minDist <= dist = minWord
                  | otherwise       = word

-- Calculate the Levenshtein distance between two strings
calcLevenshteinDist :: String -> String -> (Int, String)
calcLevenshteinDist w1 w2 = (last $ Prelude.foldl transform [0..length w1] w2, w2)
  where
    transform [] _ = []
    transform xs@(x:xs') c = scanl compute (x + 1) (zip3 w1 xs xs')
      where
        compute z (c', x', y) = minimum [y + 1, z + 1, x' + fromEnum (c /= c')]

-- Helper function to convert a Maybe into an Int
resultInt :: Maybe Int -> Int
resultInt r = case r of
    Just x  -> x
    Nothing -> -1

-- Repeatedly get search word from user
getUserInput :: Map.Map String Int -> IO ()
getUserInput wordFreq = do
  putStr "Enter a search word (or 'exit' to quit): "
  hFlush stdout
  target <- getLine

  if target == "exit"
    then putStrLn "Exiting..."
    else do
      putStrLn ("  You entered: \"" ++ target ++ "\"")
      let value = resultInt $ Map.lookup target wordFreq
      if value /= -1
        then putStrLn $ "  count: " ++ show value
        else do
          let (closest, count) = getClosestWord wordFreq target
          putStrLn $ "  search word not found:"
          putStrLn $ "  closest word: \"" ++ closest ++ "\"\n  count: " ++ (show count)
      getUserInput wordFreq


main :: IO ()
main = do
  args <- getArgs
  case args of
    [filename] -> do
      content <- readFile filename
      putStrLn "Starting MapReduce word counting..."
      let chunks = chunk (length content `div` numCapabilities) content
          mappedReduced = map doMapReduce chunks `using` parList rdeepseq
          wordFreq = Map.unionsWith (+) mappedReduced
      wordFreq `deepseq` return ()    -- force computation
      putStrLn "MapReduce completed..."
      getUserInput wordFreq
    _ -> do 
      pn <- getProgName
      die $ "Usage: " ++ pn ++ " <filename>"
