#include <signal.h>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <netdb.h>
#include <fstream.h>
#include <strstream.h>
#include <libc.h>
#include <termios.h>
#include <unistd.h>
#include <iostream.h>
#include <map.h>
#include <math.h>
extern "C" {
#include <tpm/matrix.h>
}

// This version forms an HMM transition matrix between the indeces of
// the lambdas. This effectively does reinforcement learning with the
// reward equal to the delta J value...

#ifndef SQRD
#define SQRD(x) ((x)*(x))
#endif


#ifndef MIN
#define MIN(x,y)  ( (x)<(y)  ?   (x):(y) )
#endif

#ifndef MAX
#define MAX(x,y)  ( (x)>(y)  ?   (x):(y) )
#endif


#define MAXLINE  100000
#define VTINY     1e-9
#define TINY      1e-3

// For Keyboard Interface
struct termios TermDescr;
struct termios OrigTerm;
int TermMask;
int Key;
int ch;

// Command Line Arguments
char   trainFile[256];
char   extensionFile[256];
char   lambdas_fname[256];
char   model_fname[256];
char   outputs_fname[256];
char   lambdaFile[256];
int    ntrain = -1;
double p     = 1.0;    // probability the feature is on
double c     = 30.0;
double negc  = 10.0;
double sigma = 3.0;
int    iters = -5;

double
logZthetaI(double p, double W)
{
  return( 0.5*W*W + log(p + (1.0-p)*exp(-0.5*W*W)));
}

double
logZgammaT(double c, double lt)
{
  return( lt + log(1.0-lt/c) );
}

double
logZb1(Vector l, Vector y)
{
  double val = 0.0;
  int    i;
  for (i=0; i<l->len; i++)
    val += l->d[i]*y->d[i];
  return(val);
}

double
getW(int index, Vector l, Vector y, Matrix X)
{
  int i;
  double val = 0.0;
  for (i=0; i<l->len; i++) val += (l->d[i]*y->d[i])*X->d2[i][index];
  return(val);
}

// Data, Parameters and Costs
int    N, D;
Matrix X, Xt;
Matrix HMMi;
Matrix HMMv;
Matrix HMMt;
Vector Y, Yt;
Vector W;
Vector lambdas;
Vector oldLambdas;
Vector lZtheta;
Vector lZgamma;
double lZb;
double lZb1;
double lZgam;
double lZt;
Vector linear;


double
logPartitionFast(int INDEX, double newl, double oldl)
{
  int i;
  int I;
  double val=0.0;

  I = INDEX;
  lZgam += lZgamma->d[I];
  lZgamma->d[I] = logZgammaT(c,newl);
  lZgam -= lZgamma->d[I];

  for (i=0; i<X->cols; i++)
    {
      W->d[i]       += (newl-oldl)*Y->d[I]*X->d2[I][i];
      lZtheta->d[i]  = logZthetaI(p,W->d[i]);
    }
  lZb1 += (newl-oldl)*Y->d[I];
  lZb   = 0.5*sigma*lZb1*lZb1;
  lZt   = VectorSum(lZtheta);
  val   = lZgam+lZt+lZb;
  return(val);
}

double
logPartition(Vector l)
{
  int i;
  double val=0.0;

  for (i=0; i<l->len; i++) lZgamma->d[i] = logZgammaT(c,l->d[i]);
  for (i=0; i<X->cols; i++)
    {
      W->d[i]       = getW(i,l,Y,X);
      lZtheta->d[i] = logZthetaI(p,W->d[i]);
    }
  lZb1  = logZb1(l,Y);
  lZb   = 0.5*sigma*lZb1*lZb1;
  lZgam = -VectorSum(lZgamma);
  lZt   = VectorSum(lZtheta);
  val   = lZgam+lZt+lZb;
  return(val);
}


#ifndef SIGNT
#define SIGNT(x)  ( (x)>(0)  ?   (1):(-1) )
#endif

