#include "keypoints.h"
#include "im_utils.h"
#include "lin_alg.h"
#include "sift.h"
#include "sift_utils.h"
#include "vga_ball.h"
#include "vga_ball_utils.h"

#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <float.h>

double compute_kp_contrast(Kp *kp, DoG **diffs) {
  DoG *cur_DoG = diffs[kp->DoG_idx];
  int width = cur_DoG->width;

  double dot_output;
  dot(kp->jacobian, kp->offset, &dot_output, 1, 3, 1);

  double contrast = cur_DoG->array[kp->y * width + kp->x] + 0.5 * dot_output;
  return contrast;
}

int get_argmin(double *arr, int size) {
  int argmin = 0;
  double min = 10000;
  int idx = 0;
  while (idx < size) {
    if (arr[idx] < min) {
      argmin = idx;
      min = arr[idx];
    }
    idx++;
  }
  return argmin;
}

int get_argmax(double *arr, int size) {
  int argmax = 0;
  double max = -10000;
  int idx = 0;
  while (idx < size) {
    if (arr[idx] > max) {
      argmax = idx;
      max = arr[idx];
    }
    idx++;
  }
  return argmax;
}

// find the 3x3 neighborhood around pix_idx. pix_idx is indexing into a 1-D
// array, we must find the 2D 3x3 window.
void get_neighborhood(DoG *cur_DoG, double *neighborhood, int pix_idx) {
  int width = cur_DoG->width;
  int row = pix_idx / width;
  int col = pix_idx % width;
  for (int r = row - 1; r <= row + 1; r++) {
    for (int c = col - 1; c <= col + 1; c++) {
      *neighborhood = cur_DoG->array[r * width + c];
      neighborhood++;
    }
  }
}

// finds the local min maxima for each DoG w.r.t its DoG neighboring pixels in
// different blurs
void get_candidate_keypoints_for_octave(DoG **DoGs, Keypoints *keypoints,
                                        int octave_num) {
  for (int blur_idx = 1; blur_idx < NUM_BLURS - 1 - 1; blur_idx++) {
    int cur_DoG_idx = octave_num * (NUM_BLURS - 1) + blur_idx;
    DoG *left_DoG = DoGs[cur_DoG_idx - 1];
    DoG *cur_DoG = DoGs[cur_DoG_idx];
    DoG *right_DoG = DoGs[cur_DoG_idx + 1];

    int im_height = cur_DoG->height;
    int im_width = cur_DoG->width;

    int start_pix_idx = im_width + 1;
    int ending_pix_idx = (im_height - 1) * im_width - 1;
    for (int pix_idx = start_pix_idx; pix_idx < ending_pix_idx; pix_idx++) {
      //   printf("DoG octave, blur_idx, array: %d, %d, %lx\n", cur_DoG->octave,
      //   blur_idx, (unsigned long) cur_DoG->array);
      double neighborhood[27] = {0};
      get_neighborhood(left_DoG, neighborhood, pix_idx);
      get_neighborhood(cur_DoG, neighborhood + 9, pix_idx);
      get_neighborhood(right_DoG, neighborhood + 18, pix_idx);

      int argmax = get_argmax(neighborhood, 27);
      int argmin = get_argmin(neighborhood, 27);
      // printf("argmin: %d, argmax: %d\n", argmin, argmax);
      if (argmin == 13 ||
          argmax == 13) { // center pix in cur_dog is the min/max
        keypoints->kp_list[keypoints->count] = malloc(sizeof(Kp));
        Kp *cur_kp = keypoints->kp_list[keypoints->count];
        cur_kp->DoG = cur_DoG;
        cur_kp->DoG_idx = cur_DoG_idx;
        cur_kp->x = pix_idx % im_width;
        cur_kp->y = pix_idx / im_width;
        cur_kp->is_valid = 1; // set all of them as valid for now, until localization later on;
        keypoints->count++;
        if (keypoints->count > keypoints->_max_capacity) {
          printf("[DEBUG] MAX KEYPOINTS EXCEEDED, REALLOCING!\n");
          keypoints->kp_list = (Kp **)realloc(keypoints->kp_list, (keypoints->_max_capacity + MAX_KEYPOINTS) * sizeof(Kp **));
	  keypoints->_max_capacity += MAX_KEYPOINTS;
	  printf("[DEBUG] MAX KEYPOINTS NOW EQUALS %d\n", keypoints->_max_capacity);
        }
        // printf("Keypoint x, y, count: %d, %d, %d\n", cur_kp->x, cur_kp->y,
        //       keypoints->count);
      }
    }
  }
}

void get_candidate_keypoints(DoG **DoGs, Keypoints *keypoints) {
  assert(keypoints->count == 0);
  for (int octave_num = 0; octave_num < NUM_OCTAVES; octave_num++) {
    printf("Getting candidate keypoints for octave %d\n", octave_num);
    get_candidate_keypoints_for_octave(DoGs, keypoints, octave_num);
  }
}

