/*
Back substitution module, step 4/4
*/
#include "backsub.h"
#include "gaussian.h"

//start = 128

void backsub_init(backsub_t *st, int start, fp_unit_t *sub_u, fp_unit_t *mul_u, fp_unit_t *div_u){
    st->start = start;
    st->backsub_done = 0;
    st->backsub_error = 0;
    st->i = MAT_SIZE - 1;
    st->j = MAT_SIZE;
    st->sum = 0.0f;
    st-> state = SBS_INIT;
    st->mul_u = mul_u;
    st->sub_u = sub_u; 
    st->div_u = div_u;
}

void backsub_tick(backsub_t *st){
    /*
    States (each will wait for the floating point module)
    Init
    Multiply
    Subtract
    Divide
    */
    if (st->backsub_done) return;
    if (st->div_u->exception) {
        st->backsub_error = 1;
        st->state         = SBS_DONE;
    }
    // need to count for different exception types
    

    // tick the IP cores by one cycle as well
    altera_fp_tick(st->mul_u);
    altera_fp_tick(st->sub_u);
    altera_fp_tick(st->div_u);

    switch (st->state) {
        case SBS_INIT:
            st->sum = read_float(get_I(st->i));        // separate I-vector read
            st->j   = st->i + 1;
            st->state = (st->j < MAT_SIZE) ? SBS_MUL_START : SBS_DIV_START;
            break;

        case SBS_MUL_START:
            float a = read_float(get_G(st->i, st->j));
            float v = read_float(get_v(st->j));
            st->mul_u->a = a;
            st->mul_u->b = v;
            st->mul_u->start = 1;
            st->state = SBS_MUL_WAIT;
            break;

        case SBS_MUL_WAIT:
            if (!st->mul_u->result_valid) break;    // stall until done
            st->mul_u->start = 0;
            st->state       = SBS_SUB_START;
            break;

        case SBS_SUB_START:
            st->sub_u->a     = st->sum;
            st->sub_u->b     = st->mul_u->result;
            //printf("sub_u->b: %f\n", st->sub_u->b);
            st->sub_u->start = 1;
            st->state       = SBS_SUB_WAIT;
            break;

        case SBS_SUB_WAIT:
            if (!st->sub_u->result_valid) break;
            st->sub_u->start = 0;
            st->sum         = st->sub_u->result;
            //printf("sub_u->result: %f\n", st->sub_u->result);
            st->state       = SBS_NEXT_CHECK;
            break;
        
    
        case SBS_NEXT_CHECK:
            st->j++;
            if (st->j < MAT_SIZE)
              st->state = SBS_MUL_START;
            else
              st->state = SBS_DIV_START;
            break;
    
        case SBS_DIV_START: {
            float diag = read_float(get_G(st->i, st->i));
            if (fabsf(diag) < EPSILON) {
              st->backsub_error = 1;
              st->state         = SBS_DONE;
            } else {
              st->div_u->a     = st->sum;
              st->div_u->b     = diag;
              st->div_u->start = 1;
              st->state       = SBS_DIV_WAIT;
            }
            break;
          }
    
        case SBS_DIV_WAIT:
            if (!st->div_u->result_valid) break;
            st->div_u->start     = 0;
            //printf("div_u->result: %f\n", st->div_u->result);
            write_float(get_v(st->i), st->div_u->result);
            st->state = (st->i == 0 ? SBS_DONE : SBS_ROW_DONE);
            break;
    
        case SBS_ROW_DONE:
            st->i--;
            st->state = SBS_INIT;
            break;
    
        case SBS_DONE:
            st->backsub_done = 1;
            break;
      }
    /*
    int n = st->start;
    if(!st->backsub_done){
        if (st->j > st->i) {
            st->sum = read_float(get_G(st->i, n));
            st->j = st->i + 1;
        }
        else if (st->j < n) {
            float a_ij = read_float(get_G(st->i, st->j));
            float v_j = read_v(st->j);

            st->sum -= a_ij * v_j;
            st->j++;
        }
        else {
            float diag = read_float(get_G(st->i, st->i));
            if (fabsf(diag) < 1e-12f) {
                st->backsub_error = 1;
            } else {
                write_v(st->i, st->sum / diag);
            }
        }
        st->i--;
        if (st->i < 0) {
            st->backsub_done = 1;
        } else {
            st->j = n;
        }
    }*/
}

