/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.matrix;

import com.aliasi.matrix.AbstractMatrix;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class SvdMatrix
extends AbstractMatrix {
    private final double[][] mRowVectors;
    private final double[][] mColumnVectors;
    private final int mOrder;

    public SvdMatrix(double[][] rowVectors, double[][] columnVectors, int order) {
        SvdMatrix.verifyDimensions("row", order, rowVectors);
        SvdMatrix.verifyDimensions("column", order, columnVectors);
        this.mRowVectors = rowVectors;
        this.mColumnVectors = columnVectors;
        this.mOrder = order;
    }

    public SvdMatrix(double[][] rowSingularVectors, double[][] columnSingularVectors, double[] singularValues) {
        this.mOrder = singularValues.length;
        SvdMatrix.verifyDimensions("row", this.mOrder, rowSingularVectors);
        SvdMatrix.verifyDimensions("column", this.mOrder, columnSingularVectors);
        this.mRowVectors = new double[rowSingularVectors.length][this.mOrder];
        this.mColumnVectors = new double[columnSingularVectors.length][this.mOrder];
        double[] sqrtSingularValues = new double[singularValues.length];
        for (int i = 0; i < sqrtSingularValues.length; ++i) {
            sqrtSingularValues[i] = Math.sqrt(singularValues[i]);
        }
        SvdMatrix.scale(this.mRowVectors, rowSingularVectors, sqrtSingularValues);
        SvdMatrix.scale(this.mColumnVectors, columnSingularVectors, sqrtSingularValues);
    }

    @Override
    public int numRows() {
        return this.mRowVectors.length;
    }

    @Override
    public int numColumns() {
        return this.mColumnVectors.length;
    }

    public int order() {
        return this.mRowVectors[0].length;
    }

    @Override
    public double value(int row, int column) {
        double[] rowVec = this.mRowVectors[row];
        double[] colVec = this.mColumnVectors[column];
        double result = 0.0;
        for (int i = 0; i < rowVec.length; ++i) {
            result += rowVec[i] * colVec[i];
        }
        return result;
    }

    public double[] singularValues() {
        double[] singularValues = new double[this.mOrder];
        for (int i = 0; i < singularValues.length; ++i) {
            singularValues[i] = this.singularValue(i);
        }
        return singularValues;
    }

    public double singularValue(int order) {
        if (order >= this.mOrder) {
            String msg = "Maximum order=" + (this.mOrder - 1) + " found order=" + order;
            throw new IllegalArgumentException(msg);
        }
        return SvdMatrix.columnLength(this.mRowVectors, order) * SvdMatrix.columnLength(this.mColumnVectors, order);
    }

    public double[][] leftSingularVectors() {
        return SvdMatrix.normalizeColumns(this.mRowVectors);
    }

    public double[][] rightSingularVectors() {
        return SvdMatrix.normalizeColumns(this.mColumnVectors);
    }

    public static SvdMatrix svd(double[][] values, int maxOrder, double featureInit, double initialLearningRate, double annealingRate, double regularization, double minImprovement, int minEpochs, int maxEpochs, PrintWriter writer) {
        int m = values.length;
        int n = values[0].length;
        for (int i = 1; i < m; ++i) {
            if (values[i].length == n) continue;
            String msg = "All rows must be of same length. Found row[0].length=" + n + " row[" + i + "]=" + values[i].length;
            throw new IllegalArgumentException(msg);
        }
        int[] sharedRow = new int[n];
        for (int j = 0; j < n; ++j) {
            sharedRow[j] = j;
        }
        int[][] columnIds = new int[m][];
        for (int j = 0; j < m; ++j) {
            columnIds[j] = sharedRow;
        }
        return SvdMatrix.partialSvd(columnIds, values, maxOrder, featureInit, initialLearningRate, annealingRate, regularization, minImprovement, minEpochs, maxEpochs, writer);
    }

    public static SvdMatrix partialSvd(int[][] columnIds, double[][] values, int maxOrder, double featureInit, double initialLearningRate, double annealingRate, double regularization, double minImprovement, int minEpochs, int maxEpochs, PrintWriter writer) {
        SvdMatrix.printIfWriter(writer, "Start");
        if (maxOrder < 1) {
            throw new IllegalArgumentException("4");
        }
        if (minImprovement < 0.0 || SvdMatrix.notFinite(minImprovement)) {
            throw new IllegalArgumentException("5");
        }
        if (minEpochs <= 0 || maxEpochs < minEpochs) {
            throw new IllegalArgumentException("6");
        }
        if (SvdMatrix.notFinite(featureInit) || featureInit == 0.0) {
            throw new IllegalArgumentException("7");
        }
        if (SvdMatrix.notFinite(initialLearningRate) || initialLearningRate < 0.0) {
            throw new IllegalArgumentException("8");
        }
        if (SvdMatrix.notFinite(regularization) || regularization < 0.0) {
            throw new IllegalArgumentException("9");
        }
        for (int row = 0; row < columnIds.length; ++row) {
            if (columnIds == null) {
                throw new IllegalArgumentException("colIds");
            }
            if (values == null) {
                throw new IllegalArgumentException("values");
            }
            if (columnIds[row] == null) {
                throw new IllegalArgumentException("columnIds " + row);
            }
            if (values[row] == null) {
                throw new IllegalArgumentException("vals " + row);
            }
            if (columnIds[row].length != values[row].length) {
                throw new IllegalArgumentException("10");
            }
            for (int i = 0; i < columnIds[row].length; ++i) {
                if (columnIds[row][i] < 0) {
                    throw new IllegalArgumentException("12");
                }
                if (i <= 0 || columnIds[row][i - 1] < columnIds[row][i]) continue;
                throw new IllegalArgumentException("13");
            }
        }
        if (annealingRate < 0.0 || SvdMatrix.notFinite(annealingRate)) {
            throw new IllegalArgumentException("14");
        }
        int numRows = columnIds.length;
        int numEntries = 0;
        for (double[] xs : values) {
            numEntries += xs.length;
        }
        int maxColumnIndex = 0;
        for (int[] xs : columnIds) {
            for (int i = 0; i < xs.length; ++i) {
                if (xs[i] <= maxColumnIndex) continue;
                maxColumnIndex = xs[i];
            }
        }
        int numColumns = maxColumnIndex + 1;
        maxOrder = Math.min(maxOrder, Math.min(numRows, numColumns));
        double[][] cache = new double[values.length][];
        for (int row = 0; row < numRows; ++row) {
            cache[row] = new double[values[row].length];
            Arrays.fill(cache[row], 0.0);
        }
        ArrayList<double[]> rowVectorList = new ArrayList<double[]>(maxOrder);
        ArrayList<double[]> columnVectorList = new ArrayList<double[]>(maxOrder);
        for (int order = 0; order < maxOrder; ++order) {
            SvdMatrix.printIfWriter(writer, "  Factor=" + order);
            double[] rowVector = SvdMatrix.initArray(numRows, featureInit);
            double[] columnVector = SvdMatrix.initArray(numColumns, featureInit);
            double rmseLast = Double.POSITIVE_INFINITY;
            for (int epoch = 0; epoch < maxEpochs; ++epoch) {
                double learningRateForEpoch = initialLearningRate / (1.0 + (double)epoch / annealingRate);
                double sumOfSquareErrors = 0.0;
                for (int row = 0; row < numRows; ++row) {
                    int[] columnIdsForRow = columnIds[row];
                    double[] valuesForRow = values[row];
                    double[] cacheForRow = cache[row];
                    for (int i = 0; i < columnIdsForRow.length; ++i) {
                        int column = columnIdsForRow[i];
                        double prediction = SvdMatrix.predict(row, column, rowVector, columnVector, cacheForRow[i]);
                        double error = valuesForRow[i] - prediction;
                        sumOfSquareErrors += error * error;
                        double rowCurrent = rowVector[row];
                        double columnCurrent = columnVector[column];
                        int n = row;
                        rowVector[n] = rowVector[n] + learningRateForEpoch * (error * columnCurrent - regularization * rowCurrent);
                        int n2 = column;
                        columnVector[n2] = columnVector[n2] + learningRateForEpoch * (error * rowCurrent - regularization * columnCurrent);
                    }
                }
                double rmse = Math.sqrt(sumOfSquareErrors / (double)numEntries);
                SvdMatrix.printIfWriter(writer, "    epoch=" + epoch + " rmse=" + rmse);
                if (epoch >= minEpochs && SvdMatrix.relativeDifference(rmse, rmseLast) < minImprovement) {
                    SvdMatrix.printIfWriter(writer, "     exiting in epoch=" + epoch + " rmse=" + rmse + " relDiff=" + SvdMatrix.relativeDifference(rmse, rmseLast));
                    break;
                }
                rmseLast = rmse;
            }
            SvdMatrix.printIfWriter(writer, "Order=" + order + " RMSE=" + rmseLast);
            rowVectorList.add(rowVector);
            columnVectorList.add(columnVector);
            for (int row = 0; row < cache.length; ++row) {
                double[] cacheRow = cache[row];
                for (int i = 0; i < cacheRow.length; ++i) {
                    cacheRow[i] = SvdMatrix.predict(row, columnIds[row][i], rowVector, columnVector, cacheRow[i]);
                }
            }
        }
        double[][] rowVectors = SvdMatrix.toArray(rowVectorList);
        double[][] columnVectors = SvdMatrix.toArray(columnVectorList);
        return new SvdMatrix(SvdMatrix.transpose(rowVectors), SvdMatrix.transpose(columnVectors), maxOrder);
    }

    static void printIfWriter(PrintWriter writer, String msg) {
        if (writer == null) {
            return;
        }
        writer.print("partialSvd| ");
        writer.println(msg);
        writer.flush();
    }

    static double predictRaw(int row, int column, int order, List<double[]> rowVectorList, List<double[]> columnVectorList, double lowerBound, double upperBound, double init) {
        double[][] rows = SvdMatrix.toArray(rowVectorList);
        double[][] cols = SvdMatrix.toArray(columnVectorList);
        double val = 0.0;
        for (int i = 0; i <= order; ++i) {
            val += rows[i][row] * cols[i][column];
        }
        return val;
    }

    static double relativeDifference(double x, double y) {
        return Math.abs(x - y) / (Math.abs(x) + Math.abs(y));
    }

    static double[][] transpose(double[][] xs) {
        double[][] ys = new double[xs[0].length][xs.length];
        for (int i = 0; i < xs.length; ++i) {
            for (int j = 0; j < xs[i].length; ++j) {
                ys[j][i] = xs[i][j];
            }
        }
        return ys;
    }

    static double[][] toArray(List<double[]> list) {
        double[][] result = new double[list.size()][];
        list.toArray((T[])result);
        return result;
    }

    static double predict(int row, int column, double[] rowVector, double[] columnVector, double cache) {
        return cache + rowVector[row] * columnVector[column];
    }

    static double[] initArray(int size, double val) {
        double[] xs = new double[size];
        Random random = new Random();
        for (int i = 0; i < xs.length; ++i) {
            xs[i] = random.nextGaussian() * val;
        }
        return xs;
    }

    static boolean notFinite(double x) {
        return Double.isNaN(x) || Double.isInfinite(x);
    }

    static double columnLength(double[][] xs, int col) {
        double sumOfSquares = 0.0;
        for (int i = 0; i < xs.length; ++i) {
            sumOfSquares += xs[i][col] * xs[i][col];
        }
        return Math.sqrt(sumOfSquares);
    }

    static void scale(double[][] vecs, double[][] singularVecs, double[] singularVals) {
        for (int i = 0; i < vecs.length; ++i) {
            for (int k = 0; k < vecs[i].length; ++k) {
                vecs[i][k] = singularVecs[i][k] * singularVals[k];
            }
        }
    }

    static void verifyDimensions(String prefix, int order, double[][] vectors) {
        for (int i = 0; i < vectors.length; ++i) {
            if (vectors[i].length == order) continue;
            String msg = "All vectors must have length equal to order. order=" + order + " " + prefix + "Vectors[" + i + "].length=" + vectors[i].length;
            throw new IllegalArgumentException(msg);
        }
    }

    static double[][] normalizeColumns(double[][] xs) {
        int numDims = xs.length;
        int order = xs[0].length;
        double[][] result = new double[numDims][order];
        for (int j = 0; j < order; ++j) {
            double sumOfSquares = 0.0;
            for (int i = 0; i < numDims; ++i) {
                double valIJ;
                result[i][j] = valIJ = xs[i][j];
                sumOfSquares += valIJ * valIJ;
            }
            double length = Math.sqrt(sumOfSquares);
            for (int i = 0; i < numDims; ++i) {
                double[] dArray = result[i];
                int n = j;
                dArray[n] = dArray[n] / length;
            }
        }
        return result;
    }
}

