#include "BLAS/blas_sparse.h"
#include "BLAS/nist_spblas.cc"

#include <cassert>
#include <math.h>

// simple 4 by 4 matrix
// a[i][j] refers to row i, column j
typedef double LocalStiffnessMatrix[4][4];

//         x2
//         /\
//        /  \
//     e1/    \e3
//      /  t0  \
//     /        \
//    /    e0    \
//  x0------------x1
//    \          /
//     \   t1   /
//      \      /
//     e2\    /e4
//        \  /
//         \/
//         x3
//
// Edge orientation: e0,e1,e2 point away from x0
//                      e3,e4 point away from x1


class QuadraticBending
{
public:

  // Set Q to -1 to indicate that the Hessian must be built
  QuadraticBending(TriangleMesh *_mesh) : Q(-1), mesh(_mesh) {}

  void BuildGlobalStiffness(void);

  void AddForces(double elasticStiffness, double dampingCoef);

private:

  double cotTheta(const Vector3d v, const Vector3d w);

  void ComputeLocalStiffness(const Vector3d x0, const Vector3d x1,
                             const Vector3d x2, const Vector3d x3,
                             LocalStiffnessMatrix &Q);

  blas_sparse_matrix Q;    // Global stiffness matrix

  TriangleMesh *mesh;
};


inline
double QuadraticBending::cotTheta(const Vector3d v, const Vector3d w)
{
  assert(finite(v.length()));
  assert(finite(w.length()));
  assert(v.length() > 0);
  assert(w.length () > 0);
  const double cosTheta = dot(v,w);
  const double sinTheta = cross(v,w).length();
  return (cosTheta / sinTheta);
}


// compute 4 by 4 local stiffness matrix Q(e0), with reference to the
// above diagram
inline
void QuadraticBending::ComputeLocalStiffness(const Vector3d x0, const Vector3d x1,
                                             const Vector3d x2, const Vector3d x3,
                                             LocalStiffnessMatrix &Q)
{
  const Vector3d e0 = x1-x0;
  const Vector3d e1 = x2-x0;
  const Vector3d e2 = x3-x0;
  const Vector3d e3 = x2-x1;
  const Vector3d e4 = x3-x1;

  const double c01 = cotTheta( e0, e1);
  const double c02 = cotTheta( e0, e2);
  const double c03 = cotTheta(-e0, e3);
  const double c04 = cotTheta(-e0, e4);

  const double K0[] = {c03+c04, c01+c02, -c01-c03, -c02-c04};

  const double A0 = 0.5 * cross(e0,e1).length();
  const double A1 = 0.5 * cross(e0,e2).length();

  const double coef = -3. / (2.*(A0+A1));

  assert(finite(coef));
  assert(finite(c01));
  assert(finite(c02));
  assert(finite(c03));
  assert(finite(c04));

  // compute Q = coef times outer product of K0 and K0
  for (int i=0; i<4; ++i) {
    for (int j=0; j<i; ++j) {
      Q[i][j] = Q[j][i] = coef * K0[i] * K0[j];
    }
    Q[i][i] = coef * K0[i] * K0[i];
  }
}



// compute global stiffness matrix by iterating over interior mesh
// edges and assembling the local stiffness matrices
// WARNING: if this routine is called more than once, it will allocate
// memory for a sparse matrix more than once!
void QuadraticBending::BuildGlobalStiffness(void)
{
  LocalStiffnessMatrix localQ;

  const unsigned N = mesh->Vertices.size();

  // create an N by N sparse matrix
  // this will give Q a nonnegative number (the matrix ID)
  Q = BLAS_duscr_begin(N,N);
  // at program end, use BLAS_usds(Q) to delete this matrix

  // For each edge...
  for (unsigned ei=0; ei < mesh->Edges.size(); ++ei) {
    Edge &e = *mesh->Edges[ei];
    if (mesh->isBoundaryEdge(&e)) continue;

    // grab the four vertices around edge, as shown in the diagram at
    // the top of this file
    int fromVertexIdx, toVertexIdx, leftApexIdx, rightApexIdx;
    e.GetDiamondVertices(fromVertexIdx, toVertexIdx, leftApexIdx, rightApexIdx);

    // For each *interior* mesh edge, obtain vertex indices for
    // diamond-shaped stencil
    int idx[] = {fromVertexIdx, // get index to tail of edge
                 toVertexIdx,   // get index to head of edge
                 leftApexIdx,   // opposite vertex on left-side  triangle
                 rightApexIdx}; // opposite vertex on right-side triangle

    // Build local stiffness matrix
    ComputeLocalStiffness(*meshVertices[idx[0]],
                          *meshVertices[idx[1]],
                          *meshVertices[idx[2]],
                          *meshVertices[idx[3]],
                          localQ);

    // Stick local matrix into global matrix
    // this is not implemented in my BLAS:
    //   BLAS_duscr_insert_clique(Q, 4, 4, &localQ[0][0], 1, 1, idx, idx);
    // so I wrote this code instead:
    for (int i=0; i<4; ++i) {
      for (int j=0; j<4; ++j) {
        BLAS_duscr_insert_entry(Q, localQ[i][j], idx[i], idx[j]);
      }
    }
  }

  // done adding entries
  BLAS_duscr_end(Q);
}


inline
void QuadraticBending::AddForces(double elasticStiffness, double dampingCoef)
{
  // BuildGlobalStiffness is called only once
  if (Q<0) BuildGlobalStiffness();

  const unsigned N = mesh->Vertices.size();

  // mesh->Position is the array of vertex positions,
  // ordered as [ [x1 y1 z1] [x2 y2 z2] [x3 y3 z3] ... ]

  // inputs are arrays of vertex positions (length N)
  const double *x = &(mesh->Position[0][0]);
  const double *y = &(mesh->Position[0][1]);
  const double *z = &(mesh->Position[0][2]);

  // likewise for velocities
  const double *vx = &(mesh->Velocity[0][0]);
  const double *vy = &(mesh->Velocity[0][1]);
  const double *vz = &(mesh->Velocity[0][2]);

  // array of elastic forces (length N)
  std::vector<Vector3d> fe(N);
  double *fex = &fe[0][0];
  double *fey = &fe[0][1];
  double *fez = &fe[0][2];

  // array of damping forces (length N)
  std::vector<Vector3d> fd(N);
  double *fdx = &fd[0][0];
  double *fdy = &fd[0][1];
  double *fdz = &fd[0][2];

  //compute matrix vector products

  // fex += ke*Q*x, fey += ke*Q*y, fez += ke*Q*z
  // where ke = elastic stiffness
  BLAS_dusmv(blas_no_trans, elasticStiffness, Q, x, 3, fex, 3);
  BLAS_dusmv(blas_no_trans, elasticStiffness, Q, y, 3, fey, 3);
  BLAS_dusmv(blas_no_trans, elasticStiffness, Q, z, 3, fez, 3);

  // fdx += kd*Q*vx, fdy += kd*Q*vy, fdz += kd*Q*vz
  // where ke = elastic stiffness
  BLAS_dusmv(blas_no_trans, dampingCoef, Q, vx, 3, fdx, 3);
  BLAS_dusmv(blas_no_trans, dampingCoef, Q, vy, 3, fdy, 3);
  BLAS_dusmv(blas_no_trans, dampingCoef, Q, vz, 3, fdz, 3);

  /*
    now write some code here to apply the forces, fe and fd, to the
    physical system
  */
}