void localize_kp(Kp *kp, DoG **diffs) {
  int row, col, width, DoG_idx;
  row = kp->y;
  col = kp->x;
  width = kp->DoG->width;
  DoG_idx = kp->DoG_idx;

  assert(DoG_idx % 4 == 1 || DoG_idx % 4 == 2);

  int pix_idx = row * width + col;
  double dx = (diffs[DoG_idx]->array[pix_idx + 1] -
               diffs[DoG_idx]->array[pix_idx - 1]) /
              2;
  double dy = (diffs[DoG_idx]->array[pix_idx + width] -
               diffs[DoG_idx]->array[pix_idx - width]) /
              2;
  double ds = (diffs[DoG_idx + 1]->array[pix_idx] -
               diffs[DoG_idx - 1]->array[pix_idx]) /
              2;

  double dxx = diffs[DoG_idx]->array[pix_idx + 1] -
               2 * diffs[DoG_idx]->array[pix_idx] +
               diffs[DoG_idx]->array[pix_idx - 1];

  double dxy = (diffs[DoG_idx]->array[pix_idx + 1 + width] -
                diffs[DoG_idx]->array[pix_idx - 1 + width]) -
               ((diffs[DoG_idx]->array[pix_idx + 1 - width] -
                 diffs[DoG_idx]->array[pix_idx - 1 - width]));

  double dxs = (diffs[DoG_idx + 1]->array[pix_idx + 1] -
                diffs[DoG_idx + 1]->array[pix_idx - 1]) -
               ((diffs[DoG_idx - 1]->array[pix_idx + 1] -
                 diffs[DoG_idx - 1]->array[pix_idx - 1]));

  double dyy = diffs[DoG_idx]->array[pix_idx + width] -
               2 * diffs[DoG_idx]->array[pix_idx] +
               diffs[DoG_idx]->array[pix_idx - width];

  double dys = (diffs[DoG_idx + 1]->array[pix_idx + width] -
                diffs[DoG_idx + 1]->array[pix_idx - width]) -
               ((diffs[DoG_idx - 1]->array[pix_idx + width] -
                 diffs[DoG_idx - 1]->array[pix_idx - width]));

  double dss = diffs[DoG_idx + 1]->array[pix_idx] -
               2 * diffs[DoG_idx]->array[pix_idx] +
               diffs[DoG_idx - 1]->array[pix_idx];

  dxy /= 4;
  dxs /= 4;
  dys /= 4;

  double J[3] = {dx, dy, ds};
  double HD[9] = {dxx, dxy, dxs, dxy, dyy, dys, dxs, dys, dss};
  double HD_inv[9];
  if (inv_3x3(HD, HD_inv) == -1) {
    printf("There is no inv for hessian!\n");
    print_3x3(HD);
    kp->is_valid=0;
    return;
  }
  double *offset = malloc(3 * sizeof(double));
  dot(HD_inv, J, offset, 3, 3, 1);
  for (int i = 0; i < 3; i++)
    offset[i] *= -1;

  double *dx_dy_ds = malloc(3 * sizeof(double));
  memcpy(dx_dy_ds, J, 3 * sizeof(double));

  double *dxx_dxy_dyx_dyy = malloc(4 * sizeof(double));
  dxx_dxy_dyx_dyy[0] = dxx;
  dxx_dxy_dyx_dyy[1] = dxy;
  dxx_dxy_dyx_dyy[2] = dxy;
  dxx_dxy_dyx_dyy[3] = dyy;

  kp->jacobian = dx_dy_ds;
  kp->hessian = dxx_dxy_dyx_dyy;
  kp->offset = offset;
}