void
getExtent(double lnow, double bot, double top, Vector bounds)
{
  double b0 = -HUGE;
  double b1 =  HUGE;
  double ax;

  ax = (bot - lnow);
  if ((ax>b0) && (ax<b1))
    { // Constraint is within the required bounds!
      if (ax<0.0) b0 = ax;
      else b1 = ax;
    }
  else if ((ax>b0) && (ax>b1)) {} // Constraint too small, b0<0<b1 always. Do NOTHING.
  else if ((ax<b0) && (ax<b1)) {} // Constraint too small, b0<0<b1 always. Do NOTHING.

  ax = (top - lnow);
  if ((ax>b0) && (ax<b1))
    { // Constraint is within the required bounds!
      if (ax<0.0) b0 = ax;
      else b1 = ax;
    }
  else if ((ax>b0) && (ax>b1)) {} // Constraint is too big, b0<0<b1 always. Do NOTHING.
  else if ((ax<b0) && (ax<b1)) {} // Constraint is too small, b0<0<b1 always. Do NOTHING.

  bounds->d[0] = b0;
  bounds->d[1] = b1;
}


double
getModel(Vector l, Vector linear)
{
  int i;
  double b = 0.0;
  for (i=0; i<l->len; i++)
    b += sigma*(l->d[i]*Y->d[i]);
  for (i=0; i<linear->len; i++)
    linear->d[i] = p*W->d[i] / (p + (1.0-p)*( exp(-0.5*W->d[i]*W->d[i]) ));
  return(b);
}

void
parse_args (int argc, char **argv)
{
  int i;
  strcpy(trainFile,"");
  strcpy(extensionFile,"ext");
  strcpy(lambdas_fname,"lambdas.");
  strcpy(model_fname,"model.");
  strcpy(outputs_fname,"outputs.");
  strcpy(lambdaFile,"");

  for(i=1; i<argc; i++)
    {
      if (!strcmp(argv[i], "-train"))
	strcpy(trainFile, argv[++i]);
      else if (!strcmp(argv[i], "-p"))
	{
	  p = atof(argv[++i]);
	  fprintf(stderr,"setting p=%e\n",p);
	}
      else if (!strcmp(argv[i], "-ext"))
	strcpy(extensionFile, argv[++i]);
      else if (!strcmp(argv[i], "-c"))
	{
	  c = atof(argv[++i]);
	  if (c<0)
	    {
	      negc =  c;
	      c    =  1;
	    }
	}
      else if (!strcmp(argv[i], "-sigma"))
	sigma = atof(argv[++i]);
      else if (!strcmp(argv[i], "-ntrain"))
	ntrain = atoi(argv[++i]);
      else if (!strcmp(argv[i], "-lambda"))
	strcpy(lambdaFile, argv[++i]);
      else if (!strcmp(argv[i], "-iters"))
	iters = atoi(argv[++i]);
      else
	{
	  fprintf(stderr,"Usage: %s -train <fname> -ntrain <int> -ext <fname> -iters <int> -lambda <fname> -sigma <double> -p <double> -c <double>\n",argv[0]);
	  fprintf(stderr,"Feature Selection for Linear Classification.\n");
	  exit (2);
	}
    }

  strcat(lambdas_fname,extensionFile);
  strcat(model_fname,extensionFile);
  strcat(outputs_fname,extensionFile);
}


double oldObjective, newObjective, oldObjectiveY;
int    lambdaIndex = -100;
int    oldDir      = -100;




double
costIF(double lscale)
{
  double changeObjective = -oldObjective;
  int    i=0;
  int    I=lambdaIndex;
  double lnew,lold, WNEW;

  lold = oldLambdas->d[I];
  lnew = lold + lscale;

  // logZgamma
  changeObjective +=  (lold + log(1.0-lold/c));
  changeObjective -=  (lnew + log(1.0-lnew/c));

  // logZb
  changeObjective -= lZb;
  changeObjective += 0.5*sigma*SQRD(lZb1 - lold*Y->d[I] + lnew*Y->d[I]);

  // logZetheta
  changeObjective -= lZt;
  for (i=0; i<X->cols; i++)
    { 
      WNEW = W->d[i] + (lnew-lold)*Y->d[I]*X->d2[I][i];
      changeObjective += logZthetaI(p,WNEW);
    }

  return(changeObjective);
}


