#include "./decmp/header.h"

// square root of the sum of the squares (L2 norm)
int vec_norm(int *vector, int vec_size)
{
    long int cur_element;
    long long int sum = 0;
    long int product = 0;
    
    int i;
    
    for(i = 0; i < vec_size; i++)
    {
        cur_element = vector[i];
        product = cur_element * cur_element;
        sum += product;
    }
    
    return (int)sqrt(sum);
}

//solves Ax = b for a UT cholesky decomposed matrix using backsubstitution
void linsolve(int R_I[SPARSITY][SPARSITY], int *vec, int activeSetSize, int *result, char trans_mat)
{
  int i,k;
  int sum;
 
  if ( !trans_mat ) 
  {
    for(i = activeSetSize-1; i >= 0; i--)
    {
      for(sum = vec[i], k = i+1; k < activeSetSize; k++)
        sum -= R_I[i][k]*result[k];
    
      result[i] = sum/R_I[i][i];
    }
  }
  else
  {
     int i;
  
     for(i = 0; i < activeSetSize; i++)
     {
        for(sum = vec[i], k = i-1; k >= 0; k--)
           sum -= R_I[k][i]*result[k];
    
        result[i] = sum/R_I[i][i];
     }
  }

  return;
}

// updateChol: Updates the Cholesky factor R of the matrix A(:,activeSet)'*A(:,activeSet) by adding A(:,newIndex)
void updateChol(int R[SPARSITY][SPARSITY], char A[K][SPARSITY], int *activeSet, int activeSetSize, int newIndex)
{
  int i,j,ix;
  int sum = 0, q = 0;
  
  int newVec[K];
  int result[activeSetSize];
  int p[activeSetSize];
  
  char tmp = 0;
  
  int newVecSqrSum = 0;
  
  // INTERFACE WITH HARDWARE 
  
  // add a column to the active set of A
  
  // write newIndex to its spot in the block RAM
  IOWR_32DIRECT(NEWVEC_BASE, j*4, newIndex);
  
  // write something other than EMPTY to ctrl_reg(2)
  IOWR_32DIRECT(MATRIX_VECTOR_MULT_INST_BASE, 8, 0x11110000);
        
  // write MODE1 to ctrl_reg(1)
  IOWR_32DIRECT(MATRIX_VECTOR_MULT_INST_BASE, 4, 0xFFFF0000);
    
  // write GO to ctrl_reg(0)
  IOWR_32DIRECT(MATRIX_VECTOR_MULT_INST_BASE, 0, 0xFFFFFFFF);

  unsigned int notdone = 0;
  while(notdone != 0xffffffff)
       notdone = IORD_32DIRECT(MATRIX_VECTOR_MULT_INST_BASE, 8);
        
  // fill A[*][newIndex] with the LFSR data from block RAM
  for (i = 0; i < K; i++)
        A[i][newIndex] = 2*IORD_32DIRECT(IVEC_BASE, i*4) - 1;
  
  for( i = 0; i < K; i++ )
  { 
    newVec[i] = A[i][newIndex];
    newVecSqrSum += newVec[i]*newVec[i];
  }

  if( !activeSetSize ) 
  {
    for( i = 0; i < K; i++ )
    {
      sum += (int)pow(newVec[i],2);
    }
    
    R[0][0] = (int)sqrt(sum);
  }

  else
  {
    // hard-coding mat_vec_mul( (void **)A, newVec, n, activeSetSize, 1, result, 1 );
    for ( i = 0; i < activeSetSize; i++ )
    {
      ix = activeSet[i];
      result[i] = 0;
      for ( j = 0; j < K; j++ )
      {
        tmp = A[j][ix];
        
        if (tmp == 1)
            result[i] += newVec[j];
        else
            result[i] -= newVec[j];
      }
    }
    
    // solve R*p = result
    linsolve( R, result, activeSetSize, p, 1 );
    
    for( i = 0; i < activeSetSize; i++ )
    {
      q += p[i]*p[i];
      R[i][activeSetSize] = p[i];
    }
    
    R[activeSetSize][activeSetSize] = (int)sqrt(newVecSqrSum - q);
  }
    
  return;
}

void add_dx_activeSet_to_x_activeSet(int *dx, int *x, int *activeSet)
{
    int j;  
    for(j = 0; activeSet[j] != -1; j++)
        x[activeSet[j]] += dx[activeSet[j]];
    
    return;
}

void res_update(int *y, int n, char A[K][SPARSITY], int *x, int *activeSet, int *res)
{
    // Matlab code: res = y - A(:,activeSet) * x(activeSet)
    // res, y are nx1
    
    // this uses the subset of the random A matrix generated by the matrix_vector_mult block
    
    int sum;
    char tmp = 0;
    
    int j, k;
    for(j = 0; j < n; j++)
    {
        sum = 0;
    
        for(k = 0; activeSet[k] != -1; k++)
        {   
            tmp = A[j][activeSet[k]];
            
            if (tmp == 1)
                sum += x[activeSet[k]];
            else
                sum -= x[activeSet[k]];
        }
           
        res[j] = y[j] - sum;
    }
   
    return;
}

