#ifndef _MEMORY_H
#define _MEMORY_H

#include <linux/ioctl.h>

#define MEMORY_MAGIC 'M'

typedef struct {
	int pos;
	short value;
} mem_arg_t;

typedef struct {
	int num_input;
	int num_bias;
	int num_weights;
	int src_chan;
	int dst_chan;
	int cols;
	int rows;
	int pool_size;
	short *input_p;
	short *network;
	short *bias;
	short *output;
} conv_arg_t;

#define MEM_WRITE _IOW(MEMORY_MAGIC, 1, mem_arg_t)
#define CONV_WRITE _IOW(MEMORY_MAGIC, 3, conv_arg_t)
#define FC_WRITE _IOW(MEMORY_MAGIC, 4, conv_arg_t)
#define MEM_READ _IOR(MEMORY_MAGIC, 2, mem_arg_t)
#include <linux/module.h>
#include <linux/init.h>
#include <linux/errno.h>
#include <linux/version.h>
#include <linux/kernel.h>
#include <linux/platform_device.h>
#include <linux/miscdevice.h>
#include <linux/slab.h>
#include <linux/io.h>
#include <linux/of.h>
#include <linux/of_address.h>
#include <linux/fs.h>
#include <linux/uaccess.h>
#include "memory.h"

#define DRIVER_NAME "memory"

#define MEM_SIZE 163840
#define MEM(x, y) ((x)+(y)*2)
#define SRCCHAN(x) ((x)+(1<<19)*2)
#define DSTCHAN(x) ((x)+(1<<19)*2+2)
#define NUMCOL(x) ((x)+(1<<19)*2+4)
#define NUMROW(x) ((x)+(1<<19)*2+6)
#define FC(x)		  ((x)+(1<<19)*2+8)
#define DOPOOL(x) ((x)+(1<<19)*2+10)
#define POOLSZ(x) ((x)+(1<<19)*2+12)
#define POOLSD(x) ((x)+(1<<19)*2+14)
#define DATASL(x) ((x)+(1<<19)*2+16)
#define DATASH(x) ((x)+(1<<19)*2+18) 
#define WEITSL(x) ((x)+(1<<19)*2+20) 
#define WEITSH(x) ((x)+(1<<19)*2+22) 
#define BIASSL(x) ((x)+(1<<19)*2+24) 
#define BIASSH(x) ((x)+(1<<19)*2+26) 
#define OUTPSL(x) ((x)+(1<<19)*2+28) 
#define OUTPSH(x) ((x)+(1<<19)*2+30) 
#define START(x) 	((x)+(1<<19)*2+32) 

struct memory_dev {
	struct resource res; /* Resource: our registers */
	void __iomem *virtbase; /* Where registers can be accessed in memory */
} dev; 
short read_value(int pos) {
	return ioread16(MEM(dev.virtbase, pos));
}

void write_value(int pos, short value) {
	iowrite16(value, MEM(dev.virtbase, pos));
}

void write_fc_arg(conv_arg_t conv_arg) {
	int pos = 0, i;
	for (i = 0; i < conv_arg.num_input; i ++ ) {	
		iowrite16(conv_arg.input_p[i], MEM(dev.virtbase, pos++));
	}
	for (i = 0; i < conv_arg.num_weights; i ++ ) {
		iowrite16(conv_arg.network[i], MEM(dev.virtbase, pos++));
	}
	for (i = 0; i < conv_arg.num_bias; i ++ ) {
		iowrite16(conv_arg.bias[i], MEM(dev.virtbase, pos++));
	}
	printk("here\n");
	iowrite16(conv_arg.src_chan, SRCCHAN(dev.virtbase));
	iowrite16(conv_arg.dst_chan, DSTCHAN(dev.virtbase));
	iowrite16(conv_arg.cols, NUMCOL(dev.virtbase));
	iowrite16(conv_arg.rows, NUMROW(dev.virtbase));
	int data_start = 0, weight_start = conv_arg.num_input, 
			bias_start = weight_start + conv_arg.num_weights,
			output_start = bias_start + conv_arg.num_bias;
	iowrite16(1, FC(dev.virtbase));
	iowrite16(0, DOPOOL(dev.virtbase));
	iowrite16(0, POOLSZ(dev.virtbase));
	iowrite16(0, POOLSD(dev.virtbase));
	iowrite16((short)(data_start & 0xFFFF), DATASL(dev.virtbase));
	iowrite16((short)(data_start >> 16), 		DATASH(dev.virtbase));
	iowrite16((short)(weight_start & 0xFFFF), WEITSL(dev.virtbase));
	iowrite16((short)(weight_start >> 16), 		WEITSH(dev.virtbase));
	iowrite16((short)(bias_start & 0xFFFF), 	BIASSL(dev.virtbase));
	iowrite16((short)(bias_start >> 16), 			BIASSH(dev.virtbase));
	iowrite16((short)(output_start & 0xFFFF), OUTPSL(dev.virtbase));
	iowrite16((short)(output_start >> 16), 		OUTPSH(dev.virtbase));
	iowrite16(1, 						 START(dev.virtbase));
}

