#include <iostream>
#include "VtbOpt.h"
#include <verilated.h>
#include <verilated_vcd_c.h>

#define STB_IMAGE_IMPLEMENTATION
// Copy stb_image.h into obj_dir to make this compile
#include "stb_image.h"
// #define DEBUG_OCM_RW

static double
float64_16(uint16_t x) {
    // https://github.com/skeeto/scratch/blob/master/misc/float16.c
    int s = (x     & 0x8000);
    int e = (x>>10 & 0x001f) - 15;
    int m = (x     & 0x03ff);

    switch (e) {
        case -15: if (!m) {
                e = 0;
            } else {
                // convert from denormal
                e += 1023 + 1;
                while (!(m&0x400)) {
                    e--;
                    m <<= 1;
                }
                m &= 0x3ff;
            }
            break;
        case +16: m = !!m << 9;  // canonicalize to quiet NaN
            e = 2047;
            break;
        default:  e += 1023;
    }

    uint64_t b = (uint64_t)s<<48 |
                 (uint64_t)e<<52 |
                 (uint64_t)m<<42;
    double f;
    memcpy(&f, &b, 8);
    return f;
}

void printOCM(unsigned char* mem, int offset, int len) {
	for(int i = offset; i < offset+len; i++){
		printf("%.3f\t", float64_16(((uint16_t *) mem)[i]));
		if((i-9) % 10 == 0)
			printf("\n");
	}
  printf("\n");
}

void printFeatMap(unsigned int* mem, int offset, int len) {
	for(int i = offset; i < offset+len; i++){
		printf("%.3f\t", float64_16(((uint16_t *) mem)[i]));
		if((i-9) % 10 == 0)
			printf("\n");
	}
  printf("\n");
}

void printOCMHex(unsigned char* mem, int offset, int len) {
	for(int i = offset; i < (offset+len)*2; i+=2){
		printf("%02x %02x\t", *((unsigned char *)mem+i), *((unsigned char *)mem+i+1));
		if((i-9) % 10 == 0)
			printf("\n");
	}
  printf("\n");
}

void printFeatMapHex(unsigned int* mem, int offset, int len) {
	for(int i = offset; i < (offset+len)*2; i+=2){
		printf("%02x %02x\t", *((unsigned char *)mem+i), *((unsigned char *)mem+i+1));
		if((i-9) % 10 == 0)
			printf("\n");
	}
  printf("\n");
}

void cleanOCM(uint8_t* mem, size_t len) {
  for(size_t i = 0; i < len; i++) {
    mem[i] = 0;
  }
}

unsigned char *
get_sending_weight(uint16_t in_channel, uint16_t out_channel, uint16_t epoch, unsigned char *image_data) {
    int fixed_index[] = {1792, 73856, 295168, 590080, 1180160, 2359808, 2359808, 2359808};
    int fixed_in_channel[] = {3, 64, 128, 256, 256, 512, 512, 512};
    int offset = 0;
    for (int i = 0; i < epoch; i++) {
        offset += fixed_index[i];
    }
    unsigned char *current = image_data + offset * 2;
    current += out_channel * fixed_in_channel[epoch] * 3 * 3 * 2;
    current += in_channel * 3 * 3 * 2;
    return current;
}

unsigned char *get_sending_bias(uint16_t in_channel, uint16_t out_channel, uint16_t epoch, unsigned char *image_data) {
    int fixed_index[] = {1792, 73856, 295168, 590080, 1180160, 2359808, 2359808, 2359808};
    int fixed_bias_len[] = {64, 128, 256, 256, 512, 512, 512, 512};
    int offset = 0;
    for (int i = 0; i < epoch; i++) {
        offset += fixed_index[i];
    }
    unsigned char *current = image_data + offset * 2;
    current += fixed_index[epoch] * 2;
    current -= fixed_bias_len[epoch] * 2;
    current += out_channel * 2;
    return current;
}