void localize_keypoints(Keypoints *keypoints, DoG **diffs) {
  int idx = 0;
  printf("Localizing %d keypoints...\n", keypoints->count);
  while (idx < keypoints->count) {
    //printf("idx=%d\n",idx);
    Kp *cur_kp = keypoints->kp_list[idx];
    localize_kp(cur_kp, diffs);
    if (cur_kp->is_valid=0) {
    	// printf("Cur kp has already been deemed invalid (no inv for hessian), y: %d, x: %d\n", cur_kp->y, cur_kp->x);
    }
    else if (fabs(compute_kp_contrast(cur_kp, diffs)) < CONTRAST_THRESHOLD) {
       // printf("Discarding kp because contrast is too low, y: %d, x: %d\n",
       //      cur_kp->y, cur_kp->x);
      cur_kp->is_valid = 0;
    } else if (get_det_2x2(cur_kp->hessian) < 0) {
      // printf("Discarding kp because hessian determinant is negative, y: %d, "
      //        "x:%d\n",  cur_kp->y, cur_kp->x);
      cur_kp->is_valid = 0;
    } else if (pow(get_trace_2x2(cur_kp->hessian), 2) /
                   get_det_2x2(cur_kp->hessian) >=
               PRINCIPLE_CURVATURE_THRESHOLD) {
      // printf("Discarding kp bc principal curvature is too large, y: %d, "
      //       "x: %d\n", cur_kp->y, cur_kp->x);
      cur_kp->is_valid = 0;
    } else if ((cur_kp->x + cur_kp->offset[0] > cur_kp->DoG->width) ||
               (cur_kp->y + cur_kp->offset[1] > cur_kp->DoG->height)) {
      // printf("Discarding keypoint because localized x or y is too large, x: "
      //       "%f, y: %f\n", cur_kp->x + cur_kp->offset[0], cur_kp->y + cur_kp->offset[1]);
              cur_kp->is_valid = 0;
    } else if ((cur_kp->x + cur_kp->offset[0] < 0) || (cur_kp->y + cur_kp->offset[1] < 0)){
      // printf("Discarding keypoint because localized x or y is negative, x: "
      //       "%f, y: %f\n", cur_kp->x + cur_kp->offset[0], cur_kp->y + cur_kp->offset[1]);
              cur_kp->is_valid = 0;
    }
      else {
      Kp_precise *kp_prec = malloc(sizeof(Kp_precise));
      kp_prec->x = cur_kp->x + cur_kp->offset[0];
      kp_prec->y = cur_kp->y + cur_kp->offset[1];
      kp_prec->DoG_idx = cur_kp->DoG_idx + cur_kp->offset[2];
      
      cur_kp->is_valid = 1;

      int octave_num = cur_kp->DoG_idx / (NUM_BLURS-1);
      int blur_idx = cur_kp->DoG_idx % (NUM_BLURS-1);
      
      int width = cur_kp->DoG->width;
      int height = cur_kp->DoG->height;
      int width_diff = get_octave_im_dim(IMAGE_WIDTH, octave_num) - width;
      int height_diff = get_octave_im_dim(IMAGE_HEIGHT, octave_num) - height;

      // kp_prec->real_y = (height_diff / 2) * (1 << octave_num) + kp_prec->y * (1<<octave_num);
      kp_prec->real_y = (height_diff / 2) * (1 << octave_num) + cur_kp->y * (1<<octave_num);
      kp_prec->real_x = (width_diff / 2) * (1<<octave_num) + kp_prec->x * (1<<octave_num);

      assert(height_diff == 24);
      assert(width_diff == 24);

      if ((double)rand() / RAND_MAX < 0.01) {
          printf("get_octave_im_dim(IMAGE_WIDTH, octave_num): %d, width: %d\n", get_octave_im_dim(IMAGE_WIDTH, octave_num), width);
          printf("get_octave_im_dim(IMAGE_HEIGHT, octave_num) %d, height: %d\n", get_octave_im_dim(IMAGE_HEIGHT, octave_num), height);
          // printf("Localized kp!\n");
          printf("Octave: %d, Blur: %d\n", octave_num, blur_idx);
          printf("DoG width, height: %d, %d, width_diff, height_diff: %d, %d\n", width, height, width_diff, height_diff);
          printf("Orig x, y: %d, %d\n", cur_kp->x, cur_kp->y);
          printf("real_x: %lf, real_y: %lf, kp_prec->x: %f, kp_prec->y: %f\n", kp_prec->real_x, kp_prec->real_y, kp_prec->x, kp_prec->y);
      }

      cur_kp->kp_precise = kp_prec;

    }

    idx++;
  }
  print_keypoints_stats(keypoints);
}

void print_keypoints_stats(Keypoints *keypoints) {
	int idx = 0;
	int num_valid = 0;	
	while (idx < keypoints->count) {
		if (keypoints->kp_list[idx]->is_valid) {
			num_valid++;	
		}
		idx++;		
	}
	printf("Total keypoints: %d. Total valid: %d, Total invalid: %d\n", keypoints->count, num_valid, keypoints->count-num_valid); 
}

// format: row col octave_num blur_idx is_valid
void write_kps_to_txt_file(Keypoints *keypoints, const char *file_path) {
    printf("Writing %d keypoints to file %s...\n", keypoints->count, file_path);
    FILE *file = fopen(file_path, "w");
    if (!file) {
        perror("Error opening file for writing");
        return;
    }

    for (int i = 0; i < keypoints->count; i++) {
      Kp *cur_kp = keypoints->kp_list[i];
      if (!cur_kp->is_valid) continue;
      fprintf(file, "%f %f %d %d", cur_kp->kp_precise->real_y, cur_kp->kp_precise->real_x, cur_kp->DoG_idx/(NUM_BLURS-1), cur_kp->DoG_idx % (NUM_BLURS - 1));
      fprintf(file, "\n");
    }
    printf("Wrote %d keypoints to file %s!\n", keypoints->count, file_path);

    fclose(file);
}
