module Main where

import Data.List (permutations)
import Data.List.Split (chunksOf)
import Enigma
    ( EnigmaConfig(..),
      OrientedRotor(OrientedRotor),
      Rotor,
      getWiring,
      calculateIC,
      cipher,
      combinations,
      normalize )
import Enigma.Static (m3RotorSet, makeOriented, refIC)
import Options.Applicative
  ( Parser,
    ParserInfo,
    argument,
    command,
    helper,
    info,
    metavar,
    progDesc,
    str,
    subparser,
    prefs,
    showHelpOnEmpty,
    customExecParser,
  )
import Control.Parallel.Strategies
    ( parBuffer,
      parList,
      parListChunk,
      rdeepseq,
      rpar,
      rseq,
      runEval,
      using,
      withStrategy,
      Eval,
      Strategy )
import Control.DeepSeq ( force )

type FileName = String

data DecryptStrategy = Sequential | ParBuffer Int | RDeepSeq | RSeq | RPar | SixWay | Chunks | BufferChunks

data Command
  = Encrypt FileName
  | Decrypt DecryptStrategy FileName

makeParBuffer :: String -> Command
makeParBuffer s = Decrypt (ParBuffer n) fn
  where 
   [sizeStr, fn] = 
     case words s of
      a@[_, _] -> a
      _ -> error "Missing size or file argument"
   n = read sizeStr ::Int

withInfo :: Parser a -> String -> ParserInfo a
withInfo opts desc = info (helper <*> opts) $ progDesc desc

parseEncrypt :: Parser Command
parseEncrypt = Encrypt <$> argument str (metavar "filename")

parseDecryptSequential :: Parser Command
parseDecryptSequential = Decrypt Sequential <$> argument str (metavar "filename")

parseDecryptParBuffer :: Parser Command 
parseDecryptParBuffer = makeParBuffer <$> argument str (metavar "bufferSize filename")

parseDecryptRDeepSeq :: Parser Command 
parseDecryptRDeepSeq = Decrypt RDeepSeq <$> argument str (metavar "filename")

parseDecryptRSeq :: Parser Command 
parseDecryptRSeq = Decrypt RSeq <$> argument str (metavar "filename")

parseDecryptRPar :: Parser Command 
parseDecryptRPar = Decrypt RPar <$> argument str (metavar "filename")

parseDecryptSixWay :: Parser Command
parseDecryptSixWay = Decrypt SixWay <$> argument str (metavar "filename")

parseDecryptChunks :: Parser Command
parseDecryptChunks = Decrypt Chunks <$> argument str (metavar "filename")

parseDecryptBufferChunks :: Parser Command
parseDecryptBufferChunks = Decrypt BufferChunks <$> argument str (metavar "filename")

parseCommand :: Parser Command
parseCommand =
  subparser $
    command "encrypt" (parseEncrypt `withInfo` "Encrypt the file")
      <> command "decryptSequential" (parseDecryptSequential `withInfo` "Decrypt the file sequentially")
      <> command "decryptParBuffer" (parseDecryptParBuffer `withInfo` "Decrypt the file with a parBuffer limited to the given size; use quotes: e.g. \"1000 /path/to/file\"")
      <> command "decryptRDeepSeq" (parseDecryptRDeepSeq `withInfo` "Decrypt the file with parList rDeepSeq")
      <> command "decryptRSeq" (parseDecryptRSeq `withInfo` "Decrypt the file with parList rSeq")
      <> command "decryptRPar" (parseDecryptRPar `withInfo` "Decrypt the file with parList rPar")
      <> command "decryptSixWay" (parseDecryptSixWay `withInfo` "Decrypt the file with a static six-way partition")
      <> command "decryptChunks" (parseDecryptChunks `withInfo` "Decrypt the file with parListChunk 1000 rdeepseq")
      <> command "decryptBufferChunks" (parseDecryptChunks `withInfo` "Decrypt the file in chunks of 1000 with parBuffer 128")


parseOptions :: Parser Command
parseOptions = parseCommand

main :: IO ()
main = run =<< customExecParser (prefs showHelpOnEmpty) (parseCommand `withInfo` "")

defaultConfig :: EnigmaConfig
defaultConfig =
  EnigmaConfig
    { reflector = Enigma.getWiring "refB",
      rotors =
        [ makeOriented "III" 'K' 'A',
          makeOriented "II" 'D' 'A',
          makeOriented "I" 'O' 'A'
        ],
      plugboard = []
    }

