module Lib
    ( ExpressionTree(..)
    , Clauses(..)
    , Literal(..)
    , dpll
    , parDpll
    , satisfyParDPLL
    , constructTree
    ) where

import qualified Data.Vector as V
import Control.Parallel.Strategies(Strategy, using, rpar)
import Control.DeepSeq(NFData)

---------------------------------
-- Data Types: CNF Representation
---------------------------------
data ExpressionTree = Expr [Clauses]
  deriving (Show, Read, Eq)

data Clauses = Clause [Literal]
  deriving (Show, Read, Eq)

data Literal = LiteralPos Int | LiteralNeg Int
  deriving (Show, Read, Eq)


-- Helper function: construct CNF tree representation from clauses
constructTree :: [Clauses] -> ExpressionTree
constructTree clauses = Expr clauses

-------------------
-- Helper functions
-------------------
-- Check if formula is empty (contains no clauses) i.e. SAT with current assignment
isFormulaEmpty :: ExpressionTree -> Bool
isFormulaEmpty (Expr clauses) = null clauses

-- Check if formula contains an empty clauses i.e. UNSAT with current assignment
containsEmptyClause :: ExpressionTree -> Bool
containsEmptyClause (Expr clauses) = any (\(Clause lits) -> null lits) clauses

-- Apply a single variable assignment to the formula
applyAssignment :: ExpressionTree -> (Int, Bool) -> ExpressionTree
applyAssignment (Expr clauses) (var, val) =
  Expr $ map (removeContradictions var val) $ filter (not . satisfiedBy var val) clauses
  where
    satisfiedBy v b (Clause lits) = any (literalMatches v b) lits
    literalMatches v b (LiteralPos x) = x == v && b == True
    literalMatches v b (LiteralNeg x) = x == v && b == False

    removeContradictions v b (Clause lits) =
      Clause (filter (\l -> not $ literalContradicts v b l) lits)
    literalContradicts v b (LiteralPos x) = x == v && b == False
    literalContradicts v b (LiteralNeg x) = x == v && b == True

-- Find a unit clause
findUnitClause :: ExpressionTree -> Maybe Literal
findUnitClause (Expr clauses) =
  case [lits | Clause lits <- clauses, length lits == 1] of
    ((lit:_):_) -> Just lit
    _           -> Nothing

-- Find a pure literal
findPureLiteral :: ExpressionTree -> Maybe (Int, Bool)
findPureLiteral (Expr clauses) =
  let allLits = concatMap (\(Clause ls) -> ls) clauses
      (posVars, negVars) = foldr countPolarity ([],[]) allLits
      countPolarity (LiteralPos v) (ps,ns) = (v:ps, ns)
      countPolarity (LiteralNeg v) (ps,ns) = (ps, v:ns)
      uniquePos = filter (`notElem` negVars) posVars
      uniqueNeg = filter (`notElem` posVars) negVars
  in case uniquePos of
      (v:_) -> Just (v,True)
      []    -> case uniqueNeg of
                 (v:_) -> Just (v,False)
                 []    -> Nothing

-- Set a variable in the assignment vector
setAssignment :: V.Vector (Maybe Bool) -> Int -> Bool -> V.Vector (Maybe Bool)
setAssignment asg var val = asg V.// [(var, Just val)] -- NOTE: O(n)?

pairParStrat :: (NFData a)=> Strategy [a]
pairParStrat [a,b] = do
  a' <- rpar a
  b' <- rpar b
  return [a', b']
pairParStrat _ = undefined

satisfyParDPLL :: Int  -> ExpressionTree -> V.Vector (Maybe Bool) -> (Bool, V.Vector (Maybe Bool))
satisfyParDPLL = parDpll pairParStrat 

specialOr :: [(Bool,V.Vector (Maybe Bool))] -> (Bool, V.Vector (Maybe Bool))
specialOr ((b,vec):bs) = if b then (True, vec) else specialOr bs
specialOr [] = (False,V.empty)

------------------------
-- Parallel DPLL Solver
------------------------
parDpll :: Strategy[(Bool, V.Vector (Maybe Bool))] -> Int -> ExpressionTree -> V.Vector (Maybe Bool) -> (Bool, V.Vector (Maybe Bool))
parDpll _ 0 formula assignment = dpll formula assignment
parDpll strat d formula assignment
  | isFormulaEmpty formula = (True, assignment)       -- Base case: SAT
  | containsEmptyClause formula = (False, V.empty)    -- Base case: UNSAT
  | Just unit <- findUnitClause formula =             -- Unit Propagation
      let (var, val) = case unit of
                         LiteralPos v -> (v, True)
                         LiteralNeg v -> (v, False)
          newAsg = setAssignment assignment var val
          newFormula = applyAssignment formula (var, val) in parDpll strat d newFormula newAsg
  | Just (v,b) <- findPureLiteral formula =   -- Pure Literal Elimination
      let newAsg = setAssignment assignment v b
          newFormula = applyAssignment formula (v, b) in parDpll strat d newFormula newAsg
  | otherwise = case formula of               -- Branch
        Expr (Clause c:_) ->
          let pickLiteral = head c                 -- Branching Heuristic: naive
              varToAssign = case pickLiteral of
                              LiteralPos v -> v
                              LiteralNeg v -> v
              tryTrueAsg  = setAssignment assignment varToAssign True
              tryTrueForm = applyAssignment formula (varToAssign, True)
              tryFalseAsg = setAssignment assignment varToAssign False
              tryFalseForm = applyAssignment formula (varToAssign, False)
              satFalse = parDpll strat (d-1) tryFalseForm tryFalseAsg
              satTrue  = parDpll strat (d-1) tryTrueForm tryTrueAsg
              in specialOr ([satTrue, satFalse] `using` strat)
        Expr [] -> (False, V.empty)

-------------------------
-- Sequential DPLL Solver
-------------------------
-- TODO: remove logging and refactor
dpll :: ExpressionTree -> V.Vector (Maybe Bool) -> (Bool, V.Vector (Maybe Bool))
dpll formula assignment
  | isFormulaEmpty formula = (True, assignment)       -- Base case: SAT
  | containsEmptyClause formula = (False, V.empty)    -- Base case: UNSAT
  | Just unit <- findUnitClause formula =             -- Unit Propagation
      let (var, val) = case unit of
                         LiteralPos v -> (v, True)
                         LiteralNeg v -> (v, False)
          newAsg = setAssignment assignment var val
          newFormula = applyAssignment formula (var, val) in dpll newFormula newAsg
  | Just (v,b) <- findPureLiteral formula =   -- Pure Literal Elimination
      let newAsg = setAssignment assignment v b
          newFormula = applyAssignment formula (v, b) in dpll newFormula newAsg
  | otherwise = case formula of               -- Branch
        Expr (Clause c:_) ->
          let pickLiteral = head c                 -- Branching Heuristic: naive
              varToAssign = case pickLiteral of
                              LiteralPos v -> v
                              LiteralNeg v -> v
              tryTrueAsg  = setAssignment assignment varToAssign True
              tryTrueForm = applyAssignment formula (varToAssign, True)
              (satTrue, trueAsgt) = dpll tryTrueForm tryTrueAsg
              in if satTrue
                 then (True, trueAsgt)
                 else
                  let tryFalseAsg  = setAssignment assignment varToAssign False
                      tryFalseForm = applyAssignment formula (varToAssign, False)
                      (satFalse, falseAsgt) = dpll tryFalseForm tryFalseAsg in if not satFalse then (False, V.empty) else (True, falseAsgt)
        Expr [] -> (False, V.empty)