#define ITMAX 100
#define CGOLD 0.3819660
#define ZEPS 1.0e-10
#define SHFT(a,b,c,d) (a)=(b);(b)=(c);(c)=(d);
#define SIGN(a,b) ((b) >= 0.0 ? fabs(a) : -fabs(a))

// This will minimize the function in the interval ax to cx

double
brent(double ax,double bx,double cx,double (*f)(double),double tol,double *xmin)
{
  int iter;
  double a,b,d,etemp,fu,fv,fw,fx,p,q,r,tol1,tol2,u,v,w,x,xm;
  double e=0.0;
  
  a=(ax < cx ? ax : cx);
  b=(ax > cx ? ax : cx);
  x=w=v=bx;
  fw=fv=fx=(*f)(x);
  for (iter=1;iter<=ITMAX;iter++) {
    xm=0.5*(a+b);
    tol2=2.0*(tol1=tol*fabs(x)+ZEPS);
    if (fabs(x-xm) <= (tol2-0.5*(b-a))) {
      *xmin=x;
      return fx;
    }
    if (fabs(e) > tol1) {
      r=(x-w)*(fx-fv);
      q=(x-v)*(fx-fw);
      p=(x-v)*q-(x-w)*r;
      q=2.0*(q-r);
      if (q > 0.0) p = -p;
      q=fabs(q);
      etemp=e;
      e=d;
      if (fabs(p) >= fabs(0.5*q*etemp) || p <= q*(a-x) || p >= q*(b-x))
	d=CGOLD*(e=(x >= xm ? a-x : b-x));
      else {
	d=p/q;
	u=x+d;
	if (u-a < tol2 || b-u < tol2)
	  d=SIGN(tol1,xm-x);
      }
    } else {
      d=CGOLD*(e=(x >= xm ? a-x : b-x));
    }
    u=(fabs(d) >= tol1 ? x+d : x+SIGN(tol1,d));
    fu=(*f)(u);
    if (fu <= fx) {
      if (u >= x) a=x; else b=x;
      SHFT(v,w,x,u)
	SHFT(fv,fw,fx,fu)
	} else {
	  if (u < x) a=u; else b=u;
	  if (fu <= fw || w == x) {
	    v=w;
	    w=u;
	    fv=fw;
	    fw=fu;
	  } else if (fu <= fv || v == x || v == w) {
	    v=u;
	    fv=fu;
	  }
	}
  }
  *xmin=x;
  return fx;
}
#undef ITMAX
#undef CGOLD
#undef ZEPS
#undef SHFT
#undef SIGN


double ACC_1  = 0.0;
double ACC_t  = 0.0;
double ACC_tt = 0.0;
double ACC_l  = 0.0;
double ACC_lt = 0.0;
double PAR_M  = 0.0;
double PAR_B  = 0.0;



