#pragma once
#ifndef NETWORK
#define NETWORK
#include <stdio.h>
#include <stdint.h>

typedef struct {
    int w;
    int h;
    int d;
    float * parmas;
} Conv_IO;

typedef  struct{
	int out;
	int in;
	int kh;
	int kw;
  float* parmas;
  float* bias;
} Conv;

typedef  struct{
	int out;	
	int in;	
	float* parmas;
  float* bias;
} Fc;

typedef struct {
	int width;
	float* parmas;
} Fc_IO;

typedef struct {
	Conv con1;
	Conv con2;
	Conv con3;
	Conv con4;
	Fc fc1;
	Fc fc2;
} Network;

Conv_IO CreateConv(int w, int h, int d);
Network ReadNetwork(char* filename);
Fc ReadFc(FILE* fp);
Conv ReadConv(FILE*f );
int16_t float_to_fixed(float input);
float fixed_to_float(int16_t input);
void LoadNetwork(Network nt);

#endif
#include "network.h"
#include <stdlib.h>

int16_t float_to_fixed(float input) {
    return (int16_t)(input * 128.0f);
}
float fixed_to_float(int16_t input) {
    return (float)input/128.0;
}

Conv_IO CreateConv(int w, int h, int d) 
{
    Conv_IO conv;
    conv.w = w;
    conv.h = h;
    conv.d = d;

    float * parmas;
    parmas = (float*)malloc(d*w*h*sizeof(float));

    conv.parmas = parmas;
    return conv;
}

Conv ReadConv(FILE* fp)
{
	Conv conv;
	int out, in, kw, kh, bias;
	fscanf(fp, "%d%d%d%d", &out, &in, &kh, &kw);

	int len = kh*kw*out*in;
	conv.parmas = (float*)malloc(len*sizeof(float));
	for(int i = 0; i < len; i ++ )
		fscanf(fp, "%f", &conv.parmas[i]);

	fscanf(fp, "%d", &bias);
	printf("bias size:%d\n", bias);
	conv.bias = (float*)malloc(bias*sizeof(float));
	for(int i = 0; i < bias; i ++ )
		fscanf(fp, "%f", &conv.bias[i]);

	conv.out = out;
	conv.in = in;
	conv.kh = kh;
	conv.kw = kw;

	return conv;
}

void LoadConv(Conv conv, FILE *f) {
	for (int i = 0; i < conv.in*conv.out*9; i ++ ) {
		fprintf(f, "%x\n", float_to_fixed(conv.parmas[i]));
	}
	for (int i = 0; i < conv.out; i ++) {
		fprintf(f, "%x\n", float_to_fixed(conv.bias[i]));
	}
	fprintf(f, "\n\n");
} 

void LoadFully(Fc fc, FILE *f) {
	for (int i = 0; i < fc.in * fc.out; i ++ ) {
		fprintf(f, "%x\n", float_to_fixed(fc.parmas[i]));
	}
	for(int i = 0; i < fc.out; i ++ ) {
		fprintf(f, "%x\n", float_to_fixed(fc.bias[i]));
	}
	fprintf(f, "\n\n");
}

void LoadNetwork(Network nt) {
	FILE *f = fopen("weight.txt", "w");
	LoadConv(nt.con1, f);
	LoadConv(nt.con2, f);
	LoadConv(nt.con3, f);
	LoadConv(nt.con4, f);
	LoadFully(nt.fc1, f);
	LoadFully(nt.fc2, f);
	fclose(f);
}

Fc ReadFc(FILE* fp)
{
	Fc fc;
	int out, in, bias;
	fscanf(fp, "%d%d", &out, &in);
	printf("Reading fully: %d %d\n", out, in);

	int len = out*in;
	fc.parmas = (float*)malloc(len*sizeof(float));
	for(int i = 0; i < len; i ++ )
		fscanf(fp, "%f", &fc.parmas[i]);

	fscanf(fp, "%d", &bias);
	fc.bias = (float*)malloc(bias*sizeof(float));
	for(int i = 0; i < bias; i ++ )
		fscanf(fp, "%f", &fc.bias[i]);
	fc.out = out;
	fc.in = in;

	return fc;
}

Network ReadNetwork(char* filename)
{
	Network nt; 
	FILE *fp = fopen(filename, "r");
	nt.con1 = ReadConv(fp);
	nt.con2 = ReadConv(fp);
	nt.con3 = ReadConv(fp);
	nt.con4 = ReadConv(fp);
	nt.fc1 = ReadFc(fp);
	nt.fc2 = ReadFc(fp);

	return nt;
}
