{-# LANGUAGE DeriveGeneric, DeriveAnyClass #-}

module Puzzle
  ( Board, Move(..)
  , boardSize, numTiles
  , startBoardUnsolvable
  , startBoard3, startBoard8, startBoard16, startBoard17, startBoard30, startBoard36, startBoard38, startBoard40
  , startBoardByMoves
  , getRandomBoard, getRandomSolvableBoard
  , isSolvableBoard, isSolved
  , printBoard
  , getManhattanDistanceOfBoard
  , validMoves, applyMove
  , iddfs, iddfsWithDepthLimit
  , iddfspruning, iddfsWithPruning
  , ida, idaWithDepthLimit
  , parallelIDA, getNodesFromRoot
  , parallelIddfs, parallelIddfsPruning,
  ) where

import Data.Maybe (fromJust)
import System.Random.Shuffle (shuffleM)
import Data.List (intercalate, elemIndex, minimumBy)
import Data.Ord  (comparing)
import Control.Parallel.Strategies (withStrategy, parList, rdeepseq)
import Control.DeepSeq (NFData)
import GHC.Generics (Generic)

-- Simple board type:
-- 0 = empty tile
type Board = [Int]

-- node of the search tree, to be used when running in parallel
type Node = (Board, [Move])

data Move = MoveLeft | MoveRight | MoveUp | MoveDown
  deriving (Show, Eq, Generic, NFData)

-- 4x4 board (15-puzzle)
boardSize :: Int
boardSize = 4

numTiles :: Int
numTiles = boardSize * boardSize

-----------------
-- BOARD SETUP --
-----------------

-- Sample board, unsolvable
startBoardUnsolvable :: Board
startBoardUnsolvable   =
  [ 1,  2,  3,  4
  , 5,  6,  7,  8
  , 9, 10, 11, 12
  ,13, 15, 14,  0
  ]

startBoard3 :: Board
startBoard3 =
  [ 1,  2,  3,  4
  , 5,  6,  7,  8
  , 9, 10, 11, 12
  , 0, 13, 14, 15
  ]

startBoard8 :: Board
startBoard8 =
  [ 1,  2,  3,  4
  , 5,  6, 12,  7
  , 9, 10,  8, 15
  ,13, 14, 11,  0
  ]

startBoard16 :: Board
startBoard16 =
  [ 2,  3,  4,  8
  , 1,  6,  7, 12
  , 0, 10, 11, 15
  , 5,  9, 13, 14
  ]

startBoard17 :: Board
startBoard17 =
  [ 2,  3,  4,  8
  , 1,  6,  7, 12
  ,10,  0, 11, 15
  , 5,  9, 13, 14
  ]

startBoard30 :: Board
startBoard30 =
  [ 2,  7,  3,  4
  , 10,  1,  12, 8
  , 6, 11, 15, 14
  , 5, 9, 13, 0
  ]

startBoard36 :: Board
startBoard36 =
  [ 2,  7,  0,  3
  ,10,  1,  8,  4
  , 6, 11, 12, 14
  , 5,  9, 15, 13
  ]

startBoard38 :: Board
startBoard38 =
  [ 10,  2,  7,  3
  , 6,  1,  8, 0
  , 11,12, 14, 4
  , 5, 9, 15, 13
  ]

startBoard40 :: Board
startBoard40 =
  [10,  2,  0,  7
  , 6,  1,  8,  3
  ,11, 12, 14,  4
  , 5,  9, 15, 13
  ]

startBoardByMoves :: String -> Board
startBoardByMoves "3"  = startBoard3
startBoardByMoves "8"  = startBoard8
startBoardByMoves "16" = startBoard16
startBoardByMoves "17" = startBoard17
startBoardByMoves "30" = startBoard30
startBoardByMoves "36" = startBoard36
startBoardByMoves "38" = startBoard38
startBoardByMoves "40" = startBoard40
startBoardByMoves _    = startBoard3 -- this is the default one

---------------------------------------------------------------------------------------------------------------------------
-- HELPERS TO GENERATE RANDOM BOARD AND CHECK WHETHER IT CAN BE SOLVED (was actually not used in the project in the end) --
-----------------------------------------------------------------------------------------------------------------------------

getRandomBoard :: IO Board
getRandomBoard = shuffleM [0..15]

getRandomSolvableBoard :: IO Board
getRandomSolvableBoard = do
  b <- getRandomBoard
  if isSolvableBoard b
    then return b
    else getRandomSolvableBoard

isSolvableBoard :: Board -> Bool
isSolvableBoard b = (invParity + blankParity) `mod` 2 == 1 -- either inv or blank parity is odd, but not both (this is due to ignoring the empty space, if we don't ignore it, then we want both to be even)
  where
    invCount = length [ () | i <- [0..length b - 1], j <- [i+1..length b - 1],
                            b !! i /= 0, b !! j /= 0, b !! i > b !! j ]
    invParity = invCount `mod` 2
    blankRowFromBottom = boardSize - (indexOf 0 b `div` boardSize)
    blankParity = blankRowFromBottom `mod` 2

-- Check if board is in solved configuration
isSolved :: Board -> Bool
isSolved b = b == [1..(boardSize * boardSize - 1)] ++ [0]

-- Print the board
printBoard :: Board -> IO ()
printBoard b = do
  let rows = chunk boardSize b
  putStrLn "---------"
  mapM_ printRow rows
  putStrLn "---------"
  where
    printRow row =
      putStrLn . intercalate " " $
        map showTile row

    showTile 0 = " ."   -- use "." for empty
    showTile n = pad2 n

    pad2 n =
      if n < 10 then ' ':show n else show n

-----------------------------------
-- HELPERS FOR MOVES AND SOLVING --
-----------------------------------

-- Split list into chunks of size n
chunk :: Int -> [a] -> [[a]]
chunk _ [] = []
chunk n xs =
  let (h,t) = splitAt n xs
  in h : chunk n t

-- Helper: find index of a value in the board
indexOf :: Eq a => a -> [a] -> Int
indexOf x xs = fromJust (elemIndex x xs)

getManhattanDistanceOfBoard :: Board -> Int
getManhattanDistanceOfBoard b = distanceHelper (numTiles - 1) 0
  where
    distanceHelper :: Int -> Int -> Int
    distanceHelper index total
      | index < 0      = total
      | currTile == 0  = distanceHelper (index - 1) total
      | otherwise      = distanceHelper (index - 1) (total + newDist)
      where
        currTile  = b !! index
        row       = index `div` boardSize
        col       = index `mod` boardSize
        goalIndex = currTile - 1
        goalRow   = goalIndex `div` boardSize
        goalCol   = goalIndex `mod` boardSize
        newDist   = abs (row - goalRow) + abs (col - goalCol)


validMoves :: Board -> [Move]
validMoves b =
  let emptyIdx = indexOf 0 b
      n        = boardSize
      er       = emptyIdx `div` n
      ec       = emptyIdx `mod` n
  in  [ m
      | (cond, m) <-
          [ (ec < n - 1, MoveLeft)   -- tile to the right moves left
          , (ec > 0    , MoveRight)  -- tile to the left move right
          , (er < n - 1, MoveUp)     -- tile below can moves up
          , (er > 0    , MoveDown)   -- tile above can moves down
          ]
      , cond
      ]


-- Simple move: slide tile at index i into the empty slot
slideIntoEmpty :: Board -> Int -> Board
slideIntoEmpty b tileIndex =
  let emptyIndex = indexOf 0 b
      v1 = b !! tileIndex
      v2 = b !! emptyIndex
      replace i v xs = take i xs ++ [v] ++ drop (i+1) xs
      b' = replace tileIndex v2 b
  in  replace emptyIndex v1 b'

-- assume that all inputs are valid based off validMoves function
applyMove :: Board -> Move -> Board
applyMove b move = case move of
  MoveLeft  -> slideIntoEmpty b (tile + 1)
  MoveRight -> slideIntoEmpty b (tile - 1)
  MoveUp    -> slideIntoEmpty b (tile + boardSize)
  MoveDown  ->  slideIntoEmpty b (tile - boardSize)
  where 
    tile = indexOf 0 b

-----------------------
-- SOLVING FUNCTIONS --
-----------------------

iddfs :: Board -> Maybe [Move]
iddfs b = recursivelyIddfsWithDepthLimit 0
  where
    recursivelyIddfsWithDepthLimit depth =
      case iddfsWithDepthLimit b depth of
        Just moves -> Just moves
        Nothing    -> recursivelyIddfsWithDepthLimit (depth + 1)

iddfsWithDepthLimit :: Board -> Int -> Maybe [Move]
iddfsWithDepthLimit b depth = dfs b depth []
  where
    dfs :: Board -> Int -> [Move] -> Maybe [Move]
    dfs board d path
      | isSolved board = Just (reverse path)
      | d == 0 = Nothing
      | otherwise  = tryMoves (validMoves board) board d path

    tryMoves :: [Move] -> Board -> Int -> [Move] -> Maybe [Move]
    tryMoves [] _ _ _ = Nothing
    tryMoves (m:ms) board d path =
      case dfs (applyMove board m) (d - 1) (m:path) of
        Just sol -> Just sol
        Nothing  -> tryMoves ms board d path

iddfsPruningWithDepthLimit :: Board -> Int -> Maybe [Move]
iddfsPruningWithDepthLimit b depthLimit = dfs b 0 []
  where
    dfs :: Board -> Int -> [Move] -> Maybe [Move]
    dfs board depth path
      | isSolved board          = Just (reverse path)
      | depth == depthLimit     = Nothing
      | depth + h > depthLimit  = Nothing  -- prune: even best-case path too long
      | otherwise               = tryMoves (validMoves board) board depth path
      where
        h = getManhattanDistanceOfBoard board

    tryMoves :: [Move] -> Board -> Int -> [Move] -> Maybe [Move]
    tryMoves [] _ _ _ = Nothing
    tryMoves (m:ms) board depth path =
      case dfs (applyMove board m) (depth + 1) (m : path) of
        Just sol -> Just sol
        Nothing  -> tryMoves ms board depth path

iddfspruning :: Board -> Maybe [Move]
iddfspruning b = iter 0
  where
    iter depthLimit =
      case iddfsPruningWithDepthLimit b depthLimit of
        Just moves -> Just moves
        Nothing    -> iter (depthLimit + 1)

iddfsWithPruning :: Board -> Maybe [Move]
iddfsWithPruning = iddfspruning

ida :: Board -> Maybe [Move]
ida b = iter (getManhattanDistanceOfBoard b) 
  where
    -- depth includes the heuristic (g + h)
    iter depth =
      case idaWithDepthLimit b depth of
        Just moves -> Just moves
        Nothing    -> iter (depth + 1)


idaWithDepthLimit :: Board -> Int -> Maybe [Move]
idaWithDepthLimit b depth = dfs b 0 []
  where
    dfs :: Board -> Int -> [Move] -> Maybe [Move]
    dfs board d path
      | f > depth  = Nothing
      | isSolved board = Just (reverse path)
      | otherwise  = tryMoves (validMoves board) board d path
      where
        h = getManhattanDistanceOfBoard board
        f = d + h

    tryMoves :: [Move] -> Board -> Int -> [Move] -> Maybe [Move]
    tryMoves [] _ _ _ = Nothing
    tryMoves (m:ms) board d path =
      case dfs (applyMove board m) (d + 1) (m : path) of
        Just sol -> Just sol
        Nothing  -> tryMoves ms board d path

------------------------
-- PARALLEL FUNCTIONS --
------------------------

getNodesFromRoot :: Board -> Int -> [Node]
getNodesFromRoot start d0 =
  go [(start, [])] d0
  where
    go :: [Node] -> Int -> [Node]
    go frontier 0      = frontier
    go frontier depthN = go nextFrontier (depthN - 1)
      where
        nextFrontier :: [Node]
        nextFrontier = concatMap expandNode frontier

  
expandNode :: Node -> [Node]
expandNode (b, path) = [(applyMove b m, path ++ [m]) | m <- validMoves b ]
      
solveFromNode :: Node -> Maybe [Move]
solveFromNode (b, currentMoves) =
  case ida b of
    Nothing       -> Nothing
    Just resultingMoves   -> Just (currentMoves ++ resultingMoves)

parallelIDA :: Board -> IO ()
parallelIDA b = do
  let initialDepthSequential = 2
      allLeaves    = getNodesFromRoot b initialDepthSequential
      leafResults  = withStrategy (parList rdeepseq) (map solveFromNode allLeaves)
      solutions    = [ s | Just s <- leafResults ] --get rid of the ones that failed

  case solutions of
    []   -> putStrLn "No solution found from any leaf."
    sols -> do
      let best = minimumBy (comparing length) sols
      putStrLn $ "Found solution in " ++ show (length best) ++ " moves:"
      print best

-- parallelized over subtrees at a shallow split depth.
parallelIddfs :: Board -> IO ()
parallelIddfs startBoard = iter 0
  where
    -- how deep we go sequentially before splitting work
    splitDepth :: Int
    splitDepth = 3

    -- precompute the frontier at depth = splitDepth
    frontier :: [Node]
    frontier = getNodesFromRoot startBoard splitDepth
    -- each Node = (boardAtDepthK, movesFromRootToHere)

    iter :: Int -> IO ()
    iter depthLimit
      | depthLimit < splitDepth =  -- too shallow to bother splitting
          case iddfsWithDepthLimit startBoard depthLimit of
            Nothing      -> iter (depthLimit + 1)
            Just solPath -> do
              putStrLn $ "Found solution in " ++ show (length solPath)
                       ++ " moves (depth limit " ++ show depthLimit ++ "):"
              print solPath

      | otherwise = do
          let remaining = depthLimit - splitDepth

          -- for this depth limit, search each subtree in parallel
          let leafResults :: [Maybe [Move]]
              leafResults =
                withStrategy (parList rdeepseq)
                  [ fmap (prefixMoves ++)
                      (iddfsWithDepthLimit board remaining)
                  | (board, prefixMoves) <- frontier
                  ]

              solutions = [ s | Just s <- leafResults ]

          case solutions of
            []    -> iter (depthLimit + 1)
            sols  -> do
              let best = minimumBy (comparing length) sols
              putStrLn $ "Found solution in " ++ show (length best)
                       ++ " moves (depth limit " ++ show depthLimit ++ "):"
              print best

-- Global iterative deepening over depthLimit,
-- parallelized over subtrees at a shallow split depth
parallelIddfsPruning :: Board -> IO ()
parallelIddfsPruning startBoard = iter 0
  where
    -- how deep to go sequentially before we split work
    splitDepth :: Int
    splitDepth = 6

    -- frontier at depth = splitDepth (boards and prefix move sequences)
    frontier :: [Node]
    frontier = getNodesFromRoot startBoard splitDepth

    iter :: Int -> IO ()
    iter depthLimit
      -- For very shallow limits, just do everything from the root sequentially
      | depthLimit < splitDepth =
          case iddfsPruningWithDepthLimit startBoard depthLimit of
            Nothing      -> iter (depthLimit + 1)
            Just solPath -> do
              putStrLn $ "Found solution in "
                       ++ show (length solPath)
                       ++ " moves (depth limit "
                       ++ show depthLimit
                       ++ ", pruned, sequential):"
              print solPath

      | otherwise = do
          let remaining = depthLimit - splitDepth

          let leafResults :: [Maybe [Move]]
              leafResults =
                withStrategy (parList rdeepseq)
                  [ fmap (prefixMoves ++)
                      (iddfsPruningWithDepthLimit board remaining)
                  | (board, prefixMoves) <- frontier
                  ]

              solutions = [ s | Just s <- leafResults ]

          case solutions of
            []    -> iter (depthLimit + 1)
            sols  -> do
              let best = minimumBy (comparing length) sols
              putStrLn $ "Found solution in "
                       ++ show (length best)
                       ++ " moves (depth limit "
                       ++ show depthLimit
                       ++ ", pruned, parallel):"
              print best