{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use camelCase" #-}
{-# HLINT ignore "Use ++" #-}

module KnapsackCore
  ( sort_items_vec
  , fractional_bound_vec
  , dfs_best_value_vec
  , build_frontier
  ) where

import qualified Data.Vector as V
import Data.List (sortOn)
import Debug.Trace (trace)

import Types (Item, State(..))

-- Toggle debug tracing:
debug_on :: Bool
debug_on = False

dbg :: String -> a -> a
dbg msg x = if debug_on then trace msg x else x

sort_items_vec :: [Item] -> V.Vector Item
sort_items_vec items_list =
  let v_ratio v w = v / w
      items_sorted_list =
        sortOn (\(v, w) -> -(v_ratio v w)) items_list
  in V.fromList items_sorted_list

fractional_bound_vec
  :: Double
  -> Int
  -> V.Vector Item
  -> Double
fractional_bound_vec remaining_capacity start_index items_vec =
  bound remaining_capacity start_index 0
  where
    n_items = V.length items_vec

    bound !rem_cap !idx !acc_value
      | rem_cap <= 0   = acc_value
      | idx >= n_items = acc_value
      | otherwise =
          let (v, w) = items_vec V.! idx
          in if w <= rem_cap
               then bound (rem_cap - w) (idx + 1) (acc_value + v)
               else
                 let frac = rem_cap / w
                     add  = frac * v
                 in acc_value + add

-- Pure DFS (sequential) returning best value in subtree.
-- NOTE: this is used inside each worker/thread in parallel versions.
dfs_best_value_vec
  :: Double
  -> V.Vector Item
  -> Int
  -> Double
  -> Double
  -> Double
  -> Double
dfs_best_value_vec capacity items_vec !index !current_weight !current_value !best_so_far =
  let n_items = V.length items_vec
  in
  if current_weight > capacity then best_so_far
  else if index >= n_items then max best_so_far current_value
  else
    let remaining_capacity = capacity - current_weight
        upper =
          current_value + fractional_bound_vec remaining_capacity index items_vec
    in
    if upper <= best_so_far then best_so_far
    else
      let (v, w) = items_vec V.! index

          -- _ = dbg ("include? v=" ++ show v ++ " w=" ++ show w ++ " idx=" ++ show index) ()
          _ = dbg (concat ["include? v=", show v, " w=", show w, " idx=", show index]) ()

          best_incl =
            if current_weight + w <= capacity
              then dfs_best_value_vec capacity items_vec (index + 1)
                     (current_weight + w) (current_value + v) best_so_far
              else best_so_far

          -- _ = dbg ("exclude idx=" ++ show index) ()
          _ = dbg (concat ["exclude? idx=", show index]) ()

          best_after_incl = max best_so_far best_incl

          best_excl =
            dfs_best_value_vec capacity items_vec (index + 1)
              current_weight current_value best_after_incl
      in best_excl

build_frontier
  :: Double           -- capacity
  -> Int              -- depth_limit
  -> V.Vector Item    -- items_vec
  -> [State]
build_frontier capacity depth_limit items_vec =
  build 0 0 0 0 id []
  where
    n_items = V.length items_vec

    build
      :: Int                    -- depth
      -> Int                    -- index
      -> Double                 -- current_weight
      -> Double                 -- current_value
      -> ([State] -> [State])   -- accumulator (difference list)
      -> [State]                -- final list
      -> [State]
    build depth index current_weight current_value acc out
      | current_weight > capacity =
          acc out

      | depth == depth_limit || index == n_items =
          acc (State index current_value current_weight : out)

      | otherwise =
          let (v, w) = items_vec V.! index

              -- include branch
              acc_incl =
                build
                  (depth + 1)
                  (index + 1)
                  (current_weight + w)
                  (current_value + v)
                  acc

              -- exclude branch
              acc_excl =
                build
                  (depth + 1)
                  (index + 1)
                  current_weight
                  current_value
                  acc
          in acc_incl (acc_excl out)

