module DLog (runBabyStepGiantStep,
             runBabyStepGiantStepPar,
             runBabyStepGiantStepParWithChunks) where

import qualified Data.Bits as Bits (shift)
import Data.Either
import Data.List.Split
import Control.Parallel.Strategies

-- if parallel
runBabyStepGiantStepPar :: String -> Either String Integer
runBabyStepGiantStepPar line =
    let [x, base, modulus] = map read $ words line
    in babyStepGiantStepPar base x modulus

babyStepGiantStepPar :: Integer -> Integer -> Integer -> Either String Integer
babyStepGiantStepPar base x m
  | isRelativelyPrime base m == False     = Left "Base and modulus must be relatively prime"
  | length allRight > 0                   = head allRight
  | otherwise                             = Left "No solution"
        where sqrtM                       = ceiling $ sqrt $ fromIntegral m
              lhsTable                    = [(powMod base i m, i) | i <- [ 1 .. (sqrtM - 1) ]]
              iterMap                     = [1 .. sqrtM]
              allSols                     = map (babyStepGiantStepPar' base x m sqrtM lhsTable) iterMap `using` parList rpar
              allRight                    = filter isRight allSols

babyStepGiantStepPar' :: Integer -> Integer -> Integer -> Integer -> [(Integer, Integer)] -> Integer -> Either String Integer
babyStepGiantStepPar' base x m sqrtM lhsTable numIter
  | elem rhsSol (map fst lhsTable) = Right ans
  | otherwise                      = Left "No match this iteration"
            where rhsSol           = x * (inverseEuclid (base ^ (currRhsIdx)) m) `mod` m
                  ans              = currLhsIdx + currRhsIdx
                  currLhsIdx       = getI rhsSol lhsTable
                  currRhsIdx       = sqrtM * numIter

-- if par with chunks
runBabyStepGiantStepParWithChunks :: String -> Either String Integer
runBabyStepGiantStepParWithChunks line =
    let [x, base, modulus] = map read $ words line
    in babyStepGiantStepParWithChunks base x modulus

babyStepGiantStepParWithChunks :: Integer -> Integer -> Integer -> Either String Integer
babyStepGiantStepParWithChunks base x m
  | isRelativelyPrime base m == False     = Left "Base and modulus must be relatively prime"
  | length allRight > 0                   = head allRight
  | otherwise                             = Left "No solution"
        where sqrtM                       = ceiling $ sqrt $ fromIntegral m
              lhsTable                    = [(powMod base i m, i) | i <- [ 1 .. (sqrtM - 1) ]]
              iterMap                     = chunksOf 50 [1 .. sqrtM]
              allSols                     = map (babyStepGiantStepParWithChunks' base x m sqrtM lhsTable) iterMap `using` parList rpar
              allRight                    = filter isRight $ concat allSols

babyStepGiantStepParWithChunks' :: Integer -> Integer -> Integer -> Integer -> [(Integer, Integer)] -> [Integer] -> [Either String Integer]
babyStepGiantStepParWithChunks' base x m sqrtM lhsTable iterMap =
    map (babyStepGiantStepPar' base x m sqrtM lhsTable) iterMap

-- if single
runBabyStepGiantStep :: String -> Either String Integer
runBabyStepGiantStep line =
  let [x, base, modulus] = map read $ words line
  in babyStepGiantStep base x modulus

babyStepGiantStep :: Integer -> Integer -> Integer -> Either String Integer
babyStepGiantStep base x m
    | isRelativelyPrime base m == False = Left "Base and modulus must be relatively prime"
    | otherwise                         = babyStepGiantStep' base x m sqrtM 1 lhsTable
          where sqrtM               = ceiling $ sqrt $ fromIntegral m
                lhsTable            = [(powMod base i m, i) | i <- [ 1 .. (sqrtM - 1) ]]

babyStepGiantStep' :: Integer -> Integer -> Integer -> Integer -> Integer -> [(Integer, Integer)] -> Either String Integer
babyStepGiantStep' base x m sqrtM numIter lhsTable
    | elem rhsSol (map fst lhsTable) = Right ans
    | numIter > sqrtM                = Left "No solution"
    | otherwise                      = babyStepGiantStep' base x m sqrtM (numIter + 1) lhsTable
              where rhsSol           = x * (inverseEuclid (base ^ (currRhsIdx)) m) `mod` m
                    ans              = currLhsIdx + currRhsIdx
                    currLhsIdx       = getI rhsSol lhsTable
                    currRhsIdx       = sqrtM * numIter

-- HELPER FUNCTIONS --
powMod :: Integer -> Integer -> Integer -> Integer
powMod b e m = powMod' b e m 1

powMod' :: Integer -> Integer -> Integer -> Integer -> Integer
powMod' _ 0 _ result = result
powMod' b e m result = powMod' bNew eNew m resultNew
  where resultNew   = if (odd e) then (result * b) `mod` m else result
        eNew        = Bits.shift e (-1)
        bNew        = (b * b) `mod` m

inverseEuclid :: Integer -> Integer -> Integer
inverseEuclid x m = inverseEuclid' m x m 0 1 100

inverseEuclid' :: Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Integer
inverseEuclid' m x mUpdated a b c
  | c == 0        = a `mod` m
  | otherwise     = inverseEuclid' m newX x b y newX
      where newX = mUpdated `mod` x
            y     = a - (mUpdated `div` x) * b

isRelativelyPrime :: Integer -> Integer -> Bool
isRelativelyPrime num1 num2
  | num2 == 0 = (num1 == 1)
  | otherwise = isRelativelyPrime num2 (num1 `mod` num2)

getI :: Integer -> [(Integer, Integer)] -> Integer
getI word (hd:tl)
  | ((fst hd) == word) = snd hd
  | otherwise          = getI word tl
getI word []           = error "programming error!"
-- HELPER FUNCTIONS --
