module MDEngine
  ( forceMatrix,
    velocityVerlet,
    isf,
    mdIsf,
    mdTraj
  ) where

import MDVector
import qualified Data.Vector.Unboxed as U
import Control.Parallel.Strategies (Eval, rpar, runEval)

-- Simulation Parameters

epsilon :: Double
epsilon = 1.0

sigma :: Double
sigma = 2.0 ** (-1.0/6.0) -- Chosen so that r_{min} = 1

mass :: Double
mass = 1.0

dt :: Double
dt = 1e-3

-- A parallelized map for vector data type
parMapVec :: (U.Unbox a, U.Unbox b) => (a -> b) -> U.Vector a -> Eval (U.Vector b)
parMapVec f v
  | U.length v == 0 = return (U.empty)
  | otherwise = do b  <- rpar (f (U.head v))
                   bs <- parMapVec f (U.tail v)
                   return (U.cons b bs)

-- Computes list of forces on all particles given a configuration
forceMatrix :: MDMatrix -> Double -> MDMatrix
forceMatrix rs boxLength =
  runEval (parMapVec totalForce rs)
  where
    -- Gets force acting on particle at r1 due to particle at r2
    forceVector r1 r2
      | r1 == r2 = zeroVector
      | otherwise = vectorMultiply flj (unitVector r12)
        where r12 = displacement r2 r1 boxLength
              d12 = vectorNorm r12
              sor = sigma / d12
              flj = 24.0 * epsilon * (2 * (sor ** 12.0) - (sor ** 6.0)) / d12
    -- Computes total force on particle at r due to all other particles
    totalForce r = U.foldr vectorAdd zeroVector $ U.map (forceVector r) rs

-- Updates positions, velocities, and forces using velocity Verlet
velocityVerlet :: MDMatrix -> MDMatrix -> MDMatrix -> Double -> (MDMatrix,MDMatrix,MDMatrix)
velocityVerlet rt1 vt1 ft1 boxLength =
  ( rt2 , vt2 , ft2 )
  where rt2 = matrixAdd (matrixAdd rt1 (matrixMultiply dt vt1)) (matrixMultiply c1 ft1)
        ft2 = forceMatrix rt2 boxLength
        vt2 = matrixAdd (matrixAdd vt1 (matrixMultiply c2 ft1)) (matrixMultiply c2 ft2)
        c1 = (dt ** 2.0) / (2.0 * mass)
        c2 = dt / (2.0 * mass)

-- Computes intermediate scattering function value between two configurations
isf :: MDVector -> MDMatrix -> MDMatrix -> Double
isf k r0 rt =
  let diffMatrix = U.zipWith vectorSubtract rt r0 in
  let dotMatrix = U.map (dotProduct k) diffMatrix in
  let cosMatrix = U.map cos dotMatrix in
  (U.sum cosMatrix) / (fromIntegral (U.length cosMatrix))

-- Given initial configuration, velocities, number of timesteps to execute
-- and a k-vector of interest, computes the self-ISF trajectory.
mdIsf :: MDMatrix -> MDMatrix -> Int -> Double -> MDVector -> [Double]
mdIsf r0 v0 timesteps boxLength k =
  let mdIsfHelper rt vt ft steps
        | steps == 0 = []
        | otherwise  = (isf k r0 rt) : (mdIsfHelper rt2 vt2 ft2 (steps - 1))
          where (rt2,vt2,ft2) = velocityVerlet rt vt ft boxLength
  in mdIsfHelper r0 v0 f0 timesteps
     where f0 = forceMatrix r0 boxLength

-- Given initial configuration, velocities, and number of timesteps to execute,
-- computes the trajectory of the first particle.
mdTraj :: MDMatrix -> MDMatrix -> Int -> Double -> [MDVector]
mdTraj r0 v0 timesteps boxLength =
  let mdTrajHelper rt vt ft steps
        | steps == 0 = []
        | otherwise  = (U.head rt) : (mdTrajHelper rt2 vt2 ft2 (steps - 1))
          where (rt2,vt2,ft2) = velocityVerlet rt vt ft boxLength
  in mdTrajHelper r0 v0 f0 timesteps
     where f0 = forceMatrix r0 boxLength
