module conv(
    input     logic                         clk, reset,

    // Input from pipeline
    input    logic                          in_data_ready,
    input    logic    [ 5: 0]               in_data_dim,
    input    logic    [32*32*16-1: 0]       feat_map_in,
    input    logic    [(3*3+1)*16-1: 0]     weight_bias,

    // Output to pipeline
    output   logic                          out_data_ready,
    output   logic    [32*32*16-1: 0]       feat_map_out,

    // Debug
    output   logic    [ 2: 0]               debug_state
);

    // FSM
    logic    [ 2: 0]       state;
	assign debug_state = state;
    parameter S0 = 0, S1 = 1, S2 = 2, S3 = 3, S4 = 4, S5 = 5, S6 = 6;

    // fp16 multiplier
    logic    [15:0]        mult_idx;
    logic    [2:0]         zero_padding_mask;
    logic    [9*16-1:0]    mult_1;
    logic    [9*16-1:0]    mult_2;
    logic    [9*16-1:0]    mult_o;

    assign mult_2 = weight_bias[3*3*16-1: 0];

    // fp16 adder
    logic    [ 3: 0]     add_idx;
    logic    [16-1:0]    add_1;
    logic    [16-1:0]    add_2;
    logic    [16-1:0]    add_o;

    genvar i;
    generate
        for (i = 0; i < 9; i = i + 1) begin : gen_fpu_for_each_element_in_kernel
            float_multi fp_mult(
                .num1        (mult_1[i*16 +: 16]),
                .num2        (mult_2[i*16 +: 16]),
                .result        (mult_o[i*16 +: 16])
            );
        end
    endgenerate

    float_adder fp_add(
        .num1        (add_1),
        .num2        (add_2),
        .result        (add_o)
    );

    // 0 Padding checks
    logic left_col_padding;
    logic right_col_padding;

    always_comb begin
        case (in_data_dim)
            32: 
                begin
                    zero_padding_mask = 5;
                    left_col_padding = (mult_idx[0 +: 5] == 0)? 1 : 0;
                    right_col_padding = (mult_idx[0 +: 5] == in_data_dim-1)? 1 : 0;
                end
            16: 
                begin
                    zero_padding_mask = 4;
                    left_col_padding = (mult_idx[0 +: 4] == 0)? 1 : 0;
                    right_col_padding = (mult_idx[0 +: 4] == in_data_dim-1)? 1 : 0;
                end
            8: 
                begin
                    zero_padding_mask = 3;
                    left_col_padding = (mult_idx[0 +: 3] == 0)? 1 : 0;
                    right_col_padding = (mult_idx[0 +: 3] == in_data_dim-1)? 1 : 0;
                end
            4: 
                begin
                    zero_padding_mask = 2;
                    left_col_padding = (mult_idx[0 +: 2] == 0)? 1 : 0;
                    right_col_padding = (mult_idx[0 +: 2] == in_data_dim-1)? 1 : 0;
                end
            2: 
                begin
                    zero_padding_mask = 1;
                    left_col_padding = (mult_idx[0 +: 1] == 0)? 1 : 0;
                    right_col_padding = (mult_idx[0 +: 1] == in_data_dim-1)? 1 : 0;
                end
            default:
                begin
                    // ERROR
                    zero_padding_mask = 0;
                    left_col_padding = 0;
                    right_col_padding = 0;
                end
		endcase
    end

    always @ (posedge clk) begin
        case (state)
            S0:    // Reset, wait for data ready
                begin
                    out_data_ready = 0;
                end
            S1:    // Prepare
                begin
                    out_data_ready = 0;
                    mult_idx = 0;
                    add_idx = 0;
                end
            S2:    // Convolution
                begin
                    if (mult_idx == 0)
                        begin
                            mult_1 = {feat_map_in[mult_idx*16+in_data_dim*16+16 +: 16],
                            feat_map_in[mult_idx*16+in_data_dim*16 +: 16],
                            16'b0,
                            feat_map_in[mult_idx*16+16 +: 16],
                            feat_map_in[mult_idx*16 +: 16],
                            16'b0,
                            16'b0,
                            16'b0,
                            16'b0};
                        end
                    else if (mult_idx == in_data_dim-1)
                        begin
                            mult_1 = {16'b0,
                            feat_map_in[mult_idx*16+in_data_dim*16 +: 16],
                            feat_map_in[mult_idx*16+in_data_dim*16-16 +: 16],
                            16'b0,
                            feat_map_in[mult_idx*16 +: 16],
                            feat_map_in[mult_idx*16-16 +: 16],
                            16'b0,
                            16'b0,
                            16'b0};
                        end
                    else if (mult_idx == in_data_dim*(in_data_dim-1))
                        begin
                            mult_1 = {16'b0,
                            16'b0,
                            16'b0,
                            feat_map_in[mult_idx*16+16 +: 16],
                            feat_map_in[mult_idx*16 +: 16],
                            16'b0,
                            feat_map_in[mult_idx*16-in_data_dim*16+16 +: 16],
                            feat_map_in[mult_idx*16-in_data_dim*16 +: 16],
                            16'b0};
                        end
                    else if (mult_idx == in_data_dim*in_data_dim-1)
                        begin
                            mult_1 = {16'b0,
                            16'b0,
                            16'b0,
                            16'b0,
                            feat_map_in[mult_idx*16 +: 16],
                            feat_map_in[mult_idx*16-16 +: 16],
                            16'b0,
                            feat_map_in[mult_idx*16-in_data_dim*16 +: 16],
                            feat_map_in[mult_idx*16-in_data_dim*16-16 +: 16]};
                        end
                    else if (left_col_padding)
                        begin
                            mult_1 = {feat_map_in[mult_idx*16+in_data_dim*16+16 +: 16],
                            feat_map_in[mult_idx*16+in_data_dim*16 +: 16],
                            16'b0,
                            feat_map_in[mult_idx*16+16 +: 16],
                            feat_map_in[mult_idx*16 +: 16],
                            16'b0,
                            feat_map_in[mult_idx*16-in_data_dim*16+16 +: 16],
                            feat_map_in[mult_idx*16-in_data_dim*16 +: 16],
                            16'b0};
                        end
                    else if (right_col_padding)
                        begin
                            mult_1 = {16'b0,
                            feat_map_in[mult_idx*16+in_data_dim*16 +: 16],
                            feat_map_in[mult_idx*16+in_data_dim*16-16 +: 16],
                            16'b0,
                            feat_map_in[mult_idx*16 +: 16],
                            feat_map_in[mult_idx*16-16 +: 16],
                            16'b0,
                            feat_map_in[mult_idx*16-in_data_dim*16 +: 16],
                            feat_map_in[mult_idx*16-in_data_dim*16-16 +: 16]};
                        end
                    else if (mult_idx < in_data_dim-1)
                        begin
                            mult_1 = {feat_map_in[mult_idx*16+in_data_dim*16+16 +: 16],
                            feat_map_in[mult_idx*16+in_data_dim*16 +: 16],
                            feat_map_in[mult_idx*16+in_data_dim*16-16 +: 16],
                            feat_map_in[mult_idx*16+16 +: 16],
                            feat_map_in[mult_idx*16 +: 16],
                            feat_map_in[mult_idx*16-16 +: 16],
                            16'b0,
                            16'b0,
                            16'b0};
                        end
                    else if (mult_idx > in_data_dim*(in_data_dim-1))
                        begin
                            mult_1 = {16'b0,
                            16'b0,
                            16'b0,
                            feat_map_in[mult_idx*16+16 +: 16],
                            feat_map_in[mult_idx*16 +: 16],
                            feat_map_in[mult_idx*16-16 +: 16],
                            feat_map_in[mult_idx*16-in_data_dim*16+16 +: 16],
                            feat_map_in[mult_idx*16-in_data_dim*16 +: 16],
                            feat_map_in[mult_idx*16-in_data_dim*16-16 +: 16]};
                        end
                    else
                        begin
                            mult_1 = {feat_map_in[mult_idx*16+in_data_dim*16+16 +: 16],
                            feat_map_in[mult_idx*16+in_data_dim*16 +: 16],
                            feat_map_in[mult_idx*16+in_data_dim*16-16 +: 16],
                            feat_map_in[mult_idx*16+16 +: 16],
                            feat_map_in[mult_idx*16 +: 16],
                            feat_map_in[mult_idx*16-16 +: 16],
                            feat_map_in[mult_idx*16-in_data_dim*16+16 +: 16],
                            feat_map_in[mult_idx*16-in_data_dim*16 +: 16],
                            feat_map_in[mult_idx*16-in_data_dim*16-16 +: 16]};
                        end
                    mult_idx = mult_idx + 1;
                    add_idx = 0;
                end
            S3:    // Add state 1
                begin
                    if (add_idx == 0)
                        add_1 = 0;
                    add_2 = mult_o[add_idx*16 +: 16];
                    add_idx = add_idx + 1;
                end
            S4:    // Add state 2
                begin
                    add_1 = add_o;
                end
            S5:    // Save result to feat_map_out
                begin
                    feat_map_out[(mult_idx-1)*16 +: 16] = add_o + weight_bias[3*3*16 +: 16];
                end
            S6:    // Finish
                begin
                    out_data_ready = 1;
                end
        endcase
    end
    
    // Determine the next state
    always @ (posedge clk or posedge reset) begin
        if (reset)
            begin
                state <= S0;
            end
        else
            case (state)
                S0:    // Reset, wait for data ready
                    if(in_data_ready)
                        state <= S1;
                S1:    // Prepare
                    state <= S2;
                S2:    // Convolution
                    state <= S3;
                S3:    // Add 9 times
                    if(add_idx >= 8)
                        state <= S5;
                    else
                        state <= S4;
                S4:    // Change input for next add
                    state <= S3;
                S5:    // Save result to feat_map_out
                    if(mult_idx >= in_data_dim*in_data_dim)
                        state <= S6;
                    else
                        state <= S2;
                S6:
                    if(in_data_ready == 0)
                        state <= S0;
            endcase
    end
endmodule
