import os
import csv
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np



# CONSTANTS
BRUTE_FORCE = "brute_force"
PRUNING = "pruning"
DSATUR = "DSATUR"
PARALLEL_PRUNING = "parallel_pruning"
PARALLEL_DSATUR = "parallel_DSATUR"

ALL_ALGOS = [PRUNING, DSATUR, PARALLEL_PRUNING, PARALLEL_DSATUR]

PETERSEN = "petersen.txt"
QUEEN_5 = "queen5_5.txt"
QUEEN_6 = "queen6_6.txt"
MYCIEL_4 = "myciel4.txt"
COMPLETE_10 = "complete_10.txt"

ALL_GRAPHS = [PETERSEN, QUEEN_5, QUEEN_6, MYCIEL_4, COMPLETE_10]

CHROMATIC_NUMS = {
    PETERSEN: 3,
    QUEEN_5: 5,
    QUEEN_6: 7,
    MYCIEL_4: 5,
    COMPLETE_10: 10
}

MAX_NUM_THREADS = 8

def parse_line(line):
    algo = line[0]
    graph = line[1]
    n_threads = line[2]
    trial = line[3]
    k = line[4]
    chromatic_number = line[5]
    is_k_colorable = line[6]
    real_time_seconds = line[7]
    allocated_bytes = line[8]

    return algo, graph, n_threads, trial, k, chromatic_number, is_k_colorable, real_time_seconds, allocated_bytes

def load_data():
    data_list = []
    with open("benchmark_summary4.csv", 'r') as f:
        reader = csv.reader(f)

        for line in reader:
            # print(line)
            algo, graph, n_threads, trial, k, chromatic_number, is_k_colorable, real_time_seconds, allocated_bytes = parse_line(line)
            if algo == "algo":
                # headers
                continue
            
            parsed_obj = {
                "algo": algo,
                "graph": graph,
                "n_threads": int(n_threads),
                "trial": int(trial),
                "k": int(k),
                "chromatic_number": int(chromatic_number),
                "is_k_colorable": is_k_colorable,
                "real_time_seconds": float(real_time_seconds),
                "allocated_kbytes": int(allocated_bytes)//1024
            }

            # print(algo, graph, n_threads, trial, k, chromatic_number, is_k_colorable, real_time_seconds, allocated_bytes)
            # data[algo]
            if k == chromatic_number:
                if float(real_time_seconds) <= 120:
                    assert is_k_colorable == "YES"
            else:
                assert is_k_colorable == "NO"
            data_list.append(parsed_obj)
    
    return data_list

# print(data_list[0])

def select_samples(search_params, data_list):
    ret = []

    def value_match(d):
        for key, val in search_params.items():
            if val == "*":
                continue
            
            data_val = d[key]
            if data_val != val:
                return False
        return True

    for d in data_list:
        if value_match(d):
            ret.append(d)
    
    return ret
        

def reduce_median(list_of_three, return_all=False):
    # print(list_of_three)
    assert len(list_of_three) == 3, f"Got {len(list_of_three)}"
    for d in list_of_three:
        assert d["algo"] == list_of_three[0]["algo"]
        assert d["graph"] == list_of_three[0]["graph"]
        assert d["n_threads"] == list_of_three[0]["n_threads"]
        assert d["k"] == list_of_three[0]["k"]
    
    assert list_of_three[0]["trial"] != list_of_three[1]["trial"] and list_of_three[0]["trial"] != list_of_three[2]["trial"] and list_of_three[2]["trial"] != list_of_three[1]["trial"]
    sorted_list = sorted(list_of_three, key=lambda x : x["real_time_seconds"])
    if return_all:
        return sorted_list
    return sorted_list[1]

def project(project_dims, d):
    ret = []
    if isinstance(d, list):
        for dd in d:
            # print({k: dd[k] for k in project_dims})
            ret.append({k: dd[k] for k in project_dims})
    else:
        # print({k: d[k] for k in project_dims})
        ret.append({k: d[k] for k in project_dims})
    return ret

def get_data_for_pruning_algo(data_list):
    # for algo in [PRUNING, DSATUR]:
    for algo in [PRUNING]:
        for graph in ALL_GRAPHS:
            for k in (CHROMATIC_NUMS[graph] - 1, CHROMATIC_NUMS[graph]):
                # print(algo, graph, k)
                project(["algo", "graph", "k", "n_threads", "is_k_colorable", "real_time_seconds", "allocated_kbytes", "trial"],reduce_median(select_samples({"algo": algo, "graph": graph, "k": k}, data_list)))

def get_data_for_DSATUR_algo(data_list):
    # for algo in [PRUNING, DSATUR]:
    for algo in [DSATUR]:
        for graph in ALL_GRAPHS:
            for k in (CHROMATIC_NUMS[graph] - 1, CHROMATIC_NUMS[graph]):
                # print(algo, graph, k)
                rows = project(["algo", "graph", "k", "n_threads", "is_k_colorable", "real_time_seconds", "allocated_kbytes", "trial"],reduce_median(select_samples({"algo": algo, "graph": graph, "k": k}, data_list)))

