//#include <io.h>
//#include <system.h>
#include <stdio.h>
#include <math.h>

void print_matrix(char mat_name[], void **mat, int rows, int cols, char mat_is_int, char in_line)
{
    int i,j;
    
    if (in_line == 0)
    {
        for (i = 0; i < rows; i++)
        {
            for (j = 0; j < cols; j++)
            {
                printf("%s[%d][%d] = %d\n",mat_name,i,j,((int **)mat)[i][j]);
            }
        }
    }
    else
    {
        for (i = 0; i < rows; i++)
        {
            printf("%s[%d] = {",mat_name,i);
            
            for (j = 0; j < cols; j++)
            {
                printf("%d, \n",((int **)mat)[i][j]);
            }
            
            printf("};\n\n");
        }
    }
}

void print_vector(char vec_name[], void *vec, int rows, char vec_is_int, char in_line)
{
    int i;
    
    if (in_line == 0)
    {
        for (i = 0; i < rows; i++)
        {
            printf("%s[%d] = %d\n",vec_name,i,((int *)vec)[i]);
        }
    }
    else
    {
        printf("%s = {",vec_name);
    
        for (i = 0; i < rows; i++)
        {
            printf("%d, ",((int *)vec)[i]);
        }
        
        printf("};\n");
    }
}

// square root of the sum of the squares (L2 norm)
int vec_norm(int *vector, int vec_size)
{
    long int cur_element;
    long long int sum = 0;
    long int product = 0;
    
    int i;
    
    //printf("vec_norm\n");
    
    for(i = 0; i < vec_size; i++)
    {
        cur_element = vector[i];
        //printf("%d: %d\n",i,cur_element);
        product = cur_element * cur_element;
        //printf("product: %ld\n",product);
        sum += product;
        //printf("sum: %lld\n",sum);
    }
    
    //printf("sum: %lld\n",sum);
    
    //printf("vec_norm: %d\n",(int)sqrt(sum));
    return (int)sqrt(sum);
}

// trans_mat = 1 if mat should be transposed, 0 otherwise
void mat_vec_mul(int **mat, int *vec, int mat_rows, int mat_cols, char trans_mat, int *result, char mat_is_int)
{
  printf("mat_vec_mul\n");
  //printf("trans_mat = %d\n",trans_mat);
  int i,j;
  
  //printf("pre-transpose: cols=%d, rows=%d\n",mat_cols,mat_rows);
  
  int running_sum = 0;

    for( i = mat_cols - 1; i >= 0; i-- )
    {
        running_sum = 0;
    
        for( j = mat_rows - 1; j >= 0; j-- )
            running_sum += mat[j][i]*vec[j];
            
        result[i] = running_sum;
    }
  return;
}

//solves Ax = b for a UT cholesky decomposed matrix using backsubstitution
void linsolve(int **R_I, int *vec, int activeSetSize, int *result, char trans_mat)
{
  //printf("linsolve: trans_mat = %d\n",trans_mat);
  int i,k;
  int sum;
 
  if ( !trans_mat ) 
  {
    for(i = activeSetSize-1; i >= 0; i--)
    {
      for(sum = vec[i], k = i+1; k < activeSetSize; k++)
        sum -= R_I[i][k]*result[k];
    
      result[i] = sum/R_I[i][i];
      //      printf("result[%d] in linsolve loop1a = %3.3f / %3.3f\n",i,sum,R_I[i][i]);
    }

/*    for(i = 0; i < activeSetSize; i++)
    {
      for(sum = result[i], k = i-1; k >= 0; k--)
        sum -= R_I[i][k]*result[k];
    
      result[i] = sum/R_I[i][i];
      //      printf("result[%d] in linsolve loop2a = %3.3f / %3.3f\n",i,sum,R_I[i][i]);
    }*/
  }
  else {
  
    //int **R_I_T = (int **)malloc(activeSetSize * sizeof(int *));
    
    int i, j;
    /*for (i = 0; i < activeSetSize; i++)
    {
        R_I_T[i] = (int *)malloc(activeSetSize * sizeof(int));
    
        for(j = 0; j < activeSetSize; j++)
    {
        R_I_T[i][j] = R_I[j][i];
    }
    }*/
  
     for(i = 0; i < activeSetSize; i++)
    {
      for(sum = vec[i], k = i-1; k >= 0; k--)
        sum -= R_I[k][i]*result[k];
    //sum -= R_I_T[i][k]*result[k];
    
      result[i] = sum/R_I[i][i];
      //      printf("result[%d] in linsolve loop1a = %3.3f / %3.3f\n",i,sum,R_I[i][i]);
    }
    
    /*for(i = 0; i < activeSetSize; i++)
    {
      for(sum = result[i], k = i-1; k >= 0; k--)
        sum -= R_I_T[i][k]*result[k];
    
      result[i] = sum/R_I_T[i][i];
      //      printf("result[%d] in linsolve loop2a = %3.3f / %3.3f\n",i,sum,R_I[i][i]);
    }*/
    
    /*for(i = 0; i < activeSetSize; i++)
    {
      printf("vec[i] = %f for i = %d\n", vec[i], i);
      for(sum = vec[i], k = i-1; k >= 0; k--)
        sum -= R_I[i][k]*result[k];
    
      result[i] = sum/R_I[i][i];
      printf("result[%d] in linsolve loop1b = %3.3f / %3.3f\n",i,sum,R_I[i][i]);
    }*/
    
    /*for(i = activeSetSize-1; i >= 0; i--)
    {
      for(sum = result[i], k = i+1; k < activeSetSize; k++)
        sum -= R_I[i][k]*result[k];
    
      result[i] = sum/R_I[i][i];    
      printf("result[%d] in linsolve loop2b = %3.3f / %3.3f\n",i,sum,R_I[i][i]);
    }*/
    
    //for (i = 0; i < activeSetSize; i++)
    //  free(R_I_T[i]);
    
    //free(R_I_T);
  }

  return;
}