void read_fc(conv_arg_t *conv_arg) {
	while(!ioread16(FC(dev.virtbase))) {
	}
	int output_size = conv_arg->dst_chan;
	printk("%d\n", output_size);
	conv_arg->output = (short*)kmalloc(sizeof(short)*output_size, GFP_KERNEL);
	int s = conv_arg->num_input + conv_arg->num_bias + conv_arg->num_weights; 
	int i = 0;
	for(i = 0; i < output_size; i ++ ) {
		conv_arg->output[i] = ioread16(MEM(dev.virtbase, i+s));	
	}
	for(i = 0; i < output_size; i ++ ) {
		//printk("%x\n", conv_arg->output[i]);
	}
	printk("finish\n");
}

void write_conv_arg(conv_arg_t conv_arg) {
	int pos = 0, i, j, k, cnt = 0;
	// d, w, h
	int d = conv_arg.src_chan, w = conv_arg.cols + 2, h = conv_arg.rows + 2;
	for(i = 0; i < d; i ++ ) {
		for(j = 0; j < w; j ++ ) {
			for (k = 0; k < h; k ++ ) {
				if (k == 0 || k == h - 1 || j == 0 || j == w - 1 ) {
					iowrite16(0, MEM(dev.virtbase, pos++));
				} else {
					iowrite16(conv_arg.input_p[cnt++], MEM(dev.virtbase, pos++));
				}
			}
		}
	}
	for (i = 0; i < conv_arg.num_weights; i ++ ) {
		iowrite16(conv_arg.network[i], MEM(dev.virtbase, pos++));
	}
	for (i = 0; i < conv_arg.num_bias; i ++ ) {
		iowrite16(conv_arg.bias[i], MEM(dev.virtbase, pos++));
	}
	printk("here\n");
	iowrite16(conv_arg.src_chan, SRCCHAN(dev.virtbase));
	iowrite16(conv_arg.dst_chan, DSTCHAN(dev.virtbase));
	iowrite16(conv_arg.cols+2, NUMCOL(dev.virtbase));
	iowrite16(conv_arg.rows+2, NUMROW(dev.virtbase));
	int data_start = 0, weight_start = d*w*h, 
			bias_start = weight_start + conv_arg.num_weights,
			output_start = bias_start + conv_arg.num_bias;
	iowrite16(0, FC(dev.virtbase));
	if (conv_arg.pool_size != 0) {
		iowrite16(1, DOPOOL(dev.virtbase));
	} else {
		iowrite16(0, DOPOOL(dev.virtbase));
	}
	iowrite16(conv_arg.pool_size, POOLSZ(dev.virtbase));
	iowrite16(2, POOLSD(dev.virtbase));
	iowrite16((short)(data_start & 0xFFFF), DATASL(dev.virtbase));
	iowrite16((short)(data_start >> 16), 		DATASH(dev.virtbase));
	iowrite16((short)(weight_start & 0xFFFF), WEITSL(dev.virtbase));
	iowrite16((short)(weight_start >> 16), 		WEITSH(dev.virtbase));
	iowrite16((short)(bias_start & 0xFFFF), 	BIASSL(dev.virtbase));
	iowrite16((short)(bias_start >> 16), 			BIASSH(dev.virtbase));
	iowrite16((short)(output_start & 0xFFFF), OUTPSL(dev.virtbase));
	iowrite16((short)(output_start >> 16), 		OUTPSH(dev.virtbase));
	iowrite16(1, 						 START(dev.virtbase));
}

