//#include <io.h>
//#include <system.h>
#include <stdio.h>
#include <math.h>
#include "decomp_funcs.c"

extern void print_matrix(char mat_name[], void **mat, int rows, int cols, char mat_is_int, char in_line);

extern void print_vector(char vec_name[], void *vec, int rows, char vec_is_int, char in_line);

extern int vec_norm(int *vector, int vec_size);

// trans_mat = 1 if mat should be transposed, 0 otherwise
extern void mat_vec_mul(int **mat, int *vec, int mat_rows, int mat_cols, char trans_mat, int *result, char mat_is_int);

extern void updateChol(int **R, int n, int N, int **A, int *activeSet, int activeSetSize, int newIndex);

extern void linsolve(int **R_I, int *vec, int activeSetSize, int *result, char trans_mat);

//void res_update(int *y, int n, int **A, int *x, int *activeSet, int *res);

//void add_dx_activeSet_to_x_activeSet(int *dx, int *x, int *activeSet);

void add_dx_activeSet_to_x_activeSet(int *dx, int *x, int *activeSet)
{
	int j;	
	for(j = 0; activeSet[j] != -1; j++)
	{
		x[activeSet[j]] += dx[activeSet[j]];
		//		printf("activeSet[%d] += dx[%d] (%3.3f)\n",j,activeSet[j],dx[activeSet[j]]);
	}
	
	return;
}

void res_update(int *y, int n, int **A, int *x, int *activeSet, int *res)
{
	// Matlab code: res = y - A(:,activeSet) * x(activeSet)
	// res, y are nx1
	
	int sum;
	
	int j, k;
	for(j = 0; j < n; j++)
	{
		sum = 0;
	
		for(k = 0; activeSet[k] != -1; k++)
		{	
			sum += A[j][activeSet[k]] * x[activeSet[k]];
			//printf("sum += %d*%3.3f\n",A[j][activeSet[k]],x[activeSet[k]]);
		}
	       
		res[j] = y[j] - sum;
	}
	
	return;
}