run :: Command -> IO ()
run cmd = case cmd of
  Encrypt fn -> do
    contents <- readFile fn
    let normalized = Enigma.normalize contents
    putStrLn $ Enigma.cipher defaultConfig normalized
  Decrypt strat fn -> do 
    contents <- readFile fn
    let ctext = normalize contents
    case strat of 
      Sequential -> print $ sequentialDecrypt ctext
      ParBuffer n -> print $ parListDecrypt (parBuffer n rdeepseq) ctext
      RDeepSeq -> print $ parListDecrypt (parList rdeepseq) ctext
      RSeq -> print $ parListDecrypt (parList rseq) ctext
      RPar -> print $ parListDecrypt (parList rpar) ctext
      SixWay -> print $ sixWayDecrypt ctext
      Chunks -> print $ snd . minimum $ parChunks allPermutations ctext
      BufferChunks -> print $ snd . minimum $ bufferChunks ctext

icDistance :: (Foldable t, Ord k) => Double -> t k -> Double
icDistance target plaintext = abs (target - Enigma.calculateIC plaintext)

m3RotorPermutations :: [[Rotor]]
m3RotorPermutations = concatMap permutations $ combinations 3 m3RotorSet

-- Generate all permutations of oriented rotors to test during decryption
allPermutations :: [[OrientedRotor]]
allPermutations = do
  rots <- m3RotorPermutations
  lt <- ['A' .. 'Z']
  mt <- ['A' .. 'Z']
  rt <- ['A' .. 'Z']
  return $ zipWith3 OrientedRotor rots [lt, mt, rt] (repeat 'A')

-- Like solve, but also calculates the IC
solveIC :: [Char] -> [OrientedRotor] -> (Double, String)
solveIC msg rot = (icDistance refIC ptext, ptext)
  where
    cfg = EnigmaConfig {reflector = getWiring "refB", rotors = rot, plugboard = []}
    ptext = cipher cfg msg

parChunks :: [[OrientedRotor]] -> [Char] -> [(Double, String)]
parChunks rots ctext = map (solveIC ctext) rots `using` parListChunk 1000 rdeepseq

bufferChunks :: [Char] -> [(Double, String)]
bufferChunks ctext = concat $ withStrategy (parBuffer 128 rdeepseq) (map (map (solveIC ctext)) chunks)
  where chunks = chunksOf 1000 allPermutations

sequentialDecrypt :: String -> String 
sequentialDecrypt msg = snd $ 
  minimum $ map (\c -> (icDistance refIC c, c)) $ do 
    rotorCfg <- allPermutations
    let cfg = EnigmaConfig {reflector = getWiring "refB", rotors = rotorCfg, plugboard = []}
    [cipher cfg msg]

solve :: [[OrientedRotor]] -> String -> [String]
solve rotorCfgs msg = do 
  rotor <- rotorCfgs
  let cfg = EnigmaConfig {reflector = getWiring "refB", rotors = rotor, plugboard = []}
  [cipher cfg msg]

sixWayDecrypt :: String -> String
sixWayDecrypt msg = do
  let [as, bs, cs, ds, es, fs] = chunksOf ((length allPermutations `div` 6) + 1) allPermutations
      solutions = runEval $ do
        as' <- rpar (force $ minimum $ map (\c -> (icDistance refIC c, c)) $ solve as msg)
        bs' <- rpar (force $ minimum $ map (\c -> (icDistance refIC c, c)) $ solve bs msg)
        cs' <- rpar (force $ minimum $ map (\c -> (icDistance refIC c, c)) $ solve cs msg)
        ds' <- rpar (force $ minimum $ map (\c -> (icDistance refIC c, c)) $ solve ds msg)
        es' <- rpar (force $ minimum $ map (\c -> (icDistance refIC c, c)) $ solve es msg)
        fs' <- rpar (force $ minimum $ map (\c -> (icDistance refIC c, c)) $ solve fs msg)
        _ <- rseq as'
        _ <- rseq bs'
        _ <- rseq cs'
        _ <- rseq ds'
        _ <- rseq es'
        _ <- rseq fs'
        return [as', bs', cs', ds', es', fs']
  snd $ minimum solutions

parMapStrategy :: Strategy [b] -> (a -> b) -> [a] -> [b]
parMapStrategy strat f xs = map f xs `using` strat

parListDecrypt :: ([String] -> Eval [String]) -> [Char] -> [Char]
parListDecrypt strategy msg = 
  snd $ minimum $ map (\c -> (icDistance refIC c, c)) solutions
  where solutions = parMapStrategy strategy (\cfg -> cipher (EnigmaConfig {reflector = getWiring "refB", rotors = cfg, plugboard = []}) msg) allPermutations 
