module cnn_interface (
    input  logic        clk,
    input  logic        reset,
    input  logic [15:0] writedata,
    input  logic        write,
    input  logic        chipselect,
    input  logic [19:0] address,
    output logic [15:0] readdata
);

    logic [15:0] 	bram_wr_data;
    logic        	bram_we;
	logic [18:0]	bram_addr;
	logic [15:0] 	bram_rd_data;

	logic [18:0]	bram_rd_addr;
	logic [18:0]	bram_wr_addr;

	parameter BRAM_SIZE = 19'd174080;

    //assign bram_we = chipselect && write && (~address[18]);

	logic [15:0]	hps_mem_wr_data;
	logic			hps_mem_wr;
	logic [18:0]	hps_mem_wr_addr;
	logic			hps_mem_rd;
	logic [18:0]	hps_mem_rd_addr;
	logic [15:0]	hps_mem_rd_data;
	
	always @(*) begin
		hps_mem_wr 		= chipselect && write && (~address[19]);
		hps_mem_wr_addr	= (hps_mem_wr) ? address[18:0] : 'd0;
		hps_mem_wr_data = (hps_mem_wr) ? writedata : 'd0;
		hps_mem_rd		= chipselect && ~write && (~address[19]);
		hps_mem_rd_addr = (hps_mem_rd) ? address[18:0] : 'd0;
	end

    bram #(.ADDR_WIDTH(19), .ENTRY_NUMBER(BRAM_SIZE) ,.DATA_WIDTH(16)) bram_inst (
        .clk   (clk),
        .we    (bram_we),
        .addr  (bram_addr),
        .wdata (bram_wr_data),
        .rdata (bram_rd_data)
    );

	logic [15:0]	conv_wr_data;
	logic 			conv_wr;
	logic [18:0]	conv_wr_addr;

	logic [15:0]	conv_rd_data;
	logic			conv_rd;
	logic [18:0]	conv_rd_addr;

	logic			valid;

	conv conv_inst (
		.clk(clk),
		.reset(reset),
		.mem_rd_data_i(bram_rd_data),
		.mem_rd_addr_o(conv_rd_addr),
		.mem_rd(conv_rd),
		.mem_wr_data_o(conv_wr_data),
		.mem_wr_addr_o(conv_wr_addr),
		.mem_wr(conv_wr),
		.config_addr_i(address[18:0]),
		.config_write_i(write && address[19]),
		.chipselect_i(chipselect),
		.config_data_i(writedata),
		.valid(valid)
	);

	always @(*) begin
		if (hps_mem_wr) begin
			bram_wr_data = hps_mem_wr_data;
			bram_we		 = hps_mem_wr;
			bram_wr_addr = hps_mem_wr_addr;
		end else begin
			bram_wr_data = conv_wr_data;
			bram_we		 = conv_wr;
			bram_wr_addr = conv_wr_addr;
		end
	end

	logic hps_mem_rd_delayed;

	always @(posedge clk) begin
		hps_mem_rd_delayed <= hps_mem_rd;
	end

	always @(*) begin
		bram_rd_addr = 'd0;
		if (conv_rd) begin
			bram_rd_addr = conv_rd_addr;
		end else begin
			bram_rd_addr = hps_mem_rd_addr;
		end
	end

	always @(*) begin
		readdata = 0;
		if (hps_mem_rd_delayed) begin
			readdata = bram_rd_data;
		end else begin
			readdata = valid;
		end
	end

	assign conv_rd_data = bram_rd_data;

	assign bram_addr 	= bram_we ? bram_wr_addr : bram_rd_addr;

endmodule

module bram #(
    parameter ADDR_WIDTH = 19,
    parameter ENTRY_NUMBER = 174080,
    parameter DATA_WIDTH = 16
)(
    input  logic                   clk,
    input  logic                   we,         // Write enable
    input  logic [ADDR_WIDTH-1:0]  addr,
    input  logic [DATA_WIDTH-1:0]  wdata,
    output logic [DATA_WIDTH-1:0]  rdata
);

    // Inferred BRAM block with Quartus hint
    (* ramstyle = "block" *)
    logic [DATA_WIDTH-1:0] mem [ENTRY_NUMBER-1:0];

    always_ff @(posedge clk) begin
        if (we) mem[addr] <= wdata;

        rdata <= mem[addr]; // Synchronous read (1-cycle delay)
    end

endmodule

