// medsvm.cpp 
//
// REGULAR DOUBLE PRECISION ON INPUT DATA
//
//

#include "stdio.h"
#include "stdlib.h"
#include <sys/types.h>
#include <netdb.h>
#include <termios.h>
#include <unistd.h>
#include <sys/file.h>
#include <sys/ioctl.h>
#include <string.h>
#include <limits.h>
#include <fstream.h>
#include <strstream.h>
#include <math.h>
#include "time.h"
#include "assert.h"

#include "memory.h"

#include "quickmatrix.cpp"

#ifndef DRAND
#define DRAND ((double)rand()/(double)RAND_MAX)
#endif

#ifndef M_PI
#define M_PI 3.1415927
#endif


/* For Keyboard Interface */
struct termios TermDescr;
struct termios OrigTerm;
int TermMask;
int Key;
int ch;



// Integrates over unobserved y's (indicated by a y value > 999)
// with a white Gaussian. Or, a uniform distribution from 0 to 1
// if the -uni switch is used.

// This version forms an HMM transition matrix between the indeces of
// the lambdas. This effectively does reinforcement learning with the
// reward equal to the delta J value...

#ifndef SQRD
#define SQRD(x) ((x)*(x))
#endif

#ifndef MIN
#define MIN(x,y)  ( (x)<(y)  ?   (x):(y) )
#endif

#ifndef MAX
#define MAX(x,y)  ( (x)>(y)  ?   (x):(y) )
#endif


#define MAXLINE  500000
#define VTINY     1e-9
#define TINY      1e-3




// Command Line Arguments
char   trainFile[256];
char   extensionFile[256];
char   lambdas_fname[256];
char   model_fname[256];
char   outputs_fname[256];
char   lambdaFile[256];
int    ntrain = -1;
double c     = 30.0;
double negc  = 10.0;
double sigma = 3.0;
int    iters = -5;
double tol   = 1e-8;

#define KERNEL_LINEAR   0
#define KERNEL_POLY     1
#define KERNEL_RBF      2
#define KERNEL_ERBF     3
#define KERNEL_SIGMOID  4

int kerneltype = KERNEL_LINEAR;
double p1 = 0.1;
double p2 = 3.0;


// Data, Parameters and Costs
int    N, D;
Matrix X, Xt;
Matrix HMMi;
Matrix HMMv;
Matrix HMMt;
Matrix K;
Matrix H;
Vector Y, Yt;
Vector lambdas;
Vector Tll;


#ifndef SIGNT
#define SIGNT(x)  ( (x)>(0)  ?   (1):(-1) )
#endif


#define THUGE 1e48

double
kernel(Matrix X, int R, Matrix Xt, int C)
{
  int i, D = X->cols;
  double val = 0.0;
  if (kerneltype==KERNEL_LINEAR)
    {
      for (i=0; i<D; i++) val += X->d2[R][i]*Xt->d2[C][i];
    }
  else if (kerneltype==KERNEL_POLY)
    {
      for (i=0; i<D; i++) val += X->d2[R][i]*Xt->d2[C][i];
      val = pow(val+1,p1);
    }
  else if (kerneltype==KERNEL_RBF)
    {
      for (i=0; i<D; i++) val += SQRD(X->d2[R][i]-Xt->d2[C][i]);
      val = exp(-val / (2.0*p1*p1));
    }
  else if (kerneltype==KERNEL_ERBF)
    {
      for (i=0; i<D; i++) val += SQRD(X->d2[R][i]-Xt->d2[C][i]);
      val = exp(-sqrt(val) / (2.0*p1*p1));
    }
  else if (kerneltype==KERNEL_SIGMOID)
    {
      for (i=0; i<D; i++) val += X->d2[R][i]*Xt->d2[C][i];
      val = tanh(p1*val/((double) D) + p2);
    }

  return(val);
}


