#include <iostream>
#include <fstream>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <sstream>
#include <memory>
#include "circuit.h"
#include <algorithm>
#include <functional>

// reference: https://web.stanford.edu/class/ee133/handouts/general/spice_ref.pdf

float G[MAT_SIZE][MAT_SIZE];
float Ivec[MAT_SIZE];
float v[MAT_SIZE], v_prev[MAT_SIZE];
float t = 0.0f;


std::vector<std::function<void(int)>> component_adders;
std::unordered_map<std::string, float> model_is;
std::unordered_map<std::string, std::tuple<float, float, float, std::string>> model_mosfet;

std::unordered_set<int> nodes;

float convertSPICENumToFloat(std::string input) {
    // dictionary corresponds between suffix and exponent
    std::unordered_map<char, int> suffixToExponent = {
        {'t', 12},
        {'g', 9},
        {'x', 6},
        {'k', 3},
        {'m', -3},
        {'u', -6},
        {'n', -9},
        {'p', -12},
        {'f', -15}
    };

    int exponent = 0;
    int num = 1;
    int splitPoint = -1;
    char curr;

    // find the start of the suffix
    for (int i = 0; i < input.size(); i++) {
        curr = std::tolower(input[i], std::locale());
        if (suffixToExponent.count(curr) > 0) {
            exponent = suffixToExponent[curr];
            splitPoint = i;
            break;
        } else if (curr == 'e') {
            // find the exponent
            if (i + 1 < input.size()) {
                exponent = std::stoi(input.substr(i + 1));
            }
            splitPoint = i;
            break;
        }
    }

    // read the number up to the suffix
    if (splitPoint != -1) {
        num =  std::stof(input.substr(0, splitPoint));
        return num * pow(10, exponent);
    } else {
        return std::stof(input);
    }
} 

