module Tree where
import Data.List
import Control.DeepSeq
import Data.List.Split
import Control.Parallel.Strategies


data Leaf = Leaf Point Bool | None


type Point = (Float,Float)
data Branch = Empty | Branch {
                            position :: Point,
                            parent :: Branch,
                            direction :: Point -- Vector Representation of direction
                            }

data Tree = DONE Tree | Tree {
                        leaves :: Leaves,
                        root :: Branch,
                        branches  :: Branches,
                        max_dist :: Float,
                        min_dist :: Float,
                        window_size :: Float,
                        detected :: Bool
                        }
{-
  Point Arithmetic Helpers
-}
add :: Point -> Point -> Point
add (x1, y1) (x2, y2) =
  let
    x = x1 + x2
    y = y1 + y2
  in (x, y)

sub :: Point -> Point -> Point
sub (x1, y1) (x2, y2) =
  let
    x = x1 - x2
    y = y1 - y2
  in (x, y)

vdiv :: Point -> Float -> Point
vdiv (x1, y1) a =
  let
    x = x1 / a
    y = y1 / a
  in (x, y)

vmult :: Point -> Float -> Point
vmult (x1, y1) a =
  let
    x = x1 / a
    y = y1 / a
  in (x, y)

distance :: Point -> Point -> Float
distance (x1,y1) (x2,y2) = let x' = x1 - x2
                               y' = y1 - y2
                           in
                               sqrt (x'*x' + y'*y')

normalize :: Floating b => (b, b) -> (b, b)
normalize (x,y) = let magnitude = sqrt ((x*x) + (y*y))
                  in
                    (x/magnitude, y/magnitude)

{-
  Tree,Branch,Leaf helpers
-}


-- Convert Array of points to Leaves
pointsToLeaves :: [(Float, Float)] -> [Leaf]
pointsToLeaves arr = (parMap rseq (\(x,y) -> Leaf (x,y) False) arr)


-- Check if Branch is 
notEmpty :: Branch -> Bool
notEmpty b = case b of
              Empty -> False
              otherwise -> True

-- Initialize a tree
initialTree :: [(Float, Float)] -> Float -> Float -> Float -> Tree
initialTree arr size max min = Tree {
                                    leaves = pointsToLeaves arr,
                                    root = root_init,
                                    branches = [root_init],
                                    max_dist = max,
                                    min_dist = min,
                                    window_size = size,
                                    detected = False
                                    }
                                where root_init = Branch {position=(0, -size/2), parent = Empty, direction = (0,1)}

addBranch :: Tree -> Branch -> Tree
addBranch tree branch = tree {branches= branch : (branches tree)}

addBranches :: Tree -> [Branch] -> Tree
addBranches tree b = tree {branches = b ++ (branches tree)}


detectLeaves :: Branch -> [Leaf] -> Float -> Bool
detectLeaves branch lvs maxDist = any (==True) (parMap rseq f lvs)
                                      where f None = False
                                            f (Leaf (x,y) _) = distance (x,y) (position branch) < maxDist
                      
closestBranch :: Leaf -> [Branch] -> Float -> Float -> (Leaf, Branch)
closestBranch None _ _ _ = (None, Empty)
closestBranch (Leaf (x,y) _) br minDist maxDist = let closest = minimumBy f br
                                                      dis = distance (position closest) (x,y)
                                                      newDir = sub (x,y) (position closest)
                                                      normalized = normalize newDir
                                                  in
                                                      if (dis >= maxDist) then
                                                        ((Leaf (x,y) False), Empty)
                                                      else
                                                        if (dis <= minDist) then
                                                          ((Leaf (x,y) True),closest {parent=closest, direction = normalized})
                                                        else
                                                          ((Leaf (x,y) False), closest {parent=closest, direction = normalized})
                                                  where f a b = compare (distance (position a) (x,y)) (distance (position b) (x,y))
averageDir :: [Branch] -> Branch
averageDir brches = let ref = (head brches)
                          -- sum = foldr1 add $ (parMap rseq (direction) branches)
                        sumDir = (foldl' (\acc b -> add acc (direction b)) (direction (parent ref)) brches)
                        new_dir = normalize (vdiv sumDir (fromIntegral ((length brches))))
                        new_pos = add (position ref) new_dir
                    in
                        Branch {position=new_pos, parent = (parent ref), direction = new_dir}

calculateNewBranches :: [Branch] -> [Branch]
calculateNewBranches closests = let grouped = groupBy branchPos closests
                                in
                                    map averageDir grouped 
                                where branchPos a b = (position a == position b)


step :: Tree -> Tree
step tree = let top = head (branches tree)
            in
              case (detectLeaves top (leaves tree) (max_dist tree)) of
                False -> addBranch tree (Branch {position=(add (position top) 
                        (direction top)), parent = top, direction = (direction top)})
                True -> tree {detected = True}

grow :: Tree -> Tree
grow tree = let unreached = filter (\(Leaf (_,_) reached) -> not reached) (leaves tree)
                (newLeaves,closests) = unzip ((parMap rpar (\x -> closestBranch x (branches tree) 
                                              (min_dist tree) (max_dist tree)) unreached) )
                filteredClosests = filter notEmpty closests
                newBranches = calculateNewBranches filteredClosests
            in
                case newBranches of
                  [] -> DONE tree
                  _ -> addBranches (tree {leaves = newLeaves}) newBranches

nextBranch :: p1 -> p2 -> Tree -> Tree
nextBranch _ _ (DONE tree) = DONE tree 
nextBranch _ _ tree = case (detected tree) of
                          False -> step tree
                          True -> grow tree