module Main where

import Data.List(nub, elemIndices)
import qualified Data.Map as M
import Text.CSV
import Control.Parallel.Strategies
import System.Exit(exitFailure)
import System.Environment(getArgs)


-- Better type synonyms for better understanding
-- of what data is being passed around.
type Label = String
type Feature = String
type Entropy = Double
type DataSet = [([String], Label)]

-- Define a data structure for a decision tree that'll be constructed
data DTree = DTree { feature :: String, children :: [DTree] } 
    | Node String String deriving Show

main :: IO ()
main = do
    args <- getArgs
    [filename, strat] <- case args of
        [f, s] -> return [f, s]
        _ -> do
                error "Usage: stack run <filename> <strat> -- +RTS -N<numHEC> -ls -s"
                exitFailure
    strategy <- case strat of
        x | x `elem` ["1", "2", "3", "4", "5"] -> do return strat
        _ -> do
                error "Usage: Valid strat options are: <1> - Sequential, <2> - Single-Choice (Entropy), <3> - Single-Choice (InformationGain), <4> - Both (w/o Misc.), <5> - Both (with Misc)"
                exitFailure
    rawCSV <- parseCSVFromFile ("src/" ++ filename)
    either handleError doWork rawCSV strategy

handleError = error "invalid file"


-- IF file is read successfully
-- THEN remove any invalid CSV records and construct a decision tree out of it
doWork :: CSV -> String -> IO ()
doWork fcsv strat = do
    let removeInvalids = filter (\x -> length x > 1)
    let myData = map (\x -> (init x, last x)) $ removeInvalids fcsv 
    let result = dtree strat "root" myData
    let firstNodeChildren = length b
          where DTree a b = result

    -- TODO: Uncomment to print out entire tree. WARNING: Can take really long if tree is very large!
    print "Final Decision Tree:"
    print result

    -- print "Number of child nodes at root:"
    -- print firstNodeChildren


-- Helper functions to break up the DataSet tuple into 
-- a list of samples or list of labels.
samples :: String -> DataSet -> [[String]]
samples strat sdata | strat == "5" = withStrategy (parBuffer 100 rdeepseq) (map fst sdata)
                    | otherwise = map fst sdata

labels :: String -> DataSet -> [Label]
labels strat sdata | strat == "5" = withStrategy (parBuffer 100 rdeepseq) (map snd sdata)
                    | otherwise = map snd sdata


-- Calculate the entropy of a list of values
entropy :: (Eq a) => String -> [a] -> Entropy
entropy strat xs | strat `elem` ["2","4","5"] = sum $ parMap rseq (\x -> prob x * into x) $ nub xs 
                 | otherwise = sum $ map (\x -> prob x * into x) $ nub xs
        where
            prob x = (length' (elemIndices x xs))/(length' xs)
            into x = negate $ logBase 2 (prob x)
            length' xs = fromIntegral $ length xs


-- Split an attribute by its features
splitAttr :: [(Feature, Label)] -> M.Map Feature [Label]
splitAttr dc = foldl (\m (f,c) -> M.insertWith (++) f [c] m) M.empty dc


-- Obtain each of the entropies from splitting up an attribute by its features.
splitEntropy :: String -> M.Map Feature [Label] -> M.Map Feature Entropy
splitEntropy strat m = M.map (entropy strat) m -- TODO: Parallelize


-- Compute the information gained from splitting up an attribute by its features
informationGain :: String -> [Label] -> [(Feature, Label)] -> Double
informationGain strat s a = entropy strat s - newInformation
  where
    eMap = splitEntropy strat $ splitAttr a
    m = splitAttr a
    toDouble x = read x :: Double
    ratio x y = (fromIntegral x) / (fromIntegral y)
    sumE = M.map (\x -> (fromIntegral.length) x / (fromIntegral.length) s) m 
    newInformation = M.foldrWithKey (\k a b -> b + a*(eMap M.! k)) 0 sumE


-- Determine which attribute contributes the highest information gain
highestInformationGain :: String -> DataSet -> Int
highestInformationGain strat d = snd $ maximum $ infoGains
  where
    infoGains 
       | strat `elem` ["3", "4", "5"] = zip (parMap rseq (((informationGain strat) . (labels strat)) d) attrs) [0..]
       | otherwise = zip (map (((informationGain strat) . (labels strat)) d) attrs) [0..]
    attrs = map (attr d) [0..s-1] 
    attr d n = map (\(xs,x) -> (xs!!n,x)) d 
    s = (length . fst . head) d -- Get number of features from first row, sample, length)


-- Split up the dataset by the attributes that contributes the highest
-- information gain
datatrees :: String -> DataSet -> M.Map String DataSet
datatrees strat d = 
  foldl (\m (x,n) -> M.insertWith (++) (x!!i) [((x `dropAt` i), fst (cs!!n))] m)
    M.empty (zip (samples strat d) [0..])
  where
    i = highestInformationGain strat d
    dropAt xs i = let (a,b) = splitAt i xs in a ++ drop 1 b
    cs = zip (labels strat d) [0..]


-- A helper function to determine if all  elements of a list are equal.
-- Used to check if further splitting of a dataset is necessary by checking
-- if its labels are identical. 
allEqual :: Eq a => [a] -> Bool
allEqual [] = True
allEqual [_] = True
allEqual (x:xs) = x == (head xs) && allEqual xs


-- Construct the decision tree from a labeling and a dataset of samples
dtree :: String -> String -> DataSet -> DTree
dtree strat f d 
    | allEqual (labels strat d) = Node f $ head (labels strat d) 
    | otherwise = DTree f $ M.foldrWithKey (\k a b -> b ++ [dtree strat k a] ) [] (datatrees strat d)