void parseLine(const std::string& line) {
    std::string new_line = (std::string)line;
    std::replace(new_line.begin(), new_line.end(), '(', ' '); 
    std::replace(new_line.begin(), new_line.end(), ')', ' '); 
    std::transform(new_line.begin(), new_line.end(), new_line.begin(), ::toupper);

    std::stringstream ss(new_line);
    std::string token;

    std::vector<std::string> tokens;
    while (ss >> token) {
        tokens.push_back(token);
    }

    std::unique_ptr<Component> componentPtr = nullptr;

    if (tokens[0] == ".MODEL") {
        // model: .MODEL MODName D (IS= N= Rs= CJO= Tt= BV= IBV=)
        std::string model_name = tokens[1];
        if (tokens[2] == "D") { // Diode model
            for (size_t i = 2; i < tokens.size(); i++) {
                if (tokens[i].find("IS=") == 0) {
                    model_is[model_name] = convertSPICENumToFloat(tokens[i].substr(3));
                    break;
                }
            }
        } else if (tokens[2] == "NMOS" || tokens[2] == "PMOS") { // Nmos model
            std::get<0>(model_mosfet[model_name]) = 0.00002;
            std::get<1>(model_mosfet[model_name]) = 0;
            std::get<2>(model_mosfet[model_name]) = 0;

            std::get<3>(model_mosfet[model_name]) = tokens[2];

            for (size_t i = 2; i < tokens.size(); i++) {
                if (tokens[i].find("KP=") == 0 || tokens[i].find("KN=") == 0) {
                    std::get<0>(model_mosfet[model_name]) = convertSPICENumToFloat(tokens[i].substr(3));
                }
                else if (tokens[i].find("VTO=") == 0) {
                    std::get<1>(model_mosfet[model_name]) = convertSPICENumToFloat(tokens[i].substr(4));
                }
                else if (tokens[i].find("LAMBDA=") == 0) {
                    std::get<2>(model_mosfet[model_name]) = convertSPICENumToFloat(tokens[i].substr(7));
                }
            }
        } else {
            std::cout << "Unknown model type: " << tokens[2] << std::endl;
        }
       
        return;
    }

    if (tokens[0][0] == 'V' || tokens[0][0] == 'I') {
        // voltage source: Vname N+ N- DCValue
        // first case is DC
        SpiceFuncType source_type;
        std::string new_line = (std::string)line;
        std::replace(new_line.begin(), new_line.end(), '(', ' '); 
        std::replace(new_line.begin(), new_line.end(), ')', ' '); 

        std::stringstream ss(new_line);
        std::string token;
        bool is_vsrc = tokens[0][0] == 'V';

        std::string source_name;
        if (!(ss >> source_name) || source_name.empty()) {
            throw std::runtime_error("Source name is empty or invalid.");
        }

        std::string node1, node2;
        if (!(ss >> node1 >> node2)) {
            throw std::runtime_error("Could not read source nodes for " + source_name);
        }

        if (!(ss >> token)) {
            throw std::runtime_error("Could not read transient function or DC value for " + source_name);
        }

        std::string upper_token = token;
        std::transform(upper_token.begin(), upper_token.end(), upper_token.begin(), ::toupper);

        // Check for explicit DC keyword
        if (upper_token == "DC") {
            source_type = SPICE_FUNC_DC;
            if (!(ss >> token)) {
                throw std::runtime_error("Expected DC value after 'DC' keyword for " + source_name);
            }
            // The next token *is* the DC value
            try {
                auto dc_source_ptr = std::make_unique<TransientSource>();
                dc_source_ptr->type = SPICE_FUNC_DC;
                dc_source_ptr->params.dc_value = convertSPICENumToFloat(token);
                int n1 = std::stoi(node1) - 1;
                int n2 = std::stoi(node2) - 1;
                nodes.insert(n1);   
                nodes.insert(n2);

                if (!is_vsrc) {
                    add_isrc(n1, n2, dc_source_ptr.release());
                } else {
                    component_adders.push_back([n1, n2, src = dc_source_ptr.release()](int ni) mutable {
                        add_vsrc(n1, n2, ni, src);
                    });
                }
                nnodes = std::max({nnodes, n1, n2});
            } catch (const std::exception& e) {
                throw std::runtime_error("Expected DC value after 'DC' keyword for " + source_name + ": " + e.what());
            }
        }
        // Check for transient function keywords
        else if (upper_token == "PULSE") source_type = SPICE_FUNC_PULSE;
        else if (upper_token == "SIN") source_type = SPICE_FUNC_SIN;
        else {
            // Not a keyword, assume it's an implicit DC value
            source_type = SPICE_FUNC_DC;
            try {
                float value = convertSPICENumToFloat(token);
                int n1 = std::stoi(node1) - 1;
                int n2 = std::stoi(node2) - 1;

                nodes.insert(n1);   
                nodes.insert(n2);

                auto dc_source = std::make_unique<TransientSource>();
                dc_source->type = SPICE_FUNC_DC;
                dc_source->params.dc_value = value;
                if (!is_vsrc) {
                    add_isrc(n1, n2, dc_source.release());
                } else {
                    component_adders.push_back([n1, n2, src = dc_source.release()](int ni) mutable {
                        add_vsrc(n1, n2, ni, src);
                    });
                }
                nnodes = std::max({nnodes, n1, n2});
            } catch (const std::exception& e) {
                throw std::runtime_error("Expected function keyword or DC value, got '" + token + "' for " + source_name + ": " + e.what());    
            }
        }

        if (source_type != SPICE_FUNC_DC) {
            // Get the rest of the line containing parameters
            std::string params_str;
            std::getline(ss, params_str);

            size_t first = params_str.find_first_not_of(" \t(");
            size_t last = params_str.find_last_not_of(" \t)");
            if (first == std::string::npos || last == std::string::npos) {
                params_str = "";
            } else {
                params_str = params_str.substr(first, (last - first + 1));
            }


            std::stringstream param_ss(params_str);
            std::string param_token;
            std::vector<float> values;

            try {
                // Read all parameter tokens into a temporary vector
                while (param_ss >> param_token) {
                    values.push_back(convertSPICENumToFloat(param_token));
                }
                auto source_ptr = std::make_unique<TransientSource>();
                // Assign parameters based on type and expected count
                switch (source_type) {
                    case SPICE_FUNC_DC:
                        // handled above
                        break;
                    case SPICE_FUNC_PULSE:
                        if (values.size() == 7) {
                            source_ptr->params.pulse = SpicePulseParams{values[0], values[1], values[2], values[3], values[4], values[5], values[6]};
                            source_ptr->type = SPICE_FUNC_PULSE;
                            int n1 = std::stoi(node1) - 1;
                            int n2 = std::stoi(node2) - 1;

                            nodes.insert(n1);   
                            nodes.insert(n2);

                            if (!is_vsrc) {
                                add_isrc(n1, n2, source_ptr.release());
                            } else {
                                component_adders.push_back([n1, n2,src = source_ptr.release()](int ni) mutable {
                                    add_vsrc(n1, n2, ni, src);
                                });
                            }
                            nnodes = std::max({nnodes, n1, n2});
                        } else { throw std::runtime_error("Expected 7 parameters for Pulse, got " + std::to_string(values.size())); }
                        break;
                    case SPICE_FUNC_SIN:
                        if (values.size() == 6) {
                            source_ptr->params.sin = SpiceSinParams{values[0], values[1], values[2], values[3], values[4], values[5]};
                        } else if (values.size() == 5) { // Handle optional td
                            source_ptr->params.sin = SpiceSinParams{values[0], values[1], values[2], values[3], values[4], 0.0};
                        } else if (values.size() == 4) { // Handle optional td, a
                            source_ptr->params.sin = SpiceSinParams{values[0], values[1], values[2], values[3], 0.0, 0.0};
                        } else if (values.size() == 3) {// Handle optional td, a, phase
                            source_ptr->params.sin = SpiceSinParams{values[0], values[1], values[2], 0.0, 0.0, 0.0};
                        } 
                        else { throw std::runtime_error("Expected 4, 5, or 6 parameters for Sin, got " + std::to_string(values.size())); }
                        source_ptr->type = SPICE_FUNC_SIN;
                        int n1 = std::stoi(node1) - 1;
                        int n2 = std::stoi(node2) - 1;

                        nodes.insert(n1);   
                        nodes.insert(n2);
                        
                        if (!is_vsrc) {
                            add_isrc(n1, n2, source_ptr.release());
                        } else {
                            component_adders.push_back([n1, n2, src = source_ptr.release()](int ni) mutable {
                                add_vsrc(n1, n2, ni, src);
                            });
                        }
                        nnodes = std::max({nnodes, n1, n2});
                        break;
                }
            } catch (const std::exception& e) {
                throw std::runtime_error("Error parsing parameters for " + source_name + " (" + upper_token + "): " + e.what());
            }
        }
        
    } else if (tokens[0][0] == 'R') {
        // resistor: Rname N+ N- Value
        if (tokens.size() != 4) {
            std::cout << "Error: " << line << std::endl;
            return;
        }
        int n1 = std::stoi(tokens[1]) - 1;
        int n2 = std::stoi(tokens[2]) - 1;

        nodes.insert(n1);   
        nodes.insert(n2);

        add_res(n1, n2, convertSPICENumToFloat(tokens[3]));
        nnodes = std::max({nnodes, n1, n2});
    } else if (tokens[0][0] == 'C') {
        // capacitor: Cname N+ N- Value <IC=Initial Condition>

        if (tokens.size() != 5 && tokens.size() != 4) {
            std::cout << "Error: " << line << std::endl;
            return;
        }
        float v0 = 0.0f;

        if (tokens.size() == 5) {
            v0 = convertSPICENumToFloat(tokens[4]);
        }

        int n1 = std::stoi(tokens[1]) - 1;
        int n2 = std::stoi(tokens[2]) - 1;

        nodes.insert(n1);   
        nodes.insert(n2);

        nnodes = std::max({nnodes, n1, n2});
        add_cap(n1, n2, convertSPICENumToFloat(tokens[3]), 5e-6f, v0);
    } else if (tokens[0][0] == 'L') {
        // inductor: Lname N+ N- Value <IC=Initial Condition>
        if (tokens.size() != 5 && tokens.size() != 4) {
            std::cout << "Error: " << line << std::endl;
            return;
        }

        float i0 = 0.0f;

        if (tokens.size() == 5) {
            i0 = convertSPICENumToFloat(tokens[4]);
        }

        int n1 = std::stoi(tokens[1]) - 1;
        int n2 = std::stoi(tokens[2]) - 1;
        nnodes = std::max({nnodes, n1, n2});

        add_ind(n1, n2, convertSPICENumToFloat(tokens[3]), 5e-6f, i0);
    } else if (tokens[0][0] == 'E' || tokens[0][0] == 'G') {
        // voltage controlled voltage source: Ename N+ N- NC+ NC- Gain
        // voltage controlled current source: Gname N+ N- NC+ NC- Gain
        if (tokens.size() != 6) {
            std::cout << "Dependent source has incorrect number of arguments: " << line << std::endl;
            return;
        }
        std::string source_name = tokens[0];
        try {
            int n1 = std::stoi(tokens[1]) - 1;
            int n2 = std::stoi(tokens[2]) - 1;
            int n3 = std::stoi(tokens[3]) - 1;
            int n4 = std::stoi(tokens[4]) - 1;
            float gain = convertSPICENumToFloat(tokens[5]);

            nodes.insert(n1);   
            nodes.insert(n2);
            nodes.insert(n3);
            nodes.insert(n4);
            
            nnodes = std::max({nnodes, n1, n2, n3, n4});
            if (tokens[0][0] == 'G') {
                // Voltage controlled current source
                add_vccs(n1, n2, n3, n4, gain);
            } else {
                // Voltage controlled voltage source
                component_adders.push_back([n1, n2, n3, n4, gain](int ni) {
                    add_vcvs(n1, n2, n3, n4, ni, gain);
                });
            }
        } catch (const std::exception& e) {
            throw std::runtime_error("Unexpected function keyword for dependent source, got '" + token + "' for " + source_name + ": " + e.what());    
        }
    } 
    else if (tokens[0][0] == 'D') {
        // diode: Dname N+ N- ModelName <IS=ISValue> <N=NValue> <VJ=VJValue> <XTI=XTIValue> <BV=BvValue> <IBV=IBVValue>
        if (tokens.size() != 4) {
            std::cout << "Incorrect number of arguments for Diode: " << line << std::endl;
            return;
        }
        std::string diode_name = tokens[0];
        try {
            int n1 = std::stoi(tokens[1]) - 1;
            int n2 = std::stoi(tokens[2]) - 1;
            std::string model_name = tokens[3];
            float is = 1e-15f;
            if (model_is.find(model_name) != model_is.end()) {
                is = model_is[model_name];
            }
            
            nodes.insert(n1);   
            nodes.insert(n2);
            nnodes = std::max({nnodes, n1, n2});

            add_diode(n1, n2, is, VT_DEFAULT);
        } catch (const std::exception& e) {
            throw std::runtime_error("Unexpected arguments for Diode " + diode_name + ": " + e.what());
        }
    } 
    else if (tokens[0][0] == 'M') {
        // MOSFET: Mname D G S B ModelName <L=LValue> <W=WValue> <AD=ADValue> <AS=ASValue> <PS=PSValue> <PD=PDValue>
        if (tokens.size() < 6) {
            std::cout << "Insufficient arguments for MOSFET: " << line << std::endl;
            return;
        }
        std::string mosfet_name = tokens[0];
        try {
            int n1 = std::stoi(tokens[1]) - 1; // ND
            int n2 = std::stoi(tokens[2]) - 1; // NG
            int n3 = std::stoi(tokens[3]) - 1; // NS
            std::string model_name = tokens[4];
            float l;
            float w;
            for (size_t i = 5; i < tokens.size(); i++) {
                if (tokens[i].find("L=") == 0) {
                    l = convertSPICENumToFloat(tokens[i].substr(2));
                } else if (tokens[i].find("W=") == 0) {
                    w = convertSPICENumToFloat(tokens[i].substr(2));
                }
            }

            if (l == 0 || w == 0) {
                std::cout << "Error: L and W must be specified for MOSFET " << mosfet_name << std::endl;
                return;
            }
            
            float kp = std::get<0>(model_mosfet[model_name]);
            float vto = std::get<1>(model_mosfet[model_name]);
            if (vto < 0) {
                vto *= -1;
            }
            float lambda = std::get<2>(model_mosfet[model_name]);

            nodes.insert(n1);   
            nodes.insert(n2);
            nodes.insert(n3);
            nnodes = std::max({nnodes, n1, n2, n3});

            if (std::get<3>(model_mosfet[model_name]) == "NMOS") {
                add_nmos(n2, n1, n3, kp * w / l, vto, lambda);
            } else if (std::get<3>(model_mosfet[model_name]) == "PMOS") {
                add_pmos(n2, n1, n3, kp * w / l, vto, lambda);
            } else {
                std::cout << "Unknown MOSFET model: " << model_name << std::endl;
                return;
            }
        } catch (const std::exception& e) {
            throw std::runtime_error("Unexpected arguments for MOSFET " + mosfet_name + ": " + e.what());
        }
    } else {
        // something went wrong!
        std::cout << "Not implemented yet or error: " << line << std::endl;
        return;
    }

    //TODO: analysis information

}

