#include <stdio.h>
#include "memory.h"
#include <sys/ioctl.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <stdlib.h>
#include <fcntl.h> 
#include <string.h> 
#include <unistd.h>
#include <stdbool.h>
#include "network.h"
#include "get_jgp.h"

#define INIT_X 300
#define INIT_Y 200

int memory_fd;

void read_value() {
  mem_arg_t vla;
	vla.pos = 10;
	if (ioctl(memory_fd, MEM_READ, &vla)) {
		perror("ioctl(MEM_READ) failed");	
	}
	printf("%d\n", vla.value);
}

void write_value() {
  mem_arg_t vla;
	vla.pos = 10;
	vla.value = 20;
	if (ioctl(memory_fd, MEM_WRITE, &vla)) {
		perror("ioctl(MEM_READ) failed");	
	}
}

Conv_IO get_input() {
	FILE* f_in = fopen("input", "r");
	int w, h, d;
	
	fscanf(f_in, "%d%d%d", &d, &w, &h);
	Conv_IO input = CreateConv(w, h, d);
	for (int i = 0; i < w * h * d; i ++ ) {
		fscanf(f_in, "%f", &input.parmas[i]);
	}
	return input;
}

Conv_IO run_conv(Conv_IO input, Conv conv, int pool_size) {
	int num_input = input.w*input.h*input.d, num_bias = conv.out, 
			num_weights = conv.out*conv.in*conv.kw*conv.kh;
	short *input_p, *network, *bias, *output;
	input_p = (short*)malloc(num_input*sizeof(short));
	network = (short*)malloc(num_weights*sizeof(short));
	bias = (short*)malloc(num_bias*sizeof(short));

	for(int i = 0; i < num_input; i ++ ) {
		input_p[i] = float_to_fixed(input.parmas[i]);
		if (pool_size == 8) {
			input_p[i] <<= 2;
		}
	}
	for(int i = 0; i < num_weights; i ++ ) {
		network[i] = float_to_fixed(conv.parmas[i]);
		if (pool_size == 8)
			network[i] >>= 2;
	}
	for(int i = 0; i < num_bias; i ++ ) {
		bias[i] = float_to_fixed(conv.bias[i]);
	}
	int output_size = conv.out*(input.w)*(input.h);
	if (pool_size) {
		output_size /= pool_size * pool_size;
	}
	output = (short*)malloc(output_size*sizeof(short));
	conv_arg_t args = {num_input, num_bias, num_weights, 
										conv.in, conv.out, input.w, input.h, pool_size, 
										input_p, network, bias, output};
	float* n_in = (float*)malloc(output_size*sizeof(float));

	if (ioctl(memory_fd, CONV_WRITE, &args)) {
		perror("ioctl(CONV_WRITE) failed");	
	}
	printf("Conv run ok!\n");
	printf("%d\n", output_size);
	for(int i = 0; i < output_size; i ++ ) {
		if(pool_size == 8)
			printf("%x\n", output[i]);
		n_in[i] = fixed_to_float(output[i]);
	}
	Conv_IO next_IO = {input.w, input.h, conv.out, n_in};
	if (pool_size != 0) {
		next_IO.w /= pool_size, next_IO.h /= pool_size;
	}
	return next_IO;
}

Fc_IO fc_input() {
	Fc_IO input;
	FILE *f = fopen("input", "r");
	int len = 0;
	printf("reading fc input");
	fscanf(f, "%d\n", &len);
	float *parma = (float*)malloc(sizeof(float)*len);
	input.parmas = parma;
	input.width = len;
	for(int i = 0; i < len; i ++ ) {
		fscanf(f, "%f", &parma[i]);	
	}
	return input;
}

Fc_IO run_fc(Fc_IO input, Fc fc) {
	short *input_p, *network, *bias, *output;
	int num_input = input.width, num_bias = fc.out,
			num_weights = fc.in * fc.out;
	input_p = (short*)malloc(num_input*sizeof(short));
	network = (short*)malloc(num_weights*sizeof(short));
	bias = (short*)malloc(num_bias*sizeof(short));
	output = (short*)malloc(fc.out*sizeof(short));
	float* f =  (float*)malloc(fc.out*sizeof(float));

	for(int i = 0; i < num_input; i ++ ) {
		input_p[i] = float_to_fixed(input.parmas[i]);
	}
	for(int i = 0; i < num_weights; i ++ ) {
		network[i] = float_to_fixed(fc.parmas[i]);
	}
	for(int i = 0; i < num_bias; i ++ ) {
		bias[i] = float_to_fixed(fc.bias[i]);
	}
	conv_arg_t args = {num_input, num_bias, num_weights, 
										fc.in, fc.out, 0, 0, 0, 
										input_p, network, bias, output};
	if (ioctl(memory_fd, FC_WRITE, &args)) {
		perror("ioctl(CONV_WRITE) failed");	
	}
	printf("FCINFO:%d %d\n", fc.in, fc.out);
	Fc_IO out = {fc.out, f};
	for(int i = 0; i < fc.out; i ++ ) {
	//	printf("%x\n", output[i]);
		f[i] = fixed_to_float(output[i]);
		if (f[i] < 0) f[i] = 0;
	}
	printf("fully connected layer ok\n");
	return out;
}

void compute_network(Network net, Conv_IO input) {
	Conv_IO input_2 = run_conv(input, net.con1, 0);
	Conv_IO input_3 = run_conv(input_2, net.con2, 2);
	Conv_IO input_4 = run_conv(input_3, net.con3, 2);
	for(int i = 0; i < 60; i ++ ) {
		printf("%f\n", input_4.parmas[i]);
	}
	Conv_IO input_5 = run_conv(input_4, net.con4, 2);

	Fc_IO fc = {input_5.d, input_5.parmas};

	Fc_IO fc1 = run_fc(fc, net.fc1);
	Fc_IO fc2 = run_fc(fc1, net.fc2);
	for(int i = 0; i < fc2.width; i ++ ) {
		printf("%f\n", fc2.parmas[i]);
	}
}

int main()
{
  mem_arg_t vla; int i;
  static const char filename[] = "/dev/memory";

  printf("Memory Userspace program started\n");

  if ( (memory_fd = open(filename, O_RDWR)) == -1) {
    fprintf(stderr, "could not open %s\n", filename);
    return -1;
  }
	graph_t g;
	Network net = ReadNetwork("./weights.txt");
	float *d = read_jpg("./b.jpg", &g);
	Conv_IO input = {g.width, g.height, 3, d};
//	for(int i = 0; i < 20; i ++ ) {
//		printf("%f\n", d[i]);
//	}
	
	//Conv_IO input = get_input();
	//compute_network(net, get_input());
	compute_network(net, input);


//	FILE *f;
//	f = fopen("weights", "r");
	
	//Conv conv;
	//conv = ReadConv(f);
	//run_conv(input, conv);

//	Fc_IO input = fc_input();
//	Fc fc = ReadFc(f);
//	run_fc(input, fc);
//	//printf("%d\n", conv.kw);

//	for(int i = 0; i < g.width * g.height*3; i ++ ) {
//		printf("%d %f\n",i, d[i]);
//	}

  printf("VGA BALL Userspace program terminating\n");
  return 0;
}
