module conv_state_machine(input logic clk,
			  	input logic rst_n,
                          //input from top level
			  	input logic conv_state_machine_en,
			  	input logic [9:0]input_channel_in,
            	input logic [9:0]output_channel_in,
                input logic [31:0]weight_start_address_in,
			  	input logic [31:0]bias_start_address_in,
                input logic [31:0]input_start_address_in,
		    	input logic [9:0]input_size_in,
                input logic [9:0]output_size_in,
                //interface to dma state machine
                output logic dma_en,
                output logic [31:0]dma_read_addr,
                output logic [31:0]dma_write_addr,
                output logic [31:0]dma_length,
                input logic dma_done,
			  	//read input master control
			  	output logic [10:0]addr_input_pe,
			  	output logic input_master_en,
			  	//read weight master control
			  	output logic [16:0]addr_weight_pe,
			  	output logic weight_master_en,
			  	//write output master control
			  	output logic [14:0]addr_write_pe,
			  	output logic output_master_en,
			  	//output to pe array
			  	output logic rst_n_pe,//pe_rst_n could  also be regarded as a work enable
                //output to write addr buffer
                output logic [16:0]write_addr_buffer_out
			  );
	parameter write_buffer_base=32'h09000000;
	parameter input_onchip_base=32'h08000000;
	parameter output_onchip_base=32'h08000000;
	enum { 	init,
			load_weight,
       		load_bias,
       		load_input,
       		cal_conv,
       		wait_work_done,
       		write_output,
			idle
     	} conv_sub_state;
	logic [9:0]input_channel;
	logic [9:0]output_channel;
	logic [31:0]bias_start_address;
	logic [31:0]weight_start_address;
	logic [31:0]input_start_address;
	logic [9:0]input_size;
	logic [9:0]output_size;
	logic [4:0]input_loop;//input_channel/64
	logic [4:0]output_loop;//output_channel/64
	logic [9:0]dma_weight_loop;//buffer only support loading 128byte per dma transmitiion,so we need loop to do all transmition
	logic [9:0]dma_input_channel;//3 or 64 in our convolution operation
	logic [9:0]dma_bias_loop;//same as weight loop
	logic [31:0]bias_ram_pointer;//record the write address for us to read bias and write back partial sum
	logic [31:0]input_ram_pointer;
	logic [10:0]input_onchip_pointer;
	logic [31:0]weight_ram_pointer;
	logic [9:0]dma_loop_counter;
	logic [4:0]input_loop_counter;
	logic [4:0]output_loop_counter;//seems not needed in the process,just use to make it same with input loop
	logic [9:0]conv_line_counter;
	logic [16:0]write_addr_buffer;//need to be changed according to weight on chip ram size
	logic [16:0]write_addr_buffer_bias_base;//onchip base of bias
	logic [9:0]conv_timer;//to count the time for work
	logic [9:0]cal_conv_counter;//to count how many lines that we finish
	logic [16:0]bias_onchip_pointer;
	logic [10:0]addr_input_pe_in;
	logic [16:0]addr_weight_pe_in;
	logic [14:0]addr_write_pe_in;
	logic dma_en_pulse;