void read_conv(conv_arg_t *conv_arg) {
	while(!ioread16(FC(dev.virtbase))) {
	}
	int output_size = conv_arg->dst_chan*(conv_arg->cols)*(conv_arg->rows);
	if (conv_arg->pool_size) {
		output_size /= conv_arg->pool_size * conv_arg->pool_size;
	}
		printk("poolsize:%d\n", conv_arg->pool_size);
	printk("%d\n", output_size);
	conv_arg->output = (short*)kmalloc(sizeof(short)*output_size, GFP_KERNEL);

	int d = conv_arg->src_chan, w = conv_arg->cols + 2, h = conv_arg->rows + 2;
	int s = d*w*h+ conv_arg->num_bias + conv_arg->num_weights; 
	printk("%d\n", s);
	int i = 0;
	for(i = 0; i < output_size; i ++ ) {
		conv_arg->output[i] = ioread16(MEM(dev.virtbase, i+s));	
	}
	
	printk("%x\n", output_size);
	if (conv_arg->pool_size==8)
	for(i = 0; i < output_size; i ++ ) {
		printk("%d %x\n", i, conv_arg->output[i]);
	}
	printk("finish\n");
}


/*
 * Handle ioctl() calls from userspace:
 * Read or write the segments on single digits.
 * Note extensive error checking of arguments
 */
static long memory_ioctl(struct file *f, unsigned int cmd, unsigned long arg)
{
	mem_arg_t vla;
	int i;
	conv_arg_t conv_arg;
	conv_arg_t arg2;

	switch (cmd) {
		case CONV_WRITE:
			if (copy_from_user(&conv_arg, (conv_arg_t *) arg, sizeof(conv_arg_t)))
				return -EACCES;
			arg2.input_p = (short*)kmalloc(sizeof(short)*conv_arg.num_input, GFP_KERNEL);
			arg2.network = (short*)kmalloc(sizeof(short)*conv_arg.num_weights, GFP_KERNEL);
			arg2.bias = (short*)kmalloc(sizeof(short)*conv_arg.num_bias, GFP_KERNEL);
			arg2.output = conv_arg.output;
			if (copy_from_user(arg2.input_p, conv_arg.input_p, sizeof(short)*conv_arg.num_input))
				return -EACCES;
			if (copy_from_user(arg2.network, conv_arg.network, sizeof(short)*conv_arg.num_weights))
				return -EACCES;
			if (copy_from_user(arg2.bias, conv_arg.bias, sizeof(short)*conv_arg.num_bias))
				return -EACCES;
			conv_arg.input_p = arg2.input_p;
			conv_arg.network = arg2.network;
			conv_arg.bias = arg2.bias;
			write_conv_arg(conv_arg);
			read_conv(&conv_arg);
			int output_size = conv_arg.dst_chan*(conv_arg.cols)*(conv_arg.rows);
			printk("%d %d %d", conv_arg.cols, conv_arg.rows, conv_arg.dst_chan);
			if (conv_arg.pool_size) {
				output_size /= conv_arg.pool_size * conv_arg.pool_size;
			}
			if (copy_to_user(arg2.output, conv_arg.output, sizeof(short)*output_size)) {
				return -EACCES;
			}
			break;

		case FC_WRITE:
			if (copy_from_user(&conv_arg, (conv_arg_t *) arg, sizeof(conv_arg_t)))
				return -EACCES;
				printk("%d\n", conv_arg.num_input);
			arg2.input_p = (short*)kmalloc(sizeof(short)*conv_arg.num_input, GFP_KERNEL);
			arg2.network = (short*)kmalloc(sizeof(short)*conv_arg.num_weights, GFP_KERNEL);
			arg2.bias = (short*)kmalloc(sizeof(short)*conv_arg.num_bias, GFP_KERNEL);
			arg2.output = conv_arg.output;
			if (copy_from_user(arg2.input_p, conv_arg.input_p, sizeof(short)*conv_arg.num_input))
				return -EACCES;
			if (copy_from_user(arg2.network, conv_arg.network, sizeof(short)*conv_arg.num_weights))
				return -EACCES;
			if (copy_from_user(arg2.bias, conv_arg.bias, sizeof(short)*conv_arg.num_bias))
				return -EACCES;
			conv_arg.input_p = arg2.input_p;
			conv_arg.network = arg2.network;
			conv_arg.bias = arg2.bias;
			write_fc_arg(conv_arg);		
			read_fc(&conv_arg);
			printk("%d\n", conv_arg.dst_chan);
			if (copy_to_user(arg2.output, conv_arg.output, sizeof(short)*conv_arg.dst_chan)) {
				return -EACCES;
			}
			break;

		case MEM_READ:
			if (copy_from_user(&vla, (mem_arg_t *) arg, sizeof(mem_arg_t)))
				return -EACCES;
			vla.value = read_value(vla.pos);
			if (copy_to_user((mem_arg_t *) arg, &vla, sizeof(mem_arg_t)))
				return -EACCES;
			break;
		case MEM_WRITE:
			if (copy_from_user(&vla, (mem_arg_t *) arg, sizeof(mem_arg_t)))
				return -EACCES;
			write_value(vla.pos, vla.value);
			break;
	}
	return 0;
}
/* The operations our device knows how to do */
static const struct file_operations memory_fops = {
	.owner		= THIS_MODULE,
	.unlocked_ioctl = memory_ioctl,
};

