module MDVector
  ( MDVector(..),
    zeroVector,
    vectorAdd,
    vectorSubtract,
    vectorMultiply,
    dotProduct,
    vectorNorm,
    displacement,
    distance,
    unitVector,
    matrixMultiply,
    addMatrixList
  ) where

import Data.Fixed (mod')

-- Fundamental 3-vector datatype for x, y, and z coordinates
data MDVector = MDVector !Double !Double !Double

instance Show MDVector where
  show (MDVector x y z) = "[ " ++ (show x) ++ " , " ++ (show y) ++ " , " ++ (show z) ++ " ]"

-- Allows for some floating point error when comparing MDVectors
instance Eq MDVector where
  (MDVector x1 y1 z1) == (MDVector x2 y2 z2) =
    all closeEnough rTuple
    where closeEnough (a,b) = (abs (a - b)) < 1e-10
          rTuple = [(x1,x2),(y1,y2),(z1,z2)]

-- Defining a zero vector (just for convenience)
zeroVector :: MDVector
zeroVector = MDVector 0.0 0.0 0.0

-- Adds two MDVectors together
vectorAdd :: MDVector -> MDVector -> MDVector
vectorAdd (MDVector x1 y1 z1) (MDVector x2 y2 z2) =
  MDVector (x1+x2) (y1+y2) (z1+z2)

-- Subtracts two MDVectors
vectorSubtract :: MDVector -> MDVector -> MDVector
vectorSubtract (MDVector x1 y1 z1) (MDVector x2 y2 z2) =
  MDVector (x1-x2) (y1-y2) (z1-z2)

-- Multiplies a MDVector by a scalar
vectorMultiply :: Double -> MDVector -> MDVector
vectorMultiply c (MDVector x y z) =
  MDVector (c*x) (c*y) (c*z)

-- Dot product between MDVectors
dotProduct :: MDVector -> MDVector -> Double
dotProduct (MDVector x1 y1 z1) (MDVector x2 y2 z2) =
  (x1*x2) + (y1*y2) + (z1*z2)

-- Computes the norm of an MDVector
vectorNorm :: MDVector -> Double
vectorNorm v1 =
  sqrt $ dotProduct v1 v1

-- Given MDVectors v1 and v2, computes wrapped displacement vector v_{12}
displacement :: MDVector -> MDVector -> Double -> MDVector
displacement v1 v2 boxLength =
  wrapDisplacement $ vectorSubtract v2 v1
  where wrapDisplacement (MDVector x y z) = MDVector (boxMod x) (boxMod y) (boxMod z)
        halfBox = boxLength / 2.0
        boxMod c = (mod' (c + halfBox) boxLength) - halfBox

-- Computes the distance between two position vectors
distance :: MDVector -> MDVector -> Double -> Double
distance vec1 vec2 boxLength = 
  vectorNorm $ displacement vec1 vec2 boxLength

-- Returns a unit vector in the direction of the provided MDVector
-- (or the zero vector if the provided vector is the zero vector)
unitVector :: MDVector -> MDVector
unitVector vec
  | vec == zeroVector = zeroVector
  | otherwise = vectorMultiply (1.0 / (vectorNorm vec)) vec

-- Multiplies a list of vectors (AKA matrix) by the provided double
matrixMultiply :: Double -> [MDVector] -> [MDVector]
matrixMultiply c m =
  map (vectorMultiply c) m

-- Adds a list of "matrices" to the provided "matrix"
addMatrixList :: [MDVector] -> [[MDVector]] -> [MDVector]
addMatrixList m0 matrixList =
  foldr (zipWith vectorAdd) m0 matrixList
