module PSO
  ( PSOState(..)
  , runPSO
  ) where

import Control.Monad.State
import Control.Parallel.Strategies
import Control.DeepSeq (deepseq)
import Particle as P

data PSOState = PSOState
  { swarm        :: [Particle]          -- ^ Particle swarm
  , swarmSize    :: Int                 -- ^ Size of Particle swarm
  , costFn       :: [Double] -> Double  -- ^ Cost function
  , searchDomain :: [(Double, Double)]  -- ^ Cost function domain
  , currOptima   :: [Double]            -- ^ Current global optima
  , trueOptima   :: [Double]            -- ^ Real optima
  , timestamp    :: Int                 -- ^ Current timestamp
  , weight       :: Double              -- ^ Weight parameter
  , psoLog       :: [String]            -- ^ Trace log
  }

{- `runPSO strat chunksize n` runs the `step` loop using the given strategy
   @strat@ and chunk size @chunksize@ for @n@ iterations.
-}
runPSO :: Int -> Int -> Int -> State PSOState ()
runPSO _ _ 0 = return ()
runPSO strat chunksize n = do
  step strat chunksize
  -- go to next iteration
  runPSO strat chunksize $ n-1

{- `step strat chunksize` runs one `step`. @strat@ is the index of the
   parallelization strategy to use and @chunksize@ is the number of particles
   that will be updated in parallel.
-}
step :: Int -> Int -> State PSOState ()
step strat chunksize = do
  s <- gets swarm
  fn <- gets costFn
  domain <- gets searchDomain
  currOpt <- gets currOptima
  w <- gets weight
  let s'@(p:ps) =
        -- three ways to parallelize
        case strat of
          {- vanilla split w/ rpar + rseq -}
          1 -> runEval $ do
            let subSwarms = splitN chunksize s
            subSwarms' <- mapM (rparUpdate fn domain currOpt w) subSwarms
            mapM_ rseq subSwarms'
            return $ concat subSwarms'
          {- parList -}
          2 -> map (P.update fn domain currOpt w) s
               `using` parList rdeepseq
          {- parListChunk -}
          3 -> map (P.update fn domain currOpt w) s
               `using` parListChunk chunksize rdeepseq
          {- sequential -}
          _ -> map (P.update fn domain currOpt w) s
      -- after we update all the particles, we find the new global optima
      newOpt = P.posBest $ foldr minParticle p ps
  -- update the PSO state
  modify $ \st -> st { swarm = s'
                     , currOptima = newOpt
                     , timestamp = timestamp st + 1
                     , psoLog = show (timestamp st, newOpt) : psoLog st }
  return ()
    where
      rparUpdate fn domain currOpt w subSwarm =
        rpar $ deep $ map (P.update fn domain currOpt w) subSwarm

-- `minParticle p1 p2` compares two Particles and return the one with smaller valMin
minParticle :: Particle -> Particle -> Particle
minParticle p1 p2
  | P.valMin p1 < P.valMin p2 = p1
  | otherwise = p2

deep :: NFData a => a -> a
deep a = deepseq a a

splitN :: Int -> [Particle] -> [[Particle]]
splitN chunksize ps
  | chunksize >= length ps = [ps]
  | otherwise = ps1:splitN chunksize ps2
      where (ps1, ps2) = splitAt chunksize ps
