// fsm.c
// Single, flattened FSM for full Gaussian elimination (pivot, swap, eliminate, back-sub)
#include "fsm.h"

// Initialize
void gaussian_flat_init(gaussian_flat_t *fsm, int start, int n,
    fp_unit_t *sub_u, fp_unit_t *mul_u, fp_unit_t *div_u) {
    // input
    fsm->start  = start;
    fsm->n      = n;
    // output
    fsm->done   = 0;
    fsm->success = 0;
    fsm->singular = 0;

    // internal
    fsm->state  = GS_IDLE;
    fsm->k      = 0;
    fsm->sub_u  = sub_u;
    fsm->mul_u  = mul_u;
    fsm->div_u  = div_u;
}

// One-cycle tick
void gaussian_flat_tick(gaussian_flat_t *f) {
    // always tick FP pipelines
    altera_fp_tick(f->div_u);
    altera_fp_tick(f->mul_u);
    altera_fp_tick(f->sub_u);

    switch (f->state) {
    case GS_IDLE:
        f->done = 0;
        f->success = 0;
        f->singular = 0;
        if (f->start){
            f->state = PF_INIT;
        }
        break;

    // ----- Pivot Finder -----
    case PF_INIT:
        f->row = f->k;
        f->pivot = f->k;
        f->state = PF_READ_DIAG;
        break;

    case PF_READ_DIAG:
        f->max_val = fabsf(read_float(get_G(f->k,f->k)));   // implemented with bit mask
        f->state = PF_SCAN_CHECK;
        break;

    case PF_SCAN_CHECK:
        if (++f->row < f->n) {
            f->state = PF_READ_VAL;
        } else {
            // pivot search done: initialize swap counter
            f->j     = f->k;
            if (f->max_val < EPSILON){
                f->singular = 1;
                f->state = GS_DONE;
            }
            f->state = PS_SWAP_G_A;
        }
        break;

    case PF_READ_VAL:
        f->val_buf = fabsf(read_float(get_G(f->row,f->k)));
        f->state   = PF_EVALUATE;
        break;

    case PF_EVALUATE:
        if (f->val_buf > f->max_val) {
            f->max_val = f->val_buf;
            f->pivot   = f->row;
        }
        f->state = PF_SCAN_CHECK;
        break;

    // ----- Pivot Swapper (G) -----
    case PS_SWAP_G_A:
        // if pivot == k, skip G swap
        if (f->pivot == f->k) {
            f->state = PS_SWAP_I_A;
        } else {
            // read G[k][j]
            f->temp_a = read_float(get_G(f->k, f->j));
            f->state  = PS_SWAP_G_B;
        }
        break;

    case PS_SWAP_G_B:
        f->temp_b = read_float(get_G(f->pivot,f->j));
        f->state  = PS_SWAP_G_WA;
        break;
    case PS_SWAP_G_WA:
        write_float(get_G(f->k,f->j), f->temp_b);
        f->state = PS_SWAP_G_WB;
        break;
    case PS_SWAP_G_WB:
        write_float(get_G(f->pivot,f->j), f->temp_a);
        if (++f->j < f->n) f->state = PS_SWAP_G_A;
        else               f->state = PS_SWAP_I_A;
        break;

    // ----- Pivot Swapper (I) -----
    case PS_SWAP_I_A:
        f->temp_a = read_float(get_I(f->k));
        //printf("temp_a = %f\n", f->temp_a);
        f->state  = PS_SWAP_I_B;
        break;
    case PS_SWAP_I_B:
        f->temp_b = read_float(get_I(f->pivot));
        f->state  = PS_SWAP_I_WA;
        break;
    case PS_SWAP_I_WA:
        write_float(get_I(f->k), f->temp_b);
        f->state = PS_SWAP_I_WB;
        break;
    case PS_SWAP_I_WB:
        write_float(get_I(f->pivot), f->temp_a);
        f->state = EL_INIT_READ_PIVOT;
        //printf("Pivot swap done, k=%d, pivot=%d\n", f->k, f->pivot);
        //print_matrix_mem(f->n);
        break;

    // ----- Elimination (augmented) -----
    case EL_INIT_READ_PIVOT:
        f->pivot_val = read_float(get_G(f->k,f->k));
        f->state     = EL_INIT_SETUP;
        break;
    case EL_INIT_SETUP:
        if (fabsf(f->pivot_val) < EPSILON) f->state = GS_DONE;
        else { f->ei = f->k+1; f->state = EL_READ_AIK; }
        break;
    case EL_READ_AIK:
        if (f->ei < f->n) {
            f->buf_col = read_float(get_G(f->ei,f->k));
            f->state   = EL_DIV_START;
        } else {
            f->state = GS_CHECK_K;
        }
        break;
    case EL_DIV_START:
        f->div_u->a     = f->buf_col;
        f->div_u->b     = f->pivot_val;
        f->div_u->start = 1;
        f->state        = EL_DIV_WAIT;
        break;
    case EL_DIV_WAIT:
        f->div_u->start = 0;
        if (!f->div_u->result_valid) break;
        //printf("a = %f, b=%f, m = %f\n", f->div_u->a, f->div_u->b, f->m);
        if (f->div_u->exception){
            f->state = GS_FAILED;
            break;
        }
        if (isnan(f->div_u->result)) {
            f->state = GS_FAILED;
            break;
        }
        /* In Verilog, 
        wire is_nan = (exp == 8'b1111_1111) && (frac != 0);
        */

        f->m     = f->div_u->result;
        f->state = EL_COL_SETUP;
        break;
    case EL_COL_SETUP:
        f->ej    = f->k;
        f->state = EL_READ_COL;
        break;
    case EL_READ_COL:
        if (f->ej < f->n)
            f->buf_col = read_float(get_G(f->k,f->ej));
        else
            f->buf_col = read_float(get_I(f->k));
        f->state = EL_READ_ROW;
        break;
    case EL_READ_ROW:
        if (f->ej < f->n)
            f->buf_row = read_float(get_G(f->ei,f->ej));
        else
            f->buf_row = read_float(get_I(f->ei));
        f->state = EL_MUL_START;
        break;
    case EL_MUL_START:
        f->mul_u->a     = f->m;
        f->mul_u->b     = f->buf_col;
        f->mul_u->start = 1;
        f->state        = EL_MUL_WAIT;
        break;
    case EL_MUL_WAIT:
        f->mul_u->start = 0;
        if (!f->mul_u->result_valid) break;
        if (f->mul_u->exception || isnan(f->mul_u->result)) {
            f->state = GS_FAILED;
            break;
        }  
        //printf("mul done, %.2f\n", f->mul_u->result);
        f->state = EL_SUB_START;
        break;
    case EL_SUB_START:
        f->sub_u->a     = f->buf_row;
        f->sub_u->b     = f->mul_u->result;
        f->sub_u->start = 1;
        f->state        = EL_SUB_WAIT;
        break;
    case EL_SUB_WAIT:
        f->sub_u->start = 0;
        if (!f->sub_u->result_valid) break;
        if (f->sub_u->exception || isnan(f->sub_u->result)) {
            f->state = GS_FAILED;
            break;
        }  
        f->state = EL_WRITE_COL;
        break;
    case EL_WRITE_COL:
        if (f->ej < f->n)
            write_float(get_G(f->ei,f->ej), f->sub_u->result);
        else
            write_float(get_I(f->ei), f->sub_u->result);
        f->state = EL_COL_INCREMENT;
        break;
    case EL_COL_INCREMENT:
        if (++f->ej <= f->n) f->state = EL_READ_COL;
        else                 {f->state = EL_ROW_INCREMENT;
            //printf("row done, k=%d, j=%d, pivot=%d\n", f->k, f->ej, f->pivot);
            //print_matrix_mem(f->n);
        }

        break;
    case EL_ROW_INCREMENT:
        f->ei++;
        if (f->ei < f->n)
            f->state = EL_READ_AIK;
        else{
            f->state = GS_CHECK_K;
        }

        break;

    // ----- Check k and move to back-sub -----
    case GS_CHECK_K:
        //printf("Elimination done, k=%d, pivot=%d\n", f->k, f->pivot);
        //print_matrix_mem(f->n);
        if (++f->k < f->n)
            f->state = PF_INIT;
        else{
            f->bi = f->n-1;
            f->state = BS_READ_I;
        }
        
        break;

    // ----- Back-substitution -----
    case BS_READ_I:
        f->mem_I = read_float(get_I(f->bi));
        f->state = BS_SETUP;
        break;
    case BS_SETUP:
        f->sum = f->mem_I;
        f->bj  = f->bi + 1;
        f->state = (f->bj < f->n ? BS_READ_A : BS_READ_DIAG);
        break;
    case BS_READ_A:
        f->mem_A = read_float(get_G(f->bi,f->bj));
        f->state = BS_READ_V;
        break;
    case BS_READ_V:
        f->mem_v = read_float(get_v(f->bj));
        f->state = BS_MUL_START;
        break;
    case BS_MUL_START:
        f->mul_u->a     = f->mem_A;
        f->mul_u->b     = f->mem_v;
        f->mul_u->start = 1;
        f->state        = BS_MUL_WAIT;
        break;
    case BS_MUL_WAIT:
        f->mul_u->start = 0;
        if (!f->mul_u->result_valid) break;
        if (f->mul_u->exception || isnan(f->mul_u->result)) {
            f->state = GS_FAILED;
            break;
        }  
        f->sum  -= f->mul_u->result;
        f->state = BS_NEXT_CHECK;
        break;
    case BS_NEXT_CHECK:
        if (++f->bj < f->n)
            f->state = BS_READ_A;
        else
            f->state = BS_READ_DIAG;
        break;
    case BS_READ_DIAG:
        f->mem_diag = read_float(get_G(f->bi,f->bi));
        f->state    = BS_DIV_START;
        break;
    case BS_DIV_START:
        if (fabsf(f->mem_diag) < EPSILON) {
            f->state = GS_DONE;
        } else {
            f->div_u->a     = f->sum;
            f->div_u->b     = f->mem_diag;
            f->div_u->start = 1;
            f->state        = BS_DIV_WAIT;
        }
        break;
    case BS_DIV_WAIT:
        if (!f->div_u->result_valid) break;
        //printf("a = %f, b=%f, m = %f\n", f->div_u->a, f->div_u->b, f->m);
        if (f->div_u->exception){
            f->state = GS_FAILED;
            break;
        }
        if (isnan(f->div_u->result)) {
            f->state = GS_FAILED;
            break;
        }

        write_float(get_v(f->bi), f->div_u->result);    //?? should we hold this for another cycle?
        f->state = BS_ROW_DEC;
        break;
    case BS_ROW_DEC:
        f->bi--;
        if (f->bi >= 0)
            f->state = BS_READ_I;
        else
            f->state = GS_DONE;
        break;

    case GS_DONE:
        // Finished all stages
        f->done = 1;
        f->success = 1;
        f->state = GS_IDLE;
        break;
    case GS_FAILED:
        // Finished all stages
        printf("failed\n");
        f->done = 1;
        f->success = 0;
        //f->state = GS_IDLE;
        break;
    }
}