always_ff@(posedge clk or negedge rst_n)
begin
  	if (rst_n==0)
  	begin
		dma_en<=0;
		conv_sub_state<=idle;
  	end
	else
  	begin
    	case(conv_sub_state)
     	init:
    	begin
      		bias_ram_pointer<=bias_start_address;
		//pointers and counters are used in controlling address and loop,also can be refreshed in work_done state
      		input_ram_pointer<=input_start_address;
      		weight_ram_pointer<=weight_start_address;
      		dma_loop_counter<=dma_weight_loop;
      		//dma_bias_loop_counter<=dma_bias_loop;
      		write_addr_buffer<=0;//onchip weight memory address,since the only master of weight on chip memory is the s to p buffer,we can make it beginning address 0
      		output_loop_counter<=output_loop;//control conv loop;
      		input_loop_counter<=input_loop;
      		conv_sub_state<=load_weight;
	end
    	load_weight:
    	begin
      		if(dma_loop_counter==0&&dma_done==1)//load weight done
      		begin
        		conv_sub_state<=load_bias;
        		dma_read_addr<=bias_ram_pointer;
			dma_write_addr<=write_buffer_base;
			dma_length=128;
			dma_loop_counter<=dma_bias_loop;
        		//write_addr_buffer_bias_base<=write_addr_buffer;	//base onchip address for every line of bias
			conv_line_counter<=output_size;
                        bias_onchip_pointer<=write_addr_buffer;//beginning onchip address for us to load bias
      		end
      		else if(dma_done==1&&dma_en_pulse==0)
      		begin
			dma_loop_counter<=dma_loop_counter-1;
        		dma_en<=1;
			dma_en_pulse<=1;
        		dma_read_addr<=weight_ram_pointer;
        		dma_write_addr<=write_buffer_base;//base address of our series to parallel buffer
        		dma_length<=128;
        		write_addr_buffer_out<=write_addr_buffer;//address of weight onchip memory 
        		weight_ram_pointer<=weight_ram_pointer+128;
        		write_addr_buffer<=write_addr_buffer+128;
      		end
      		else //wait dma done
      		begin
        		dma_en<=0;
			dma_en_pulse<=0;
      		end
    	end
    	load_bias:
    	begin
      		if(dma_loop_counter==0&&dma_done==1)
      		begin
			conv_sub_state<=load_input;
        	dma_write_addr<=input_onchip_base;//the beginning address of input onchip memory
			dma_loop_counter<=3;
			//At every beginning of a line,we need to store input that enough for one output,since we read 3*output channel
			dma_read_addr<=input_ram_pointer;
			cal_conv_counter <= output_size;
			input_onchip_pointer<=input_onchip_base;
			addr_write_pe_in<=0;
		end
      		else if(dma_done==1&&dma_en_pulse==0)
      		begin
			dma_loop_counter<=dma_loop_counter-1;
			dma_en<=1;
			dma_en_pulse<=1;
			dma_read_addr<=bias_ram_pointer;	
			dma_write_addr<=write_buffer_base;//write_buffer_base
			dma_length<=128;
			write_addr_buffer_out<=write_addr_buffer;
			bias_ram_pointer<=bias_ram_pointer+128;
			write_addr_buffer<=write_addr_buffer+128;
      		end
      		else 
      		begin
			dma_en<=0;
			dma_en_pulse<=0;
      		end
    	end
    	load_input:
    	begin
      		if(dma_loop_counter==0&&dma_done==1)//finish loading input
      		begin
			conv_sub_state<=cal_conv;
			//dma_loop_counter<=3;
			//input_ram_pointer<=input_ram_pointer-3*input_channel*2*input_size+input_channel*2;
			//dma_write_addr<=0;//MM-stream buffer;
			input_onchip_pointer<=input_onchip_base;
			//dma_length<=dma_input_channel;
        	conv_timer<=0;
			addr_weight_pe_in<=0;
			addr_input_pe_in<=0;
      		end
      		else if(dma_done==1&&dma_en_pulse==0)
      		begin
			dma_loop_counter<=dma_loop_counter-1;
			dma_en<=1;
			dma_en_pulse<=1;
			dma_read_addr<=input_ram_pointer;
			dma_write_addr<=input_onchip_pointer;//input onchip base
			dma_length<=3*dma_input_channel*2; 
        	input_ram_pointer<=input_ram_pointer+input_size*dma_input_channel*2;//max to 64 input channel
        	input_onchip_pointer<=dma_write_addr+3*dma_input_channel*2;
                        //we use buffer to change MM to stream,write address can always be the address of input MM-Stream buffer
      		end
      		else 
      		begin
			dma_en<=0;
			dma_en_pulse<=0;
      		end
    	end
	cal_conv:
	begin
		if(conv_timer>(9*dma_input_channel+4))//finished one output
		begin
			if(cal_conv_counter==1)
			begin
				conv_sub_state<=write_output;
				dma_loop_counter<=1;
			end
			else if(cal_conv_counter!=0)
			begin
				conv_sub_state<=load_input;
				dma_loop_counter<=3;
				cal_conv_counter<=cal_conv_counter-1;
				input_ram_pointer<=input_ram_pointer-3*dma_input_channel*2*input_size+dma_input_channel*2;

			end
		end
		else 
		begin
      		if(conv_timer<9*dma_input_channel)//all timing order need to be adjusted;
			begin
        		input_master_en<=1;
				addr_input_pe<=addr_input_pe_in;
				addr_input_pe_in<=addr_input_pe_in+2;
				weight_master_en<=1;
				addr_weight_pe<=addr_weight_pe_in;
				addr_weight_pe_in<=addr_weight_pe_in+128;
				rst_n_pe<=1;
      			end
			else if(conv_timer==9*dma_input_channel)
			begin
				input_master_en<=0;
			end
			else if(conv_timer==9*dma_input_channel+1)
			begin
				weight_master_en<=0;
			end
			else if(conv_timer==9*dma_input_channel+2)
			begin
			end
			else if(conv_timer==9*dma_input_channel+3)
			begin
				output_master_en<=1;
				addr_write_pe<=addr_write_pe_in;
				addr_write_pe_in<=addr_write_pe_in+128;
			end
			else if(conv_timer==9*dma_input_channel+4)
			begin
				output_master_en<=0;
				rst_n_pe<=0;
			end
		end
	end
    	write_output:
    	begin
      		if(dma_loop_counter==1&&dma_done==1)//output one line
      		begin
        		dma_read_addr<=output_onchip_base;//output onchip memory base
        		dma_write_addr<=bias_ram_pointer-128*dma_bias_loop;
        		dma_length<=64*output_size*2;
        		dma_en<=1;
			dma_en_pulse<=1;
			conv_line_counter<=conv_line_counter-1;
			dma_loop_counter<=dma_loop_counter-1;
      		end
      		else if(dma_done==1&&dma_en_pulse==0)
      		begin
			if(conv_line_counter==0)//finish one output line
			begin
	  			if(input_loop_counter==1&&output_loop_counter==1)//we finish all loops in the conv layer
	  			begin
            				conv_sub_state<=idle;
	  			end
          			else if(output_loop_counter==1)//we only finish loops for calculating 64 channle output
	  			begin
	    				output_loop_counter<=output_loop;
	    				input_loop_counter<=input_loop_counter-1;
					bias_start_address<=bias_ram_pointer;
					input_ram_pointer<=input_start_address;
					conv_sub_state<=load_weight;
         			end
	  			else
	  			begin//we still in loop of calculating oe batch of output partial sum
	    				output_loop_counter<=output_loop_counter-1;
					bias_ram_pointer<=bias_start_address;
					conv_sub_state<=load_weight;
					input_ram_pointer<=input_start_address+(input_loop-input_loop_counter)*64*input_size*input_size;
	  			end
			end
        		else if(conv_line_counter>0)//load next line
			begin
	  			conv_sub_state<=load_bias;
				dma_read_addr<=bias_ram_pointer;
				dma_write_addr<=write_addr_buffer_bias_base;
				input_ram_pointer<=input_ram_pointer-3*dma_input_channel*input_size*2+3*dma_input_channel*2;
				dma_length<=128;
				dma_loop_counter<=dma_bias_loop;
			end
      		end
      		else if(dma_done==0)
      		begin
			dma_en<=0;
			dma_en_pulse<=0;
      		end
    	end  // end write_ouput
	idle:
	begin
		if(conv_state_machine_en==0)
		begin
			dma_en<=0;
			rst_n_pe<=0;
		end
  		else if(conv_state_machine_en==1)//we set conv_state_machine_en to start every conv layer computation
  		begin
    			input_channel<=input_channel_in;   //we need to use register to store channel and address information 
    			output_channel<=output_channel_in;
    			bias_start_address<=bias_start_address_in;
    			input_start_address<=input_start_address_in;
    			weight_start_address<=weight_start_address_in;
    			input_size<=input_size_in;
    			output_size<=output_size_in;
    			if(input_channel_in<64)//we support 64 input channel calculation,input channel more than 64 need to operate in loop
      				input_loop<=1;
    			else
      				input_loop<=input_channel_in/64;
    			if(output_channel_in<64)
			//we only support 64 output channels to be calculated simultaneously,output channel more than 64 need to operate in loop
      				output_loop<=1;
    			else
      				output_loop<=output_channel_in/64; 
    			if (input_channel_in>64)//weight loading size can be up to 64*64*9*2
    			begin
      				dma_weight_loop<=9*64;
      				dma_input_channel<=64;
    			end
    			else
    			begin
      				dma_weight_loop<=9*input_channel_in;
      				dma_input_channel<=input_channel_in;
    			end
    			dma_bias_loop<=output_size_in;//every time in bias loading state,we only load one line of bias
    			conv_sub_state<=init;
		end
	end// end idle

    	endcase
       	end
end
endmodule


			
		
