module convOpt(
	input	logic							clk, reset,

	// Input from pipeline
	input	logic							in_data_ready,
	input	logic	[ 5: 0]					in_data_dim,
	input	logic	[(3*3+1)*16-1: 0]		weight_bias,

	// Request FM from pipeline
	output	logic							start_fm,
	output	logic	[15: 0]					conv_idx,
	
	input	logic							finish_fm,
	input	logic	[3*3*16-1: 0]			feat_map_in,

	input	logic							finish_out,

	// Output to pipeline
	output	logic							start_out,
	output	logic	[16: 0]					out_idx,
	output	logic	[3*3*16-1: 0]			feat_map_out,

	// Output to HPS
	output	logic							finish,

	// Debug
	output	logic	[ 2: 0]					debug_state
);

	// FSM
	logic    [ 2: 0]       state;
	assign debug_state = state;
	parameter S0 = 0, S1 = 1, S2 = 2, S3 = 3, S4 = 4, S5 = 5, S6 = 6, S7 = 7;

	// fp16 multiplier
	logic    [15:0]        mult_idx;
	logic    [9*16-1:0]    mult_1;
	logic    [9*16-1:0]    mult_2;
	logic    [9*16-1:0]    mult_o;

	assign mult_2 = weight_bias[3*3*16-1: 0];

	// fp16 adder
	logic    [ 3: 0]     add_idx;
	logic    [16-1:0]    add_1;
	logic    [16-1:0]    add_2;
	logic    [16-1:0]    add_o;

	genvar i;
	generate
		for (i = 0; i < 9; i = i + 1) begin : gen_fpu_for_each_element_in_kernel
			float_multi fp_mult(
				.num1        (mult_1[i*16 +: 16]),
				.num2        (mult_2[i*16 +: 16]),
				.result      (mult_o[i*16 +: 16])
			);
		end
	endgenerate

	float_adder fp_add(
		.num1        (add_1),
		.num2        (add_2),
		.result        (add_o)
	);


	always @ (posedge clk) begin
		case (state)
			S0:    // Reset, wait for data ready
				begin
					finish = 0;
					mult_idx = 0;
					add_idx = 0;
				end
			S1:    // Request OCM
				begin
					conv_idx = mult_idx;
					start_fm = 1;

					mult_idx = mult_idx + 1;
					add_idx = 0;
				end
			S2:    // Parallel convolution
				begin
					if (finish_fm)
						mult_1 = feat_map_in;
				end
			S3:    // Add state 1
				begin
					if (add_idx == 0)
						add_1 = 0;
					add_2 = mult_o[add_idx*16 +: 16];
					add_idx = add_idx + 1;
				end
			S4:    // Add state 2
				begin
					add_1 = add_o;
				end
			S5:    // Save result to feat_map_out
				begin
					feat_map_out = add_o + weight_bias[3*3*16 +: 16];
					out_idx = (mult_idx-1)*2*9;
					start_out = 1;
				end
			S6:
				begin
					if(finish_out)
						start_out = 0;
				end
			S7:    // Finish
				begin
					finish = 1;
				end
		endcase
	end
	
	// Determine the next state
	always @ (posedge clk or posedge reset) begin
		if (reset)
			begin
				state <= S0;
			end
		else
			case (state)
				S0:    // Reset, wait for data ready
					if(in_data_ready)
						state <= S1;
				S1:    // Retrieve FM from OCM
					state <= S2;
				S2:	   // Conv
					if(finish_fm)
						state <= S3;
				S3:    // Add 9 times (1)
					if(add_idx >= 8)
						state <= S5;
					else
						state <= S4;
				S4:    // Add 9 times (2): Change input for next add
					state <= S3;
				S5:    // Save result to feat_map_out
					state <= S6;
				S6:
					if(finish_out) begin
						if(mult_idx >= in_data_dim*in_data_dim)
							state <= S7;
						else
							state <= S1;
					end
				S7:
					if(in_data_ready == 0)
						state <= S0;
			endcase
	end
endmodule