void
parse_args (int argc, char **argv)
{
	int i;
	strcpy(trainFile,"");
	strcpy(extensionFile,"ext");
	strcpy(lambdas_fname,"lambdas.");
	strcpy(model_fname,"model.");
	strcpy(outputs_fname,"outputs.");
	strcpy(lambdaFile,"");
	
	for(i=1; i<argc; i++)
    {
		if (!strcmp(argv[i], "-train"))
			strcpy(trainFile, argv[++i]);
		else if (!strcmp(argv[i], "-ext"))
			strcpy(extensionFile, argv[++i]);
		else if (!strcmp(argv[i], "-c"))
		{
			c = atof(argv[++i]);
			if (c<0)
			{
				negc =  c;
				c    =  1;
			}
		}
		else if (!strcmp(argv[i], "-sigma"))
			sigma = atof(argv[++i]);
		else if (!strcmp(argv[i], "-ntrain"))
			ntrain = atoi(argv[++i]);
		else if (!strcmp(argv[i], "-p1"))
			p1    = atof(argv[++i]);
		else if (!strcmp(argv[i], "-p2"))
			p2    = atof(argv[++i]);
		else if (!strcmp(argv[i], "-lambda"))
			strcpy(lambdaFile, argv[++i]);
		else if  (!strcmp(argv[i], "-kernel"))
			kerneltype = atoi(argv[++i]);
		else if (!strcmp(argv[i], "-iters"))
			iters = atoi(argv[++i]);
		else
		{
			fprintf(stderr,"Usage: %s -train <fname> -ntrain <int> -ext <fname> -kernel <int> -p1 <float> -p2 <float> -iters <int> -lambda <fname> -sigma <double> -c <double>\n",argv[0]);
			fprintf(stderr,"MED for Kernel SVM Classification.\n");
			exit (2);
		}
    }
	
	strcat(lambdas_fname,extensionFile);
	strcat(model_fname,extensionFile);
	strcat(outputs_fname,extensionFile);
}




double
getModel(Vector l)
{
	int i;
	double b = 0.0;
	for (i=0; i<l->len; i++) b += sigma*(l->d[i]*Y->d[i]);
	return(b);
}

