// nonlinear.c
#include "circuit.h"
#define CUTOFF_I 1e-8f

// nonlinear components: diode, nmos, pmos

void dio_stamp_nl(Component *c, float Gm[][MAT_SIZE], float I[]) {
    Diode *d = &c->u.dio;
    // 1. read last stage's voltage
    float Vd = d->v_prev;

    // 2. compute new iteration parameters
    float ExpV = expf(Vd / d->Vt);  
    float Geq = (d->Is / d->Vt) * ExpV;             // Geq = Is/Vt * exp(Vd/Vt)
    float Ieq = d->Is * (ExpV - 1.0f) - Geq * Vd;   // Ieq = Id - Geq*Vd
    if (PRINT_NL)
        printf("diode_stamp_nl: Vd=%.6g  Geq=%.6g  Ieq=%.6g v_prev=%.6g\n", Vd, Geq, Ieq, Vd);

    // 3. stamp G and I
    int n1 = d->n1, n2 = d->n2;
    if (n1 != -1) Gm[n1][n1] += (Geq);
    if (n2 != -1) Gm[n2][n2] += (Geq);
    if (n1 != -1 && n2 != -1) {
        Gm[n1][n2] -= (Geq);
        Gm[n2][n1] -= (Geq);
    }
    if (n1 != -1) I[n1] -= (Ieq);
    if (n2 != -1) I[n2] += (Ieq);
}

void dio_update(Component *c) {
    Diode *d = &c->u.dio;
    d->v_prev = (d->n1!=-1? v[d->n1]:0) - (d->n2!=-1? v[d->n2]:0);
    if (PRINT_NL)
        printf("diode_update: v_prev=%.6g\n", d->v_prev);
}


void nmos_stamp_nl(Component *c, float Gm[][MAT_SIZE], float I[]) {
    Nmos *m = &c->u.nmos;
    int ng = m->ng, nd = m->nd, ns = m->ns;
    float beta = m->beta, VT = m->Vt, lambda = m->lambda;

    // 1. read last iteration's voltages
    float Vgs = m->vgs_prev;
    float Vds = m->vds_prev;

    // 2. compute Ids and small‑signal gains g_d = dIds/dVds, g_m = dIds/dVgs
    float Ids, g_d, g_m;
    if (Vgs <= VT) {        // cutoff
        Ids = CUTOFF_I * Vds; // small cutoff current
        g_d = CUTOFF_I;
        g_m = 0.0f;
    } else {
        float Vov = Vgs - VT;
        if (Vds < Vov) {    // triode
            Ids = beta * (Vds*Vov - 0.5f*Vds*Vds);
            g_d = beta * (Vov - Vds);
            g_m = beta * Vds;
        } else {            // saturation
            Ids = 0.5f * beta * Vov*Vov * (1 + lambda*(Vds - Vov));
            g_d = 0.5f * beta * Vov*Vov * lambda;
            g_m = beta * Vov * (1 + lambda*(Vds- Vov)) - 0.5f * beta * Vov*Vov *lambda ;
        }
    }

  //    Ieq = Ids_prev - g_d*Vds_prev - g_m*Vgs_prev
    float Ieq = Ids - g_d * Vds - g_m * Vgs;

    // 3. stamp on G and I
    if (nd != -1) {
        Gm[nd][nd] +=  g_d;
        if (ns != -1) Gm[nd][ns] -=  g_d;
    }
    if (ns != -1) {
        Gm[ns][ns] +=  g_d;
        if (nd != -1) Gm[ns][nd] -=  g_d;
    }
    if (ng != -1) {
        // injection at drain
        if (nd != -1) {
            Gm[nd][ng] += g_m;
            if (ns != -1) Gm[nd][ns] += -g_m;
        }
        // injection at source
        if (ns != -1) {
            Gm[ns][ng] += -g_m;
            Gm[ns][ns] += +g_m;
        }
    }

    if (nd != -1) I[nd] -= Ieq;
    if (ns != -1) I[ns] += Ieq;

    if (PRINT_NL)
        printf("nmos_stamp_nl: vds=%.6g vgs=%.6g Ids=%.6g, g_m=%.6g g_d=%.6g Ieq=%.6g\n", m->vds_prev, m->vgs_prev,  Ids, g_m, g_d, Ieq);
}