int main(int argc, char **argv) {
    if (argc < 3) {
        std::cerr << "Usage: " << argv[0] << " <input_file> <output_file>" << std::endl;
        return 1;
    }

    std::string input_file = argv[1];
    std::ifstream file(input_file);
    if (!file.is_open()) {
        std::cerr << "File cannot be opened: " << input_file << std::endl;
        return 1;
    }

    std::string output_file = argv[2];
    std::ofstream output(output_file);
    if (!output.is_open()) {
        std::cerr << "File cannot be opened: " << output_file << std::endl;
        return 1;
    }

    float time_step = -1;
    int nsteps = -1;
    float start_time = 0;

    std::string line;
    // skip first line of input file
    std::getline(file, line);

    while (std::getline(file, line)) {
        if (line.empty()) {
            continue;
        } else if (line[0] == '*') {
            continue;
        } 
        
        std::stringstream ss(line);
        std::string token;
    
        std::vector<std::string> tokens;
        while (ss >> token) {
            tokens.push_back(token);
        }

        if (tokens[0] == ".END") {
            break;
        }
        
        if (tokens[0] == ".TRAN") {
            if (time_step != -1) {
                std::cerr << "Error: Conflicting .TRAN statements are specified." << std::endl;
                return 1;
            }

            if (tokens.size() < 3 || tokens.size() > 4) {
                std::cerr << "Error: .TRAN has incorrect number of arguments." << std::endl;
                return 1;
            }

            time_step = convertSPICENumToFloat(tokens[1]);
            float stop_time = convertSPICENumToFloat(tokens[2]);
            
            if (tokens.size() == 3) {
                nsteps = static_cast<int>(stop_time / time_step);
            } else if (tokens.size() == 4) {
                start_time = convertSPICENumToFloat(tokens[3]);
            }

            continue;
        }

        // otherwise we have a valid line so we split the line by whitespace
        parseLine(line);
    }

    if (time_step < 0 || nsteps < 0) {
        std::cerr << "Error: .TRAN not found or invalid." << std::endl;
        return 1;
    }

    nnodes++;

    if (nnodes + 1 != nodes.size()) {
        std::cerr << "Error: Number of nodes does not match the number of unique nodes found." << std::endl;
        return 1;
    }

    for (const auto& add_component : component_adders) {
        add_component(nnodes++);
    }

    std::vector<int> printed_nodes;
    for (int i = 3; i < argc; i++) {
        try {
            int node = std::stoi(argv[i]) - 1;
            if (node >= 0 && node < nnodes) {
                printed_nodes.push_back(node);
            } else {
                std::cerr << "Invalid node number: " << argv[i] << std::endl;
            }
        } catch (const std::invalid_argument&) {
            std::cerr << "Invalid node number: " << argv[i] << std::endl;
        }
    }
    if (printed_nodes.empty()) {
        for (int i = 0; i < nnodes; i++) {
            printed_nodes.push_back(i);
        }
    }

    stamp_static();

    
    output << "time";
    for (int i : printed_nodes) {
        output << "," << "node" << i + 1;
    }
    output << "\n";

    for (int n = 0; n < nsteps; n++) {
        t = n * time_step;
        update_all(t);

        std::stringstream s;

        if (n * time_step < start_time) {
            continue;
        }

        s << t;

        for (int i : printed_nodes) {
            s << "," << v[i];
        }
        s << "\n";
        output << s.str();
    }

    file.close();
    return 0;

}