{-
COMPILE THE CODE: stack --resolver lts-21.9 ghc -- -O2 -threaded -rtsopts --make -Wall -O tsp

USAGE: ./tsp [problem_file] +RTS -N8 -ls
[problem_file] format:
 - Each line represents a city
 - Each line takes the form:
    [index] [x-coord] [y-coord]
-}

import System.Environment (getArgs,getProgName)
import System.IO (hPutStrLn,stderr)
import Control.Parallel.Strategies (parList,rpar,using,rseq,parMap)
import System.Exit (exitFailure)
import qualified Data.Map as Map
import System.Random.Shuffle (shuffleM)
import System.Random (randomRIO)
import Data.List (sortBy,maximumBy)
import Data.Function (on)
import Control.Monad (replicateM)


type City = (Int,Float,Float)
type Route = [Int]

-- MUTATES OFFSPRING. TUNE MUTATION FUNCITON HERE
mutate :: Route -> IO Route
mutate tour = do
    let len = length tour
    index1 <- randomRIO (0, len - 1)
    index2 <- randomRIO (0, len - 1)
    let mutatedTour = swapCities index1 index2 tour
    return mutatedTour

-- SWAPS TWO CITIES IN A TOUR
swapCities :: Int -> Int -> Route -> Route
swapCities i j tour = map swap [0 .. ((length tour) - 1)]
  where
    swap k
      | k == i     = (tour !! j)
      | k == j     = (tour !! i)
      | otherwise  = (tour !! k)

--PERFORMS THE CROSSOVER BETWEEN EACH PAIR OF PARENTS IN A GIVEN LIST
crossover :: [(Route, Route)] -> IO [Route]
crossover parentPairs = mapM (\(parent1, parent2) -> crossoverSingle parent1 parent2) parentPairs
--Ordered Crossover (OX)
crossoverSingle :: Route -> Route -> IO Route
crossoverSingle parent1 parent2 = do
    let len = length parent1
    (start, end) <- do
        indices <- shuffleM [0 .. (len - 1)]
        let start' = head indices
            end' = last indices
        return (min start' end', max start' end')
    let slice = take (end - start) . drop start
        sliceP1 = slice parent1
        remainderP2 = filter (`notElem` sliceP1) parent2
        offspring = sliceP1 ++ remainderP2
    mutationProb <- randomRIO (1, 100) :: IO Int
    finalOffspring <- if mutationProb <= 5       --TUNE MUTATION PROBABILITY HERE
                      then mutate offspring
                      else return offspring
    return $ finalOffspring ++ [head finalOffspring]

-- CALCULATES THE TOTAL DISTANCE TRAVELED DURING A TOUR
tourLength :: Map.Map (Int, Int) Float -> Route -> Float
tourLength distances tour =
    sum [Map.findWithDefault 0 (city1,city2) distances | (city1,city2) <- zip tour (tail tour)]

-- CALCULATES THE FITNESS OF AN INDIVIDUAL
tourFitness :: Map.Map (Int, Int) Float -> Route -> Float
tourFitness distances tour = 1 / tourLength distances tour --TUNE FITNESS METRIC HERE

-- RANDOMLY GENERATES A ROUTE THAT VISITS EACH CITY ONCE AND RETURNS TO THE STARTING POINT
newRandomRoute :: Int -> IO Route
newRandomRoute numCities = do
    shuffledIndices <- shuffleM [1..numCities]
    return $ shuffledIndices ++ [head shuffledIndices]

-- RANDOMLY GENERATES AN INITIAL POPULATION
generateInitPop :: Int -> Int -> IO [Route]
generateInitPop numCities populationSize =
    sequence [newRandomRoute numCities | _ <- [1..populationSize]]

-- PARSES INPUT FILE AND STORES CITY INFORMATION
parseCity :: String -> City
parseCity line = case words line of
    [index,x,y]   -> (read index, read x, read y)
    _             -> error "Invalid problem file format"

-- CALCULATES PAIRWISE EUCLYDEAN DISTANCE FOR THE GIVEN SET
calculateDistances :: [City] -> Map.Map (Int,Int) Float
calculateDistances cities =
    Map.fromList [((i,j), distance (x1,y1) (x2,y2)) | (i,x1,y1) <- cities, (j,x2,y2) <- cities, i /= j]

-- CALCULATES EUCLYDEAN DISTANCE BETWEEN TWO CITIES
distance :: (Float,Float) -> (Float,Float) -> Float
distance (x1,y1) (x2,y2) = sqrt ((x2 - x1)^(2::Integer) + (y2 - y1)^(2::Integer))

-- RANDOMLY PICKS A PARENT FROM A LIST
selectParent :: [Route] -> IO Route
selectParent elites = do
    parentIndex <- randomRIO (0, length elites - 1)
    return (elites !! parentIndex)

-- HELPER FUNCITON TO BREAK A LIST INTO CHUNKS WITH chunkSize ELEMENTS IN EACH
makeChunks :: Int -> [a] -> [[a]]
makeChunks _ [] = []
makeChunks chunkSize lst = chunk : makeChunks chunkSize rest
    where
        (chunk,rest) = splitAt chunkSize lst

-- CREATE A NEW GENERATION FROM THE OLD
generateNextGeneration :: [Route] -> [Route] -> IO [Route]
generateNextGeneration remainingPopulation elites = do
    pairs <- replicateM 50 $ do    -- TUNE NUMBER OF CROSSOVERS PER GENERATION HERE
        parent1 <- selectParent elites
        parent2 <- selectParent elites
        return (parent1, parent2)

    let pairChunks = makeChunks 1 pairs  --TUNE NUMBER OF PAIRS PER CHUNK HERE
        offSpring = map crossover pairChunks `using` parList rseq
    
    offspring <- sequence (offSpring)
    return $ elites ++ (concat offspring) ++ remainingPopulation

-- THIS FUNCTION PERFORMS THE BULK OF THE ALGORITHM
evolve :: Int -> [Route] -> Map.Map (Int, Int) Float -> IO [Route]
evolve 0 population _ = return population
evolve gen population distances = do
    let fitnesses = parMap rpar (tourFitness distances) population
        sortedPopulation = map snd $ sortBy (compare `on` fst) $ zip fitnesses population
        elites = take (length population `div` 10) sortedPopulation
        remainingPopulation = drop (length population `div` 10) sortedPopulation
    newGeneration <- generateNextGeneration remainingPopulation elites
    evolve (gen - 1) newGeneration distances

main :: IO ()
main = do
    args <- getArgs
    case args of
        [filename] -> do
            content <- readFile filename
            let cities = map parseCity (lines content)
                distances = calculateDistances cities
                numCities = length cities
            initialPopulation <- generateInitPop numCities numCities  --TUNE INITIAL POPUALTION SIZE HERE
            let numGenerations = 100  --TUNE NUMBER OF GENERATIONS HERE
            finalPopulation <- evolve numGenerations initialPopulation distances
            let fitnesses = parMap rpar (tourFitness distances) finalPopulation
                bestTour = snd $ maximumBy (compare `on` fst) $ zip fitnesses finalPopulation
                bestFitness = tourFitness distances bestTour
                bestDistance = tourLength distances bestTour
            putStrLn $ "Best Tour: \n" ++ show bestTour ++ "\n"
            putStrLn $ "Fitness: " ++ show bestFitness
            putStrLn $ "Distance: " ++ show bestDistance
        _ -> do
            pn <- getProgName
            hPutStrLn stderr $ "Usage: " ++ pn ++ " <tsp_filename>"
            exitFailure