// ----------------------------------------------------------------------------
//
// Copyright (C) Columbia University, 1994. All Rights Reserved.
// Sameer A. Nene, Shree K. Nayar, Hiroshi Murase
//
// See file LICENSE for details of software license agreement.
//
// ----------------------------------------------------------------------------
//
// matrix.C
//
// Author:               Sameer Nene
// Date:                 05/23/94
// Version:              1.0
// Modification History:
//   10/27/94: (Post Release). Made Matrix::getMaxEigenVectors more efficient.
//
// Bugs:
//   There is no last element check for Matrix::element() and hence will
//   coredump if an invalid (row,col) pair is supplied. Similiarly,
//   Matrix::getMaxEigenVectors does not check if the required number of
//   Eigenvectors is realizable.
//
//   Matrix::getMaxEigenVectors works only for symmetric matrices. Note that
//   this is a limitation of the algorithm.
//
//   There are no out of memory checks.
//
// Classes:
//   Matrix
//   MatrixUtil
//
// Notes:
//   This module contains implementation of classes declared in matrix.h
//   
// ----------------------------------------------------------------------------

#include <iostream.h>
#include <math.h>
#include <memory.h>
#include "errorscope.h"
#include "registry.h"
#include "persistent.h"
#include "vector.h"
#include "matrix.h"

void* Matrix::d_protocol = (void*)(&Matrix::d_protocol);
char* Matrix::d_name_p = "Matrix";
Registrar Matrix::d_registrar(Matrix::d_name_p, &(Matrix::createFunc));

Persistent* Matrix::createFunc()
{
  return new Matrix;
}

int Matrix::hasProtocol(void* protocol) const
{
  return protocol == d_protocol;
}
  
Matrix::Matrix() : d_data_p(0), d_rsize(0), d_csize(0)
{
}

Matrix::Matrix(int row, int col) : d_rsize(row), d_csize(col),
d_data_p(new float[row * col])
{
}

Matrix::Matrix(const Matrix &m) : d_rsize(m.d_rsize), d_csize(m.d_csize),
d_data_p(new float[m.d_rsize * m.d_csize])
{
  memcpy(d_data_p, m.d_data_p, d_rsize * d_csize * sizeof(float));
}

Matrix::~Matrix()
{
  delete[] d_data_p;
}

Matrix& Matrix::operator=(const Matrix &m)
{
  if(&m != this) {
    if(d_rsize != m.d_rsize || d_csize != m.d_csize) {
      delete[] d_data_p;
      d_rsize = m.d_rsize;
      d_csize = m.d_csize;
      d_data_p = new float[d_rsize * d_csize];
    }
    memcpy(d_data_p, m.d_data_p, d_rsize * d_csize * sizeof(float));
  }
  return *this;
}

void Matrix::element(int row, int col, double v)
{
  d_data_p[row * d_csize + col] = v;
}

Matrix& Matrix::operator+=(const Matrix &m)
{
  int i, j = d_rsize * d_csize;

  for(i = 0; i < j; ++i)
    d_data_p[i] += m.d_data_p[i];

  return *this;
}

Matrix& Matrix::operator-=(const Matrix &m)
{
  int i, j = d_rsize * d_csize;

  for(i = 0; i < j; ++i)
    d_data_p[i] -= m.d_data_p[i];

  return *this;
}

Matrix& Matrix::operator*=(const Matrix &m)
{
  int i, j, k;
  float *data = new float[d_rsize * m.d_csize], *c, *a1, *a2, *b1, *b2;
  double t;

  c = data;
  a1 = d_data_p;
  b1 = m.d_data_p;
  
  for(i = 0; i < d_rsize; ++i) {
    a2 = a1;
    for(j = 0; j < m.d_csize; ++j) {
      b2 = b1;
      t = *a2 * *b2;
      b2 += m.d_csize;
      for(k = 1; k < d_csize; ++k) {
	t += a2[k] * *b2;
	b2 += m.d_csize;
      }
      *(c++) = t;
      ++b1;
    }
    a1 += d_csize;
  }

  delete[] d_data_p;
  d_data_p = data;
  d_csize = m.d_csize;
  return *this;
}

Matrix& Matrix::operator*=(double d)
{
  int i, j;

  for(i = 0, j = d_rsize * d_csize; i < j; ++i)
    d_data_p[i] *= d;

  return *this;
}

Matrix* Matrix::safeCast(Persistent *p)
{
  return p -> hasProtocol(d_protocol) ? (Matrix*)p : 0;
}
	
double Matrix::element(int row, int col) const
{
  return d_data_p[row * d_csize + col];
}

int Matrix::numRows() const
{
  return d_rsize;
}

int Matrix::numCols() const
{
  return d_csize;
}

Matrix Matrix::operator+(const Matrix &m) const
{
  int i, j = d_rsize * d_csize;
  Matrix c(d_rsize, d_csize);  

  for(i = 0; i < j; ++i)
    c.d_data_p[i] = d_data_p[i] + m.d_data_p[i];

  return c;
}

Matrix Matrix::operator-(const Matrix &m) const
{
  int i, j = d_rsize * d_csize;
  Matrix c(d_rsize, d_csize);

  for(i = 0; i < j; ++i)
    c.d_data_p[i] = d_data_p[i] - m.d_data_p[i];

  return c;
}

