module Enigma.Internal where

import Data.Array (Array)
import qualified Data.Array as A
import Data.Char (chr, ord)
import Data.List (elemIndex, foldl')
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M

nRotorPos :: Int
nRotorPos = 26

type Wiring = Array Int Char

type Plugboard = [(Char, Char)]

data Rotor = Rotor
  { rotId :: String,
    wiring :: Wiring,
    invWiring :: Wiring,
    turnovers :: [Char]
  }
  deriving (Show, Eq)

data OrientedRotor = OrientedRotor
  { rotor :: Rotor,
    topLetter :: Char,
    ringSetting :: Char
  }
  deriving (Show, Eq)

data EnigmaConfig = EnigmaConfig
  { reflector :: Wiring,
    rotors :: [OrientedRotor],
    plugboard :: Plugboard
  }
  deriving (Show, Eq)

index :: Char -> Int
index c = if i < 0 || i > 25 then error "invalid index" else i
  where
    i = ord c - ord 'A'

revIndex :: Int -> Char
revIndex i =
  if i < 0 || i > 25 -- 0 is min rotor ix; 25 is max rotor ix
    then error "invalid index"
    else chr $ i + ord 'A'

mapRotor :: Wiring -> Char -> Int -> Int
mapRotor wiring topLetter inputPos =
  (outputContact - offset) `mod` nRotorPos -- lh output position
  where
    offset = index topLetter
    inputContact = (inputPos + offset) `mod` nRotorPos
    outputContact = index $ wiring A.! inputContact

mapReflector :: Wiring -> Int -> Int
mapReflector wiring = mapRotor wiring 'A'

mapPlug :: Plugboard -> Char -> Char
mapPlug [] c = c
mapPlug ((a, b) : tl) c
  | c == a = b
  | c == b = a
  | otherwise = mapPlug tl c

mapRotors :: [OrientedRotor] -> (Rotor -> Wiring) -> (Wiring -> Char -> Int -> Int) -> Int -> Int
mapRotors [] _ _ pos = pos
mapRotors (curr : tl) wGetter mapper pos =
  mapRotors tl wGetter mapper outputPos
  where
    currWiring = wGetter $ rotor curr
    currTopletter = topLetter curr
    currRingSet = ord (ringSetting curr) - ord 'A'
    ringAdjPos = (pos - currRingSet) `mod` nRotorPos
    outputPos = (currRingSet + mapper currWiring currTopletter ringAdjPos) `mod` nRotorPos

mapRotorsRightLeft :: [OrientedRotor] -> Int -> Int
mapRotorsRightLeft rotors =
  mapRotors (reverse rotors) wiring mapRotor

mapRotorsLeftRight :: [OrientedRotor] -> Int -> Int
mapRotorsLeftRight rotors =
  mapRotors rotors invWiring mapRotor

-- Encipher a single character
cipherChar :: EnigmaConfig -> Char -> Char
cipherChar config c =
  let plugOut = index $ mapPlug plugs c
      reflectorIn = mapRotorsRightLeft rots plugOut
      reflectorOut = mapReflector (reflector config) reflectorIn
      plugIn = revIndex $ mapRotorsLeftRight rots reflectorOut
   in mapPlug plugs plugIn
  where
    plugs = plugboard config
    rots = rotors config

-- Step a single rotor
stepRotor :: OrientedRotor -> OrientedRotor
stepRotor rot =
  rot {topLetter = nextTL}
  where
    tl = topLetter rot
    nextTL = revIndex $ (index tl + 1) `mod` nRotorPos

-- Step all of the rotors in the machine
stepRotors :: [OrientedRotor] -> OrientedRotor -> [OrientedRotor]
stepRotors [] _ = []
stepRotors [curr] prev =
  if topLetter prev `elem` turnovers (rotor prev)
    then [stepRotor curr]
    else [curr]
stepRotors (curr : rest) prev =
  let curr' = if shouldStep then stepRotor curr else curr
   in curr' : stepRotors rest curr
  where
    shouldStep =
      topLetter prev `elem` turnovers (rotor prev)
        || topLetter curr `elem` turnovers (rotor curr)

-- Step the entire Enigma config based on the number of
-- rotors in the machine (3 or 4)
step :: EnigmaConfig -> EnigmaConfig
step cfg =
  case rotors cfg of
    [] -> cfg
    [left, ml, mr, right] ->
      let newRM = stepRotor right
          newRots = newRM : stepRotors [mr, ml] right
       in cfg {rotors = left : reverse newRots}
    [left, mid, right] ->
      let newRM = stepRotor right
          newRots = newRM : stepRotors [mid, left] right
       in cfg {rotors = reverse newRots}
    _ -> error "stepping behavior undefined for configs with < 3 or > 4 rotors"

freqs :: (Foldable t, Ord k) => t k -> [Int]
freqs = freqList . countItems
  where
    countItems :: (Foldable t, Ord k) => t k -> Map k Int
    countItems = foldl' (\counts item -> M.insertWith (+) item (1 :: Int) counts) M.empty
    freqList :: Map k Int -> [Int]
    freqList = M.foldl' (flip (:)) []

ic :: [Int] -> Double
ic freqs = numerator / (s * (s -1))
  where
    s = fromIntegral $ sum freqs
    numerator = fromIntegral $ sum [f * (f -1) | f <- freqs]
