{-

WordLadder solver.

Shows the performance difference between parallel and sequential execution.
See run.sh for performance tests.

Usage: wordLadder <[par|seq]> <dictionary-filename> <from-word> <to-word>
$ ./wordLadder par words.txt bar none

-}

{-# LANGUAGE OverloadedStrings #-}

import qualified Prelude
import Prelude hiding (words, putStr, length, readFile)

import System.Environment
import Data.Char
import qualified Data.Set as Set

import Control.Parallel.Strategies

import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BSU
import Data.ByteString hiding (map, filter, head, concat, reverse, all)

noArgErrorMessage :: String
noArgErrorMessage = "Usage: wordLadder <dictionary-filename> <from-word> <to-word>"

createDict :: ByteString -> Int -> [ByteString]
createDict content desiredLength = filter (allLowerAlphaLength desiredLength) $ BSU.words content

allLowerAlphaLength :: Int -> ByteString -> Bool
allLowerAlphaLength desiredLength word | (BS.length word) /= desiredLength = False
allLowerAlphaLength _ word = BSU.all (\c -> and [(isAlpha c), (isLower c)]) word

type Words = Set.Set ByteString

-- |
-- The algorithm:
--  Take all the paths
--  For every path p:
--    Find all hops from the last element of that path
--    Multiply that path by all the hops
--    Example:
--      path is ["ac", "ab", "aa"], hops from "ac" are "bc" and "cc",
--      we end up with:
--      [
--        ["bc", "ac", "ab", "aa"],
--        ["cc", "ac", "ab", "aa"],
--      ]
--      for that path.
--      Also remove every hop from the dictionary after it's found.
--  Gather all paths (concat) into a new set
--  If any of the paths ends with `to`, end the algorithm.
--  Otherwise, repeat.
process :: Int -> [[ByteString]] -> ByteString -> Words -> Maybe [ByteString]
process 0 _ _ _ = Nothing
process limit paths to dict = 
  if Set.null dict
    then Nothing
    else
      let 
        results = map (eval dict) paths
        paths' = concat . map fst $ results
        nexts = map snd results
        dict' = Set.difference dict (Set.unions nexts)

        found = filter (\p -> head p == to) paths'
      in
        case found of
          (p:_) -> Just $ reverse p
          [] -> process (limit-1) paths' to dict'
  where
      eval d p =
        let
          next = findNextWords (head p) d          
        in
          (map (\w -> w:p) (Set.toList next), next)

processPar :: Int -> [[ByteString]] -> ByteString -> Words -> Maybe [ByteString]
processPar 0 _ _ _ = Nothing
processPar limit paths to dict = 
  if Set.null dict
    then Nothing
    else
      let 
        results = parMap rpar (eval dict) paths
        paths' = concat . parMap rpar fst $ results
        nexts = parMap rpar snd results
        dict' = Set.difference dict (Set.unions nexts)

        found = filter (\p -> head p == to) paths'
      in
        case found of
          (p:_) -> Just $ Prelude.reverse p
          [] -> processPar (limit-1) paths' to dict'
  where
      eval d p =
        let
          next = findNextWords (head p) d          
        in
          (map (\w -> w:p) (Set.toList next), next)

findNextWords :: ByteString -> Words -> Words
findNextWords currentWord dict = Set.filter suitable dict
  where
    suitable word = oneCharDifference word currentWord

oneCharDifference :: ByteString -> ByteString -> Bool
oneCharDifference a b | (length a) /= (length b) = False
oneCharDifference a b = BSU.count '1' diffs == 1
  where
    diffs = (BSU.packZipWith (\ca cb -> if ca /= cb then '1' else '0') a b)

showResults :: Maybe [ByteString] -> ByteString
showResults Nothing = "Unable to find a ladder in 20"
showResults (Just path) = BSU.unlines path

main :: IO ()
main = do
  args <- getArgs
  case args of
    [_, _, fromWord, toWord] | (Prelude.length fromWord) /= (Prelude.length toWord) ->
      do
        Prelude.putStrLn ("\"" ++ fromWord ++ "\" and \""++toWord++"\" must be the same length")
    [parflag, filename, fromWord, toWord] ->
      do
         contents <- readFile filename
         let dict = Set.fromList $ createDict contents (Prelude.length fromWord)
         results <- case parflag of
                      "seq" -> do
                        Prelude.putStrLn "Sequential mode"
                        return $ process 20 [[BSU.pack fromWord]] (BSU.pack toWord) dict
                      "par" -> do
                        Prelude.putStrLn "Parallel mode"
                        return $ processPar 20 [[BSU.pack fromWord]] (BSU.pack toWord) dict
                      _ -> error "Set the par or seq flag"
         BSU.putStr (showResults results)
    _ -> Prelude.putStrLn noArgErrorMessage