/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.stats;

import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Matrices;
import com.aliasi.matrix.SparseFloatVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.Math;
import com.aliasi.util.Strings;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.Arrays;

public class LogisticRegression
implements Compilable,
Serializable {
    private final Vector[] mWeightVectors;

    public LogisticRegression(Vector[] weightVectors) {
        if (weightVectors.length < 1) {
            String msg = "Require at least one weight vector.";
            throw new IllegalArgumentException(msg);
        }
        int numDimensions = weightVectors[0].numDimensions();
        for (int k = 1; k < weightVectors.length; ++k) {
            if (numDimensions == weightVectors[k].numDimensions()) continue;
            String msg = "All weight vectors must be same dimensionality. Found weightVectors[0].numDimensions()=" + numDimensions + " weightVectors[" + k + "]=" + weightVectors[k].numDimensions();
            throw new IllegalArgumentException(msg);
        }
        this.mWeightVectors = weightVectors;
    }

    public LogisticRegression(Vector weightVector) {
        this.mWeightVectors = new Vector[]{weightVector};
    }

    public int numInputDimensions() {
        return this.mWeightVectors[0].numDimensions();
    }

    public int numOutcomes() {
        return this.mWeightVectors.length + 1;
    }

    public Vector[] weightVectors() {
        Vector[] immutables = new Vector[this.mWeightVectors.length];
        for (int i = 0; i < immutables.length; ++i) {
            immutables[i] = Matrices.unmodifiableVector(this.mWeightVectors[i]);
        }
        return immutables;
    }

    public double[] classify(Vector x) {
        int k;
        if (this.numInputDimensions() != x.numDimensions()) {
            String msg = "Vector and classifer must be of same dimensionality. Regression model this.numInputDimensions()=" + this.numInputDimensions() + " Vector x.numDimensions()=" + x.numDimensions();
            throw new IllegalArgumentException(msg);
        }
        double[] ysHat = new double[this.numOutcomes()];
        int numOutcomesMinus1 = this.numOutcomes() - 1;
        double sum = 1.0;
        for (k = 0; k < numOutcomesMinus1; ++k) {
            ysHat[k] = java.lang.Math.exp(x.dotProduct(this.mWeightVectors[k]));
            sum += ysHat[k];
        }
        k = 0;
        while (k < numOutcomesMinus1) {
            int n = k++;
            ysHat[n] = ysHat[n] / sum;
        }
        ysHat[numOutcomesMinus1] = 1.0 / sum;
        return ysHat;
    }

    private Object writeReplace() {
        return new Externalizer(this);
    }

    public void compileTo(ObjectOutput out) throws IOException {
        out.writeObject(new Externalizer(this));
    }

    public static LogisticRegression estimate(Vector[] xs, int[] cs, RegressionPrior prior, AnnealingSchedule annealingSchedule, double minImprovement, int minEpochs, int maxEpochs, PrintWriter progressWriter) {
        boolean hasPrior;
        if (xs.length < 1) {
            String msg = "Require at least one training instance.";
            throw new IllegalArgumentException(msg);
        }
        if (xs.length != cs.length) {
            String msg = "Require same number of training instances as outcomes. Found xs.length=" + xs.length + " cs.length=" + cs.length;
            throw new IllegalArgumentException(msg);
        }
        int numTrainingInstances = xs.length;
        int numOutcomesMinus1 = LogisticRegression.max(cs);
        int numOutcomes = numOutcomesMinus1 + 1;
        int numDimensions = xs[0].numDimensions();
        prior.verifyNumberOfDimensions(numDimensions);
        for (int i = 1; i < xs.length; ++i) {
            if (xs[i].numDimensions() == numDimensions) continue;
            String msg = "Number of dimensions must match for all input vectors. Found xs[0].numDimensions()=" + numDimensions + " xs[" + i + "].numDimensions()=" + xs[i].numDimensions();
            throw new IllegalArgumentException(msg);
        }
        Vector[] weightVectors = new DenseVector[numOutcomesMinus1];
        for (int k = 0; k < numOutcomesMinus1; ++k) {
            weightVectors[k] = new DenseVector(numDimensions);
        }
        boolean hasSparseInputs = LogisticRegression.isSparse(xs);
        boolean bl = hasPrior = prior != null && !(prior instanceof RegressionPrior.NoninformativeRegressionPrior);
        if (progressWriter != null) {
            progressWriter.println("Logistic Regression Progress Report");
            progressWriter.println("Number of dimensions=" + numDimensions);
            progressWriter.println("Number of Outcomes=" + numOutcomes);
            progressWriter.println("Number of Parameters=" + (long)(numOutcomes - 1) * (long)numDimensions);
            progressWriter.println("Prior:\n" + prior);
            progressWriter.println("Annealing Schedule=" + annealingSchedule);
            progressWriter.println("Minimum Epochs=" + minEpochs);
            progressWriter.println("Maximum Epochs=" + maxEpochs);
            progressWriter.println("Minimum Improvement Per Period=" + minImprovement);
            progressWriter.println("Has Sparse Inputs=" + hasSparseInputs);
            progressWriter.println("Has Informative Prior=" + hasPrior);
        }
        long startTime = System.currentTimeMillis();
        long[] lastRegularizations = hasSparseInputs && hasPrior ? new long[numDimensions] : null;
        long step = -1L;
        double lastLog2LikelihoodAndPrior = -8.988465674311579E307;
        LogisticRegression regression = new LogisticRegression(weightVectors);
        double rollingAverageRelativeDiff = 1.0;
        double bestLog2LikelihoodAndPrior = Double.NEGATIVE_INFINITY;
        for (int epoch = 0; epoch < maxEpochs; ++epoch) {
            boolean acceptUpdate;
            DenseVector[] weightVectorCopies = LogisticRegression.copy((DenseVector[])weightVectors);
            if (lastRegularizations != null) {
                step = 0L;
                Arrays.fill(lastRegularizations, 0L);
            }
            double learningRate = annealingSchedule.learningRate(epoch);
            for (int j = 0; j < numTrainingInstances; ++j) {
                Vector xsJ = xs[j];
                int csJ = cs[j];
                double[] conditionalProbs = regression.classify(xsJ);
                if (hasSparseInputs) {
                    int[] dimensions = xsJ.nonZeroDimensions();
                    if (hasPrior) {
                        for (int i = 0; i < dimensions.length; ++i) {
                            int dim = dimensions[i];
                            for (int k = 0; k < numOutcomesMinus1; ++k) {
                                Vector weightVectorsK = weightVectors[k];
                                double weightVectorsKDim = weightVectorsK.value(dim);
                                double priorGrad = prior.gradient(weightVectorsKDim, dim);
                                double delta = priorGrad * (learningRate * (double)(step - lastRegularizations[dim])) / (double)numTrainingInstances;
                                double newVal = weightVectorsKDim > 0.0 ? java.lang.Math.max(0.0, weightVectorsKDim - delta) : java.lang.Math.min(0.0, weightVectorsKDim - delta);
                                weightVectorsK.setValue(dim, newVal);
                            }
                            lastRegularizations[dim] = step;
                        }
                        ++step;
                    }
                    for (int k = 0; k < numOutcomesMinus1; ++k) {
                        Vector weightVectorsK = weightVectors[k];
                        double conditionalProbMinusTruth = conditionalProbs[k];
                        if (k == csJ) {
                            conditionalProbMinusTruth -= 1.0;
                        }
                        weightVectorsK.increment(-learningRate * conditionalProbMinusTruth, xsJ);
                    }
                    continue;
                }
                for (int k = 0; k < numOutcomesMinus1; ++k) {
                    Vector weightVectorsK = weightVectors[k];
                    double conditionalProbMinusTruth = conditionalProbs[k];
                    if (k == csJ) {
                        conditionalProbMinusTruth -= 1.0;
                    }
                    for (int i = 0; i < numDimensions; ++i) {
                        double weightVectorsKI = weightVectorsK.value(i);
                        double gradient = xsJ.value(i) * conditionalProbMinusTruth;
                        if (hasPrior && (weightVectorsKI -= learningRate * gradient) != 0.0) {
                            double priorGradient = prior.gradient(weightVectorsKI, i);
                            double delta = learningRate * priorGradient / (double)numTrainingInstances;
                            weightVectorsKI = weightVectorsKI > 0.0 ? java.lang.Math.max(0.0, weightVectorsKI - delta) : java.lang.Math.min(0.0, weightVectorsKI - delta);
                        }
                        weightVectorsK.setValue(i, weightVectorsKI);
                    }
                }
            }
            if (hasPrior) {
                int i;
                Vector weightVectorsK;
                int k;
                if (hasSparseInputs) {
                    for (k = 0; k < numOutcomesMinus1; ++k) {
                        weightVectorsK = weightVectors[k];
                        for (i = 0; i < numDimensions; ++i) {
                            double weightVectorsKI = weightVectorsK.value(i);
                            if (weightVectorsKI == 0.0) continue;
                            double priorGradient = prior.gradient(weightVectorsKI, i);
                            double delta = (double)(step - lastRegularizations[i]) * learningRate * priorGradient / (double)numTrainingInstances;
                            weightVectorsKI = weightVectorsKI > 0.0 ? java.lang.Math.max(0.0, weightVectorsKI - delta) : java.lang.Math.min(0.0, weightVectorsKI - delta);
                        }
                    }
                } else {
                    for (k = 0; k < numOutcomesMinus1; ++k) {
                        weightVectorsK = weightVectors[k];
                        for (i = 0; i < numDimensions; ++i) {
                            double weightVectorsKI = weightVectorsK.value(i);
                            if (weightVectorsKI == 0.0) continue;
                            double priorGradient = prior.gradient(weightVectorsKI, i);
                            double delta = learningRate * priorGradient / (double)numTrainingInstances;
                            weightVectorsKI = weightVectorsKI > 0.0 ? java.lang.Math.max(0.0, weightVectorsKI - delta) : java.lang.Math.min(0.0, weightVectorsKI - delta);
                        }
                    }
                }
            }
            double log2Likelihood = LogisticRegression.log2Likelihood(xs, cs, regression);
            double log2Prior = prior.log2Prior(weightVectors);
            double log2LikelihoodAndPrior = log2Likelihood + prior.log2Prior(weightVectors);
            if (log2LikelihoodAndPrior > bestLog2LikelihoodAndPrior) {
                bestLog2LikelihoodAndPrior = log2LikelihoodAndPrior;
            }
            if (!(acceptUpdate = annealingSchedule.receivedError(epoch, learningRate, -log2LikelihoodAndPrior))) {
                weightVectors = weightVectorCopies;
                regression = new LogisticRegression(weightVectors);
            }
            double relativeDiff = LogisticRegression.relativeDifference(lastLog2LikelihoodAndPrior, log2LikelihoodAndPrior);
            rollingAverageRelativeDiff = (9.0 * rollingAverageRelativeDiff + relativeDiff) / 10.0;
            lastLog2LikelihoodAndPrior = log2LikelihoodAndPrior;
            if (progressWriter != null) {
                progressWriter.printf("epoch=%5d lr=%11.9f ll=%11.4f lp=%11.4f llp=%11.4f llp*=%11.4f %9s\n", epoch, learningRate, log2Likelihood, log2Prior, log2LikelihoodAndPrior, bestLog2LikelihoodAndPrior, Strings.msToString(System.currentTimeMillis() - startTime));
            }
            if (rollingAverageRelativeDiff < minImprovement) break;
        }
        return regression;
    }

    public static double log2Likelihood(Vector[] inputs, int[] cats, LogisticRegression regression) {
        if (inputs.length != cats.length) {
            String msg = "Inputs and categories must be same length. Found inputs.length=" + inputs.length + " cats.length=" + cats.length;
            throw new IllegalArgumentException(msg);
        }
        int numTrainingInstances = inputs.length;
        double log2Likelihood = 0.0;
        for (int j = 0; j < numTrainingInstances; ++j) {
            double[] conditionalProbs = regression.classify(inputs[j]);
            log2Likelihood += Math.log2(conditionalProbs[cats[j]]);
        }
        return log2Likelihood;
    }

    private static boolean isSparse(Vector[] xs) {
        int sparseCount = 0;
        for (int i = 0; i < xs.length; ++i) {
            if (!(xs[i] instanceof SparseFloatVector)) continue;
            ++sparseCount;
        }
        return sparseCount >= xs.length / 2;
    }

    private static int max(int[] xs) {
        int max = xs[0];
        for (int i = 1; i < xs.length; ++i) {
            if (xs[i] <= max) continue;
            max = xs[i];
        }
        return max;
    }

    private static double relativeDifference(double x, double y) {
        return Double.isInfinite(x) || Double.isInfinite(y) ? Double.POSITIVE_INFINITY : java.lang.Math.abs(x - y) / (java.lang.Math.abs(x) + java.lang.Math.abs(y));
    }

    private static double[][] deepCopy(double[][] xs) {
        double[][] ys = new double[xs.length][];
        for (int i = 0; i < xs.length; ++i) {
            ys[i] = LogisticRegression.deepCopy(xs[i]);
        }
        return ys;
    }

    private static double[] deepCopy(double[] xs) {
        double[] ys = new double[xs.length];
        for (int i = 0; i < xs.length; ++i) {
            ys[i] = xs[i];
        }
        return ys;
    }

    private static DenseVector[] copy(DenseVector[] xs) {
        DenseVector[] result = new DenseVector[xs.length];
        for (int k = 0; k < xs.length; ++k) {
            result[k] = new DenseVector(xs[k]);
        }
        return result;
    }

    static class Externalizer
    extends AbstractExternalizable {
        static final long serialVersionUID = -2256261505231943102L;
        final LogisticRegression mRegression;

        public Externalizer() {
            this(null);
        }

        public Externalizer(LogisticRegression regression) {
            this.mRegression = regression;
        }

        public void writeExternal(ObjectOutput out) throws IOException {
            int numOutcomes = this.mRegression.mWeightVectors.length + 1;
            out.writeInt(numOutcomes);
            int numDimensions = this.mRegression.mWeightVectors[0].numDimensions();
            out.writeInt(numDimensions);
            for (int c = 0; c < numOutcomes - 1; ++c) {
                Vector vC = this.mRegression.mWeightVectors[c];
                for (int i = 0; i < numDimensions; ++i) {
                    out.writeDouble(vC.value(i));
                }
            }
        }

        public Object read(ObjectInput in) throws IOException {
            int numOutcomes = in.readInt();
            int numDimensions = in.readInt();
            Vector[] weightVectors = new Vector[numOutcomes - 1];
            for (int c = 0; c < weightVectors.length; ++c) {
                DenseVector weightVectorsC = new DenseVector(numDimensions);
                weightVectors[c] = weightVectorsC;
                for (int i = 0; i < numDimensions; ++i) {
                    weightVectorsC.setValue(i, in.readDouble());
                }
            }
            return new LogisticRegression(weightVectors);
        }
    }
}

