#include "lin_alg.h"
#include <stdio.h>
#include <string.h>
#include <math.h>

double get_trace_2x2(double *arr) {
    return arr[0] + arr[3];
}

double get_det_2x2(double *arr) {
    return arr[0] * arr[3] - arr[1] * arr[2];
}

// λ = (-b ± sqrt(b^2 - 4ac)) / 2a, where a=1, b=-(a+d), and c = ad-bc
void get_eigvals_2x2(double *arr, double *output) {
    double a = 1;
    double b = -(arr[0] + arr[3]);
    double c = arr[0] * arr[3] - arr[1] * arr[2];

    output[0] = (-b - sqrt(b*b - 4*a*c)) / (2 * a);
    output[1] = (-b + sqrt(b*b - 4*a*c)) / (2 * a);
}

void print_3x3(double *arr) {
    for (int i=0; i<9; i++) {
        if (i % 3 == 2)
            printf("%f\n", arr[i]);
        else
            printf("%f, ", arr[i]);
    }
}

int inv_3x3(double *arr, double *output) {
    double det = arr[0] * (arr[4]*arr[8] - arr[7]*arr[5]) - arr[1] * (arr[3]*arr[8] - arr[6]*arr[5]) + arr[2] * (arr[3]*arr[7] - arr[6]*arr[4]);
    if (det == 0) return -1;
    output[0] = arr[4] * arr[8] - arr[7] * arr[5];
    output[1] = arr[7] * arr[2] - arr[1] * arr[8];
    output[2] = arr[1] * arr[5] - arr[4] * arr[2];
    output[3] = arr[5] * arr[6] - arr[8] * arr[3];
    output[4] = arr[8] * arr[0] - arr[2] * arr[6];
    output[5] = arr[2] * arr[3] - arr[5] * arr[0];
    output[6] = arr[3] * arr[7] - arr[4] * arr[6];
    output[7] = arr[6] * arr[1] - arr[0] * arr[7];
    output[8] = arr[0] * arr[4] - arr[3] * arr[1];
    for (int i=0; i<9; i++) {
        output[i] /= det;
    }
    return 0;
}

// Matrix with dims ixj dot product with matrix with dims jxk
void dot(double *arr1, double *arr2, double *out, int i, int j, int k) {
    memset(out, 0, i*k*sizeof(double));
    for (int _i=0; _i<i; _i++) {
        for (int _j=0; _j<j; _j++) {
            for (int _k=0; _k<k; _k++) {
            out[_i * k + _k] += arr1[_i * j + _j] * arr2[_j * k + _k];
            }
        }
    }
}

// int main() {
//     double arr[9] = {1,2,3,4,5,6,7,8,9};
//     double arr2[9] = {2,3,4,5,6,7,8,9,10};
//     double out[9] = {0};
//     dot(arr, arr2, out, 3, 3, 3);
//     print_3x3(out);
// }

// int main() {
//     double arr[9] = {1,1,0,1,0,1,1,1,1};
//     // output should be {1,1,-1,0,-1,1,-1,0,1};
//     double output[9] = {};
//     inv_3x3(arr, output);
//     print_3x3(output);
// }