void nmos_update(Component *c) {
    Nmos *m = &c->u.nmos;
    float vg = (m->ng != -1 ? v[m->ng] : 0.0f);
    float vd = (m->nd != -1 ? v[m->nd] : 0.0f);
    float vs = (m->ns != -1 ? v[m->ns] : 0.0f);

    float new_vgs = vg - vs;
    float new_vds = vd - vs;

    // clamp each delta to prevent abrupt changes
    float dgs = new_vgs - m->vgs_prev;
    if      (dgs >  DV_MAX) new_vgs = m->vgs_prev + DV_MAX;
    else if (dgs < -DV_MAX) new_vgs = m->vgs_prev - DV_MAX;
    float dds = new_vds - m->vds_prev;
    if      (dds >  DV_MAX) new_vds = m->vds_prev + DV_MAX;
    else if (dds < -DV_MAX) new_vds = m->vds_prev - DV_MAX;

    // store the clamped values
    m->vgs_prev = new_vgs;
    m->vds_prev = new_vds;

    if (PRINT_NL)
        printf("nmos_update (clamped): vgs=%.6g  vds=%.6g\n", m->vgs_prev, m->vds_prev);
}

// — PMOS nonlinear companion stamp —
void pmos_stamp_nl(Component *c, float Gm[][MAT_SIZE], float I[]) {
    Pmos *m = &c->u.pmos;
    int ng = m->ng, nd = m->nd, ns = m->ns;
    float beta = m->beta, VT = m->Vt, lambda = m->lambda;

    // 1. read last iteration's voltages
    float Vsg = m->vsg_prev;
    float Vsd = m->vsd_prev;

    // 2. compute Isd and small‑signal gains g_d = dIsd/dVsd, g_m = dIsd/dVsg
    float Isd, g_d, g_m;
    if (Vsg <= VT) {        // cutoff
        Isd = CUTOFF_I * Vsd; // small cutoff current
        g_d = CUTOFF_I;
        g_m = 0.0f;
    } else {
        float Vov = Vsg - VT;
        if (Vsd < Vov) {    // triode
            Isd = beta * (Vsd*Vov - 0.5f*Vsd*Vsd);
            g_d = beta * (Vov - Vsd);
            g_m = beta * Vsd;
        } else {            // saturation
            Isd = 0.5f * beta * Vov*Vov * (1 + lambda*(Vsd - Vov));
            g_d = 0.5f * beta * Vov*Vov * lambda;
            g_m = beta * Vov * (1 + lambda*(Vsd- Vov)) - 0.5f * beta * Vov*Vov *lambda ;
        }
    }

    //    Ieq = Isd_prev - g_d*Vsd_prev - g_m*Vsg_prev
    float Ieq = Isd - g_d*Vsd - g_m*Vsg;

    // 3. stamp on G and I
    if (nd != -1) {
        Gm[nd][nd] +=  g_d;
        if (ns != -1) Gm[nd][ns] -=  g_d;
    }
    if (ns != -1) {
        Gm[ns][ns] +=  g_d;
        if (nd != -1) Gm[ns][nd] -=  g_d;
    }
    if (ng != -1) {
        // injection at drain
        if (nd != -1) {
            Gm[nd][ng] += g_m;
            if (ns != -1) Gm[nd][ns] += -g_m;
        }
        // injection at source
        if (ns != -1) {
            Gm[ns][ng] += -g_m;
            Gm[ns][ns] += +g_m;
        }
    }

    if (nd != -1) I[nd] += Ieq;
    if (ns != -1) I[ns] -= Ieq;

    if (PRINT_NL)
        printf("pmos_stamp_nl: Vsd=%.6g Vsg=%.6g Isd=%.6g, g_m=%.6g g_d=%.6g Ieq=%.6g\n",
            Vsd, Vsg, Isd, g_m, g_d, Ieq);
}

void pmos_update(Component *c) {
    Pmos *m = &c->u.pmos;
    float vg = (m->ng != -1 ? v[m->ng] : 0.0f);
    float vd = (m->nd != -1 ? v[m->nd] : 0.0f);
    float vs = (m->ns != -1 ? v[m->ns] : 0.0f);

    float new_vsg = vs - vg;
    float new_vsd = vs - vd;

    // clamp each delta to prevent abrupt changes
    float dsg = new_vsg - m->vsg_prev;
    if      (dsg >  DV_MAX) new_vsg = m->vsg_prev + DV_MAX;
    else if (dsg < -DV_MAX) new_vsg = m->vsg_prev - DV_MAX;
    float dsd = new_vsd - m->vsd_prev;
    if      (dsd >  DV_MAX) new_vsd = m->vsd_prev + DV_MAX;
    else if (dsd < -DV_MAX) new_vsd = m->vsd_prev - DV_MAX;

    // store the clamped values
    m->vsg_prev = new_vsg;
    m->vsd_prev = new_vsd;

    if (PRINT_NL)
        printf("pmos_update (clamped): Vsg=%.6g  Vsd=%.6g\n", m->vsg_prev, m->vsd_prev);
}