{-# LANGUAGE PatternGuards, BangPatterns #-}
{-|
This module implements a decision procedure for quantifier-free linear
arithmetic.  The algorithm is based on the following paper:

  An Online Proof-Producing Decision Procedure for
  Mixed-Integer Linear Arithmetic
  by
  Sergey Berezin, Vijay Ganesh, and David L. Dill
-}
module Data.Integer.SAT
  ( PropSet
  , noProps
  , checkSat
  , checkSatPar
  , assert
  , Prop(..)
  , Expr(..)
  , BoundType(..)
  , getExprBound
  , getExprRange
  , Name
  , toName
  , fromName
  -- * Iterators
  , allSolutions
  , slnCurrent
  , slnNextVal
  , slnNextVar
  , slnEnumerate


  -- * Debug
  , dotPropSet
  , sizePropSet
  , allInerts
  , ppInerts

  -- * For QuickCheck
  , iPickBounded
  , Bound(..)
  , tConst
  ) where

-- import Debug.Trace

import           Data.List(partition)
import           Data.Maybe(maybeToList,fromMaybe,mapMaybe)
import           Control.Applicative(Alternative(..))
import qualified Control.Monad.Fail as Fail
import           Control.Monad(liftM,ap,MonadPlus(..),guard,when)
import           Text.PrettyPrint
import qualified Data.IntMap.Strict as IntM

-- import Control.DeepSeq (NFData (..), force)
-- import Control.Parallel.Strategies (rpar, rseq, runEval)
import Control.Parallel (par)

infixr 2 :||
infixr 3 :&&
infix  4 :==, :/=, :<, :<=, :>, :>=
infixl 6 :+, :-
infixl 7 :*

--------------------------------------------------------------------------------
-- Solver interface

-- | A collection of propositions.
newtype PropSet = State (Answer RW)
                  deriving Show

dotPropSet :: PropSet -> Doc
dotPropSet (State a) = dotAnswer (ppInerts . inerts) a

sizePropSet :: PropSet -> (Integer,Integer,Integer)
sizePropSet (State a) = answerSize a

-- | An empty collection of propositions.
noProps :: PropSet
noProps = State $ return initRW

-- | Add a new proposition to an existing collection.
assert :: Prop -> PropSet -> PropSet
assert p (State rws) = State $ fmap snd $ m =<< rws
  where S m = prop (reorderProp p)

-- | Reorder constraints to process equalities before inequalities.
-- This is crucial for performance: equalities create variable definitions
-- that simplify the constraint system, while inequalities can trigger
-- expensive shadow computations. By processing equalities first, we often
-- detect UNSAT quickly or simplify bounds before shadow computation.
reorderProp :: Prop -> Prop
reorderProp p =
  case (eqs, others) of
    ([], _) -> p
    (_, []) -> p
    _ -> foldr (:&&) (foldr (:&&) PTrue others) eqs
  where
    (eqs, others) = partitionProp p

-- | Partition a conjunction into equalities and other constraints.
-- Equalities (`:==`) are separated out for priority processing.
partitionProp :: Prop -> ([Prop], [Prop])
partitionProp p =
  let (eqsR, othersR) = go p ([], [])
   in (reverse eqsR, reverse othersR)
  where
    -- accumulators are in reverse order
    go :: Prop -> ([Prop], [Prop]) -> ([Prop], [Prop])
    go (p1 :&& p2) acc =
      -- visit p1 first, then p2 (same order as your original)
      let acc1 = go p1 acc
       in go p2 acc1
    go eq@(_ :== _) (es, os) = (eq : es, os)
    go other (es, os) = (es, other : os)

-- | Extract a model from a consistent set of propositions.
-- Returns 'Nothing' if the assertions have no model.
-- If a variable does not appear in the assignment, then it is 0 (?).
checkSat :: PropSet -> Maybe [(Int,Integer)]
checkSat (State m) = go m
  where
  go None            = mzero
  go (One rw)        = return [ (x,v) | (UserName x, v) <- iModel (inerts rw) ]
  go (Choice m1 m2)  = mplus (go m1) (go m2)

checkSatPar :: Int -> PropSet -> Maybe [(Int, Integer)]
checkSatPar cutoff (State m) =
  go cutoff m
  where
    go _ None = Nothing
    go _ (One rw) = Just [(x, v) | (UserName x, v) <- iModel (inerts rw)]
    go d (Choice m1 m2)
      | d == cutoff =
          goChoice d m1 m2
      | otherwise = goChoice d m1 m2

    goChoice d m1 m2
      | d <= 0 = go 0 m1 <|> go 0 m2
      | otherwise =
          let r = go (d - 1) m2
              l = go (d - 1) m1
           in r `par` (l <|> r)
           
allInerts :: PropSet -> [Inerts]
allInerts (State m) = map inerts (toList m)

allSolutions :: PropSet -> [Solutions]
allSolutions = map startIter . allInerts


-- | Computes bounds on the expression that are compatible with the model.
-- Returns `Nothing` if the bound is not known.
getExprBound :: BoundType -> Expr -> PropSet -> Maybe Integer
getExprBound bt e (State s) =
  do let S m          = expr e
         check (t,s1) = iTermBound bt t (inerts s1)
     bs <- mapM check $ toList $ s >>= m
     case bs of
       [] -> Nothing
       _  -> Just (maximum bs)

-- | Compute the range of possible values for an expression.
-- Returns `Nothing` if the bound is not known.
getExprRange :: Expr -> PropSet -> Maybe [Integer]
getExprRange e (State s) =
  do let S m          = expr e
         check (t,s1) = do l <- iTermBound Lower t (inerts s1)
                           u <- iTermBound Upper t (inerts s1)
                           return (l,u)
     bs <- mapM check $ toList $ s >>= m
     case bs of
       [] -> Nothing
       _  -> let (ls,us) = unzip bs
             in Just [minimum ls .. maximum us]



-- | The type of proposition.
data Prop = PTrue
          | PFalse
          | Prop :|| Prop
          | Prop :&& Prop
          | Not Prop
          | Expr :== Expr
          | Expr :/= Expr
          | Expr :<  Expr
          | Expr :>  Expr
          | Expr :<= Expr
          | Expr :>= Expr
            deriving (Read,Show)

-- | The type of integer expressions.
-- Variable names must be non-negative.
data Expr = Expr :+ Expr          -- ^ Addition
          | Expr :- Expr          -- ^ Subtraction
          | Integer :* Expr       -- ^ Multiplication by a constant
          | Negate Expr           -- ^ Negation
          | Var Name              -- ^ Variable
          | K Integer             -- ^ Constant
          | If Prop Expr Expr     -- ^ A conditional expression
          | Div Expr Integer      -- ^ Division, rounds down
          | Mod Expr Integer      -- ^ Non-negative remainder
            deriving (Read,Show)

prop :: Prop -> S ()
prop PTrue       = return ()
prop PFalse      = mzero
prop (p1 :|| p2) = prop p1 `mplus` prop p2
prop (p1 :&& p2) = prop p1 >> prop p2
prop (Not p)     = prop (neg p)
  where
  neg PTrue       = PFalse
  neg PFalse      = PTrue
  neg (p1 :&& p2) = neg p1 :|| neg p2
  neg (p1 :|| p2) = neg p1 :&& neg p2
  neg (Not q)     = q
  neg (e1 :== e2) = e1 :/= e2
  neg (e1 :/= e2) = e1 :== e2
  neg (e1 :<  e2) = e1 :>= e2
  neg (e1 :<= e2) = e1 :>  e2
  neg (e1 :>  e2) = e1 :<= e2
  neg (e1 :>= e2) = e1 :<  e2

prop (e1 :== e2) = do t1 <- expr e1
                      t2 <- expr e2
                      solveIs0 (t1 |-| t2)

prop (e1 :/= e2)  = do t1 <- expr e1
                       t2 <- expr e2
                       let t = t1 |-| t2
                       solveIsNeg t `orElse` solveIsNeg (tNeg t)

prop (e1 :< e2)   = do t1 <- expr e1
                       t2 <- expr e2
                       solveIsNeg (t1 |-| t2)

prop (e1 :<= e2)  = do t1 <- expr e1
                       t2 <- expr e2
                       let t = t1 |-| t2 |-| tConst 1
                       solveIsNeg t

prop (e1 :> e2)   = prop (e2 :<  e1)
prop (e1 :>= e2)  = prop (e2 :<= e1)


expr :: Expr -> S Term
expr (e1 :+ e2)   = (|+|)   <$> expr e1 <*> expr e2
expr (e1 :- e2)   = (|-|)   <$> expr e1 <*> expr e2
expr (k  :* e2)   = (k |*|) <$> expr e2
expr (Negate e)   = tNeg    <$> expr e
expr (Var x)      = pure (tVar x)
expr (K x)        = pure (tConst x)
expr (If p e1 e2) = do x <- newVar
                       prop (p :&& Var x :== e1 :|| Not p :&& Var x :== e2)
                       return (tVar x)
expr (Div e k) = fst <$> exprDivMod e k
expr (Mod e k) = snd <$> exprDivMod e k

exprDivMod :: Expr -> Integer -> S (Term,Term)
exprDivMod e k =
  do guard (k /= 0) -- Always unsat
     q <- newVar
     r <- newVar
     let er = Var r
     prop (k :* Var q :+ er :== e :&& er :< K k :&& K 0 :<= er)
     return (tVar q, tVar r)





--------------------------------------------------------------------------------

data RW = RW { nameSource :: !Int
             , inerts     :: Inerts
             } deriving Show

initRW :: RW
initRW = RW { nameSource = 0, inerts = iNone }

--------------------------------------------------------------------------------
-- Constraints and Bound on Variables

ctLt :: Term -> Term -> Term
ctLt t1 t2 = t1 |-| t2

ctEq :: Term -> Term -> Term
ctEq t1 t2 = t1 |-| t2

data Bound      = Bound !Integer !Term  -- ^ The integer is strictly positive
                  deriving Show

data BoundType  = Lower | Upper
                  deriving Show

-- toCt :: BoundType -> Name -> Bound -> Term
-- toCt Lower x (Bound c t) = ctLt t              (c |*| tVar x)
-- toCt Upper x (Bound c t) = ctLt (c |*| tVar x) t

toCtK :: BoundType -> Int -> Bound -> Term
toCtK Lower k (Bound c t) = ctLt t (c |*| tVarK k)
toCtK Upper k (Bound c t) = ctLt (c |*| tVarK k) t

--------------------------------------------------------------------------------
-- Inert set

-- | The inert contains the solver state on one possible path.
data Inerts = Inerts
  { bounds :: !(NameMap ([Bound],[Bound]))
    -- ^ Known lower and upper bounds for variables.
    -- Each bound @(c,t)@ in the first list asserts that  @t < c * x@
    -- Each bound @(c,t)@ in the second list asserts that @c * x < t@

  , solved :: !(NameMap Term)
    -- ^ Definitions for resolved variables.
    -- These form an idempotent substitution.
  } deriving Show

ppInerts :: Inerts -> Doc
ppInerts is =
  vcat $
    [ppLower (keyName k) b | (k, (ls, _)) <- bnds, b <- ls]
      ++ [ppUpper (keyName k) b | (k, (_, us)) <- bnds, b <- us]
      ++ [ppEq (keyName k, t) | (k, t) <- IntM.toList (solved is)]
  where
    bnds = IntM.toList (bounds is)

    ppT c x = ppTerm (c |*| tVar x)
    ppLower x (Bound c t) = ppTerm t <+> text "<" <+> ppT c x
    ppUpper x (Bound c t) = ppT c x <+> text "<" <+> ppTerm t
    ppEq (x, t) = ppName x <+> text "=" <+> ppTerm t



-- | An empty inert set.
iNone :: Inerts
iNone = Inerts { bounds = IntM.empty
               , solved = IntM.empty
               }

-- | Rewrite a term using the definitions from an inert set.
iApSubst :: Inerts -> Term -> Term
iApSubst is (T n m) =
  IntM.foldlWithKey' step (T n IntM.empty) m
  where
    defs = solved is

    step :: Term -> Int -> Integer -> Term
    step acc k c
      | c == 0 = acc
      | otherwise =
          case IntM.lookup k defs of
            Nothing   -> addCoeffK k c acc
            Just defT -> addScaledTerm c defT acc

-- Add c * term into accumulator (fast; preserves your “no zero coeffs” invariant)
addScaledTerm :: Integer -> Term -> Term -> Term
addScaledTerm c (T n1 m1) (T n2 m2) =
  let !n'      = n2 + c * n1
      mScaled  = scaleMap c m1
      m' | IntM.null mScaled = m2
         | IntM.null m2      = mScaled
         | otherwise         = mergeAddDrop0 mScaled m2
  in T n' m'

-- Add a single coefficient for a variable key
addCoeffK :: Int -> Integer -> Term -> Term
addCoeffK k c (T n m)
  | c == 0 = T n m
  | otherwise =
      let m' = IntM.insertWith (+) k c m
      in case IntM.lookup k m' of
           Just 0  -> T n (IntM.delete k m')
           _       -> T n m'

-- | Add a definition.  Upper and lower bound constraints that mention
-- the variable are "kicked-out" so that they can be reinserted in the
-- context of the new knowledge.
--
--    * Assumes substitution has already been applied.
--
--    * The kicked-out constraints are NOT rewritten, this happens
--      when they get inserted in the work queue.

iSolved :: Name -> Term -> Inerts -> ([Term], Inerts)
iSolved x t i =
  ( kickedOut,
    Inerts
      { bounds = otherBounds,
        solved = IntM.insert kx t (IntM.map (tLetK kx t) (solved i))
      }
  )
  where
    kx :: Int
    kx = nameKey x

    (kickedOut, otherBounds) =
      let -- eliminate entry for x
          (mb, mp1) = IntM.updateLookupWithKey (\_ _ -> Nothing) kx (bounds i)

          -- eliminate constraints mentioning x in other bounds
          mp2 = IntM.mapWithKey extractBounds mp1
       in ( [ ct | (lbs, ubs) <- maybeToList mb, ct <- map (toCtK Lower kx) lbs ++ map (toCtK Upper kx) ubs
            ]
              ++ [ ct | (_, cts) <- IntM.elems mp2, ct <- cts
                 ],
            fmap fst mp2
          )

    extractBounds :: Int -> ([Bound], [Bound]) -> (([Bound], [Bound]), [Term])
    extractBounds ky (lbs, ubs) =
      let (lbsStay, lbsKick) = partition stay lbs
          (ubsStay, ubsKick) = partition stay ubs
       in ( (lbsStay, ubsStay),
            map (toCtK Lower ky) lbsKick
              ++ map (toCtK Upper ky) ubsKick
          )

    stay :: Bound -> Bool
    stay (Bound _ bnd) = not (tHasVarK kx bnd)



-- | Given some lower and upper bounds, find the interval the satisfies them.
-- Note the upper and lower bounds are strict (i.e., < and >)
boundInterval :: [Bound] -> [Bound] -> Maybe (Maybe Integer, Maybe Integer)
boundInterval lbs ubs =
  do ls <- mapM (normBound Lower) lbs
     us <- mapM (normBound Upper) ubs
     let lb = case ls of
                [] -> Nothing
                _  -> Just (maximum ls + 1)
         ub = case us of
                [] -> Nothing
                _  -> Just (minimum us - 1)
     case (lb,ub) of
       (Just l, Just u) -> guard (l <= u)
       _                -> return ()
     return (lb,ub)
  where
  normBound Lower (Bound c t) = do k <- isConst t
                                   return (div (k + c - 1) c)
  normBound Upper (Bound c t) = do k <- isConst t
                                   return (div k c)

data Solutions = Done
               | TopVar Name Integer (Maybe Integer) (Maybe Integer) Inerts
               | FixedVar Name Integer Solutions
                  deriving Show

slnCurrent :: Solutions -> [(Int,Integer)]
slnCurrent s = [ (x,v) | (UserName x, v) <- go s ]
  where
  go Done                = []
  go (TopVar x v _ _ is) = (x, v) : iModel (iLet x v is)
  go (FixedVar x v i)    = (x, v) : go i

-- | Replace occurances of a variable with an integer.
-- WARNING: The integer should be a valid value for the variable.
iLet :: Name -> Integer -> Inerts -> Inerts
iLet x v is = Inerts { bounds = fmap updBs (bounds is)
                     , solved = fmap (tLetNum x v) (solved is) }
  where
  updB (Bound c t) = Bound c (tLetNum x v t)
  updBs (ls,us)    = (map updB ls, map updB us)


startIter :: Inerts -> Solutions
startIter is =
  case IntM.maxViewWithKey (bounds is) of
    Nothing ->
      case IntM.maxViewWithKey (solved is) of
        Nothing -> Done
        Just ((kx, t), mp1) ->
          case tFirstVar t of
            Just y -> TopVar y 0 Nothing Nothing is
            Nothing ->
              let v = tConstPart t
               in TopVar (keyName kx) v (Just v) (Just v) $ is {solved = mp1}
    Just ((kx, (lbs, ubs)), mp1) ->
      let x = keyName kx
       in case firstVarInBounds lbs <|> firstVarInBounds ubs of
            Just y -> TopVar y 0 Nothing Nothing is
            Nothing ->
              case boundInterval lbs ubs of
                Nothing -> error "bug: cannot compute interval?"
                Just (lb, ub) ->
                  let v = fromMaybe 0 (mplus lb ub)
                   in TopVar x v lb ub $ is {bounds = mp1}


slnEnumerate :: Solutions -> [ Solutions ]
slnEnumerate s0 = go s0 []
  where
  go s k  = case slnNextVar s of
              Nothing -> hor s k
              Just s1 -> go s1 $ case slnNextVal s of
                                   Nothing -> k
                                   Just s2 -> go s2 k

  hor s k = s
          : case slnNextVal s of
              Nothing -> k
              Just s1 -> hor s1 k

slnNextVal :: Solutions -> Maybe Solutions
slnNextVal Done = Nothing
slnNextVal (FixedVar x v i) = FixedVar x v `fmap` slnNextVal i
slnNextVal it@(TopVar _ _ lb _ _) =
  case lb of
    Just _  -> slnNextValWith (+1) it
    Nothing -> slnNextValWith (subtract 1) it


slnNextValWith :: (Integer -> Integer) -> Solutions -> Maybe Solutions
slnNextValWith _ Done = Nothing
slnNextValWith f (FixedVar x v i) = FixedVar x v `fmap` slnNextValWith f i
slnNextValWith f (TopVar x v lb ub is) =
  do let v1 = f v
     case lb of
       Just l  -> guard (l <= v1)
       Nothing -> return ()
     case ub of
       Just u  -> guard (v1 <= u)
       Nothing -> return ()
     return $ TopVar x v1 lb ub is

slnNextVar :: Solutions -> Maybe Solutions
slnNextVar Done = Nothing
slnNextVar (TopVar x v _ _ is) = Just $ FixedVar x v $ startIter $ iLet x v is
slnNextVar (FixedVar x v i)    = FixedVar x v `fmap` slnNextVar i




-- Given a list of lower (resp. upper) bounds, compute the least (resp. largest)
-- value that satisfies them all.
iPickBounded :: BoundType -> [Bound] -> Maybe Integer
iPickBounded _ [] = Nothing
iPickBounded bt bs =
  do xs <- mapM (normBound bt) bs
     return $ case bt of
                Lower -> maximum xs
                Upper -> minimum xs
  where
  -- t < c*x
  -- <=> t+1 <= c*x
  -- <=> (t+1)/c <= x
  -- <=> ceil((t+1)/c) <= x
  -- <=> t `div` c + 1 <= x
  normBound Lower (Bound c t) = do k <- isConst t
                                   return (k `div` c + 1)
  -- c*x < t
  -- <=> c*x <= t-1
  -- <=> x   <= (t-1)/c
  -- <=> x   <= floor((t-1)/c)
  -- <=> x   <= (t-1) `div` c
  normBound Upper (Bound c t) = do k <- isConst t
                                   return (div (k-1) c)


-- | The largest (resp. least) upper (resp. lower) bound on a term
-- that will satisfy the model
iTermBound :: BoundType -> Term -> Inerts -> Maybe Integer
iTermBound bt (T k xs) is =
  IntM.foldlWithKey' step (Just k) xs
  where
    step acc kx c = do
      s <- acc
      v <- iVarBoundK (newBt c) kx is
      let !s' = s + c * v
      pure s'

    newBt c = if c > 0 then bt else case bt of
                                      Lower -> Upper
                                      Upper -> Lower


iVarBoundK :: BoundType -> Int -> Inerts -> Maybe Integer
iVarBoundK bt k is
  | Just t <- IntM.lookup k (solved is) = iTermBound bt t is
  | otherwise = do
      both <- IntM.lookup k (bounds is)
      case mapMaybe fromBound (chooseBounds both) of
        [] -> Nothing
        bs -> return (combineBounds bs)
  where
    fromBound (Bound c t) = fmap (scaleBound c) (iTermBound bt t is)

    combineBounds = case bt of
      Upper -> minimum
      Lower -> maximum

    chooseBounds = case bt of
      Upper -> snd
      Lower -> fst

    scaleBound c b = case bt of
      Upper -> div (b - 1) c
      Lower -> div b c + 1

-- | The largest (resp. least) upper (resp. lower) bound on a variable
-- that will satisfy the model.
-- iVarBound :: BoundType -> Name -> Inerts -> Maybe Integer
-- iVarBound bt x = iVarBoundK bt (nameKey x)

prependZeroVars :: Term -> [(Name, Integer)] -> [(Name, Integer)]
prependZeroVars (T _ m) su =
  IntM.foldrWithKey (\k _ acc -> (keyName k, 0) : acc) su m

iModel :: Inerts -> [(Name, Integer)]
iModel i = goBounds [] (bounds i)
  where
    goBounds su mp =
      case IntM.maxViewWithKey mp of
        Nothing ->
          IntM.foldlWithKey' goEq su (solved i)
        Just ((kx, (lbs0, ubs0)), mp1) ->
          let x = keyName kx
              lbs = [Bound c (tLetNums su t) | Bound c t <- lbs0]
              ubs = [Bound c (tLetNums su t) | Bound c t <- ubs0]
              sln =
                fromMaybe 0 $
                  mplus (iPickBounded Lower lbs) (iPickBounded Upper ubs)
           in goBounds ((x, sln) : su) mp1

    goEq su kx t =
      let x = keyName kx
          t1 = tLetNums su t
          su' = prependZeroVars t1 ((x, tConstPart t1) : su)
       in su'

--------------------------------------------------------------------------------
-- Solving constraints

solveIs0 :: Term -> S ()
solveIs0 t = solveIs0' =<< apSubst t

-- | Solve a constraint if the form @t = 0@.
-- Assumes substitution has already been applied.
solveIs0' :: Term -> S ()
solveIs0' t

  -- A == 0
  | Just a <- isConst t = guard (a == 0)

  -- A + B * x = 0
  | Just (a,b,x) <- tIsOneVar t =
    case divMod (-a) b of
      (q,0) -> addDef x (tConst q)
      _     -> mzero

  --  x + S = 0
  -- -x + S = 0
  | Just (xc,x,s) <- tGetSimpleCoeff t =
    addDef x (if xc > 0 then tNeg s else s)

  -- A * S = 0
  | Just (_, s) <- tFactor t  = solveIs0 s

  -- See Section 3.1 of paper for details.
  -- We obtain an equivalent formulation but with smaller coefficients.
  | Just (ak,xk,s) <- tLeastAbsCoeff t =
      do let m = abs ak + 1
         v <- newVar
         let sgn  = signum ak
             soln =     (negate sgn * m) |*| tVar v
                    |+| tMapCoeff (\c -> sgn * modulus c m) s
         addDef xk soln

         let upd i = div (2*i + m) (2*m) + modulus i m
         solveIs0 (negate (abs ak) |*| tVar v |+| tMapCoeff upd s)

  | otherwise = error "solveIs0: unreachable"

modulus :: Integer -> Integer -> Integer
modulus a m = a - m * div (2 * a + m) (2 * m)


solveIsNeg :: Term -> S ()
solveIsNeg t = solveIsNeg' =<< apSubst t


-- | Solve a constraint of the form @t < 0@.
-- Assumes that substitution has been applied
solveIsNeg' :: Term -> S ()
solveIsNeg' t
  -- A < 0
  | Just a <- isConst t = guard (a < 0)
  -- A * S < 0
  | Just (_, s) <- tFactor t = solveIsNeg s
  -- See Section 5.1 of the paper
  | Just (xc, x, s) <- tLeastVar t = do
      ctrs <-
        if xc < 0
          -- -XC*x + S < 0
          -- S < XC*x
          then do
            ubs <- getBounds Upper x
            let b = negate xc
                beta = s
            addBound Lower x (Bound b beta)
            return [(a, alpha, b, beta) | Bound a alpha <- ubs]
          -- XC*x + S < 0
          -- XC*x < -S
          else do
            lbs <- getBounds Lower x
            let a = xc
                alpha = tNeg s
            addBound Upper x (Bound a alpha)
            return [(a, alpha, b, beta) | Bound b beta <- lbs]

      -- See Note [Shadows]
      mapM_
        ( \(a, alpha, b, beta) -> do
            let real = ctLt (a |*| beta) (b |*| alpha)
                dark = ctLt (tConst (a * b)) (b |*| alpha |-| a |*| beta)

                -- Build: (((dark OR gray1) OR gray2) ... OR gray(b-1))
                -- but WITHOUT allocating the gray list.
                grayOrDark :: S ()
                grayOrDark =
                    solveIsNeg dark `orElse` grayRange 1 (b - 1)
                    where
                      grayAt :: Integer -> S ()
                      grayAt i =
                        let eqi = ctEq (b |*| tVar x) (tConst i |+| beta)
                        in solveIs0 eqi
                      grayRange :: Integer -> Integer -> S ()
                      grayRange lo hi
                        | lo > hi = mzero
                        | lo == hi = grayAt lo
                        | otherwise =
                            let mid = (lo + hi) `div` 2
                            in grayRange lo mid `orElse` grayRange (mid + 1) hi

            solveIsNeg real
            grayOrDark
        )
        ctrs
  | otherwise = error "solveIsNeg: unreachable"

orElse :: S () -> S () -> S ()
orElse = mplus

{- Note [Shadows]

  P: beta < b * x
  Q: a * x < alpha

real: a * beta < b * alpha

  beta     < b * x      -- from P
  a * beta < a * b * x  -- (a *)
  a * beta < b * alpha  -- comm. and Q


dark: b * alpha - a * beta > a * b


gray: b * x = beta + 1 \/
      b * x = beta + 2 \/
      ...
      b * x = beta + (b-1)

We stop at @b - 1@ because if:

> b * x                >= beta + b
> a * b * x            >= a * (beta + b)     -- (a *)
> a * b * x            >= a * beta + a * b   -- distrib.
> b * alpha            >  a * beta + a * b   -- comm. and Q
> b * alpha - a * beta > a * b               -- subtract (a * beta)

which is covered by the dark shadow.
-}


--------------------------------------------------------------------------------
-- Monads

data Answer a = None | One a | Choice (Answer a) (Answer a)
                deriving Show

-- instance (NFData a) => NFData (Answer a) where
--   rnf None = ()
--   rnf (One a) = rnf a
--   rnf (Choice l r) = rnf l `seq` rnf r


answerSize :: Answer a -> (Integer,Integer,Integer)
answerSize = go 0 0 0
  where
  go !n !o !c ans =
    case ans of
      None  -> (n+1, o, c)
      One _ -> (n, o + 1, c)
      Choice x y ->
        case go n o (c+1) x of
          (n',o',c') -> go n' o' c' y


dotAnswer :: (a -> Doc) -> Answer a -> Doc
dotAnswer pp g0 = vcat [text "digraph {", nest 2 (fst $ go 0 g0), text "}"]
  where
  node x d            = integer x <+> brackets (text "label=" <+> text (show d))
                                                              <+> semi
  edge x y            = integer x <+> text "->" <+> integer y

  go x None           = let x' = x + 1
                        in seq x' ( node x "", x' )
  go x (One a)        = let x' = x + 1
                        in seq x' ( node x (show (pp a)), x' )
  go x (Choice c1 c2) = let x'       = x + 1
                            (ls1,x1) = go x' c1
                            (ls2,x2) = go x1    c2
                        in seq x'
                           ( vcat [ node x "|"
                                  , edge x x'
                                  , edge x x1
                                  , ls1
                                  , ls2
                                  ], x2 )
toList :: Answer a -> [a]
toList a = go a []
  where
  go (Choice xs ys) zs = go xs (go ys zs)
  go (One x) xs        = x : xs
  go None xs           = xs


instance Monad Answer where
  None >>= _ = None
  One a >>= k = k a
  Choice m1 m2 >>= k = Choice (m1 >>= k) (m2 >>= k)


instance Fail.MonadFail Answer where
  fail _ = None

instance Alternative Answer where
  empty = mzero
  (<|>) = mplus

instance MonadPlus Answer where
  mzero = None
  mplus = Choice

instance Functor Answer where
  fmap _ None           = None
  fmap f (One x)        = One (f x)
  fmap f (Choice x1 x2) = Choice (fmap f x1) (fmap f x2)

instance Applicative Answer where
  pure = One
  (<*>)  = ap


newtype S a = S (RW -> Answer (a,RW))

instance Monad S where
  S m >>= k     = S $ \s -> do (a,s1) <- m s
                               let S m1 = k a
                               m1 s1

instance Alternative S where
  empty = mzero
  (<|>) = mplus

instance MonadPlus S where
  mzero               = S $ const mzero
  mplus (S m1) (S m2) = S $ \s -> mplus (m1 s) (m2 s)

instance Functor S where
  fmap = liftM

instance Applicative S where
  pure a      = S $ \s -> return (a,s)
  (<*>) = ap

updS :: (RW -> (a,RW)) -> S a
updS f = S $ \s -> return (f s)

updS_ :: (RW -> RW) -> S ()
updS_ f = updS $ \rw -> ((),f rw)

get :: (RW -> a) -> S a
get f = updS $ \rw -> (f rw, rw)

newVar :: S Name
newVar = updS $ \rw -> ( SysName (nameSource rw)
                       , rw { nameSource = nameSource rw + 1 }
                       )

-- | Get lower ('fst'), or upper ('snd') bounds for a variable.
getBounds :: BoundType -> Name -> S [Bound]
getBounds f x = get $ \rw -> case IntM.lookup (nameKey x) $ bounds $ inerts rw of
                               Nothing -> []
                               Just bs -> case f of
                                            Lower -> fst bs
                                            Upper -> snd bs

-- | Normalize a bound to have coefficient 1 when the term is constant.
-- This is crucial for performance: bounds with coefficient c > 1 generate
-- c-1 "gray shadow" alternatives during shadow computation, causing
-- exponential blowup.
--
-- For Lower bound "t < c*x": the tightest integer bound is ceil(t/c) < x,
-- which is equivalent to floor(t/c) < x (i.e., Bound 1 (floor(t/c)))
--
-- For Upper bound "c*x < t": the tightest integer bound is x < ceil(t/c),
-- which means x <= floor((t-1)/c), so x < floor((t-1)/c) + 1
normalizeBound :: BoundType -> Bound -> Bound
normalizeBound _ (Bound 1 t) = Bound 1 t
normalizeBound bt (Bound c t)
  | Just k <- isConst t =
      case bt of
        -- t < c*x  means  x > t/c  means  x >= ceil((t+1)/c) = floor(t/c) + 1
        -- So floor(t/c) < x, i.e., Bound 1 (floor(t/c))
        Lower -> Bound 1 (tConst (div k c))
        -- c*x < t  means  x < t/c  means  x <= floor((t-1)/c)
        -- So x < floor((t-1)/c) + 1, i.e., Bound 1 (floor((t-1)/c) + 1)
        Upper -> Bound 1 (tConst (div (k - 1) c + 1))
  | otherwise = Bound c t

addBound :: BoundType -> Name -> Bound -> S ()
addBound bt x b = do
  -- Normalize the bound to coefficient 1 when possible
  let b' = normalizeBound bt b
  updS_ $ \rw ->
    let i = inerts rw
        entry = case bt of
                  Lower -> ([b'],[])
                  Upper -> ([],[b'])
        jn (newL,newU) (oldL,oldU) = (newL++oldL, newU++oldU)
    in rw { inerts = i { bounds = IntM.insertWith jn (nameKey x) entry (bounds i) }}
  -- Early conflict detection: if all bounds are now constant, check feasibility
  checkBoundsFeasibility x

-- | Check if bounds on a variable are feasible. Only checks when all bounds
-- are constant (no variables in bound terms). This catches conflicts early
-- before expensive shadow computations.
checkBoundsFeasibility :: Name -> S ()
checkBoundsFeasibility x = do
  lbs <- getBounds Lower x
  ubs <- getBounds Upper x
  -- Only check if we have both lower and upper bounds
  case (lbs, ubs) of
    ([], _) -> return ()
    (_, []) -> return ()
    _ -> do
      -- Check if all bounds are constant
      let allConst = all isConstBound lbs && all isConstBound ubs
      when allConst $ do
        case boundInterval lbs ubs of
          Nothing -> mzero  -- Infeasible! Fail immediately
          Just _  -> return ()
  where
    isConstBound (Bound _ t) = case isConst t of
                                 Just _  -> True
                                 Nothing -> False

-- | Add a new definition.
-- Assumes substitution has already been applied
addDef :: Name -> Term -> S ()
addDef x t =
  do newWork <- updS $ \rw -> let (newWork,newInerts) = iSolved x t (inerts rw)
                              in (newWork, rw { inerts = newInerts })
     mapM_ solveIsNeg newWork

apSubst :: Term -> S Term
apSubst t =
  do i <- get inerts
     return (iApSubst i t)




--------------------------------------------------------------------------------


data Name = UserName !Int | SysName !Int
            deriving (Read,Show,Eq,Ord)

ppName :: Name -> Doc
ppName (UserName x) = text "u" <+> int x
ppName (SysName x)  = text "s" <+> int x

toName :: Int -> Name
toName = UserName

fromName :: Name -> Maybe Int
fromName (UserName x) = Just x
fromName (SysName _)  = Nothing



type NameKey = Int

nameKey :: Name -> NameKey
nameKey (UserName i) = 2 * i
nameKey (SysName i) = 2 * i + 1

keyName :: Int -> Name
keyName k
  | even k = UserName (k `div` 2)
  | otherwise = SysName ((k - 1) `div` 2)

type NameMap = IntM.IntMap
type CoeffMap = IntM.IntMap Integer

-- | The type of terms.  The integer is the constant part of the term,
-- and the `Map` maps variables (represented by @Int@ to their coefficients).
-- The term is a sum of its parts.
-- INVARIANT: the `Map` does not map anything to 0.
data Term = T !Integer !CoeffMap
              deriving (Eq,Ord)

infixl 6 |+|, |-|
infixr 7 |*|

-- | A constant term.
tConst :: Integer -> Term
tConst k = T k IntM.empty

-- | Construct a term with a single variable.
tVar :: Name -> Term
tVar x = T 0 (IntM.singleton (nameKey x) 1)

tVarK :: Int -> Term
tVarK k = T 0 (IntM.singleton k 1)

tFirstVar :: Term -> Maybe Name
tFirstVar (T _ m) =
 case IntM.minViewWithKey m of
       Nothing -> Nothing
       Just ((k,_),_) -> Just (keyName k)

firstVarInBounds :: [Bound] -> Maybe Name
firstVarInBounds = go
  where
    go [] = Nothing
    go (Bound _ t : more) =
      case tFirstVar t of
        Just v -> Just v
        Nothing -> go more

(|+|) :: Term -> Term -> Term
T n1 m1 |+| T n2 m2 = T (n1 + n2)
                    $ if IntM.null m1 then m2 else
                      if IntM.null m2 then m1 else
                      IntM.mergeWithKey
                        (\_ a b -> let c = a + b in if c == 0 then Nothing else Just c)
                        id
                        id
                        m1 m2

(|*|) :: Integer -> Term -> Term
0 |*| _     = tConst 0
1 |*| t     = t
k |*| T n m = T (k * n) (IntM.map (k *) m)

tNeg :: Term -> Term
tNeg t = (-1) |*| t

(|-|) :: Term -> Term -> Term
t1 |-| t2 = t1 |+| tNeg t2

mergeAddDrop0 :: CoeffMap -> CoeffMap -> CoeffMap
mergeAddDrop0 =
  IntM.mergeWithKey
    (\_ a b -> let c = a + b in if c == 0 then Nothing else Just c)
    id
    id

scaleMap :: Integer -> CoeffMap -> CoeffMap
scaleMap 1 m = m
scaleMap (-1) m = IntM.map negate m
scaleMap k m = IntM.map (k *) m

tLetK :: Int -> Term -> Term -> Term
tLetK kx (T n1 m1) (T n2 m2) =
  case IntM.updateLookupWithKey (\_ _ -> Nothing) kx m2 of
    (Nothing, _) -> T n2 m2
    (Just a, mRest) ->
      let !n' = n2 + a * n1
          mScaled = scaleMap a m1
          !m'
            | IntM.null mScaled = mRest
            | IntM.null mRest = mScaled
            | otherwise = mergeAddDrop0 mScaled mRest
       in T n' m'


-- | Replace a variable with a term.
-- tLet :: Name -> Term -> Term -> Term
-- tLet x = tLetK (nameKey x)

tLetNumK :: Int -> Integer -> Term -> Term
tLetNumK kx v (T n m) =
  case IntM.updateLookupWithKey (\_ _ -> Nothing) kx m of
    (Nothing, _) -> T n m
    (Just c, m1) ->
      let !n' = n + c * v
       in T n' m1

-- | Replace a variable with a constant.
tLetNum :: Name -> Integer -> Term -> Term
tLetNum x = tLetNumK (nameKey x)

-- | Replace the given variables with constants.
tLetNums :: [(Name,Integer)] -> Term -> Term
tLetNums xs t = foldr (\(x,i) t1 -> tLetNum x i t1) t xs




instance Show Term where
  showsPrec c t = showsPrec c (show (ppTerm t))

ppTerm :: Term -> Doc
ppTerm (T k m) =
  case IntM.toList m of
    [] -> integer k
    xs | k /= 0 -> hsep (integer k : map ppProd xs)
    x : xs -> hsep (ppFst x : map ppProd xs)
  where
    ppFst (k', 1) = ppName (keyName k')
    ppFst (k', -1) = text "-" <+> ppName (keyName k')
    ppFst (k', n) = ppMul n (keyName k')

    ppProd (k', 1) = text "+" <+> ppName (keyName k')
    ppProd (k', -1) = text "-" <+> ppName (keyName k')
    ppProd (k', n)
      | n > 0 = text "+" <+> ppMul n (keyName k')
      | otherwise = text "-" <+> ppMul (abs n) (keyName k')

    ppMul n x = integer n <+> text "*" <+> ppName x

-- | Remove a variable from the term, and return its coefficient.
-- If the variable is not present in the term, the coefficient is 0.
-- tSplitVar :: Name -> Term -> (Integer, Term)
-- tSplitVar x t@(T n m) =
--   case IntM.updateLookupWithKey (\_ _ -> Nothing) (nameKey x) m of
--     (Nothing,_) -> (0,t)
--     (Just k,m1) -> (k, T n m1)

-- | Does the term contain this varibale?
-- tHasVar :: Name -> Term -> Bool
-- tHasVar x (T _ m) = IntM.member (nameKey x) m

tHasVarK :: Int -> Term -> Bool
tHasVarK k (T _ m) = IntM.member k m

-- | Is this terms just an integer.
isConst :: Term -> Maybe Integer
isConst (T n m)
  | IntM.null m  = Just n
  | otherwise   = Nothing

tConstPart :: Term -> Integer
tConstPart (T n _) = n

-- | Returns: @Just (a, b, x)@ if the term is the form: @a + b * x@
tIsOneVar :: Term -> Maybe (Integer, Integer, Name)
tIsOneVar (T a m) = case IntM.minViewWithKey m of
                      Nothing         -> Nothing
                      Just ((k,b), m1)
                        | IntM.null m1 -> Just (a, b, keyName k)
                        | otherwise -> Nothing

-- | Spots terms that contain variables with unit coefficients
-- (i.e., of the form @x + t@ or @t - x@).
-- Returns (coeff, var, rest of term)
tGetSimpleCoeff :: Term -> Maybe (Integer, Name, Term)
tGetSimpleCoeff (T a m) = do
  let (mUnit, mOther) = IntM.partition (\xc -> xc == 1 || xc == -1) m
  ((k, xc), mUnit') <- IntM.minViewWithKey mUnit
  pure (xc, keyName k, T a (IntM.union mUnit' mOther))

-- tVarList :: Term -> [Name]
-- tVarList (T _ m) = map keyName (IntM.keys m)


-- | Try to factor-out a common consant (> 1) from a term.
-- For example, @2 + 4x@ becomes @2 * (1 + 2x)@.
tFactor :: Term -> Maybe (Integer, Term)
tFactor (T c m) =
  let d0 = abs c
      d = IntM.foldl' (\acc xc -> gcd acc (abs xc)) d0 m
   in if d <= 1
        then Nothing
        else Just (d, T (c `div` d) (IntM.map (`div` d) m))

-- | Extract a variable with a coefficient whose absolute value is minimal.
tLeastAbsCoeff :: Term -> Maybe (Integer, Name, Term)
tLeastAbsCoeff (T c m)
  | IntM.null m = Nothing
  | otherwise =
      let -- pick an initial key/coeff
          (k0, xc0) =
            case IntM.minViewWithKey m of
              Just ((k, xc), _) -> (k, xc)
              Nothing -> error "impossible"

          step (!kBest', !cBest) k xc
            | abs xc < abs cBest = (k, xc)
            | otherwise = (kBest', cBest)

          (kBest, xcBest) = IntM.foldlWithKey' step (k0, xc0) m
          m1 = IntM.delete kBest m
       in Just (xcBest, keyName kBest, T c m1)

-- | Extract the least variable from a term
tLeastVar :: Term -> Maybe (Integer, Name, Term)
tLeastVar (T c m) = do
  ((k, xc), m1) <- IntM.minViewWithKey m
  pure (xc, keyName k, T c m1)

-- | Apply a function to all coefficients, including the constnat
tMapCoeff :: (Integer -> Integer) -> Term -> Term
tMapCoeff f (T c m) = T (f c) (IntM.map f m)