module conv (
	input clk,
	input reset,
	input 	[15:0] 	mem_rd_data_i,
	output 	[18:0] 	mem_rd_addr_o,
	output			mem_rd,
	output  [15:0]	mem_wr_data_o,
	output  [18:0]	mem_wr_addr_o,
	output			mem_wr,
	input	[4:0]	config_addr_i,
	input			config_write_i,
	input			chipselect_i,
	input	[15:0]	config_data_i,
	output logic	valid
);
		
	logic [15:0] 	num_src_chans;
	logic [15:0] 	num_dst_chans;
	logic [15:0] 	num_cols;
	logic [15:0] 	num_rows;
	logic			do_pool;
	logic			fc;
	logic [15:0]	pool_size;
	logic [15:0]	pool_stride;
	logic [31:0]	data_starting_addr;
	logic [31:0]	weight_starting_addr;
	logic [31:0]	bias_starting_addr;
	logic [31:0]	output_starting_addr;
	logic			start;

	always_ff @(posedge clk) begin
		if (reset) begin
			num_src_chans 	<= 'd0;
			num_dst_chans 	<= 'd0;
			num_cols		<= 'd0;
			num_rows		<= 'd0;
			do_pool			<= 0;
			fc				<= 0;
			pool_size 		<= 'd0;
			pool_stride		<= 'd0;
			data_starting_addr 		<= 'd0;
			weight_starting_addr 	<= 'd0;
			bias_starting_addr		<= 'd0;
			output_starting_addr	<= 'd0;
			start					<= 0;
		end else begin
			if (chipselect_i && config_write_i) begin			
				case (config_addr_i) 
					'd0: num_src_chans	<= config_data_i;
					'd1: num_dst_chans	<= config_data_i;
					'd2: num_cols		<= config_data_i;
					'd3: num_rows		<= config_data_i;
					'd4: fc				<= config_data_i[0];
					'd5: do_pool		<= config_data_i[0];
					'd6: pool_size		<= config_data_i;
					'd7: pool_stride	<= config_data_i;
					'd8: data_starting_addr[15:0] 	<= config_data_i;
					'd9: data_starting_addr[31:16] 	<= config_data_i;
					'd10: weight_starting_addr[15:0]	<= config_data_i;
					'd11: weight_starting_addr[31:16]	<= config_data_i;
					'd12: bias_starting_addr[15:0]		<= config_data_i;
					'd13: bias_starting_addr[31:16]		<= config_data_i;
					'd14: output_starting_addr[15:0]	<= config_data_i;
					'd15: output_starting_addr[31:16]	<= config_data_i;
					'd16: start							<= 1;
				endcase
			end
			else if (start) begin
				start <= 0;
			end
		end
	end

	logic [15:0] curr_input_x;
	logic [15:0] curr_input_y;

	logic [15:0] curr_output_x;
	logic [15:0] curr_output_y;

	logic [15:0]	curr_src_chan;
	logic [15:0]	curr_dst_chan;
	
	logic mem_weight_rd;
	logic [18:0] mem_weight_rd_addr;
	logic [18:0] mem_weight_rd_addr_delayed;
	logic mem_weight_valid;

	logic signed [15:0] weight_data [8:0];

	logic mem_data_rd;
	logic [18:0] mem_data_rd_addr;
	logic [18:0] mem_data_rd_addr_delayed;
	logic mem_data_valid;

	logic mem_prev_acc_rd;
	logic [18:0] mem_prev_acc_rd_addr;

	logic mem_bias_rd;
	logic [18:0] mem_bias_rd_addr;

	logic signed [15:0] bias_data;

	/******** POOL ********/
	logic [15:0] curr_pool_x;
	logic [15:0] curr_pool_y;

	logic [15:0] curr_pool_stride_x;
	logic [15:0] curr_pool_stride_y;
	logic [15:0] curr_pool_stride_x_delayed;
	logic [15:0] curr_pool_stride_y_delayed;

	logic [18:0] mem_pool_rd_addr;
	logic		 mem_pool_rd;
	logic [18:0] mem_pool_rd_addr_delayed;
	logic		 mem_pool_valid;
	logic [18:0] mem_pool_wr_addr;
	logic [15:0] mem_pool_wr_data;
	logic		 mem_pool_wr;

	logic [18:0] mem_fc_rd_addr;
	logic		 mem_fc_rd;
	logic [18:0] mem_fc_rd_addr_delayed;
	logic		 mem_fc_valid;
	logic [18:0] mem_fc_wr_addr;
	logic [15:0] mem_fc_wr_data;
	logic		 mem_fc_wr;


	logic signed [15:0] curr_pool_max;

	logic [4:0]  pool_cycle_cnt; 
	
	logic 			mem_output_wr;
	logic [18:0] 	mem_output_wr_addr;
	logic [15:0]	mem_output_wr_data;

	/******** FC ********/
	logic [15:0] curr_fc_w_row;
	logic [15:0] curr_fc_w_col;

	localparam S_INIT				= 'd0;
	localparam S_CONV				= 'd1;
	localparam S_SET_WEIGHT_ADDR	= 'd2;
	localparam S_LOAD_WEIGHT		= 'd3;
	localparam S_COMPUTE_POS		= 'd4;
	localparam S_KERNEL_LOOP		= 'd5;
	localparam S_SET_PREV_ADDR		= 'd7;
	localparam S_FETECH_PREV_SRC_CH	= 'd8;
	localparam S_CHECK_CH			= 'd9;
	localparam S_SET_BIAS_ADDR		= 'd10;
	localparam S_LOAD_BIAS			= 'd11;
	localparam S_RELU				= 'd12;
	localparam S_RESULT_WRITE_BACK	= 'd13;
	localparam S_UPDATE_POS_AND_CH	= 'd14;
	localparam S_CONV_DONE			= 'd15;
	localparam S_POOL				= 'd16;
	localparam S_POOL_FETCH_LOOP	= 'd17;
	localparam S_POOL_FETCH_LOOP_Y	= 'd18;
	localparam S_POOL_WRITE_BACK	= 'd19;
	localparam S_POOL_UPDATE_POS	= 'd20;
	localparam S_FC					= 'd21;
	localparam S_FC_LOOP			= 'd22;
	localparam S_FC_LOOP_WEIGHT		= 'd23;
	localparam S_FC_LOOP_ACC		= 'd24;
	localparam S_FC_SET_BIAS_ADDR	= 'd25;
	localparam S_FC_BIAS			= 'd26;
	localparam S_FC_WRITE_BACK		= 'd27;
	localparam S_FC_UPDATE_POS		= 'd28;
	localparam S_POOL_DONE			= 'd29;
	localparam S_FC_DONE			= 'd30;
	localparam S_LAYER_DONE			= 'd128;

	logic [7:0] curr_state;
	logic [7:0] next_state;

	always @(posedge clk) begin
		if (reset) begin
			curr_state <= S_INIT;
		end else begin	
			curr_state <= next_state;
		end
	end

	logic [4:0] cycle_cnt;

	always @(posedge clk) begin
		case (curr_state) 
			S_LOAD_WEIGHT: cycle_cnt <= (mem_weight_rd_addr_delayed == 'd8) ? 'd0 : cycle_cnt + 'd1;
			S_KERNEL_LOOP: cycle_cnt <= (mem_data_rd_addr_delayed == 'd8) ? 'd0 : cycle_cnt + 'd1;
			default: cycle_cnt <= 'd0;
		endcase
	end

	always @(posedge clk) begin
		if (reset) begin
			valid <= 0;
		end else begin
			case (curr_state)
				S_LAYER_DONE: 	valid <= 1;
				S_INIT:			valid <= start ? 0 : valid;
			endcase
		end
	end


	always @(*) begin
		case (curr_state) 
			S_INIT: 				next_state = (start) ? ((fc) ? S_FC : S_CONV) : S_INIT;
			S_CONV: 				next_state = S_SET_WEIGHT_ADDR;
			S_SET_WEIGHT_ADDR:		next_state = S_LOAD_WEIGHT;
			S_LOAD_WEIGHT:			next_state = (mem_weight_rd_addr_delayed == 'd8) ? S_COMPUTE_POS : S_LOAD_WEIGHT;
			S_COMPUTE_POS:			next_state = S_SET_PREV_ADDR;
			S_SET_PREV_ADDR:		next_state = S_FETECH_PREV_SRC_CH;
			S_FETECH_PREV_SRC_CH:	next_state = S_KERNEL_LOOP;
			S_KERNEL_LOOP:			next_state = (mem_data_rd_addr_delayed == 'd8) ? S_CHECK_CH : S_KERNEL_LOOP;
			S_CHECK_CH:				next_state = (curr_src_chan == num_src_chans - 'd1) ? S_SET_BIAS_ADDR : S_RESULT_WRITE_BACK;
			S_SET_BIAS_ADDR:		next_state = S_LOAD_BIAS;
			S_LOAD_BIAS:			next_state = S_RELU;
			S_RELU:					next_state = S_RESULT_WRITE_BACK;
			S_RESULT_WRITE_BACK:	next_state = S_UPDATE_POS_AND_CH;
			S_UPDATE_POS_AND_CH:	begin
				if ((curr_input_x == (num_cols - 'd2)) && (curr_input_y == (num_rows - 'd2))) begin
					if ((curr_src_chan == num_src_chans - 'd1)) begin
						if (do_pool) begin
							next_state = S_POOL;
						end else begin
							if (curr_dst_chan == num_dst_chans - 'd1) begin
								next_state = S_CONV_DONE;
							end else begin
								next_state = S_SET_WEIGHT_ADDR;
							end
						end
					end else begin
						next_state = S_SET_WEIGHT_ADDR;
					end
				end else begin
					next_state = S_COMPUTE_POS;
				end
			end
			S_CONV_DONE: 			next_state = (do_pool) ? S_POOL : S_LAYER_DONE;
			S_POOL:					next_state = S_POOL_FETCH_LOOP;
			S_POOL_FETCH_LOOP:		begin
				if ((curr_pool_stride_x_delayed == pool_size - 'd1) && (curr_pool_stride_y_delayed == pool_size - 'd1)) begin
					next_state = S_POOL_WRITE_BACK;
				end else begin
					next_state = S_POOL_FETCH_LOOP;
				end
			end
			S_POOL_WRITE_BACK:		next_state = S_POOL_UPDATE_POS;
			S_POOL_UPDATE_POS:		begin	
				if (((curr_pool_x * pool_size) == (num_cols - pool_size - 'd2)) && ((curr_pool_y * pool_size) == (num_rows - pool_size - 'd2))) begin
					next_state = S_POOL_DONE;
				end else begin
					next_state = S_POOL_FETCH_LOOP;
				end
			end
			S_POOL_DONE:			begin
				if (curr_dst_chan == num_dst_chans - 'd1 && curr_src_chan == num_src_chans - 'd1) begin
					next_state = S_LAYER_DONE;
				end else begin
					next_state = S_CONV;
				end
			end
			S_FC:					next_state = S_FC_LOOP;
			S_FC_LOOP:				next_state = S_FC_LOOP_WEIGHT;
			S_FC_LOOP_WEIGHT:		next_state = S_FC_LOOP_ACC;
			S_FC_LOOP_ACC:			begin
				if (curr_fc_w_row == num_src_chans - 'd1) begin
					next_state = S_FC_SET_BIAS_ADDR;
				end else begin
					next_state = S_FC_LOOP;
				end
			end
			S_FC_SET_BIAS_ADDR:		next_state = S_FC_BIAS;
			S_FC_BIAS:				next_state = S_FC_WRITE_BACK;
			S_FC_WRITE_BACK:		next_state = S_FC_UPDATE_POS;
			S_FC_UPDATE_POS:		begin
				if (curr_fc_w_col == num_dst_chans - 'd1) begin
					next_state = S_FC_DONE;
				end else begin
					next_state = S_FC_LOOP;
				end
			end
			S_FC_DONE:				next_state = S_LAYER_DONE;
			S_LAYER_DONE:			next_state = S_INIT;
			default:				next_state = S_INIT;
		endcase
	end

	always @(posedge clk) begin
		if (reset) begin
			curr_input_x <= 'd1;
			curr_input_y <= 'd1;
		end else begin
			case (curr_state) 
				S_LOAD_WEIGHT: begin
					curr_input_x <= 'd1;
					curr_input_y <= 'd1;
				end

				S_UPDATE_POS_AND_CH: begin
					curr_input_x <= (curr_input_x == (num_cols - 'd2)) ? 'd1 : (curr_input_x + 'd1);
					curr_input_y <= (curr_input_x == (num_cols - 'd2)) ? ((curr_input_y == (num_rows - 'd2)) ? 'd1 : (curr_input_y + 'd1)) : curr_input_y;
				end
			endcase
		end
	end

	logic [15:0] curr_kernel_pos;
	logic [15:0] curr_pixel_pos;

	always @(posedge clk) begin
		if (reset) begin
			curr_kernel_pos <= 'd0;
		end else begin
			if (curr_state == S_COMPUTE_POS) begin
				curr_kernel_pos <= curr_input_y * num_cols + curr_input_x; 
			end
		end
	end

	always @(posedge clk) begin
		if (reset) begin
			curr_src_chan <= 'd0;
			curr_dst_chan <= 'd0;
		end else begin
			case (curr_state)
				S_UPDATE_POS_AND_CH: begin
					if (~do_pool) begin
						if ((curr_input_x == (num_cols - 'd2)) && (curr_input_y == (num_rows - 'd2))) begin
							curr_src_chan <= (curr_src_chan == num_src_chans - 'd1) ? 'd0 : (curr_src_chan + 'd1);
							curr_dst_chan <= (curr_src_chan == num_src_chans - 'd1) ? ((curr_dst_chan == num_dst_chans - 'd1) ? 'd0 : (curr_dst_chan + 'd1)) : curr_dst_chan;
						end
					end else begin
						if (curr_src_chan != num_src_chans - 'd1) begin
							if ((curr_input_x == (num_cols - 'd2)) && (curr_input_y == (num_rows - 'd2))) begin
								curr_src_chan <= (curr_src_chan == num_src_chans - 'd1) ? 'd0 : (curr_src_chan + 'd1);
								curr_dst_chan <= (curr_src_chan == num_src_chans - 'd1) ? ((curr_dst_chan == num_dst_chans - 'd1) ? 'd0 : (curr_dst_chan + 'd1)) : curr_dst_chan;
							end
						end
					end
				end
				S_POOL_DONE: begin
					curr_src_chan <= (curr_src_chan == num_src_chans - 'd1) ? 'd0 : (curr_src_chan + 'd1);
					curr_dst_chan <= (curr_src_chan == num_src_chans - 'd1) ? ((curr_dst_chan == num_dst_chans - 'd1) ? 'd0 : (curr_dst_chan + 'd1)) : curr_dst_chan;
				end
			endcase
		end
	end

	always @(*)begin
		curr_pixel_pos = 'd0;
		if (curr_state == S_KERNEL_LOOP) begin
			case (cycle_cnt) 
				'd0: 	curr_pixel_pos = curr_kernel_pos - num_cols - 'd1;
				'd1: 	curr_pixel_pos = curr_kernel_pos - num_cols; 
				'd2: 	curr_pixel_pos = curr_kernel_pos - num_cols + 'd1; 
				'd3: 	curr_pixel_pos = curr_kernel_pos - 'd1; 
				'd4: 	curr_pixel_pos = curr_kernel_pos; 
				'd5: 	curr_pixel_pos = curr_kernel_pos + 'd1; 
				'd6: 	curr_pixel_pos = curr_kernel_pos + num_cols - 'd1; 
				'd7: 	curr_pixel_pos = curr_kernel_pos + num_cols; 
				'd8: 	curr_pixel_pos = curr_kernel_pos + num_cols + 'd1; 
			endcase
		end
	end
		
	//logic [31:0] kernel_acc;
	//logic [31:0] weight_mult_src;
	//logic [15:0] prev_src_result;

	//logic [15:0] mem_src_data;
	//logic [15:0] mem_weight_data;

	//always @(posedge clk) begin
	//	if (reset) begin
	//		kernel_acc <= 'd0;
	//	end else begin
	//		if (curr_state == S_COMPUTE_POS) kernel_acc <= 'd0; 
	//		else if (curr_state == S_KERNEL_LOOP) begin
	//			if (cycle_cnt < 'd9) kernel_acc <= kernel_acc + weight_mult_src;
	//			else if (cycle_cnt == 'd18) kernel_acc <= kernel_acc + prev_src_result;
	//		end
	//	end
	//end
		
	//assign weight_mult_src = mem_src_data * mem_weight_data;

	logic [18:0] mem_rd_addr;
	logic [18:0] mem_wr_addr;
	logic [15:0] mem_wr_data;

	assign mem_rd_addr_o = mem_rd ? mem_rd_addr : 'hBEEF;
	assign mem_wr_addr_o = mem_wr ? mem_wr_addr : 'hBEEF;
	assign mem_wr_data_o = mem_wr ? mem_wr_data : 'd0;

	//TODO OR this rd signal
	assign mem_rd = mem_weight_rd | mem_data_rd | mem_prev_acc_rd | mem_bias_rd | mem_pool_rd | mem_fc_rd;
	
	assign mem_wr = mem_output_wr | mem_pool_wr | mem_fc_wr;

	always @(*) begin
		mem_weight_rd = 0;
		case (curr_state)
			S_LOAD_WEIGHT: begin
				if (mem_weight_rd_addr < 'd9) mem_weight_rd = 1;
			end
		endcase
	end
	
	always @(posedge clk) begin	
		mem_weight_valid			<= mem_weight_rd;
		mem_weight_rd_addr_delayed	<= mem_weight_rd_addr;
	end

	always @(posedge clk) begin	
		if (reset) begin
			mem_weight_rd_addr <= 'hBEEF;
		end else begin
			case (curr_state) 
				S_SET_WEIGHT_ADDR:	mem_weight_rd_addr <= 'd0;
				S_LOAD_WEIGHT:		mem_weight_rd_addr <= (mem_weight_rd_addr == 'd8) ? 'hBEEF : (mem_weight_rd_addr + 'd1);
			endcase
		end
	end

	always @(posedge clk) begin
		if (mem_weight_valid) begin	
			weight_data[mem_weight_rd_addr_delayed] <= mem_rd_data_i;
		end
	end

	always @(*) begin
		mem_data_rd = 0;
		case (curr_state)
			S_KERNEL_LOOP: begin
				if (cycle_cnt < 'd9) mem_data_rd = 1;
			end
		endcase
	end

	always @(posedge clk) begin
		mem_data_valid				<= mem_data_rd;
		mem_data_rd_addr_delayed	<= cycle_cnt;
	end

	assign mem_data_rd_addr = curr_pixel_pos;

	logic signed [31:0] kernel_acc;
	logic signed [15:0] kernel_acc_real;

	assign kernel_acc_real = kernel_acc >>> 8;
	
	always @(posedge clk) begin
		if (reset) begin
			kernel_acc <= 'd0;
		end else begin
			case (curr_state)
				S_KERNEL_LOOP: begin
					if (mem_data_valid) begin
						kernel_acc <= kernel_acc + $signed(mem_rd_data_i) * $signed(weight_data[mem_data_rd_addr_delayed]);
					end
				end
				S_FETECH_PREV_SRC_CH: begin	
					if (curr_src_chan == 'd0) kernel_acc <= 'd0;
					else kernel_acc <= $signed(mem_rd_data_i) <<< 8;
				end
			endcase
		end
	end

	logic signed [31:0] result_to_write;

	logic signed [31:0] relu_result;

	assign relu_result = result_to_write + bias_data;

	always @(posedge clk) begin
		if (reset) begin
			result_to_write <= 'd0;
		end else begin
			case (curr_state)
				S_CHECK_CH:			result_to_write <= kernel_acc >>> 8;
				S_RELU:				result_to_write <= (~relu_result[31]) ? (relu_result) : 'd0;
			endcase
		end
	end

	always @(posedge clk) begin
		if (reset) begin
			curr_output_x <= 'd0;
			curr_output_y <= 'd0;
		end else begin
			case (curr_state) 
				S_CONV: begin
					curr_output_x <= 'd0;
					curr_output_y <= 'd0;
				end
				S_UPDATE_POS_AND_CH: begin
					curr_output_x <= (curr_output_x == num_cols - 'd3) ? 'd0 : (curr_output_x + 'd1);
					curr_output_y <= (curr_output_x == num_cols - 'd3) ? ((curr_output_y == num_rows - 'd3) ? 'd0 : (curr_output_y + 'd1)) : curr_output_y;
				end
			endcase
		end
	end


	always @(*) begin	
		mem_prev_acc_rd = 0;
		mem_prev_acc_rd_addr = 'hBEEF;
		case (curr_state) 
			S_SET_PREV_ADDR: begin 
				mem_prev_acc_rd			= 1;
				mem_prev_acc_rd_addr	= curr_output_y * (num_cols - 'd2) + curr_output_x + curr_dst_chan * (num_cols - 'd2) * (num_rows - 'd2);
			end
		endcase
	end

	always @(*) begin	
		mem_rd_addr = 'hBEEF;
		case (curr_state)
			S_SET_BIAS_ADDR: 		mem_rd_addr = mem_bias_rd_addr + bias_starting_addr;
			S_SET_PREV_ADDR:		mem_rd_addr = mem_prev_acc_rd_addr + output_starting_addr;
			S_LOAD_WEIGHT: 			mem_rd_addr = mem_weight_rd_addr + weight_starting_addr + 9 * curr_src_chan + 9 * num_src_chans * curr_dst_chan; 
			S_KERNEL_LOOP:			mem_rd_addr = curr_pixel_pos + data_starting_addr + num_cols * num_rows * curr_src_chan;
			S_POOL_FETCH_LOOP:		mem_rd_addr = mem_pool_rd_addr + output_starting_addr;
			S_FC_LOOP:				mem_rd_addr = mem_fc_rd_addr + data_starting_addr;
			S_FC_LOOP_WEIGHT:		mem_rd_addr = mem_fc_rd_addr + weight_starting_addr;
			S_FC_SET_BIAS_ADDR:		mem_rd_addr = mem_fc_rd_addr + bias_starting_addr;
		endcase
	end

	always @(*) begin
		mem_bias_rd = 0;
		mem_bias_rd_addr = 'hBEEF;
		case (curr_state)
			S_SET_BIAS_ADDR: begin
				mem_bias_rd			= 1;
				mem_bias_rd_addr	= curr_dst_chan;
			end
		endcase
	end

	always @(posedge clk) begin
		if (reset) begin
			bias_data <= 'd0;
		end else begin
			case (curr_state) 
				S_LOAD_BIAS: bias_data <= mem_rd_data_i;
			endcase
		end
	end

	always @(*) begin
		mem_output_wr = 0;
		mem_output_wr_addr = 'hBEEF;
		mem_output_wr_data = 'd0;
		case (curr_state)
			S_RESULT_WRITE_BACK: begin
				mem_output_wr = 1;
				mem_output_wr_addr = curr_output_y * (num_cols - 'd2) + curr_output_x + curr_dst_chan * (num_cols - 'd2) * (num_rows - 'd2);
				mem_output_wr_data = result_to_write[0 +: 16];
			end
		endcase
	end

	always @(*) begin
		mem_wr_addr = 'd0;
		mem_wr_data = 'd0;
		case (curr_state)
			S_RESULT_WRITE_BACK: begin
				mem_wr_addr	= mem_output_wr_addr + output_starting_addr;
				mem_wr_data	= mem_output_wr_data;
			end
			S_POOL_WRITE_BACK: begin	
				mem_wr_addr	= mem_pool_wr_addr + output_starting_addr;
				mem_wr_data	= mem_pool_wr_data;
			end
			S_FC_WRITE_BACK: begin
				mem_wr_addr = mem_fc_wr_addr + output_starting_addr;
				mem_wr_data = mem_fc_wr_data;
			end
		endcase
	end

	/********* POOL ********/

	always @(posedge clk) begin
		mem_pool_valid				<= mem_pool_rd;
		curr_pool_stride_x_delayed	<= curr_pool_stride_x;
		curr_pool_stride_y_delayed	<= curr_pool_stride_y;
	end

	always @(posedge clk) begin
		if (reset) begin	
			curr_pool_x <= 'd0;
			curr_pool_y <= 'd0;
		end else begin
			case (curr_state) 
				S_POOL: begin
					curr_pool_x <= 'd0;
					curr_pool_y <= 'd0;
				end
				S_POOL_UPDATE_POS: begin
					//if ((curr_pool_stride_x == pool_size - 'd1) && (curr_pool_stride_y == pool_size - 'd1)) begin
						curr_pool_x <= ((curr_pool_x * pool_size) == (num_cols - pool_size - 'd2)) ? 'd0 : (curr_pool_x + 'd1);
						curr_pool_y <= ((curr_pool_x * pool_size) == (num_cols - pool_size - 'd2)) ? (((curr_pool_y * pool_size) == (num_rows - pool_size - 'd2)) ? 'd0 : (curr_pool_y + 'd1)) : curr_pool_y;
					//end
				end
			endcase

		end
	end

	always @(posedge clk) begin
		if (reset) begin
			curr_pool_stride_x <= 'd0;
			curr_pool_stride_y <= 'd0;
		end else begin
			case (curr_state) 
				S_POOL_FETCH_LOOP: begin
					curr_pool_stride_x <= (curr_pool_stride_x == pool_size - 'd1) ? 'd0 : (curr_pool_stride_x + 'd1);
					curr_pool_stride_y <= (curr_pool_stride_x == pool_size - 'd1) ? ((curr_pool_stride_y == pool_size - 'd1) ? 'd0 : (curr_pool_stride_y + 'd1)) : curr_pool_stride_y;
				end
				S_POOL_UPDATE_POS: begin
					curr_pool_stride_x <= 'd0;
					curr_pool_stride_y <= 'd0;
				end
			endcase
		end
	end

	logic [15:0] pool_cnt_max;
	assign pool_cnt_max = pool_size * pool_size;

	logic [15:0] pool_cnt;

	always @(posedge clk) begin
		if (reset) begin
			pool_cnt <= 'd0;
		end else begin
			case (curr_state)
				S_POOL_FETCH_LOOP: pool_cnt <= pool_cnt + 'd1;
				S_POOL_UPDATE_POS: pool_cnt <= 'd0;
			endcase
		end
	end
	

	always @(*) begin
		mem_pool_rd_addr 	= 'd0;
		mem_pool_rd			= 0;
		case (curr_state)
			S_POOL_FETCH_LOOP: begin
				if (pool_cnt < pool_cnt_max) begin
					mem_pool_rd_addr = (curr_pool_y * pool_size + curr_pool_stride_y) * (num_cols - 'd2) + (curr_pool_x * pool_size + curr_pool_stride_x) + curr_dst_chan * (num_cols - 'd2) * (num_rows - 'd2);
					mem_pool_rd		 = 1;
				end
			end
		endcase
	end

	always @(posedge clk) begin
		if (reset) begin
			curr_pool_max <= 'h8000;
		end else begin
			case (curr_state) 
				S_POOL: 			curr_pool_max <= 'h8000;
				S_POOL_FETCH_LOOP: 	begin
					if (mem_pool_valid && $signed(mem_rd_data_i) > curr_pool_max) begin
						curr_pool_max <= $signed(mem_rd_data_i);
					end
				end
				S_POOL_UPDATE_POS:	curr_pool_max <= 'h8000;
			endcase
		end
	end

	always @(*) begin
		mem_pool_wr_addr	= 'd0;
		mem_pool_wr			= 0;
		mem_pool_wr_data	= 'd0;
		case (curr_state)
			S_POOL_WRITE_BACK: begin
				if (pool_size == 'd2) begin
					mem_pool_wr_addr 	= (curr_pool_y) * ((num_cols - 'd2) >>> 1) + (curr_pool_x);
				end else begin
					mem_pool_wr_addr 	= (curr_pool_y) * ((num_cols - 'd2) >>> 3) + (curr_pool_x);
				end
				mem_pool_wr			= 1;
				mem_pool_wr_data	= curr_pool_max;
			end
		endcase
	end

	/******** FC *********/
	always @(posedge clk) begin
		if (reset) begin
			curr_fc_w_row <= 'd0;
			curr_fc_w_col <= 'd0;
		end else begin
			case (curr_state)
				S_FC_LOOP_ACC: begin
					curr_fc_w_row <= (curr_fc_w_row == num_src_chans - 'd1) ? 'd0 : (curr_fc_w_row + 'd1);
				end
				S_FC_UPDATE_POS: begin
					curr_fc_w_col <= (curr_fc_w_col == num_dst_chans - 'd1) ? 'd0 : (curr_fc_w_col + 'd1);
				end
			endcase
		end
	end

	always @(*) begin
		mem_fc_rd_addr		= 'd0;
		mem_fc_rd			= 0;
		case (curr_state) 
			S_FC_LOOP:		  	begin
				mem_fc_rd_addr  = curr_fc_w_row;
				mem_fc_rd		= 1;
			end
			S_FC_LOOP_WEIGHT:	begin
				mem_fc_rd_addr 	= curr_fc_w_row * num_dst_chans + curr_fc_w_col;
				mem_fc_rd		= 1;
			end
			S_FC_SET_BIAS_ADDR: begin
				mem_fc_rd_addr  = curr_fc_w_col;
				mem_fc_rd		= 1;
			end
			default:			begin
				mem_fc_rd_addr 	= 'd0;
				mem_fc_rd		= 0;
			end
		endcase
	end

	logic signed [31:0] fc_acc;
	logic signed [15:0] fc_data;

	always @(posedge clk) begin
		if (reset) begin
			fc_acc 	<= 'd0;
			fc_data <= 'd0;
		end else begin
			case (curr_state)
				S_FC: 				fc_acc <= 'd0;
				S_FC_LOOP_WEIGHT: 	fc_data <= mem_rd_data_i;
				S_FC_LOOP_ACC:		fc_acc <= fc_acc + $signed(fc_data) * $signed(mem_rd_data_i);
				S_FC_BIAS:			fc_acc <= fc_acc + ($signed(mem_rd_data_i) <<< 8);
				S_FC_UPDATE_POS:	fc_acc <= 'd0;	
			endcase
		end
	end

	always @(*) begin
		mem_fc_wr_addr		= 'd0;
		mem_fc_wr			= 0;
		mem_fc_wr_data		= 'd0;
		case (curr_state)
			S_FC_WRITE_BACK:	begin
				mem_fc_wr_addr	= curr_fc_w_col;
				mem_fc_wr		= 1;
				mem_fc_wr_data	= fc_acc >>> 8;
			end
		endcase
	end
endmodule
