module MDVector
  ( MDVector,
    MDMatrix,
    zeroVector,
    vectorAdd,
    vectorSubtract,
    vectorMultiply,
    dotProduct,
    vectorNorm,
    displacement,
    distance,
    unitVector,
    matrixAdd,
    matrixMultiply
  ) where

import Data.Fixed (mod')
import qualified Data.Vector.Unboxed as U

-- Fundamental 3-vector type for physical objects
type MDVector = (Double,Double,Double)
type MDMatrix = U.Vector MDVector

-- Defining a zero vector (just for convenience)
zeroVector :: MDVector
zeroVector = (0.0,0.0,0.0)

-- Adds two MDVectors together
vectorAdd :: MDVector -> MDVector -> MDVector
vectorAdd (x1,y1,z1) (x2,y2,z2) =
  (x1+x2,y1+y2,z1+z2)

-- Subtracts two MDVectors
vectorSubtract :: MDVector -> MDVector -> MDVector
vectorSubtract (x1,y1,z1) (x2,y2,z2) =
  (x1-x2,y1-y2,z1-z2)

-- Multiplies a MDVector by a scalar
vectorMultiply :: Double -> MDVector -> MDVector
vectorMultiply c (x,y,z) =
  (c*x,c*y,c*z)

-- Dot product between MDVectors
dotProduct :: MDVector -> MDVector -> Double
dotProduct (x1,y1,z1) (x2,y2,z2) =
  (x1*x2) + (y1*y2) + (z1*z2)

-- Computes the norm of an MDVector
vectorNorm :: MDVector -> Double
vectorNorm v =
  sqrt $ dotProduct v v

-- Given MDVectors v1 and v2, computes wrapped displacement vector v_{12}
displacement :: MDVector -> MDVector -> Double -> MDVector
displacement v1 v2 boxLength =
  wrapDisplacement $ vectorSubtract v2 v1
  where wrapDisplacement (x,y,z) = ((boxMod x),(boxMod y),(boxMod z))
        halfBox = boxLength / 2.0
        boxMod c = (mod' (c + halfBox) boxLength) - halfBox

-- Computes the distance between two position vectors
distance :: MDVector -> MDVector -> Double -> Double
distance vec1 vec2 boxLength = 
  vectorNorm $ displacement vec1 vec2 boxLength

-- Returns a unit vector in the direction of the provided MDVector
-- (or the zero vector if the provided vector is the zero vector)
unitVector :: MDVector -> MDVector
unitVector vec
  | vec == zeroVector = zeroVector
  | otherwise = vectorMultiply (1.0 / (vectorNorm vec)) vec

-- Adds two matrices to each other
matrixAdd :: MDMatrix -> MDMatrix -> MDMatrix
matrixAdd m0 m1 =
  U.zipWith vectorAdd m0 m1

-- Multiplies a matrix by the provided double
matrixMultiply :: Double -> MDMatrix -> MDMatrix
matrixMultiply c m =
  U.map (vectorMultiply c) m
