module DPLL.DpllSolver (
    SatSolver(..),
    newSatSolver, isSolved,
    selectBranchVar, solve,
    guess
) where

import Data.Maybe (mapMaybe)
import DPLL.Clause
import DPLL.Literal
import qualified Data.IntMap as IM
import Control.Applicative (Alternative(..))
import Data.List (sortBy)
import Control.DeepSeq (NFData, rnf)

data SatSolver = SatSolver
  { clauses :: ![Clause],           -- Force strict evaluation of clauses
    bindings :: !(IM.IntMap Bool)  -- Force strict evaluation of bindings
  }
  deriving (Show, Eq)

instance NFData SatSolver where
    rnf solver = rnf (clauses solver) `seq` rnf (bindings solver)

newSatSolver :: SatSolver
newSatSolver = SatSolver [] IM.empty

selectBranchVar :: SatSolver -> Var
selectBranchVar solver =
    var $ head $ literals $ head $ sortBy shorterClause (clauses solver)

isSolved :: SatSolver -> Bool
isSolved = null . clauses

solve :: (Monad m, Alternative m) => SatSolver -> m SatSolver
solve solver =
    maybe empty solveRecursively (simplify solver)

solveRecursively :: (Monad m, Alternative m) => SatSolver -> m SatSolver
solveRecursively solver
    | isSolved solver = pure solver
    | otherwise = do
        let varToBranch = selectBranchVar solver
        branchOnUnbound varToBranch solver >>= solveRecursively

branchOnUnbound :: (Monad m, Alternative m) => Var -> SatSolver -> m SatSolver
branchOnUnbound name solver =
    guessAndRecurse (mkLit name True) solver
    <|>
    guessAndRecurse (mkLit name False) solver

guessAndRecurse :: (Monad m, Alternative m) => Lit -> SatSolver -> m SatSolver
guessAndRecurse lit solver = do
    case guess lit solver of
        Nothing -> empty -- Conflict detected, backtrack
        Just simplifiedSolver -> solveRecursively simplifiedSolver -- Continue solving recursively

guess :: Lit -> SatSolver -> Maybe SatSolver
guess lit solver =
    let updatedBindings = IM.insert (var lit) (not (sign lit)) (bindings solver)
        updatedClauses = mapMaybe (filterClause lit) (clauses solver)
    in simplify $ solver { clauses = updatedClauses, bindings = updatedBindings }

simplify :: (Monad m, Alternative m) => SatSolver -> m SatSolver
simplify solver = do
    case findUnitClause (clauses solver) of
        Nothing -> pure solver
        Just lit -> do
            let updatedSolver = solver { bindings = IM.insert (var lit) (not (sign lit)) (bindings solver) }
            case propagate lit (clauses updatedSolver) of
                Nothing -> empty
                Just updatedClauses ->
                    simplify $ updatedSolver { clauses = updatedClauses }

propagate :: Lit -> [Clause] -> Maybe [Clause]
propagate lit inputClauses =
    let updatedClauses = mapMaybe (processClause lit) inputClauses
    in if any (null . literals) updatedClauses
       then Nothing
       else Just updatedClauses

findUnitClause :: [Clause] -> Maybe Lit
findUnitClause [] = Nothing
findUnitClause (c:cs)
    | clauseSize c == 1 = Just (head (literals c))
    | otherwise = findUnitClause cs

processClause :: Lit -> Clause -> Maybe Clause
processClause lit clause
    | lit `elem` literals clause = Nothing
    | neg lit `elem` literals clause =
        let newLits = filter (\l -> l /= neg lit) (literals clause)
        in if null newLits
           then Just $ Clause [] (learnt clause) (activity clause)
           else Just $ Clause newLits (learnt clause) (activity clause)
    | otherwise = Just clause

filterClause :: Lit -> Clause -> Maybe Clause
filterClause lit clause
    | lit `elem` literals clause = Nothing
    | neg lit `elem` literals clause =
        Just $ Clause (filter (\l -> l /= neg lit) (literals clause)) (learnt clause) (activity clause)
    | otherwise = Just clause

shorterClause :: Clause -> Clause -> Ordering
shorterClause c1 c2 = compare (clauseSize c1) (clauseSize c2)
