{-# LANGUAGE DeriveDataTypeable, RecordWildCards #-}
 
module Main where

import qualified Data.ByteString.Char8 as C
import System.Console.CmdArgs
import System.Exit
import Control.Monad (when)
import qualified Read as R
import qualified Train as T
import qualified LossFun as L
import Data.Time

data SgdOpts = SgdOpts
    { trainFile :: String 
 --   , testFile :: String 
    , lambda :: Double
    , epochs :: Int
    } deriving (Data, Typeable, Show, Eq)

sgdOpts :: SgdOpts
sgdOpts = SgdOpts
    { trainFile = def &= argPos 0 &= typ "TRAINING-FILE"
    --, testFile = def &= typ "TESTING-FILE" &= help "Testing data" &= name "T" &= name "test"
    , lambda = 1e-5   &= help "Regularization parameter, default 1e-5"
    , epochs = 200      &= help "Number of training epochs, default 200"
    } 

getOpts :: IO SgdOpts
getOpts = cmdArgs $ sgdOpts
    &= versionArg [explicit, name "version", name "v", summary _PROGRAM_INFO]
    &= summary (_PROGRAM_INFO ++ ", " ++ _COPYRIGHT)
    &= help _PROGRAM_ABOUT
    &= helpArg [explicit, name "help", name "h"]
    &= program _PROGRAM_NAME

_PROGRAM_NAME :: String
_PROGRAM_NAME = "svmsgd"
_PROGRAM_VERSION :: String
_PROGRAM_VERSION = "0.0.1"
_PROGRAM_INFO :: String
_PROGRAM_INFO = _PROGRAM_NAME ++ " version " ++ _PROGRAM_VERSION
_PROGRAM_ABOUT :: String
_PROGRAM_ABOUT = "svm + sgd + haskell"
_COPYRIGHT :: String
_COPYRIGHT = "(C) Dazhuo Li 2013"
 
main :: IO ()
main = do
    opts <- getOpts
    optionHandler opts

optionHandler :: SgdOpts -> IO ()
optionHandler SgdOpts{..}  = do
    when (null trainFile) $ putStrLn "--trainFile is blank!" >> exitWith (ExitFailure 1)
    when (lambda <= 0 || lambda > 10000) $ putStrLn "--lambda must be in (0, 1e4]" >> exitWith (ExitFailure 1)
    when (epochs <= 0 || epochs > 1000000) $ putStrLn "--epochs must be in (0, 1e6]" >> exitWith (ExitFailure 1)
    exec SgdOpts{..}

exec :: SgdOpts -> IO ()
exec SgdOpts{..} = do 
    contents <- C.readFile trainFile
    let dat = case R.read contents :: Maybe [R.Sample Int] of
                Nothing -> error "Wrong input format"
                Just x  -> x
                            
    start <- getCurrentTime

    let model = T.train L.logLoss dat lambda epochs
    --let predLoss = T.testMany (L.loss L.logLoss ) dat lambda model
    print model
    --print predLoss

    end <- getCurrentTime
    let diff = diffUTCTime end start
    putStrLn $ "Execution Time: " ++ (show diff)
    