void
main(int argc, char *argv[]) 
{
	int     i,j,k;
	FILE    *fp;
	char    line[MAXLINE+1];
	char    b2[50];
	int     stop = -1;
	int     dump = -1;
	
	// Get Command Line Arguments
	parse_args(argc,argv);

  // Setup the Keyboard
  ioctl(STDIN_FILENO, TCGETS, &OrigTerm);
  ioctl(STDIN_FILENO, TCGETS, &TermDescr);
  TermMask=04401;
  TermDescr.c_lflag=TermDescr.c_lflag & TermMask;
  TermDescr.c_cc[VEOF]=1;
  TermDescr.c_cc[VEOL]=0;
  ioctl(STDIN_FILENO, TCSETS, &TermDescr);
	


	// Open the File to Determine its Dimensions
	if ((fp = fopen(trainFile,"r"))==NULL)
    {
		fprintf(stderr,"%s can't open %s for input.\n",argv[0],trainFile);
		exit(0);
    }
	fgets(line,MAXLINE,fp);
	istrstream ist(line,strlen(line));
	D = -1;
	while (ist>>b2) D++;
	
	printf("Dimensionality=%d\n",D);
	
	int NN = 1;
	while (fgets(line, MAXLINE, fp) != NULL) if (strlen(line)>2) NN++;
	fclose(fp);
	

	if (ntrain>0) N = ntrain;
	else N = NN;

	printf("Train Count = %d Total Count = %d\n",N,NN);
	


	// Allocate the Right Size Arrays
	X       = MatrixCreate(N,D);
	Xt      = MatrixCreate(NN-N,D);  // contains test data...
	Y       = VectorCreate(N);
	Yt      = VectorCreate(NN-N);
	K       = MatrixCreate(NN,NN);

	// Reread the File and Load it into Data Array
	fp = fopen(trainFile,"r");
	for (i=0; i<N; i++)
	{
		fgets(line, MAXLINE, fp);
		istrstream ist(line,strlen(line));
		for (j=0; j<D; j++)
		{
			ist>>b2;
			X->d2[i][j] = atof(b2);
		}
		ist>>b2;
		Y->d[i] = atof(b2);
	} 
	for (i=0; i<(NN-N); i++)
	{
		fgets(line, MAXLINE, fp);
		istrstream ist(line,strlen(line));
		for (j=0; j<D; j++)
		{
			ist>>b2;
			Xt->d2[i][j] = atof(b2);
		}
		ist>>b2;
		Yt->d[i] = atof(b2);
	}
	fclose(fp);
	
	// Form Kernel Matrix
	for (i=0; i<N; i++)
    {
		for (j=0; j<N; j++) K->d2[i][j] = kernel(X,i,X,j);
		for (j=0; j<(NN-N); j++) K->d2[i][j+N] = kernel(X,i,Xt,j);
    }
	for (i=0; i<(NN-N); i++)
    {
		for (j=0; j<N; j++) K->d2[i+N][j] = kernel(Xt,i,X,j);
		for (j=0; j<(NN-N); j++) K->d2[i+N][j+N] = kernel(Xt,i,Xt,j);
    }
	

	// Deallacoate the stuff we don't need and reallacoate more things
	MatrixFree(X);
	MatrixFree(Xt);
	HMMv    = MatrixCreate(N,20);
	HMMi    = MatrixCreate(N,20);
	HMMt    = MatrixCreate(N,20);
	VectorSet((Vector) HMMi,-1.0);
	VectorSet((Vector) HMMv,-1.0);
	VectorSet((Vector) HMMt,-1.0);
	H       = MatrixCreate(N,N);
	lambdas = VectorCreate(N);
	Tll     = VectorCreate(N);
	

	// Form H Matrix
	for (i=0; i<N; i++)
    {
		for (j=0; j<N; j++) H->d2[i][j] = K->d2[i][j]*Y->d[i]*Y->d[j];
    }

	// Must Learn Lambdas
	VectorSet(lambdas,0.0);
	if (strlen(lambdaFile)>2)
	{
		fp = fopen(lambdaFile,"r");
		for (i=0; i<N; i++)
		{
			fgets(line,MAXLINE,fp);
			lambdas->d[i] = atof(line);
		}
		fclose(fp);
	}

	
	// Setup	
	double J = 0.0;
	J   += VectorSum(lambdas);
	J   += -0.5*sigma*SQRD(VectorDot(lambdas,Y));
	MatrixVectorMultiply(Tll,H,lambdas);
	J   += -0.5*VectorDot(lambdas,Tll);
	for (i=0; i<N; i++) J += log(1.0-(1.0/c)*lambdas->d[i]);
    double sumylXll = VectorDot(Y,lambdas);
	fprintf(stderr,"Initial Objective = %e\n", J);

	// For axis reinforcement speedup
	double ACC_1  = 0.0;
	double ACC_t  = 0.0;
	double ACC_tt = 0.0;
	double ACC_l  = 0.0;
	double ACC_lt = 0.0;
	double PAR_M  = 0.0;
	double PAR_B  = 0.0;
	int    direction = 0;
	int    oldDir    = -100;
	int    itr       = 1;
	double tolerance = 1e-10;
	double increase  = 0.0;
	int    past      = HMMi->cols;
	double lastCheck = -2.0*fabs(J);
	if (iters < 0) iters = 64000000;

	while (stop<0)
	{ 
      // Grab from Keyboard
      ioctl(0,FIONREAD,&Key);
      if (Key>0)
	{
	  Key = getchar();
	if (Key=='q') stop = 1;
	if (Key=='s') dump = 1;
		}
		itr++;	
		if (itr>iters) stop = 1;
		
		
		// Pick a direction to optimize
		if ((DRAND>0.2) && (oldDir>=0))
		{
			// pick an HMM direction
			double accV    =  0.0;
			double sumV    =  0.0;
			double maxV    = -1e30;
			int    maxJ    =  0;
			double value   =  0.0;
			double decay   =  0.0;
			double curprob =  0.0;
			double totprob =  0.0;
			int    sample  = -100;
			for (j=0; j<past; j++) 
				if ((HMMi->d2[oldDir][j]>=0) && (HMMi->d2[oldDir][j]!=oldDir))
				{
					value  = (HMMv->d2[oldDir][j] - HMMt->d2[oldDir][j]*PAR_M - PAR_B);
					if (value>maxV)
					{
						maxV = value;
						maxJ = j;
					}
					sumV  += value;
					accV  += 1.0;
				}
				if (accV>0.0)
				{
					if ((maxV*accV-sumV)!=0.0)
					  decay = accV / (maxV*accV - sumV);
					else
					  decay = 1.0;
					for (j=0; j<past; j++) 
						if ((HMMi->d2[oldDir][j]>=0) && (HMMi->d2[oldDir][j]!=oldDir))
						{
							value    = (HMMv->d2[oldDir][j] - HMMt->d2[oldDir][j]*PAR_M - PAR_B);
							totprob += decay*exp(-decay*(maxV-value));
						}
						totprob *= DRAND;
						for (j=0; j<past; j++) 
							if ((HMMi->d2[oldDir][j]>=0) && (HMMi->d2[oldDir][j]!=oldDir))
							{
								value    = (HMMv->d2[oldDir][j] - HMMt->d2[oldDir][j]*PAR_M - PAR_B);
								curprob += decay*exp(-decay*(maxV-value));
								if ((sample<0) && (curprob>=totprob)) sample = j;
							}
							direction = (int) HMMi->d2[oldDir][sample];
							direction = (int) HMMi->d2[oldDir][maxJ];
							if (direction>=(N)) direction = N-1;
							if (direction<0) direction = 0;
				}
				else
				{
					// pick a random direction
					direction = (int) (DRAND*(N+1.0));
					if (direction>=N) direction = N-1;
					if (direction<0) direction = 0;
				}
		}
		else
		{
			// pick a random direction
			direction = (int) (DRAND*(N+1.0));
			if (direction>=N) direction = N-1;
			if (direction<0) direction = 0;
		}
		


		// Do analytic line search
		i = direction;
		double T1 = sigma;
		double Jtop = -3.0;
		double Jtry = -3.0;
		double Jold = -3.0;
		double l1   = lambdas->d[i];
		double io   = l1;
		double opt  = l1;
		double t2   = H->d2[i][i];
		double t3   = Y->d[i]*(sumylXll-Y->d[i]*l1);
		double t4   = -H->d2[i][i]*l1;
		for (j=0; j<N; j++) t4 += H->d2[i][j]*lambdas->d[j];
		Jold        = l1 + log(1.0-l1/c) - 0.5*T1*(SQRD(l1)+2.0*l1*t3)-0.5*(t2*SQRD(l1)+2.0*l1*t4);
		double d0   = -1.0 + c - c*T1*t3 - c*t4;
		double d1   = -1.0 - c*T1 - c*t2 + T1*t3 + t4;
		double d2   = t2 + T1;
		double p1   = sqrt(d1*d1-4.0*d0*d2);
		double sol1 = ((-d1+p1)/(2.0*d2));
		double sol2 = ((-d1-p1)/(2.0*d2));
		
		sol1 = MIN(sol1, (c-tol));
		sol1 = MAX(sol1, 0.0);
		sol2 = MIN(sol2, (c-tol));
		sol2 = MAX(sol2, 0.0);
		// KATHERINE: if (_isnan(sol1)) sol1 = 0.0;
		// KATHERINE: if (_isnan(sol2)) sol2 = 0.0;
		
		l1   = sol1;
		Jtry = l1 + log(1.0-l1/c) - 0.5*T1*(SQRD(l1)+2.0*l1*t3)-0.5*(t2*SQRD(l1)+2.0*l1*t4);  
		Jtry = Jtry - Jold;
		if (Jtry>Jtop)
		{
			Jtop = Jtry;
			io   = l1;
		}
		
		l1   = sol2;
		Jtry = l1 + log(1.0-l1/c) - 0.5*T1*(SQRD(l1)+2.0*l1*t3)-0.5*(t2*SQRD(l1)+2.0*l1*t4);   
		Jtry = Jtry - Jold;
		if (Jtry>Jtop)
		{
			Jtop = Jtry;
			io   = l1;
		}
		
		if (Jtop>0)
		{
			opt           = io;
			J             = J + Jtop;
			increase      = Jtop;
			sumylXll      = sumylXll - Y->d[i]*lambdas->d[i] + Y->d[i]*opt;
			lambdas->d[i] = opt;
		}
		else increase = 0.0;
		

		
		if ((itr%(MAX(N,500)))==0)
		{
			printf("%e %d\n",J,itr); 
			fflush(stdout); 
		}
		
		
		if ((itr%(10*N))==0)
		{
			if (((J-lastCheck)<=(0.00000002)*fabs(J)))
			{
				if (negc<0)
				{
					if (c>fabs(negc))
						stop = 1;
					else
					{
						c = c*1.5;
						dump = 1;
					}
				}
				else stop = 1;
			}
			lastCheck = J;
		}
		
		
		// Add to the history model the current data point. Replace the least valuable previous one.
		if (oldDir>=0)
		{
			int worst = -5;
			double t;
			double value = 0.0;
			double lowestvalue = 1e30;
			for (j=0; j<past; j++)
				if (HMMi->d2[oldDir][j]==direction) worst = j;
				if (worst<0) 
					for (j=0; j<past; j++)
						if (HMMi->d2[oldDir][j]<0) worst = j;
						if (worst<0)
						{
							for (j=0; j<past; j++)
							{
								value = HMMv->d2[oldDir][j] - HMMt->d2[oldDir][j]*PAR_M - PAR_B;
								if (value<lowestvalue)
								{
									worst       = j;
									lowestvalue = value;
								}
							}
						}
						if (worst<0) worst = 0;
						HMMi->d2[oldDir][worst] = direction;
						HMMv->d2[oldDir][worst] = log(increase+1e-10);
						HMMt->d2[oldDir][worst] = itr;
						t = itr;
						ACC_1  = t;
						ACC_t  = ACC_t + t;
						ACC_tt = ACC_tt + t*t;
						ACC_l  = ACC_l + log(increase+1e-10);
						ACC_lt = ACC_lt + log(increase+1e-10)*t;
						PAR_M  = (ACC_lt - ACC_t*ACC_l/ACC_1) / (ACC_tt - ACC_t*ACC_t/ACC_1);
						PAR_B  = (ACC_l - PAR_M*ACC_t) / ACC_1;
		}
		oldDir = direction;
		
		
		
		// Check if any output needs to be generated
		if ((stop==1) || (dump==1))
		{
			// Save the lambdas
			fp = fopen(lambdas_fname,"w");
			for (k=0; k<N; k++) fprintf(fp,"%e\n",lambdas->d[k]);
			fclose(fp);
			
			// Compute the actual final linear model and bias
			double bias;
			FILE *fplog = fopen("log","a");
			for (k=0; k<argc; k++) fprintf(fplog,"%s ",argv[k]);
			fprintf(fplog,"\n");
			fprintf(fplog,"J=%e i=%d ",J,itr);
			fprintf(stderr,"J=%e i=%d ",J,itr);
			bias = getModel(lambdas);
			fp = fopen(model_fname,"w");
			fprintf(fp,"%e\n",bias);
			fclose(fp);

			
			// Recompute the newoutputs and the prediction error
			double val;
			double err1 = 0.0;
			double err2 = 0.0; 
			fp = fopen(outputs_fname,"w");
			for (k=0; k<N; k++)
			{
				val = bias;
				for (j=0; j<N; j++) val += Y->d[j]*lambdas->d[j]*K->d2[j][k];
				fprintf(fp,"%e %e\n",val,Y->d[k]);
				if (SIGNT(val)!=SIGNT(Y->d[k])) err1++;
			}
			fprintf(fplog,"err1=%e ",err1);
			fprintf(stderr,"err1=%e ",err1);
			for (k=0; k<(NN-N); k++)
			{
				if (Yt->d[k]!=0)
				{
					val = bias;
					for (j=0; j<N; j++) val += Y->d[j]*lambdas->d[j]*K->d2[j][k+N];
					fprintf(fp,"%e %e\n",val,Yt->d[k]);
					if (SIGNT(val)!=SIGNT(Yt->d[k])) err2++;
				}
			}
			fprintf(fplog,"err2=%e\n",err2);
			fprintf(fplog,"%e %e\n",c,err2);
			fprintf(stderr,"c=%e err2=%e\n",c,err2);
			fclose(fplog);
			fclose(fp);
			
			dump = -1;
		}
    }
	
	
	// Free Arrays
	MatrixFree(H);
	MatrixFree(K);
	VectorFree(lambdas);
	VectorFree(Tll);
	VectorFree(Y);
	VectorFree(Yt);
	MatrixFree(HMMi);
	MatrixFree(HMMv);
	MatrixFree(HMMt);
}






