module Solver
(solver, Square)
where

import qualified Data.List as L
import qualified Data.Set as S
import Control.Monad -- guard
import Control.Parallel.Strategies hiding (parMap)

type Square = [[Int]]

magic :: Square
magic = [[2,16,13,3],
         [11,5,8,10],
         [7,9,12,6],
         [14,4,1,15]]

-- prettyPrint :: Maybe Square -> IO()
-- prettyPrint (Just s) = putStrLn $ unlines $ L.map (unwords . L.map show) s
-- prettyPrint Nothing = putStrLn ""

fDiagonal :: Square -> Int -> [Int]
fDiagonal (x:xs) i = (x !! i) : fDiagonal xs (i+1)
fDiagonal _ _ = []

aDiagonal :: Square -> Int -> Int -> [Int]
aDiagonal (x:xs) i n = (x !! (n-1-i)) : aDiagonal xs (i+1) n
aDiagonal _ _ _ = []

validator :: Square -> Bool
validator s = (all (== magicNum) rowSumf) && (all (< magicNum) rowSumu)
      where n = length s
            magicNum = (1 + (n*n)) * n `div` 2
            fDiagonalR = filter (/=(-1)) $ fDiagonal s 0
            aDiagonalR = filter (/=(-1)) $ aDiagonal s 0 n
            rows_ = L.map (filter (/=(-1))) s
            rows = rows_ ++ [fDiagonalR] ++ [aDiagonalR]
            filledRow = filter (\x -> length x == n) rows
            unfilledRow = filter (\x -> length x < n) rows
            rowSumf = L.map sum filledRow
            rowSumu = L.map sum unfilledRow


squareFilled :: Square -> Bool
squareFilled = and . L.map (all (/= -1))

findEmptyPos :: Square -> Int -> (Int, Int)
findEmptyPos (x:xs) curRow
      | emptyPos == length x = findEmptyPos xs $ curRow + 1
      | otherwise = (curRow, emptyPos)
      where emptyPos = length $ takeWhile (/= -1) x
findEmptyPos _ _ = (-1, -1)


updateAtPos :: Square -> Int -> Int -> Int -> Square
updateAtPos s x y val = prevRows ++ updatedRow:(tail otherRows)
      where (prevRows, otherRows) = splitAt x s
            curRow = head otherRows
            (prevElems, otherElems) = splitAt y curRow
            updatedRow = prevElems ++ val:(tail otherElems)


nextSteps :: (Square, S.Set Int) -> Int -> Int -> [Int] -> [(Square, S.Set Int)]
nextSteps state x y (v:vs) 
      | validator newS && validator newST = curPair:nextSteps state x y vs
      | otherwise = nextSteps state x y vs
      where (s, choiceS) = state
            newS = updateAtPos s x y v
            newST = L.transpose newS
            curPair = (newS, S.delete v choiceS)
nextSteps _ _ _ _ = []

parMapDeep :: NFData b => (a -> b) -> [a] -> [b]
parMapDeep f xs = L.map f xs `using` parList rdeepseq

parMap :: NFData b => (a -> b) -> [a] -> [b]
parMap f xs = L.map f xs `using` parList rseq

splitHelper :: [Int] -> Int -> [[Int]]
splitHelper l s
      | n > 0 = cur : splitHelper remain s
      | otherwise = []
      where n = length cur
            cur = take s l
            remain = drop s l

splitList :: [Int] -> [[Int]]
splitList l
      | n < 4 = [l]
      | otherwise = splitHelper l size
      where n = length l
            x = round (sqrt (fromIntegral n))
            size = n `div` x

solverHelper :: Int -> (Square, S.Set Int) -> [Maybe Square]
solverHelper parLayer state
      | squareFilled s = [Just s >>= (\x -> guard (validator x) >> return x)]
      | parLayer > 1  = concat $ parMap (solverHelper (parLayer-1)) next
      | parLayer == 1 = concat $ parMapDeep (solverHelper (parLayer-1)) next
      | otherwise = concat $ L.map (solverHelper parLayer) next
      where (s, choiceS) = state
            (x, y) = findEmptyPos s 0
            choiceL = S.toList choiceS
            next = nextSteps state x y choiceL

solver :: Square -> [Maybe Square]
solver s = solverHelper (n+1) (s, S.difference allState curState)
      where n = length s
            allState = S.fromList [1..n*n]
            curState = foldl (S.union) S.empty $ L.map S.fromList s

-- parseHelper :: [String] -> [Int]
-- parseHelper (x:xs) = (read x :: Int) : parseHelper xs
-- parseHelper _ = []

-- parseInput :: [String] -> Square
-- parseInput (x:xs) = (parseHelper $ words x) : parseInput xs
-- parseInput _ = []

-- main :: IO ()
-- main = do
--         args <- getArgs
--         (filename) <- case args of
--             [filename] -> return (filename)
--             _ -> do pn <- getProgName
--                     die $ "Usage: " ++ pn ++ " <filename> +RTS -N4 -ls"
--         file <- readFile filename
--         let inputMagic = parseInput $ lines file
--         mapM_ prettyPrint (solver inputMagic)