Matrix Matrix::operator*(const Matrix &m) const
{
  int i, j, k;
  float *c, *a1, *a2, *b1, *b2;
  double t;
  Matrix m2(d_rsize, m.d_csize);

  c = m2.d_data_p;
  a1 = d_data_p;
  b1 = m.d_data_p;
  
  for(i = 0; i < d_rsize; ++i) {
    a2 = a1;
    for(j = 0; j < m.d_csize; ++j) {
      b2 = b1;
      t = *a2 * *b2;
      b2 += m.d_csize;
      for(k = 1; k < d_csize; ++k) {
	t += a2[k] * *b2;
	b2 += m.d_csize;
      }
      *(c++) = t;
      ++b1;
    }
    a1 += d_csize;
  }

  return m2;
}

Matrix Matrix::operator*(double d) const
{
  int i, j;
  Matrix m2(d_rsize, d_csize);

  for(i = 0, j = d_rsize * d_csize; i < j; ++i)
    m2.d_data_p[i] = d_data_p[i] * d;

  return m2;
}

Vector Matrix::operator*(const Vector &v) const
{
  int i, j;
  float *p = d_data_p;
  double t;
  Vector v2(d_rsize);

  for(i = 0; i < d_rsize; ++i) {
    t = *(p++) * v.d_data_p[0];
    for(j = 1; j < d_csize; ++j)
      t += *(p++) * v.d_data_p[j];
    v2.d_data_p[i] = t;
  }

  return v2;
}

Matrix operator*(double d, const Matrix &m)
{
  int i, j;
  Matrix m2(m.d_rsize, m.d_csize);

  for(i = 0, j = m.d_rsize * m.d_csize; i < j; ++i)
    m2.d_data_p[i] = m.d_data_p[i] * d;

  return m2;
}

ostream& operator<<(ostream &o, const Matrix &m)
{
  int i, j;
  float *p;

  o << '[';
  for(i = 0, p = m.d_data_p; i < m.d_rsize; ++i) {
    o << '[';
    for(j = 0; j < m.d_csize; ++j, ++p)
      o << *p << ' ';
    o << ']' << ' ';
  }
  o << ']' << ' ';

  return o;
}
      
ErrorScope::Error Matrix::get(FILE *handle)
{
  int rsize, csize;

  if(fread((char *)&rsize, sizeof(int), 1, handle) != 1)
    return FILE_READ_ERROR;

  if(fread((char *)&csize, sizeof(int), 1, handle) != 1)
    return FILE_READ_ERROR;

  if(rsize != d_rsize || csize != d_csize) {
    delete[] d_data_p;
    d_data_p = new float[(d_rsize = rsize) * (d_csize = csize)];
  }

  if(fread((char *)d_data_p, sizeof(float), rsize * csize, handle) != rsize * csize)
    return FILE_READ_ERROR;
 
  return OK;
}

ErrorScope::Error Matrix::put(FILE *handle) const
{
  if(fwrite((char *)&d_rsize, sizeof(int), 1, handle) != 1)
    return FILE_WRITE_ERROR;

  if(fwrite((char *)&d_csize, sizeof(int), 1, handle) != 1)
    return FILE_WRITE_ERROR;
  
  if(fwrite((char *)d_data_p, sizeof(float), d_rsize * d_csize, handle)
     != d_rsize * d_csize)
    return FILE_WRITE_ERROR;

  return OK;
}

const char* Matrix::name() const
{
  return d_name_p;
}

void MatrixUtil::getMaxEigenVectors(Matrix *min, VectAry *v, Vector *ev)
{
  int i, j = min -> numRows(), k, l, cnt;
  double b, c, d, lamda, lamda_old, x_sq, pa, pb, pc, pd;
  void *t;
  float *e;
  Vector *p_old, *p, *r_old, *r, *x;
  Vector p1, p2, r1, r2, mx, x0(j);

  for(i = 0, b = -1.; i < j; ++i) {
    x0.element(i, b);
    b = -b;
  }

  VectAryIter it(*v);
  for(i = 0; it; ++it, ++i) {
    it() = x0;
    x = &it();

    mx = *min * *x;
    x_sq = *x * *x;
    lamda = (*x * mx) / x_sq;
    p1 = r1 = (lamda * *x - mx) / x_sq;
    pa = (p1 * mx) / x_sq;
    pb = (p1 * (*min * p1)) / x_sq;
    pc = (p1 * *x) / x_sq;
    pd = (p1 * p1) / x_sq;
    d = pb * pc - pa * pd;
    b = pb - lamda * pd;
    c = pa - lamda * pc;
    *x += ((-b - sqrt(b * b - 4. * c * d)) / (2. * d)) * p1;
    p_old = &p1;
    p = &p2;
    r_old = &r1;
    r = &r2;
    cnt = 1;
    do {
      lamda_old = lamda;
      mx = *min * *x;
      x_sq = *x * *x;
      lamda = (*x * mx) / x_sq;
      *r = (lamda * *x - mx) / x_sq;
      *p = *r + ((*r * *r) / (*r_old * *r_old)) * *p_old;
      pa = (*p * mx) / x_sq;
      pb = (*p * (*min * *p)) / x_sq;
      pc = (*p * *x) / x_sq;
      pd = (*p * *p) / x_sq;
      d = pb * pc - pa * pd;
      b = pb - lamda * pd;
      c = pa - lamda * pc;
      *x += ((-b - sqrt(b * b - 4. * c * d)) / (2. * d)) * *p;
      t = r_old;
      r_old = r;
      r = (Vector *)t;
      t = p_old;
      p_old = p;
      p = (Vector *)t;
    }
    while(lamda - lamda_old >= 1e-7 * pow(10., int(log10(lamda))));
    ev -> element(i, lamda);
    *x /= (*x)();

    e = min -> d_data_p;
    for(k = 0; k < min -> d_rsize; ++k) {
      b = x -> element(k) * lamda;
      for(l = 0; l < min -> d_csize; ++l)
	*(e++) -= b * x -> element(l);
    }
  }
}
