module LossFun where

data Loss = Loss {
    loss :: Double -> Double -> Double
    --  -dloss(a,y)/da
  , dloss :: Double -> Double -> Double
  }
  
-- logloss(a,y) = log(1+exp(-a*y))
logLoss :: Loss 
logLoss = Loss { 
    loss = loss'
  , dloss = dloss'  
  }
  where 
    loss' a y 
      | z > 18      = exp (-z)
      | z < -18     = -z
      | otherwise   = log $ 1 + exp (-z)
      where z = a*y
 
    dloss' a y
      | z > 18      = y * exp (-z)
      | z < -18     = y
      | otherwise   = y / (1 + exp z)
      where z = a*y
