{-all math types and fns needed for ray tracing -}

module RayMath where
  type Point2D = (Int, Int)
  type Point3D = (Double, Double, Double)
  type Vector = (Double, Double, Double)
  type Resolution = (Int, Int)

  type Dimension = (Int, Int) {-Screen window res if use GUI -}

  data Ray = Ray Point3D Vector
  data RenderObj = Sphere Double Point3D
                 | Plane (Double,Double,Double,Double)
 
  {- Didn't use folds, slower applying the fn -}
  (<+>) :: (Double, Double, Double) -> (Double, Double, Double) -> (Double, Double, Double)
  (x1,y1,z1) <+> (x2,y2,z2) = (x1+x2, y1+y2, z1+z2) 
  (<->) :: (Double, Double, Double) -> (Double, Double, Double) ->  (Double, Double, Double)
  (x1,y1,z1) <-> (x2,y2,z2) = (x1-x2, y1-y2, z1-z2) 
  (<**>) :: (Double, Double, Double) -> (Double, Double, Double) -> (Double, Double, Double)
  (x1,y1,z1) <**> (x2,y2,z2) = (x1*x2,y1*y2,z1*z2) 

  (**>) :: (Double, Double, Double) -> Double -> (Double,Double,Double){-only used for scaling -}
  (x,y,z) **> f = (x*f,y*f, z*f)
   
  dot :: Vector -> Vector -> Double {- dot product -}
  dot (x1,y1,z1) (x2,y2,z2) = x1*x2 + y1*y2 + z1*z2  
  
  len :: Vector -> Double
  len v = sqrt (v `dot` v)

  norm :: Vector -> Vector
  norm v
        | len v < 10**(-9) = (0.0,0.0,0.0) {- pesky floating point percision -}
        | otherwise = v **> (1/(len v))
  point2Vec :: Point3D -> Point3D -> Vector {- make a normalised vector from two ooints-}
  point2Vec v w = norm (w <-> v)

  dist :: Point3D -> Point3D -> Double
  dist p0 p1 = sqrt ((p1 <-> p0) `dot` (p1 <-> p0))
 
  {-clipping the color to be in [0,1] -} 
  clipUp :: Double -> (Double, Double, Double) -> (Double, Double, Double)
  clipUp f (x,y,z) = (max x f, max y f, max z f)
  clipDown :: Double -> (Double, Double, Double) -> (Double, Double, Double)
  clipDown f (x,y,z) = (min x f, min y f, min z f)
  colorClip :: (Double, Double, Double) -> (Double, Double, Double)
  colorClip = (clipUp 0.0) . (clipDown 1.0) 

  instRay :: Point3D -> Point3D -> Ray
  instRay p1 p2 = Ray p1 (point2Vec p1 p2)

  solveQuad :: (Double,Double,Double) -> [Double] {-used for ray intersection -}
  solveQuad (a,b,c)
    | d<0 = []
    | d>0 = [(-b - sqrt d)/(2*a),(-b+sqrt d)/(2*a)]
    | otherwise = [-b/(2*a)]
    where
      d = b*b - 4*a*c

  rayIntersectWith :: Ray -> RenderObj -> [Double]
  {- solving for (x-cenX)^2 + (y-cenY)^2 + (z - CenZ)^2 = rad^2 -}
  rayIntersectWith (Ray start dir) (Sphere rad cent) = solveQuad (dir `dot` dir, 2*(dir `dot` d), (d `dot` d) - rad*rad)
    where d = start <-> cent
  rayIntersectWith (Ray start dir) (Plane (a,b,c,d))
    | abs((a,b,c) `dot` dir) < 10**(-9) = []
    | otherwise = [- (d+((a,b,c) `dot` start) ) / ((a,b,c) `dot` dir)]
  
  normal :: Point3D -> RenderObj -> Vector
  normal p (Sphere rad cent) = norm ((p <-> cent) **> (1/rad))
  normal _ (Plane (a,b,c,_)) = norm (a,b,c)
  {- reflected direction given normalized direction and normal vectors -}
  reflectDir :: Vector -> Vector -> Vector
  reflectDir i n = i <-> (n **> (2*(n `dot` i)))

  {- refracted direction given normalized direction and normal vectors -}
  refractDir :: Vector -> Vector -> Double -> Vector
  refractDir i n r
    | v < 0 = (0.0,0.0,0.0)
    | otherwise = norm $ (i **> rc) <+> (n **> (rc*(abs c) - sqrt v))
    where
      c = n `dot` (i **> (-1))
      rc
        | c<0 = r -- if cosine < 0, inside sphere
        | otherwise = 1/r
      v = 1+(rc*rc) * (c*c -1)

  {- convert a pixel to a ray vector from camera eye -}
  resToWin :: Resolution -> Dimension -> Point2D -> Point3D
  resToWin (rx,ry) (w,h) (px,py) = (x/rxD, y/ryD, 0.0)
    where
      (rxD, ryD) = (fromIntegral rx, fromIntegral ry)
      (pxD, pyD) = (fromIntegral px, fromIntegral py)
      (wD, hD) = (fromIntegral w, fromIntegral h)
      (x, y) = ((pxD-rxD/2)*wD, (pyD-ryD/2)*hD)