float f16_to_f32(uint16_t __x) {
  // https://blog.csdn.net/ysaeeiderup/article/details/124104042
  unsigned short n = *((unsigned short *)&__x);
  unsigned int x = (unsigned int)n;
  x = x & 0xffff;
  unsigned int sign = x & 0x8000;                   //符号位
  unsigned int exponent_f16 = (x & 0x7c00) >> 10;   //half指数位
  unsigned int mantissa_f16 = x & 0x03ff;           //half小数位
  unsigned int y = sign << 16;
  unsigned int exponent_f32;                        //float指数位
  unsigned int mantissa_f32;                        //float小数位
  unsigned int first_1_pos = 0;                     //（half小数位）最高位1的位置
  unsigned int mask;
  unsigned int hx;
 
  hx = x & 0x7fff;
 
  if (hx == 0) {
    return *((float *)&y);
  }
  if (hx == 0x7c00) {
    y |= 0x7f800000;
    return *((float *)&y);
  }
  if (hx > 0x7c00) {
    y = 0x7fc00000;
    return *((float *)&y);
  }
 
  exponent_f32 = 0x70 + exponent_f16;
  mantissa_f32 = mantissa_f16 << 13;
 
  for (first_1_pos = 0; first_1_pos < 10; first_1_pos++) {
    if ((mantissa_f16 >> (first_1_pos + 1)) == 0) {
      break;
    }
  }
 
  if (exponent_f16 == 0) {
    mask = (1 << 23) - 1;
    exponent_f32 = exponent_f32 - (10 - first_1_pos) + 1;
    mantissa_f32 = mantissa_f32 << (10 - first_1_pos);
    mantissa_f32 = mantissa_f32 & mask;
  }
 
  y = y | (exponent_f32 << 23) | mantissa_f32;
 
  return *((float *)&y);
}