def plot_speedup(all_rows, algo):
    times = {}
    for entry in all_rows:
        if isinstance(entry, (tuple, list)):
            entry = entry[0]
        g = entry['graph']
        k = int(entry['k'])
        n = int(entry['n_threads'])
        t = float(entry['real_time_seconds'])
        times.setdefault(g, {}).setdefault(k, {})[n] = t
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

    # Create a single figure with two subplots side-by-side and plot both k-modes
    fig, axes = plt.subplots(1, 2, figsize=(18, 6))
    for ai, k_mode in enumerate(("below", "chromatic")):
        ax = axes[ai]
        for gi, graph in enumerate(ALL_GRAPHS):
            chrom = CHROMATIC_NUMS[graph]
            k_val = chrom - 1 if k_mode == "below" else chrom

            gdata = times[graph]
            kdata = gdata[k_val]
            base_time = kdata[1]

            x = []
            y = []
            for n in range(1, MAX_NUM_THREADS+1):
                tn = kdata[n]
                x.append(n)
                y.append(base_time / tn if tn > 0 else 0.0)

            color = colors[gi % len(colors)]
            ax.plot(x, y, marker='o', label=graph.replace('.txt',''), color=color)
            s = base_time/kdata[MAX_NUM_THREADS]
            print(f"Amdahl's law for {k_mode}, {graph}, {algo}: p = ", s/(1 + s*(MAX_NUM_THREADS - 1)))

        # plot ideal y = x line for reference
        x = np.linspace(1, MAX_NUM_THREADS, 500)
        y = x
        ax.plot(x, y, color='red', linewidth=2, label='Ideal')

        ax.set_xlabel('Number of threads')
        ax.set_ylabel('Speedup T(1)/T(n)')
        title = f"Speedup (k = chromatic-1) for {algo}" if k_mode == "below" else f"Speedup (k = chromatic) for {algo}"
        ax.set_title(title)
        ax.set_xticks(list(range(1, MAX_NUM_THREADS+1)))
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.legend()

    plt.tight_layout()

    # ensure output directory exists and save combined figure as PNG
    out_dir = "plots"
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"speedup_combined_{algo}.png")
    fig.savefig(out_path)
    plt.close(fig)

def get_data_for_parallel_pruning_algo(data_list):
    algo = PARALLEL_PRUNING
    all_rows = []
    for graph in ALL_GRAPHS:
        for k in (CHROMATIC_NUMS[graph] - 1, CHROMATIC_NUMS[graph]):
            # print(algo, graph, k)
            for n in range(1, MAX_NUM_THREADS + 1):
                row = project(["algo", "graph", "k", "n_threads", "is_k_colorable", "real_time_seconds", "allocated_kbytes", "trial"],reduce_median(select_samples({"algo": algo, "graph": graph, "k": k, "n_threads": n}, data_list), return_all=False))
                all_rows.append(row)
                # for i in range(len(rows)):
                    # all_rows.append((rows[i], i))
    # for r in all_rows:
        # print(r)
    # print(len(all_rows))
    plot_speedup(all_rows, algo)

    # print(project(["algo", "graph", "k", "n_threads", "is_k_colorable", "real_time_seconds", "allocated_kbytes", "trial"],reduce_median(select_samples({"algo": algo, "graph": graph, "k": k}, data_list))))


def get_data_for_parallel_DSATUR_algo(data_list):
    algo = PARALLEL_DSATUR
    all_rows = []
    for graph in ALL_GRAPHS:
        for k in (CHROMATIC_NUMS[graph] - 1, CHROMATIC_NUMS[graph]):
            # print(algo, graph, k)
            for n in range(1, MAX_NUM_THREADS + 1):
                row = project(["algo", "graph", "k", "n_threads", "is_k_colorable", "real_time_seconds", "allocated_kbytes", "trial"],reduce_median(select_samples({"algo": algo, "graph": graph, "k": k, "n_threads": n}, data_list), return_all=False))
                all_rows.append(row)
                # for i in range(len(rows)):
                    # all_rows.append((rows[i], i))
    # for r in all_rows:
        # print(r)
    # print(len(all_rows))
    plot_speedup(all_rows, algo)


def main():
    data_list = load_data()

    # print(select_samples({"algo": PRUNING, "graph": MYCIEL_4}, data_list))
    # get_data_for_pruning_algo(data_list)
    # get_data_for_DSATUR_algo(data_list)
    get_data_for_parallel_pruning_algo(data_list)
    get_data_for_parallel_DSATUR_algo(data_list)

    # print(select_samples({"algo": "pruning", "graph": "complete_10.txt"}, data_list))
    



if __name__ == "__main__":
    main()

    

    