{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use camelCase" #-}

module KnapsackIORefPar (solve_instance_ioref) where

import Control.Concurrent (forkFinally, newEmptyMVar, putMVar, takeMVar)
import Control.Monad (replicateM, forM_)
import Data.IORef
import qualified Data.Vector as V

import Types (State(..))
import KnapsackCore (sort_items_vec, build_frontier, dfs_best_value_vec)

solve_instance_ioref
  :: Double            -- capacity
  -> [(Double,Double)] -- items (value, weight)
  -> Int            -- depth_limit
  -> Int            -- n_workers
  -> IO Double
solve_instance_ioref capacity items_list depth_limit n_workers = do
  let items_vec        = sort_items_vec items_list
      frontier_states  = build_frontier capacity depth_limit items_vec

  best_ref <- newIORef 0
  work_ref <- newIORef frontier_states

  done_vars <- replicateM n_workers newEmptyMVar

  forM_ done_vars $ \done_var -> do
    _ <- forkFinally
          (worker_loop capacity items_vec best_ref work_ref)
          (\_ -> putMVar done_var ())
    pure ()

  mapM_ takeMVar done_vars
  readIORef best_ref


worker_loop
  :: Double
  -> V.Vector (Double,Double)
  -> IORef Double
  -> IORef [State]
  -> IO ()
worker_loop capacity items_vec best_ref work_ref = do
  m_task <- atomicModifyIORef' work_ref $ \tasks ->
    case tasks of
      []     -> ([], Nothing)
      (t:ts) -> (ts, Just t)

  case m_task of
    Nothing -> pure ()
    Just st -> do
      best_global <- readIORef best_ref
      let dfs = dfs_best_value_vec capacity items_vec
          !local_best =
            dfs (state_level st)
                (state_weight st)
                (state_value st)
                (max best_global (state_value st))

      atomicModifyIORef' best_ref $ \old_best ->
        let !new_best = max old_best local_best
        in (new_best, ())

      worker_loop capacity items_vec best_ref work_ref