int main(int argc, const char ** argv, const char ** env) {
  Verilated::commandArgs(argc, argv);

  VtbOpt * dut = new VtbOpt;

  Verilated::traceEverOn(true);
  VerilatedVcdC * tfp = new VerilatedVcdC;
  dut->trace(tfp, 99);
  tfp->open("tbOpt.vcd");

  dut->clk = 0;
  dut->reset = 0;

  dut->start = 0;
  dut->read_length = 0;
  dut->read_data_dim = 0;

  // ---------- Load weight bias and input image ----------
  int width, height, channels;
  unsigned char *weight_bias_all = stbi_load("/home/scott/4840_final/hw/fpga/hps_system/weight_bias_old.jpg", &width, &height, &channels, 1);
  unsigned char *weight_bias_little_endian = (unsigned char *)malloc(width*height*sizeof(unsigned char));
  for(int _wb_idx = 0; _wb_idx < width*height; _wb_idx++) {
    if (_wb_idx % 2 == 0)
      weight_bias_little_endian[_wb_idx] = weight_bias_all[_wb_idx+1];
    else
      weight_bias_little_endian[_wb_idx] = weight_bias_all[_wb_idx-1];
  }
  printf("\nWeight&Bias loaded with fp16 %dx%d\n", width/2, height/2);

  unsigned char *input_pic = stbi_load("/home/scott/4840_final/hw/fpga/hps_system/dog.jpg", &width, &height, &channels, 3);
  printf("InputPic width %d height %d channels %d\n", width, height, channels);
  unsigned char input_pic_fp16[(width * height * channels) * 2];
  // Reshape input image: R11 G11 B11 R12 G12 B12 -> B11 B12 ... G11 G12 ... R11 R12
  // idx 2, 2+3, 2+6, ...,
  int i = 0;
  for (int channel = 2; channel >= 0; channel--) {
      for (int nth_pixel = 0; nth_pixel < 32 * 32; nth_pixel++) {
          float output = (float) input_pic[channel + 3 * nth_pixel] / (float) 255.0;
          // Float to half-precision float (fp16)
          uint32_t x = *((uint32_t *) &output);
          uint16_t h =
                  ((x >> 16) & 0x8000) | ((((x & 0x7f800000) - 0x38000000) >> 13) & 0x7c00) | ((x >> 13) & 0x03ff);

          ((uint16_t *) input_pic_fp16)[i++] = h;
      }
  }

  printf("Image:\t\t");
  printOCM(input_pic_fp16, 0, 10);
  printf("\t\t");
  printOCMHex(input_pic_fp16, 0, 10);
  printf("\n");

  // ------------------------------------------------------


  float feat_map_out_0[64*32*32] = {0};
  unsigned char feat_map_out_0_temp[64*32*32*2] = {0};

  unsigned char ocm0[4096] = {0};

  // Compile one send
  int input_ch = 0;
  int output_ch = 0;
  unsigned char *base_weight = get_sending_weight(input_ch, output_ch, 0, weight_bias_little_endian);
  unsigned char *base_bias = get_sending_bias(input_ch, output_ch, 0, weight_bias_little_endian);
  memcpy(ocm0, base_weight, 9 * sizeof(uint16_t));
  memcpy(ocm0 + 9 * sizeof(uint16_t), base_bias, 1 * sizeof(uint16_t));
  memcpy(ocm0 + 10 * sizeof(uint16_t), input_pic_fp16 + 32*32*input_ch*sizeof(uint16_t), 32*32*sizeof(uint16_t));

  // memcpy(ocm0, all_weights, 2048*sizeof(uint16_t));
  printf("\nocm0:\t\t");
  printOCM(ocm0, 0, 10);
  printf("\t\t");
  printOCMHex(ocm0, 0, 10);
  printf("\n");
  // Simulate ocm read behavior: data come out next cycle
  int ocm0_nextcycle_index = -1;

  uint8_t ocm1[4096] = {0};

  int time = 0;
  
  // enum TBProcess{initial_fm_start, initial_fm_end};
  // int h2f_1_time = 100;
  // int h2f_2_time = -1;
  // int f2h_1_time = -1;
  // int f2h_2_time = -1;
  // TBProcess currentProcess = h2f_1_start;

  // unsigned long counter = 0;
  for ( ; time < 2000000 ; time += 10) {
    // 50MHz Clock
    dut->clk = ((time % 20) >= 10) ? 1 : 0;
    // Mem
    if (dut->clk) {
      if (dut->ocm0_chip && dut->ocm0_clk_enab) {
        if (ocm0_nextcycle_index != -1) {
        #ifdef DEBUG_OCM_RW
        printf("[DEBUG] ocm0[%d]->0x%02x\n", ocm0_nextcycle_index, dut->ocm0_readdata);
        #endif
          dut->ocm0_readdata = ocm0[ocm0_nextcycle_index];
        }
      }
      if (dut->ocm1_chip && dut->ocm1_clk_enab) {
        #ifdef DEBUG_OCM_RW
        printf("[DEBUG] ocm1[%d]<-0x%02x\n", dut->ocm1_addr, dut->ocm1_writedata);
        #endif
        ocm1[dut->ocm1_addr] = dut->ocm1_writedata;
      }
    }
    // Initial reset
    if (time == 0) dut->reset = 1;
    if (time == 40) dut->reset = 0;

    // ------------------h2f_1--------------------------
    if (time == 100) {
      dut->read_data_dim = 32;
      dut->read_length = (32*32+10)*2;
      dut->start = 1;
    }

    dut->eval();
    tfp->dump( time );

    // Mem
    if (dut->ocm0_chip && dut->ocm0_clk_enab) {
      ocm0_nextcycle_index = dut->ocm0_addr;
    }

    // if (dut->conv_state == 3) {
    //   counter++;
    // }
  }

  memcpy(feat_map_out_0_temp + output_ch*(32*32*2), dut->feat_map_out, 32*32*2);
  for(int i = 0; i < 32*32*2; i+=2) {
    feat_map_out_0[i] += f16_to_f32(feat_map_out_0_temp[i+1] << 8 | feat_map_out_0_temp[i]);
  }

  // printf("Counter=%lu\n", counter);

  printf("\nfeat_map_in:\t");
  printFeatMap(dut->feat_map_in, 0, 10);
  printf("\t\t");
  printFeatMapHex(dut->feat_map_out, 0, 10);
  printf("\n");
      
  printf("\nfeat_map_out:\t");
  printFeatMap(dut->feat_map_out, 0, 10);
  printf("\t\t");
  printFeatMapHex(dut->feat_map_out, 0, 10);
  printf("\n");

  printf("\nOCM1:\t\t");
  printOCM(ocm1, 0, 10);
  printf("\t\t");
  printOCMHex(ocm1, 0, 10);
  printf("\n");

  tfp->close();
  delete tfp;

  dut->final();
  delete dut;

  return 0;
}
