module Train (train) where

import qualified Read as R
import Vector
import LossFun 
import Control.Parallel.Strategies
import Data.List (foldl')

type Sample = R.Sample Int
type Model = (FullVector, Double, Double)       -- model parameter (w, wDivsor, wBias)
--type PredLoss = (Double, Double, Double)        -- prediction error on the Model (los, cost, num.error)
type Chunk a = [a]

-- Function to chunk the data
chunkData :: Int -> [a] -> [Chunk a]
chunkData _ [] = []
chunkData n xs = let (chunk, rest) = splitAt n xs in chunk : chunkData n rest

-- Constants
chunkSize :: Int
chunkSize = 200 -- chosen through multiple tests

initParam :: Int -> Model
initParam dimVar = (rep (dimVar + 1)  0, 1, 0)

train :: Loss -> [Sample] -> Double -> Int -> Model
train l x lambda epochs = foldl' go wParam0 [1..epochs]
  where 
    go wParam _ = trainEpoch wParam
    wParam0 = initParam $ R.dimSample x 
    eta = calculateEta lambda epochs
    
    --trainEpoch wParam = trainBatch wParam (x, eta) -- seq
    
    trainEpoch wParam = 
        let chunks = chunkData chunkSize x
            updatedModels = parMap rdeepseq (trainBatch wParam) (zip chunks (repeat eta))
        in combineModels updatedModels wParam0 
        -- parallel

    trainBatch wParam (samples, etaBatch) = 
        foldl' (\wParam' sample -> trainOne (dloss l) sample lambda wParam' etaBatch) wParam samples

combineModels :: [Model] -> Model -> Model
combineModels models wParam0 = 
    let numModels = fromIntegral $ length models
        summedModels = foldl' addModels wParam0 models
    in scaleModel summedModels numModels

addModels :: Model -> Model -> Model
addModels (w1, div1, bias1) (w2, div2, bias2) = 
    (addFF w1 w2, div1 + div2, bias1 + bias2)

scaleModel :: Model -> Double -> Model
scaleModel (w, divisor, bias) scaleFactor = 
    (scale w (1 / scaleFactor), divisor / scaleFactor, bias / scaleFactor)

calculateEta :: Double -> Int -> Double
calculateEta lambda epochs = 
    let eta0 = 1 -- Modify this to determine the initial eta
    in eta0 / (1 + lambda * eta0 * fromIntegral epochs)

trainOne :: (Double -> Double -> Double) 
          -> Sample 
          -> Double 
          -> Model 
          -> Double 
          -> Model
trainOne dloss (x, y) lambda (w, wDiv, wBias) eta = (w'', wDiv'', wBias')
  where 
    s = (dotFS w x) / wDiv + wBias
    d = dloss s y                                                         
    wDiv' = wDiv / (1 - eta * lambda)
    (w', wDiv'') = renorm w wDiv'
    w'' = addFS w' . mul x $ eta * d * wDiv''
    wBias' = wBias + d * eta * 0.01

renorm :: FullVector -> Double -> (FullVector, Double) 
renorm w wDiv 
  | wDiv == 1.0 || wDiv <= 1e5  = (w, wDiv)
  | otherwise                   = (scale w $ 1 / wDiv, 1.0) 

{-wnorm (w, wDiv, wBias) = (Vector.dot w w ) / wDiv / wDiv 

trainMany :: (Double -> Double -> Double)     -- dloss function
          -> [Sample]                         -- list of samples
          -> Double                           -- regularizer 
          -> Model                            -- current model parameter 
          -> [Double]                         -- sgd gain eta for each iteration (or sample)
          -> Model      
trainMany dloss x lambda wParam0 eta= foldl go wParam0 $ zip x eta 
  where
    go wParam (xt, etat) = trainOne dloss xt lambda wParam etat 

testMany :: (Double -> Double -> Double)      -- loss function
          -> [Sample]                         -- list of samples
          -> Double                           -- regularizer 
          -> Model                            -- model parameter 
          -> PredLoss
testMany loss x lambda wParam = (los, cost, nerr)
  where
    los = ploss / fromIntegral (length x)
    nerr = (fromIntegral pnerr) / fromIntegral (length x)
    cost = los + 0.5 * lambda * (wnorm wParam)
    (ploss, pnerr) = (\(t1, t2, _) -> (sum t1, sum t2)) . unzip3 . map go $ x
      where 
        go x = testOne loss x lambda wParam
   
testOne :: (Double -> Double -> Double)    -- loss function
          -> (SparseVector, Double)        -- single input and response (x, y)
          -> Double                        -- regularizer 
          -> Model                         -- model parameter 
          -> (Double, Int, Double)
testOne loss (x, y) lambda (w, wDiv, wBias) = (ploss, pnerr, s)
  where 
    s = (dotFS w x) / wDiv + wBias
    ploss = loss s y
    pnerr = if s * y <= 0 then 1 else 0
-}
