{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use camelCase" #-}

module KnapsackStealPool (solve_instance_steal) where

import Control.Concurrent
import Control.Monad
import Data.IORef
import qualified Data.Vector as V
import Control.Exception (mask_)

import Types (State(..))
import KnapsackCore (sort_items_vec, build_frontier, dfs_best_value_vec)

solve_instance_steal
  :: Double
  -> [(Double,Double)]
  -> Int
  -> Int
  -> IO Double
solve_instance_steal capacity items_list depth_limit n_workers = do
  let items_vec       = sort_items_vec items_list
      frontier_states = build_frontier capacity depth_limit items_vec
      chunks          = chunk_list_round_robin n_workers frontier_states

  best_mvar   <- newMVar 0
  victims_mvar <- newMVar []                 -- all deques
  active_ref  <- newIORef 0                  -- #workers currently processing a task

  done_vars <- replicateM n_workers newEmptyMVar

  forM_ (zip chunks done_vars) $ \(init_chunk, done_var) -> do
    local_ref <- newIORef init_chunk
    modifyMVar_ victims_mvar $ \vs -> pure (local_ref : vs)

    _ <- forkFinally
          (worker_loop capacity items_vec best_mvar victims_mvar active_ref local_ref)
          (\_ -> putMVar done_var ())
    pure ()

  mapM_ takeMVar done_vars
  readMVar best_mvar


worker_loop
  :: Double
  -> V.Vector (Double,Double)
  -> MVar Double
  -> MVar [IORef [State]]
  -> IORef Double
  -> IORef [State]
  -> IO ()
worker_loop capacity items_vec best_mvar victims_mvar active_ref local_ref = work
  where
    work = do
      m_task0 <- pop_local local_ref
      m_task  <- case m_task0 of
        Just st -> pure (Just st)
        Nothing -> steal_work victims_mvar local_ref

      case m_task of
        Just st -> do
          atomicModifyIORef' active_ref $ \k -> (k+1, ())
          best_global <- readMVar best_mvar
          let dfs = dfs_best_value_vec capacity items_vec
              !local_best =
                dfs (state_level st) (state_weight st) (state_value st)
                    (max best_global (state_value st))
          modifyMVar_ best_mvar $ \old_best -> pure (max old_best local_best)
          atomicModifyIORef' active_ref $ \k -> (k-1, ())
          work

        Nothing -> do
          -- robust termination: only stop if everyone idle AND all deques empty
          active <- readIORef active_ref
          if active /= 0
            then threadDelay 2000 >> work
            else do
              refs <- readMVar victims_mvar
              empties <- mapM (fmap null . readIORef) refs
              if and empties
                then pure ()
                else threadDelay 2000 >> work


pop_local :: IORef [State] -> IO (Maybe State)
pop_local local_ref =
  atomicModifyIORef' local_ref $ \stk ->
    case stk of
      []     -> ([], Nothing)
      (x:xs) -> (xs, Just x)


steal_work :: MVar [IORef [State]] -> IORef [State] -> IO (Maybe State)
steal_work victims_mvar thief_ref = do
  victims <- readMVar victims_mvar
  try_victims victims
  where
    try_victims []     = pure Nothing
    try_victims (v:vs) = do
      stolen <- steal_half v thief_ref
      case stolen of
        Just st -> pure (Just st)
        Nothing -> try_victims vs


-- push a list onto a stack without using (++)
push_many :: [a] -> [a] -> [a]
push_many xs stk = foldr (:) stk xs


steal_half :: IORef [State] -> IORef [State] -> IO (Maybe State)
steal_half victim_ref thief_ref = mask_ $ do
  m_chunk <- atomicModifyIORef' victim_ref $ \victim_stk ->
    case victim_stk of
      [] -> ([], Nothing)
      _  ->
        let (take_part, keep_part) = split_at_half victim_stk
        in (keep_part, Just take_part)

  case m_chunk of
    Nothing -> pure Nothing
    Just chunk ->
      case chunk of
        []     -> pure Nothing
        (x:xs) -> do
          atomicModifyIORef' thief_ref $ \stk -> (push_many xs stk, ())
          pure (Just x)


split_at_half :: [a] -> ([a],[a])
split_at_half xs =
  let len = length xs
      k   = max 1 (len `div` 2)
  in splitAt k xs


-- Better chunking than repeatedly length/splitAt in a loop: round-robin
chunk_list_round_robin :: Int -> [a] -> [[a]]
chunk_list_round_robin k xs =
  let n = max 1 k
  in distribute n xs
  where
    distribute n ys = dist 0 (replicate n []) ys
    dist _ acc [] = map reverse acc
    dist i acc (z:zs) =
      let j = i `mod` length acc
          acc' = take j acc ++ [z : (acc !! j)] ++ drop (j+1) acc
      in dist (i+1) acc' zs
