#include <stdlib.h>
#include <stdint.h>
#include <math.h>
#include <string.h>
#include "stft.h"
#include "operations.h"

#define PI_FIXED ((int16_t)(3.14159265358979323846 * (1 << 15)))

int16_t wrap(int16_t x) {
    const int16_t pi = 3217;    // 3.14159265358979323846 in Q15 format
    const int16_t pi2 = 6435;   // 2*pi in Q15 format
    
    x = (x + pi) % pi2;
    if (x < 0) {
        x += pi2;
    }
    return x - pi;
}

void encode(complex_fixed *complex_data, int framesize, int hopsize, int samplerate, complex_fixed *encoded_data) {
    int N = framesize; // number of samples per frame
    int16_t phaseinc = 2.0 * PI_FIXED * hopsize / framesize; // phase increment between adjacent samples
    int16_t freqinc = samplerate / framesize; // frequency increment
    complex_fixed buffer[N]; // to store previous frame
    complex_fixed frame[N];
    int16_t abs[N], arg[N];
    double delta[N], freq[N];
    for (int n = 0; n < N; n++) {
        abs[n] = hypot_fixed(complex_data[n].real, complex_data[n].imag);
        arg[n] = atan2_fixed(complex_data[n].imag, complex_data[n].real);
        delta[n] = arg[n] - buffer[n].imag;
        buffer[n].real = 0;
        buffer[n].imag = arg[n];
        freq[n] = wrap(delta[n] - n * phaseinc) / phaseinc;
        freq[n] = (n + freq[n]) * freqinc;
        encoded_data[n].real = (int16_t)round(abs[n]);
        encoded_data[n].imag = (int16_t)round(freq[n]);
    }
}

void decode(complex_fixed *data, int framesize, int hopsize, int samplerate, complex_fixed *decoded_data) {
    int N = framesize; // number of samples per frame
    int16_t phaseinc = 2.0 * PI_FIXED * hopsize / framesize; // phase increment between adjacent samples
    int16_t freqinc = samplerate / (double)framesize; // frequency increment
    complex_fixed buffer[N]; // to store previous frame
    double abs[N], freq[N], delta[N], arg[N];
    for (int n = 0; n < N; n++) {
        abs[n] = (double)data[n].real;
        freq[n] = (double)data[n].imag;
        delta[n] = (n + (freq[n] - n * freqinc) / freqinc) * phaseinc;
        buffer[n].real = round(abs[n] * cos(delta[n]));
        buffer[n].imag = round(abs[n] * sin(delta[n]));
        decoded_data[n].real = buffer[n].real;
        decoded_data[n].imag = buffer[n].imag;
    }
}