{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use camelCase" #-}

module KnapsackSeq
  ( knapsackSeq
  , knapsackSeqFromState
  , fractional_bound
  ) where

import Data.List (sortOn)
import qualified Data.Vector as V
import Data.Vector (Vector)
import Types (Item, State(..))

import Control.Parallel (par, pseq)
import Control.DeepSeq (deepseq)
-- import Control.DeepSeq (NFData(..), deepseq)
-- import Debug.Trace

------------------------------------------------------------
-- Fractional upper bound (unchanged)
------------------------------------------------------------

fractional_bound
  :: Double      -- capacity
  -> Double      -- remaining_capacity
  -> Int         -- start_index
  -- -> [Item]   -- items_sorted
  -> Vector Item -- items_sorted
  -> Double
fractional_bound _ remaining_capacity start_index items_sorted =
  bound remaining_capacity start_index 0
  where
    n_items = length items_sorted

    bound rem_cap idx acc_value
      | rem_cap <= 0   = acc_value
      | idx >= n_items = acc_value
      | otherwise =
          let (v, w) = items_sorted V.! idx
          in if w <= rem_cap
               then bound (rem_cap - w) (idx + 1) (acc_value + v)
               else
                 let frac = rem_cap / w
                     add  = frac * v
                 in acc_value + add


------------------------------------------------------------
-- Sequential branch-and-bound with simple parallelism
------------------------------------------------------------

knapsackSeq
  :: Double   -- capacity
  -> [Item]   -- items
  -> Double
knapsackSeq capacity items =
  let items_sorted_list =
        sortOn (\(v, w) -> -(v / w)) items
      -- Convert to O(1)-indexing structure
      items_sorted = V.fromList items_sorted_list
  in dfs 0 0 0 0 items_sorted
  where
    n_items = length items
    -- n_items = V.length items_sorted
    dfs
      :: Int    -- index
      -> Double -- current_weight
      -> Double -- current_value
      -> Double -- best_so_far
      -- -> [Item] -- items_sorted
      -> Vector Item -- items_sorted
      -> Double
    dfs index current_weight current_value best_so_far items_sorted
      | current_weight > capacity = best_so_far
      | index == n_items =
          max best_so_far current_value
      | otherwise =
          let remaining_capacity = capacity - current_weight
              upper =
                current_value +
                fractional_bound capacity remaining_capacity index items_sorted
          in if upper <= best_so_far
               then best_so_far
               else
                 let (v, w) = items_sorted V.! index
                     -- include branch, force deep evaluation of result
                     best_incl =
                       let candidate =
                             dfs (index+1)
                                 (current_weight + w)
                                 (current_value + v)
                                 best_so_far
                                 items_sorted
                       in candidate `deepseq` max best_so_far candidate
                       -- deepseq might block pruning
                       -- in max best_so_far candidate

                     -- Parallelize include/exclude split:
                     -- We spark best_incl, compute best_excl, then combine.
                     best_excl =
                       dfs (index+1)
                           current_weight
                           current_value
                           best_incl
                           items_sorted

                     -- Introduce parallel hint:
                     result =
                       best_incl `par`
                       (best_excl `pseq` best_excl)
                 in result
                 -- in max best_incl best_excl


------------------------------------------------------------
-- DFS from a partial state (used for frontier-based parallelism)
------------------------------------------------------------

knapsackSeqFromState
  :: Double
  -> [Item]
  -> State
  -> Double
knapsackSeqFromState capacity items start_state =
  let items_sorted_list =
        sortOn (\(v, w) -> -(v / w) ) items
      items_sorted = V.fromList items_sorted_list
  in dfs (state_level start_state)
         (state_weight start_state)
         (state_value start_state)
         (state_value start_state)
         items_sorted
  where
    n_items = length items

    dfs index current_weight current_value best_so_far items_sorted
      | current_weight > capacity = best_so_far
      | index == n_items =
          max best_so_far current_value
      | otherwise =
          let remaining_capacity = capacity - current_weight
              upper =
                current_value +
                fractional_bound capacity remaining_capacity index items_sorted
          in if upper <= best_so_far
               then best_so_far
               else
                 let (v, w) = items_sorted V.! index
                     best_incl =
                       let candidate =
                             dfs (index+1)
                                 (current_weight + w)
                                 (current_value + v)
                                 best_so_far
                                 items_sorted
                       in candidate `deepseq` max best_so_far candidate

                     best_excl =
                       dfs (index+1)
                           current_weight
                           current_value
                           best_incl
                           items_sorted

                     combined =
                       best_incl `par`
                       (best_excl `pseq` best_excl)
                 in combined