void decomp(int **A, int *y, int n, int N, int sparsity, int *x)
{
	// int n = length of y
	// int N = length of solution (number of pixels in image)
	int sqrtN = (int)sqrt(N);
	
	double OptTol = 0.00001;
	double solFreq = 0;
	double lambdaStop = 0;
	int maxIters = n;
	
	// Parameters for linsolve function
	// Global variables for linsolve function
	double machPrec = 0.00001;
	
	int i, j;
	int k = 1;
	//test
	/*double **test = (double **)A;
	for( i = 0; i < sqrtN; i++) {
	  for( j = 0; j < sqrtN; j++) 
	    printf("A[%i][%i] = %f\n",i,j,(double)((int **)test)[i][j]);
	}*/
/*	printf("Creating x array\n");
	
	double *x = (double *)malloc(N*sizeof(double));
	for (i = 0; i < N; i++)
		x[i] = 0;*/
	
	printf("Creating R_I 2D array\n");
	
	// assuming for now that R_I can be up to sqrtNxsqrtN... may be wrong
	int **R_I = (int **)malloc(sparsity*sizeof(int *));
	for (i = 0; i < sparsity; i++)
	{
		R_I[i] = (int *)malloc(sparsity*sizeof(int));
		
		for (j = 0; j < sparsity; j++)
			R_I[i][j] = 0;
	}

	// assuming for now that activeSet can be up to 1xN... may be wrong
	int activeSetSize = 0;
	
	printf("Creating activeSet array\n");
	int *activeSet = (int *)malloc(N*sizeof(int));
	for (i = 0; i < N; i++)
	{
		activeSet[i] = -1;
	}
	//printf("wtf?\n");
	int *res = (int *)malloc(n*sizeof(int));
	
	for (i = 0; i < n; i++)
		res[i] = y[i];
	
	int normy = vec_norm(y, n);
	int resnorm = normy;
	
	char done = 0;
	
	// corr = A^T * res; A is nxN, so A^T is Nxn, and res is nx1
	// so corr is Nx1
	int *corr = (int *)malloc(N*sizeof(int));
	
	// subset of corr vector with indices from activeSet
	int *corr_activeSet = (int *)malloc(N*sizeof(int));
	for (j = 0; j < N; j++)
		corr_activeSet[j] = 0;
	//printf("wtf2?\n");

	i = 0;
	int maxcorr = 0;
	
	int newIndex;
	
	int *dx = (int *)malloc(N*sizeof(int));
	for (j = 0; j < N; j++)
		dx[j] = 0;
	
	int *dx_activeSet = (int *)malloc(N*sizeof(int));
	for (j = 0; j < N; j++)
		dx_activeSet[j] = 0;
		
	// used to set certain elements of dx to the values of a returned vector
	//int *dx_activeSet_indices = (int *)malloc(N*sizeof(int));
	//for (j = 0; j < N; j++)
	//	dx_activeSet_indices[j] = -1;
	
	// stores result of first linsolve in updateChol function
	int *z = (int *)malloc(N*sizeof(int));
	for (j = 0; j < N; j++)
		z[j] = 0;
		
	while (done == 0)
	{
	  //printf("wtf3?\n");
	        // corr = A^T * res
		//print_matrix("A", (void **)A, n, N, 1, 1);
		//print_vector("res", res, n, 0, 1);
		
		//printf("wtf4a?\n");
		
		mat_vec_mul(A, res, n, N, 1, corr, 1);
	 
	 //printf("wtf4b?\n");
	 
	 	//print_vector("corr", corr, N, 0, 1);
	 
		maxcorr = 0;
		// [maxcorr i] = max(abs(corr))
		for (j = 0; j < N; j++)
		{
		        if ((int)abs(corr[j]) > maxcorr)
			{
				i = j;
				maxcorr = (int)abs(corr[j]);
			}
		}
		
		printf("maxcorr = %d\n",maxcorr);
		
		newIndex = i;
		
		// Update Cholesky factorization of A_I
		updateChol(R_I, n, N, A, activeSet, activeSetSize, newIndex);
		
		activeSet[activeSetSize] = newIndex;
		activeSetSize++;
		
		// Solve for the least squares update: (A_I'*A_I)dx_I = corr_I
		for (j = 0; j < N; j++)
			dx[j] = 0;
		
		for (j = 0; j < activeSetSize; j++)
			corr_activeSet[j] = corr[activeSet[j]];
			
		//print_matrix("R_I", (void **)R_I, sparsity, sparsity, 0, 1);
		//print_vector("corr_activeSet", corr_activeSet, activeSetSize, 0, 1);
		
		linsolve(R_I, corr_activeSet, activeSetSize, z, 1);
		
		//print_vector("z", z, N, 0, 1);
		
		linsolve(R_I, z, activeSetSize, dx_activeSet, 0);
		
		for(j = 0; j < activeSetSize; j++)
			dx[activeSet[j]] = dx_activeSet[j];
			
		//print_vector("dx", dx, N, 0, 1);
		
		add_dx_activeSet_to_x_activeSet(dx, x, activeSet);
		
		// compute new residual
		res_update(y, n, A, x, activeSet, res);
		//print_vector("res", res, n, 0, 1);

		if (k <= 273 && k >= 269)
		{
			//print_vector("res",res,n,1,0);
			print_vector("corr",corr,N,1,0);
		}

		resnorm = vec_norm(res, n);
		printf("resnorm = %d\n",resnorm);
		
		done = (resnorm <= 1000) ? 1 : 0;
		//if ( k == 400 ) 
		//	done = 1;
			
		printf("Iteration %d: Adding variable %d\n", k, newIndex);
		
		k += 1;
	}
	
	//for(j = 0; j < N; j++)
	//	printf("%f;\n", x[j]);

	return;
}