void updateChol(int **R, int n, int N, int **A, int *activeSet, int activeSetSize, int newIndex)
{
  //printf("updateChol\n");
  int i,j,ix;
  int sum = 0, q = 0;
  //int *newVec = malloc( sizeof(int)*activeSetSize );
  //int *result = malloc( sizeof(int)*activeSetSize );
  //int *p      = malloc( sizeof(int)*activeSetSize );
  
  int newVec[n];
  int result[activeSetSize];
  int p[activeSetSize];
  
  int newVecSqrSum = 0;
  
  for( i = 0; i < n; i++ )
  { 
    newVec[i] = A[i][newIndex];
    newVecSqrSum += newVec[i]*newVec[i];
  }

  //print_vector("newVec", newVec, n, 0, 1);
    
  if( !activeSetSize ) 
  {
    for( i = 0; i < n; i++ )
    {
      sum += (int)pow(newVec[i],2);
      //printf("in updateChol, sum now equals %3.3f\n",sum);
    }
    
    R[0][0] = (int)sqrt(sum);
    //printf("R[0][0] = %f\n",R[0][0]);
  }

  else
  {
    //mat_vec_mul( (void **)A, newVec, n, activeSetSize, 1, result, 1 );
    for ( i = 0; i < activeSetSize; i++ )
    {
      ix = activeSet[i];
      result[i] = 0;
      for ( j = 0; j < n; j++ )
    result[i] += A[j][ix]*newVec[j];
    }
    
    //print_vector("result in updateChol", result, activeSetSize, 0, 1);
    
    linsolve( R, result, activeSetSize, p, 1 );
    
    //print_vector("p", p, activeSetSize, 0, 1);
    
    for( i = 0; i < activeSetSize; i++ )
    {
      //      printf("p[%i] = %f\n",i,p[i]);
      q += p[i]*p[i];
      //      printf("q = %3.3f\n",q);
      R[i][activeSetSize] = p[i];
    }
    
    R[activeSetSize][activeSetSize] = (int)sqrt(newVecSqrSum - q);
    //    printf("R[%d][%d] = %3.3f\n",activeSetSize,activeSetSize,sqrt(q));
  }
  
  for ( i = 0; i < activeSetSize; i++ ) 
    //printf("R[%i][0] = %f\n",i,R[i][0]);
  
  return;
}

void add_dx_activeSet_to_x_activeSet(int *dx, int *x, int *activeSet)
{
	int j;	
	for(j = 0; activeSet[j] != -1; j++)
	{
		x[activeSet[j]] += dx[activeSet[j]];
		//		printf("activeSet[%d] += dx[%d] (%3.3f)\n",j,activeSet[j],dx[activeSet[j]]);
	}
	
	return;
}

void res_update(int *y, int n, int **A, int *x, int *activeSet, int *res)
{
	// Matlab code: res = y - A(:,activeSet) * x(activeSet)
	// res, y are nx1
	
	int sum;
	
	int j, k;
	for(j = 0; j < n; j++)
	{
		sum = 0;
	
		for(k = 0; activeSet[k] != -1; k++)
		{	
			sum += A[j][activeSet[k]] * x[activeSet[k]];
			//printf("sum += %d*%3.3f\n",A[j][activeSet[k]],x[activeSet[k]]);
		}
	       
		res[j] = y[j] - sum;
	}
	
	return;
}

