module Main (main) where

import Control.Monad (forM_)
import Data.Integer.SAT
  ( Expr (..),
    Prop (..),
    assert,
    checkSat,
    noProps,
    toName,
  )
import Data.List (sortBy)
import System.CPUTime (getCPUTime)
import System.Environment (getArgs)
import Text.Printf (printf)

newtype RNG = RNG Integer

mkRNG :: Int -> RNG
mkRNG seed = RNG (fromIntegral seed)

randInt :: Integer -> Integer -> RNG -> (Integer, RNG)
randInt lo hi (RNG s) =
  let s' = (s * 1103515245 + 12345) `mod` (2 ^ 31)
      range = hi - lo + 1
      val = lo + (s' `mod` range)
   in (val, RNG s')

randInts :: Int -> Integer -> Integer -> RNG -> ([Integer], RNG)
randInts 0 _ _ rng = ([], rng)
randInts n lo hi rng =
  let (x, rng') = randInt lo hi rng
      (xs, rng'') = randInts (n - 1) lo hi rng'
   in (x : xs, rng'')

shuffle :: [a] -> RNG -> ([a], RNG)
shuffle [] rng = ([], rng)
shuffle xs rng = go (length xs) xs rng
  where
    go 0 xs' rng' = (xs', rng')
    go _ [] rng' = ([], rng')
    go n xs' rng' =
      let (i, rng'') = randInt 0 (fromIntegral n - 1) rng'
          (before, x : after) = splitAt (fromIntegral i) xs'
       in let (rest, rng''') = go (n - 1) (before ++ after) rng''
           in (x : rest, rng''')

-- Benchmark Family 1: Weighted Ternary Assignment

mkTernaryProp :: Int -> Int -> Double -> Prop
mkTernaryProp n seed difficulty =
  let rng = mkRNG seed
      xs = [Var (toName i) | i <- [0 .. n - 1]]

      ternary v = (v :== K 0) :|| (v :== K 1) :|| (v :== K 2)
      ternaryConstraints = foldr (:&&) PTrue (map ternary xs)

      sumTarget = fromIntegral n
      sumConstraint = foldr1 (:+) xs :== K sumTarget

      (weights, _) = randInts n 1 (fromIntegral (2 * n)) rng
      weightedSum = foldr1 (:+) [(weights !! i) :* (xs !! i) | i <- [0 .. n - 1]]

      sortedW = sortBy compare weights
      minWsum = sum (take n sortedW)
      maxWsum = sum (take n (reverse sortedW))
      range = maxWsum - minWsum
      wsumTarget
        | difficulty < 0.3 = minWsum - 1 - floor (fromIntegral range * (0.3 - difficulty))
        | otherwise = minWsum + floor (fromIntegral range * (difficulty - 0.3))
      wsumConstraint = weightedSum :== K wsumTarget
   in ternaryConstraints :&& sumConstraint :&& wsumConstraint

-- Benchmark Family 2: Prime Subset Sum

primes50 :: [Integer]
primes50 =
  [ 2,
    3,
    5,
    7,
    11,
    13,
    17,
    19,
    23,
    29,
    31,
    37,
    41,
    43,
    47,
    53,
    59,
    61,
    67,
    71,
    73,
    79,
    83,
    89,
    97,
    101,
    103,
    107,
    109,
    113,
    127,
    131,
    137,
    139,
    149,
    151,
    157,
    163,
    167,
    173,
    179,
    181,
    191,
    193,
    197,
    199,
    211,
    223,
    227,
    229
  ]

mkSubsetSumProp :: Int -> Int -> Double -> Prop
mkSubsetSumProp n seed difficulty =
  let rng = mkRNG seed
      xs = [Var (toName i) | i <- [0 .. n - 1]]

      bit v = (v :== K 0) :|| (v :== K 1)
      bitConstraints = foldr (:&&) PTrue (map bit xs)

      countTarget = fromIntegral (n `div` 2)
      countConstraint = foldr1 (:+) xs :== K countTarget

      baseWeights = take n primes50
      (weights, _) = shuffle baseWeights rng
      weightedSum = foldr1 (:+) [(weights !! i) :* (xs !! i) | i <- [0 .. n - 1]]

      k = fromIntegral countTarget
      sortedW = sortBy compare weights
      minSum = sum (take (fromIntegral k) sortedW)
      maxSum = sum (take (fromIntegral k) (reverse sortedW))
      range = maxSum - minSum
      sumTarget
        | difficulty < 0.3 = minSum - 1 - floor (fromIntegral range * (0.3 - difficulty) * 0.5)
        | otherwise = minSum + floor (fromIntegral range * (difficulty - 0.3))
      sumConstraint = weightedSum :== K sumTarget
   in bitConstraints :&& countConstraint :&& sumConstraint

-- Benchmark Suite

data Benchmark = Benchmark
  { benchName :: String,
    benchProp :: Prop
  }

genBenchmarkSuite :: Int -> [Benchmark]
genBenchmarkSuite baseN =
  let
      ternaryBenches =
        [ Benchmark
            (printf "ternary_n%d_s%d_d%.1f" baseN seed diff)
            (mkTernaryProp baseN seed diff)
          | seed <- [1 .. 3],
            diff <- [0.1, 0.5, 0.9]
        ]

      subsetN = baseN + 10
      subsetBenches =
        [ Benchmark
            (printf "subset_n%d_s%d_d%.1f" subsetN seed diff)
            (mkSubsetSumProp subsetN seed diff)
          | seed <- [1 .. 3],
            diff <- [0.2, 0.8]
        ]
   in ternaryBenches ++ subsetBenches

runBenchmark :: Benchmark -> IO (Double, Bool)
runBenchmark bench = do
  start <- getCPUTime
  let result = checkSat (assert (benchProp bench) noProps)
      isSat = case result of
        Nothing -> False
        Just _ -> True
  isSat `seq` return ()
  end <- getCPUTime
  let elapsed = fromIntegral (end - start) / 1e12 :: Double
  return (elapsed, isSat)

main :: IO ()
main = do
  args <- getArgs
  let baseN = case args of
        [n] -> read n
        [] -> 17
        _ -> error "Usage: bench [baseN]"

  let suite = genBenchmarkSuite baseN
  putStrLn $ printf "Running %d benchmarks (baseN=%d)...\n" (length suite) baseN

  forM_ suite $ \bench -> do
    putStr $ printf "%-35s " (benchName bench)
    (elapsed, isSat) <- runBenchmark bench
    let satStr = if isSat then "SAT  " else "UNSAT"
    putStrLn $ printf "%s  %8.3fs" satStr elapsed