#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <math.h>
#include <string.h>
#include "operations.h"
#include "stft.h"
#include "vocoder.h"

#define PI 3.14159265358979323846
#define PI_FIXED ((int16_t)(3.14159265358979323846 * (1 << 15)))
#define FRAME_SIZE 1024
#define HOP_SIZE 32
#define FIXED_POINT_SHIFT 16
#define FFT_SIZE_IN_BITS 10
#define FRACT_BITS FIXED_POINT_SHIFT-1
#define FRACT_SCALE (1 << FRACT_BITS)
#define FACTOR 2

typedef int16_t sample_t;

int16_t* linear(const int16_t* x, size_t n, int16_t factor, size_t* m);
int16_t argmax(int16_t *arr, size_t size);
void take_along_axis(double *arr_in, double *ind_in, double *arr_out, int n, int m, int k);
void clip(int *arr, int size, int min_val, int max_val);

int main(int argc, char *argv[]) {
    if (argc < 2) {
        fprintf(stderr, "Usage: %s input_file.wav\n", argv[0]);
        return 1;
    }
    
    const char *input_file = argv[1];
    FILE *fp = fopen(input_file, "rb");
    if (!fp) {
        fprintf(stderr, "Failed to open file: %s\n", input_file);
        return 1;
    }

    // Read WAV header
    char riff[4];
    uint32_t file_size;
    char wave[4];
    char fmt[4];
    uint32_t fmt_size;
    uint16_t audio_format;
    uint16_t num_channels;
    uint32_t sample_rate;
    uint32_t byte_rate;
    uint16_t block_align;
    uint16_t bits_per_sample;
    char data[4];
    uint32_t data_size;
    
    fread(riff, 1, 4, fp);
    fread(&file_size, 4, 1, fp);
    fread(wave, 1, 4, fp);
    fread(fmt, 1, 4, fp);
    fread(&fmt_size, 4, 1, fp);
    fread(&audio_format, 2, 1, fp);
    fread(&num_channels, 2, 1, fp);
    fread(&sample_rate, 4, 1, fp);
    fread(&byte_rate, 4, 1, fp);
    fread(&block_align, 2, 1, fp);
    fread(&bits_per_sample, 2, 1, fp);
    fread(data, 1, 4, fp);
    fread(&data_size, 4, 1, fp);

    //CONSTANTS
    // Calculate number of audio samples
    uint32_t num_samples = data_size / block_align;

    printf("%d", num_samples);

    //Calculate number of frames for frequency domain
    uint16_t num_bins = (num_samples - FRAME_SIZE) / HOP_SIZE + 1;

    // Allocate memory for fixed point array
    int16_t *fixed_point_array = malloc(num_samples * sizeof(sample_t));
    if (!fixed_point_array) {
        fprintf(stderr, "Failed to allocate memory\n");
        return 1;
    }

    // Read audio data and convert to fixed point
    for (uint32_t i = 0; i < num_samples; i++) {
        sample_t sample;
        fread(&sample, sizeof(sample_t), 1, fp);
        fixed_point_array[i] = sample;
    }

    fclose(fp);

    int16_t** frames = sliding_window(fixed_point_array, num_samples, FRAME_SIZE, HOP_SIZE);

    free(fixed_point_array);

    complex_fixed* complex_array = (complex_fixed*)malloc(((FRAME_SIZE/2)+1) * sizeof(complex_fixed));
    complex_fixed* encoded_array = (complex_fixed*)malloc(((FRAME_SIZE/2)+1) * sizeof(complex_fixed));

    for (uint16_t i = 0; i < num_bins; i++) {
        apply_hann_window_forward(frames[i]);
        fix_fftr(frames[i],FFT_SIZE_IN_BITS,0);
        for (u_int16_t j = 0; j < FRAME_SIZE/2; j++){
            complex_array[j].real = frames[i][j];
            complex_array[j].imag = frames[i][j+ (FRAME_SIZE/2)];
        }
        encode(complex_array, FRAME_SIZE, HOP_SIZE, sample_rate, encoded_array);
    }

    free(frames);
    free(complex_array);
    free(encoded_array);

    return 0;
}

int16_t* linear(const int16_t* x, size_t n, int16_t factor, size_t* m) {
    if (factor == 1.0) {
        *m = n;
        int16_t* y = malloc(n * sizeof(int16_t));
        if (!y) {
            return NULL;
        }
        for (size_t i = 0; i < n; i++) {
            y[i] = x[i];
        }
        return y;
    }

    *m = (size_t) (n * factor);
    int16_t* y = malloc((*m) * sizeof(int16_t));
    if (!y) {
        return NULL;
    }

    float q = (float) n / (*m);
    for (size_t i = 0; i < fmin(n, *m); i++) {
        float k = i * q;
        size_t j = (size_t) floorf(k);
        k -= j;

        if (j < n - 1) {
            y[i] = (int16_t) roundf((1.0f - k) * x[j] + k * x[j+1]);
        } else {
            y[i] = x[j];
        }
    }
    return y;
}

int16_t argmax(int16_t *arr, size_t size) {
    int16_t max_val = arr[0];
    int16_t max_idx = 0;
    for (size_t i = 1; i < size; i++) {
        if (arr[i] > max_val) {
            max_val = arr[i];
            max_idx = i;
        }
    }
    return max_idx;
}

void take_along_axis(double *arr_in, double *ind_in, double *arr_out, int n, int m, int k) {
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            int idx = (int) ind_in[i*m+j];
            if (idx >= 0 && idx < k) {
                arr_out[i*m+j] = arr_in[i*m*k+j*k+idx];
            } else {
                printf("Index out of bounds: %d\n", idx);
            }
        }
    }
}

//IMPLEMENT SHIFTPITCH
//DO STFT HOW IT IS DONE WITH PYTHON (PROBABLY SHOULDNT HAVE ONE HANN WINDOW AS FUNCTION)