void decomp(int **A, int *y, int n, int N, int sparsity, int *x)
{
	// int n = length of y
	// int N = length of solution (number of pixels in image)
	int sqrtN = (int)sqrt(N);
	
	double OptTol = 0.00001;
	double solFreq = 0;
	double lambdaStop = 0;
	int maxIters = n;
	
	// Parameters for linsolve function
	// Global variables for linsolve function
	double machPrec = 0.00001;
	
	int i, j;
	int k = 1;
	//test
	/*double **test = (double **)A;
	for( i = 0; i < sqrtN; i++) {
	  for( j = 0; j < sqrtN; j++) 
	    printf("A[%i][%i] = %f\n",i,j,(double)((int **)test)[i][j]);
	}*/
/*	printf("Creating x array\n");
	
	double *x = (double *)malloc(N*sizeof(double));
	for (i = 0; i < N; i++)
		x[i] = 0;*/
	
	printf("Creating R_I 2D array\n");
	
	// assuming for now that R_I can be up to sqrtNxsqrtN... may be wrong
	int **R_I = (int **)malloc(sparsity*sizeof(int *));
	for (i = 0; i < sparsity; i++)
	{
		R_I[i] = (int *)malloc(sparsity*sizeof(int));
		
		for (j = 0; j < sparsity; j++)
			R_I[i][j] = 0;
	}

	// assuming for now that activeSet can be up to 1xN... may be wrong
	int activeSetSize = 0;
	
	printf("Creating activeSet array\n");
	int *activeSet = (int *)malloc(N*sizeof(int));
	for (i = 0; i < N; i++)
	{
		activeSet[i] = -1;
	}
	//printf("wtf?\n");
	int *res = (int *)malloc(n*sizeof(int));
	
	for (i = 0; i < n; i++)
		res[i] = y[i];
	
	int normy = vec_norm(y, n);
	int resnorm = normy;
	
	char done = 0;
	
	// corr = A^T * res; A is nxN, so A^T is Nxn, and res is nx1
	// so corr is Nx1
	int *corr = (int *)malloc(N*sizeof(int));
	
	// subset of corr vector with indices from activeSet
	int *corr_activeSet = (int *)malloc(N*sizeof(int));
	for (j = 0; j < N; j++)
		corr_activeSet[j] = 0;
	//printf("wtf2?\n");

	i = 0;
	int maxcorr = 0;
	
	int newIndex;
	
	int *dx = (int *)malloc(N*sizeof(int));
	for (j = 0; j < N; j++)
		dx[j] = 0;
	
	int *dx_activeSet = (int *)malloc(N*sizeof(int));
	for (j = 0; j < N; j++)
		dx_activeSet[j] = 0;
		
	// used to set certain elements of dx to the values of a returned vector
	//int *dx_activeSet_indices = (int *)malloc(N*sizeof(int));
	//for (j = 0; j < N; j++)
	//	dx_activeSet_indices[j] = -1;
	
	// stores result of first linsolve in updateChol function
	int *z = (int *)malloc(N*sizeof(int));
	for (j = 0; j < N; j++)
		z[j] = 0;
		
	while (done == 0)
	{
	  //printf("wtf3?\n");
	        // corr = A^T * res
		//print_matrix("A", (void **)A, n, N, 1, 1);
		//print_vector("res", res, n, 0, 1);
		
		//printf("wtf4a?\n");
		
		mat_vec_mul(A, res, n, N, 1, corr, 1);
	 
	 //printf("wtf4b?\n");
	 
	 	//print_vector("corr", corr, N, 0, 1);
	 
		maxcorr = 0;
		// [maxcorr i] = max(abs(corr))
		for (j = 0; j < N; j++)
		{
		        if ((int)abs(corr[j]) > maxcorr)
			{
				i = j;
				maxcorr = (int)abs(corr[j]);
			}
		}
		
		printf("maxcorr = %d\n",maxcorr);
		
		newIndex = i;
		
		// Update Cholesky factorization of A_I
		updateChol(R_I, n, N, A, activeSet, activeSetSize, newIndex);
		
		activeSet[activeSetSize] = newIndex;
		activeSetSize++;
		
		// Solve for the least squares update: (A_I'*A_I)dx_I = corr_I
		for (j = 0; j < N; j++)
			dx[j] = 0;
		
		for (j = 0; j < activeSetSize; j++)
			corr_activeSet[j] = corr[activeSet[j]];
			
		//print_matrix("R_I", (void **)R_I, sparsity, sparsity, 0, 1);
		//print_vector("corr_activeSet", corr_activeSet, activeSetSize, 0, 1);
		
		linsolve(R_I, corr_activeSet, activeSetSize, z, 1);
		
		//print_vector("z", z, N, 0, 1);
		
		linsolve(R_I, z, activeSetSize, dx_activeSet, 0);
		
		for(j = 0; j < activeSetSize; j++)
			dx[activeSet[j]] = dx_activeSet[j];
			
		//print_vector("dx", dx, N, 0, 1);
		
		add_dx_activeSet_to_x_activeSet(dx, x, activeSet);
		
		// compute new residual
		res_update(y, n, A, x, activeSet, res);
		//print_vector("res", res, n, 0, 1);

		if (k <= 273 && k >= 269)
		{
			//print_vector("res",res,n,1,0);
			print_vector("corr",corr,N,1,0);
		}

		resnorm = vec_norm(res, n);
		printf("resnorm = %d\n",resnorm);
		
		done = (resnorm <= 1000) ? 1 : 0;
		//if ( k == 400 ) 
		//	done = 1;
			
		printf("Iteration %d: Adding variable %d\n", k, newIndex);
		
		k += 1;
	}
	
	//for(j = 0; j < N; j++)
	//	printf("%f;\n", x[j]);

	return;
}
