#include "pivot_swapper.h"

void pivot_swapper_init(pivot_swapper_t *st, int k, int n_rows, int pivot) {
    st->k               = k;
    st->n               = n_rows;
    st->pivot           = pivot;
    st->j               = k;
    st->state           = (pivot == k ? PS_DONE : PS_SWAP_G_A);
    st->pivot_swap_done = 0;
}

void pivot_swapper_tick(pivot_swapper_t *st) {
    if (st->state == PS_DONE) {
        st->pivot_swap_done = 1;
        return;
    }

    switch (st->state) {
      // —— Swap G rows ——
      case PS_SWAP_G_A:
        st->temp_a = read_float(get_G(st->k, st->j));
        st->state  = PS_SWAP_G_B;
        break;

      case PS_SWAP_G_B:
        st->temp_b = read_float(get_G(st->pivot, st->j));
        st->state  = PS_SWAP_G_WA;
        break;

      case PS_SWAP_G_WA:
        write_float(get_G(st->k, st->j), st->temp_b);
        st->state  = PS_SWAP_G_WB;
        break;

      case PS_SWAP_G_WB:
        write_float(get_G(st->pivot, st->j), st->temp_a);
        if (++st->j < st->n)
            st->state = PS_SWAP_G_A;   // keep swapping G columns
        else
            st->state = PS_SWAP_I_A;   // now move on to RHS
        break;

      // —— Swap I entries ——
      case PS_SWAP_I_A:
        st->temp_a = read_float(get_I(st->k));
        st->state  = PS_SWAP_I_B;
        break;

      case PS_SWAP_I_B:
        st->temp_b = read_float(get_I(st->pivot));
        st->state  = PS_SWAP_I_WA;
        break;

      case PS_SWAP_I_WA:
        write_float(get_I(st->k), st->temp_b);
        st->state  = PS_SWAP_I_WB;
        break;

      case PS_SWAP_I_WB:
        write_float(get_I(st->pivot), st->temp_a);
        st->state  = PS_DONE;
        break;

      default:
        // should never happen
        st->state = PS_DONE;
        break;
    }
}