/* Information about our device for the "misc" framework -- like a char dev */
static struct miscdevice memory_misc_device = {
	.minor		= MISC_DYNAMIC_MINOR,
	.name		= DRIVER_NAME,
	.fops		= &memory_fops,
};

/*
 * Initialization code: get resources (registers) and display
 * a welcome message
 */
static int __init memory_probe(struct platform_device *pdev)
{
	int ret;

	/* Register ourselves as a misc device: creates /dev/memory */
	ret = misc_register(&memory_misc_device);

	/* Get the address of our registers from the device tree */
	ret = of_address_to_resource(pdev->dev.of_node, 0, &dev.res);
	if (ret) {
		ret = -ENOENT;
		goto out_deregister;
	}

	/* Make sure we can use these registers */
	if (request_mem_region(dev.res.start, resource_size(&dev.res),
				DRIVER_NAME) == NULL) {
		ret = -EBUSY;
		goto out_deregister;
	}

	/* Arrange access to our registers */
	dev.virtbase = of_iomap(pdev->dev.of_node, 0);
	if (dev.virtbase == NULL) {
		ret = -ENOMEM;
		goto out_release_mem_region;
	}

	return 0;

out_release_mem_region:
	release_mem_region(dev.res.start, resource_size(&dev.res));
out_deregister:
	misc_deregister(&memory_misc_device);
	return ret;
}

/* Clean-up code: release resources */
static int memory_remove(struct platform_device *pdev)
{
	iounmap(dev.virtbase);
	release_mem_region(dev.res.start, resource_size(&dev.res));
	misc_deregister(&memory_misc_device);
	return 0;
}

/* Which "compatible" string(s) to search for in the Device Tree */
#ifdef CONFIG_OF
static const struct of_device_id memory_of_match[] = {
	{ .compatible = "csee4840,cnn_interface-1.0" },
	{},
};
MODULE_DEVICE_TABLE(of, memory_of_match);
#endif

/* Information for registering ourselves as a "platform" driver */
static struct platform_driver memory_driver = {
	.driver	= {
		.name	= DRIVER_NAME,
		.owner	= THIS_MODULE,
		.of_match_table = of_match_ptr(memory_of_match),
	},
	.remove	= __exit_p(memory_remove),
};

/* Called when the module is loaded: set things up */
static int __init memory_init(void)
{
	pr_info(DRIVER_NAME ": init\n");
	return platform_driver_probe(&memory_driver, memory_probe);
}

/* Calball when the module is unloaded: release resources */
static void __exit memory_exit(void)
{
	platform_driver_unregister(&memory_driver);
	pr_info(DRIVER_NAME ": exit\n");
}

module_init(memory_init);
module_exit(memory_exit);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Xincheng Yu, Embedded System");
MODULE_DESCRIPTION("memory driver");

#endif