// takes in a compressed data set y [k rows, 1 col] and produces solution x [N rows, 1 col]
// see Matlab code; this is mostly a direct translation
// except for where we use the hardware
void decomp(int *y, int *x)
{
    // K = length of y
    // N = length of solution (number of pixels in image)
    
    int i, j;
    int ind = 1;
    
    // reduced-size A matrix stores only the active set
    // small enough that we can store it in memory
    char A[K][SPARSITY];
    for (i = 0; i < K; i++)
        for (j = 0; j < SPARSITY; j++)
	    A[i][j] = 0;
    
    int R_I[SPARSITY][SPARSITY];
    
    for (i = 0; i < SPARSITY; i++)
    {
        for (j = 0; j < SPARSITY; j++)
            R_I[i][j] = 0;
    }

    int activeSetSize = 0;
    
    int activeSet[N];
    
    for (i = 0; i < N; i++)
    {
        activeSet[i] = -1;
    }
    
    int res[K];
    
    for (i = 0; i < K; i++)
        res[i] = y[i];
    
    int normy = vec_norm(y, K);
    int resnorm = normy;
    
    char done = 0;
    
    // corr = A^T * res; A is nxN, so A^T is Nxn, and res is nx1
    // so corr is Nx1
    int corr[N];
    
    // subset of corr vector with indices from activeSet
    int corr_activeSet[N];
    
    for (j = 0; j < N; j++)
        corr_activeSet[j] = 0;
    
    i = 0;
    int maxcorr = 0;
    
    int newIndex;
    
    int dx[N];
    for (j = 0; j < N; j++)
        dx[j] = 0;
    
    int dx_activeSet[N];
    for (j = 0; j < N; j++)
        dx_activeSet[j] = 0;
        
    // stores result of first linsolve in updateChol function
    int z[N];
    for (j = 0; j < N; j++)
        z[j] = 0;
        
    unsigned int notdone, iter;
    
    while (done == 0)
    {
        
	/////////////////////////////////////////////////////////////////////
	// this code multiplies A^T by res and assigns the result to corr
        
        // write res to the input vector space of matrix_vector_mult block
        for (j = 0; j < K/2; j++)
             IOWR_32DIRECT(IVEC_BASE, j*4, res[j]);
        
	// set control signals for mode 0 of the matrix multiplier
        
	// write something other than EMPTY to ctrl_reg(2)
        IOWR_32DIRECT(MATRIX_VECTOR_MULT_INST_BASE, 8, 0x11110000);
        
        // write MODE0 to ctrl_reg(1)
        IOWR_32DIRECT(MATRIX_VECTOR_MULT_INST_BASE, 4, 0x0000FFFF);
    
        // write GO to ctrl_reg(0)
        IOWR_32DIRECT(MATRIX_VECTOR_MULT_INST_BASE, 0, 0xFFFFFFFF);

        notdone = 0;
        while(notdone != 0xffffffff)
            notdone = IORD_32DIRECT(MATRIX_VECTOR_MULT_INST_BASE, 8);
        
        // fill the whole corr array by getting P_NUM elements at a time
        while (iter * P_NUM < N)
        {
            for (j = 0; j < P_NUM; j++)
                corr[iter*P_NUM + j] = IORD_32DIRECT(OVEC_BASE, j*4);
            
            iter++;
        }
        
        maxcorr = 0;

        // [maxcorr i] = max(abs(corr))
        for (j = 0; j < N; j++)
        {
            if ((int)abs(corr[j]) > maxcorr)
            {
                i = j;
                maxcorr = (int)abs(corr[j]);
            }
        }
        
	/////////////////////////////////////////////////////////////////////
	
        newIndex = i;
        
        // Update Cholesky factorization of A_I
        updateChol(R_I, A, activeSet, activeSetSize, newIndex);
        
        activeSet[activeSetSize] = newIndex;
        activeSetSize++;
        
        // Solve for the least squares update: (A_I'*A_I)dx_I = corr_I
        for (j = 0; j < N; j++)
            dx[j] = 0;
        
        for (j = 0; j < activeSetSize; j++)
            corr_activeSet[j] = corr[activeSet[j]];
        
        linsolve(R_I, corr_activeSet, activeSetSize, z, 1);
        linsolve(R_I, z, activeSetSize, dx_activeSet, 0);
        
        for(j = 0; j < activeSetSize; j++)
            dx[activeSet[j]] = dx_activeSet[j];
        
        add_dx_activeSet_to_x_activeSet(dx, x, activeSet);
        
        // compute new residual
        res_update(y, K, A, x, activeSet, res);

        resnorm = vec_norm(res, K);
        printf("resnorm = %d\n",resnorm);
        
        done = (resnorm <= 1000) ? 1 : 0;
            
        printf("Iteration %d: Adding variable %d\n", ind, newIndex);
        
        ind += 1;
    }

    return;
}