void
main(int argc, char *argv[]) 
{
  int     i,j,k;
  FILE    *fp;
  char    line[MAXLINE+1];
  char    b2[50];
  int     stop = -1;
  int     dump = -1;

  // Get Command Line Arguments
  parse_args(argc,argv);

  // Setup the Keyboard
  ioctl(STDIN_FILENO, TCGETS, &OrigTerm);
  ioctl(STDIN_FILENO, TCGETS, &TermDescr);
  TermMask=04401;
  TermDescr.c_lflag=TermDescr.c_lflag & TermMask;
  TermDescr.c_cc[VEOF]=1;
  TermDescr.c_cc[VEOL]=0;
  ioctl(STDIN_FILENO, TCSETS, &TermDescr);

  // Open the File to Determine its Dimensions
  if ((fp = fopen(trainFile,"r"))==NULL)
    {
      fprintf(stderr,"%s can't open %s for input.\n",argv[0],trainFile);
      exit(0);
    }
  fgets(line,MAXLINE,fp);
  istrstream ist(line,strlen(line));
  D = -1;
  while (ist>>b2) D++;
  int NN = 1;
  while (fgets(line, MAXLINE, fp) != NULL) 
    if (strlen(line)>2) NN++;
  fclose(fp);

  if (ntrain>0) N = ntrain;
  else N = NN;

  // Allocate the Right Size Arrays
  X           = MatrixCreate(N,D);
  HMMv        = MatrixCreate(N,20);
  HMMi        = MatrixCreate(N,20);
  HMMt        = MatrixCreate(N,20);
  VectorSet((Vector) HMMi,-1.0);
  VectorSet((Vector) HMMv,-1.0);
  VectorSet((Vector) HMMt,-1.0);
  Xt          = MatrixCreate(NN-N,D);  // contains test data...
  Y           = VectorCreate(N);
  Yt          = VectorCreate(NN-N);
  lambdas     = VectorCreate(N);
  oldLambdas  = VectorCreate(N);
  lZtheta     = VectorCreate(D);
  lZgamma     = VectorCreate(N);
  W           = VectorCreate(D);
  linear      = VectorCreate(D);
  
  // Reread the File and Load it into Data Array
  fp = fopen(trainFile,"r");
  for (i=0; i<N; i++)
    {
      fgets(line, MAXLINE, fp);
      istrstream ist(line,strlen(line));
      for (j=0; j<D; j++)
	{
	  ist>>b2;
	  X->d2[i][j] = atof(b2);
	}
      ist>>b2;
      Y->d[i] = atof(b2);
    } 
  for (i=0; i<(NN-N); i++)
    {
      fgets(line, MAXLINE, fp);
      istrstream ist(line,strlen(line));
      for (j=0; j<D; j++)
	{
	  ist>>b2;
	  Xt->d2[i][j] = atof(b2);
	}
      ist>>b2;
      Yt->d[i] = atof(b2);
    }
  fclose(fp);

  // Must Learn Lambdas
  VectorSet(lambdas,0.01);

  if (strlen(lambdaFile)>2)
    {
      fp = fopen(lambdaFile,"r");
      for (i=0; i<N; i++)
	{
	  fgets(line,MAXLINE,fp);
	  lambdas->d[i] = atof(line);
	}
      fclose(fp);
    }

  VectorMove(oldLambdas,lambdas);
  VectorSet(W,0.0);
  VectorSet(lZtheta,0.0);
  VectorSet(lZgamma,0.0);
  lZb = 0.0;
  lZt = 0.0;

  if (strlen(lambdaFile)>2)
    {
      fprintf(stderr,"Initial Objective = %e\n", -logPartition(oldLambdas));
    }

  // Setup
  oldObjective     = -logPartition(lambdas);
  Vector extent    =  VectorCreate(2);
  int    direction = 0;
  int    itr       = 1;
  double tolerance = 1e-10;
  double increase  = 0.0;
  int    past      = HMMi->cols;
  double lastCheck = -2.0*fabs(oldObjective);

  if (iters < 0) iters = 64000000;
  
  while (stop<0)
    { 
      // Grab from Keyboard
      ioctl(0,FIONREAD,&Key);
      if (Key>0)
	{
	  Key = getchar();
	  if (Key=='q') stop = 1;
	  if (Key=='s') dump = 1;
	}
      itr++;



      // Pick a direction to optimize
      if ((drand48()>0.2) && (oldDir>=0))
	{
	  // pick an HMM direction
	  double accV    =  0.0;
	  double sumV    =  0.0;
	  double maxV    = -1e30;
	  int    maxJ    =  0;
	  double value   =  0.0;
	  double decay   =  0.0;
	  double curprob =  0.0;
	  double totprob =  0.0;
	  int    sample  = -100;
	  for (j=0; j<past; j++) 
	    if ((HMMi->d2[oldDir][j]>=0) && (HMMi->d2[oldDir][j]!=oldDir))
	      {
		value  = (HMMv->d2[oldDir][j] - HMMt->d2[oldDir][j]*PAR_M - PAR_B);
		if (value>maxV)
		  {
		    maxV = value;
		    maxJ = j;
		  }
		sumV  += value;
		accV  += 1.0;
	      }
	  if (accV>0.0)
	    {
	      decay = accV / (maxV*accV - sumV);
	      for (j=0; j<past; j++) 
		if ((HMMi->d2[oldDir][j]>=0) && (HMMi->d2[oldDir][j]!=oldDir))
		  {
		    value    = (HMMv->d2[oldDir][j] - HMMt->d2[oldDir][j]*PAR_M - PAR_B);
		    totprob += decay*exp(-decay*(maxV-value));
		  }
	      totprob *= drand48();
	      for (j=0; j<past; j++) 
		if ((HMMi->d2[oldDir][j]>=0) && (HMMi->d2[oldDir][j]!=oldDir))
		  {
		    value    = (HMMv->d2[oldDir][j] - HMMt->d2[oldDir][j]*PAR_M - PAR_B);
		    curprob += decay*exp(-decay*(maxV-value));
		    if ((sample<0) && (curprob>=totprob)) sample = j;
		  }
	      direction = HMMi->d2[oldDir][sample];
	      direction = HMMi->d2[oldDir][maxJ];
	      if (direction>=(N)) direction = N-1;
	      if (direction<0) direction = 0;
	    }
	  else
	    {
	      // pick a random direction
	      direction = drand48()*(N+1.0);
	      if (direction>=N) direction = N-1;
	      if (direction<0) direction = 0;
	    }
	}
      else
	{
	  // pick a random direction
	  direction = drand48()*(N+1.0);
	  if (direction>=N) direction = N-1;
	  if (direction<0) direction = 0;
	}

      // Do line search
      double lnow;
      lnow = oldLambdas->d[direction];
      lambdaIndex = direction;
      getExtent(lnow,0.0+VTINY,c-TINY,extent);
      double ax,bx,cx,minf,xmin;
      ax           = extent->d[0];
      cx           = extent->d[1];
      bx           = 0.0;
      minf         = brent(ax,bx,cx, &costIF, tolerance, &xmin);

      // Update cost function and vector of lambdas
      increase = -oldObjective;
      int I = lambdaIndex;
      oldObjective = -logPartitionFast(I, oldLambdas->d[I]+xmin,oldLambdas->d[I]);
      oldLambdas->d[I] += xmin;
      lambdas->d[I]     = oldLambdas->d[I];
      increase += oldObjective;
      if ((itr%(MAX(N,500)))==0)
	{
	  printf("%e %d\n",oldObjective,itr); 
	  fflush(stdout); 
	}
      if ((itr%(10*N))==0)
	{
	  if (((oldObjective-lastCheck)<=(0.0000001)*fabs(oldObjective)))
	    {
	      if (negc<0)
		{
		  if (c>abs(negc))
		    stop = 1;
		  else
		    {
		      c = c*1.5;
		      dump = 1;
		    }
		}
	      else stop = 1;
	    }
	  lastCheck = oldObjective;
	}

      
      // Add to the history model the current data point. Replace the least valuable previous one.
      if (oldDir>=0)
	{
	  int worst = -5;
	  double t;
	  double value = 0.0;
	  double lowestvalue = 1e30;
	  for (j=0; j<past; j++)
	    if (HMMi->d2[oldDir][j]==lambdaIndex) worst = j;
	  if (worst<0) 
	    for (j=0; j<past; j++)
	      if (HMMi->d2[oldDir][j]<0) worst = j;
	  if (worst<0)
	    {
	      for (j=0; j<past; j++)
		{
		  value = HMMv->d2[oldDir][j] - HMMt->d2[oldDir][j]*PAR_M - PAR_B;
		  if (value<lowestvalue)
		    {
		      worst       = j;
		      lowestvalue = value;
		    }
		}
	    }
	  if (worst<0) worst = 0;
	  HMMi->d2[oldDir][worst] = lambdaIndex;
	  HMMv->d2[oldDir][worst] = log(increase+1e-10);
	  HMMt->d2[oldDir][worst] = itr;
	  t = itr;
	  ACC_1  = t;
	  ACC_t  = ACC_t + t;
	  ACC_tt = ACC_tt + t*t;
	  ACC_l  = ACC_l + log(increase+1e-10);
	  ACC_lt = ACC_lt + log(increase+1e-10)*t;
	  PAR_M  = (ACC_lt - ACC_t*ACC_l/ACC_1) / (ACC_tt - ACC_t*ACC_t/ACC_1);
	  PAR_B  = (ACC_l - PAR_M*ACC_t) / ACC_1;
	}
      oldDir = lambdaIndex;



      // Check if any output needs to be generated
      if ((stop==1) || (dump==1))
	{
	  // Save the lambdas and lambdasP
	  fp = fopen(lambdas_fname,"w");
	  for (k=0; k<N; k++) fprintf(fp,"%e\n",lambdas->d[k]);
	  fclose(fp);
	  
	  // Compute the actual final linear model and bias
	  double bias;
	  newObjective = -logPartition(lambdas);
	  FILE *fplog = fopen("log","a");
	  for (k=0; k<argc; k++) fprintf(fplog,"%s ",argv[k]);
	  fprintf(fplog,"\n");
	  fprintf(fplog,"J=%e i=%d ",newObjective,itr);
	  fprintf(stderr,"J=%e i=%d ",newObjective,itr);
	  bias = getModel(lambdas,linear);
	  fp = fopen(model_fname,"w");
	  for (k=0; k<linear->len; k++) fprintf(fp,"%e\n",linear->d[k]);
	  fprintf(fp,"%e\n",bias);
	  fclose(fp);
	  
	  // Recompute the newoutputs and the prediction error
	  double val;
	  double err1 = 0.0;
	  double err2 = 0.0; 
	  fp = fopen(outputs_fname,"w");
	  for (k=0; k<N; k++)
	    {
	      val = bias;
	      for (j=0; j<D; j++) val += linear->d[j]*X->d2[k][j];
	      fprintf(fp,"%e %e\n",val,Y->d[k]);
	      if (SIGNT(val)!=SIGNT(Y->d[k])) err1++;
	    }
	  fprintf(fplog,"err1=%e ",err1);
	  fprintf(stderr,"err1=%e ",err1);
	  for (k=0; k<(NN-N); k++)
	    {
	      if (Yt->d[k]!=0)
		{
		  val = bias;
		  for (j=0; j<D; j++) val += linear->d[j]*Xt->d2[k][j];
		  fprintf(fp,"%e %e\n",val,Yt->d[k]);
		  if (SIGNT(val)!=SIGNT(Yt->d[k])) err2++;
		}
	    }
	  fprintf(fplog,"err2=%e\n",err2);
	  fprintf(fplog,"%e %e %e\n",p,c,err2);
	  fprintf(stderr,"err2=%e\n",err2);
	  fclose(fplog);
	  fclose(fp);

	  dump = -1;
	}
    }


  // Free Arrays
  VectorFree(extent);
  VectorFree(linear);
  VectorFree(lambdas);
  VectorFree(oldLambdas);
  VectorFree(Y);
  MatrixFree(X);
  VectorFree(lZtheta);
  VectorFree(lZgamma);
  VectorFree(W);
  MatrixFree(HMMi);
  MatrixFree(HMMv);
  MatrixFree(HMMt);
}

