/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.stats;

import com.aliasi.matrix.Vector;
import com.aliasi.util.Math;

public abstract class RegressionPrior {
    private static final RegressionPrior NONINFORMATIVE_PRIOR = new NoninformativeRegressionPrior();
    static final double sqrt2 = java.lang.Math.sqrt(2.0);
    static final double log2Sqrt2Over2 = Math.log2(sqrt2 / 2.0);
    static final double log2Sqrt2Pi = Math.log2(java.lang.Math.sqrt(java.lang.Math.PI * 2));
    static final double log21OverPi = -Math.log2(java.lang.Math.PI);

    private RegressionPrior() {
    }

    public abstract double gradient(double var1, int var3);

    public abstract double log2Prior(double var1, int var3);

    public double log2Prior(Vector beta) {
        int numDimensions = beta.numDimensions();
        this.verifyNumberOfDimensions(numDimensions);
        double log2Prior = 0.0;
        for (int i = 0; i < numDimensions; ++i) {
            log2Prior += this.log2Prior(beta.value(i), i);
        }
        return log2Prior;
    }

    public double log2Prior(Vector[] betas) {
        double log2Prior = 0.0;
        for (Vector beta : betas) {
            log2Prior += this.log2Prior(beta);
        }
        return log2Prior;
    }

    void verifyNumberOfDimensions(int numDimensions) {
    }

    public static RegressionPrior noninformative() {
        return NONINFORMATIVE_PRIOR;
    }

    public static RegressionPrior gaussian(double priorVariance, boolean noninformativeIntercept) {
        RegressionPrior.verifyPriorVariance(priorVariance);
        return new VariableGaussianRegressionPrior(priorVariance, noninformativeIntercept);
    }

    public static RegressionPrior gaussian(double[] priorVariances) {
        RegressionPrior.verifyPriorVariances(priorVariances);
        return new GaussianRegressionPrior(priorVariances);
    }

    public static RegressionPrior laplace(double priorVariance, boolean noninformativeIntercept) {
        RegressionPrior.verifyPriorVariance(priorVariance);
        return new VariableLaplaceRegressionPrior(priorVariance, noninformativeIntercept);
    }

    public static RegressionPrior laplace(double[] priorVariances) {
        RegressionPrior.verifyPriorVariances(priorVariances);
        return new LaplaceRegressionPrior(priorVariances);
    }

    public static RegressionPrior cauchy(double priorSquaredScale, boolean noninformativeIntercept) {
        RegressionPrior.verifyPriorVariance(priorSquaredScale);
        return new VariableCauchyRegressionPrior(priorSquaredScale, noninformativeIntercept);
    }

    public static RegressionPrior cauchy(double[] priorSquaredScales) {
        RegressionPrior.verifyPriorVariances(priorSquaredScales);
        return new CauchyRegressionPrior(priorSquaredScales);
    }

    static void verifyPriorVariance(double priorVariance) {
        if (priorVariance < 0.0 || Double.isNaN(priorVariance) || priorVariance == Double.NEGATIVE_INFINITY) {
            String msg = "Prior variance must be a non-negative number. Found priorVariance=" + priorVariance;
            throw new IllegalArgumentException(msg);
        }
    }

    static void verifyPriorVariances(double[] priorVariances) {
        for (int i = 0; i < priorVariances.length; ++i) {
            if (!(priorVariances[i] < 0.0) && !Double.isNaN(priorVariances[i]) && priorVariances[i] != Double.NEGATIVE_INFINITY) continue;
            String msg = "Prior variances must be non-negative numbers. Found priorVariances[" + i + "]=" + priorVariances[i];
            throw new IllegalArgumentException(msg);
        }
    }

    static class VariableCauchyRegressionPrior
    extends VariableRegressionPrior {
        VariableCauchyRegressionPrior(double priorVariance, boolean noninformativeIntercept) {
            super(priorVariance, noninformativeIntercept);
        }

        public double gradient(double beta, int dimension) {
            return dimension == 0 && this.mNoninformativeIntercept ? 0.0 : 2.0 * beta / (beta * beta + this.mPriorVariance);
        }

        public double log2Prior(double beta, int dimension) {
            if (dimension == 0 && this.mNoninformativeIntercept) {
                return 0.0;
            }
            return log21OverPi + 0.5 * Math.log2(this.mPriorVariance) - Math.log2(beta * beta + this.mPriorVariance);
        }

        public String toString() {
            return this.toString("CauchyRegressionPrior", "Scale");
        }
    }

    static class VariableLaplaceRegressionPrior
    extends VariableRegressionPrior {
        VariableLaplaceRegressionPrior(double priorVariance, boolean noninformativeIntercept) {
            super(priorVariance, noninformativeIntercept);
        }

        public double gradient(double beta, int dimension) {
            return dimension == 0 && this.mNoninformativeIntercept || beta == 0.0 ? 0.0 : (beta > 0.0 ? java.lang.Math.sqrt(2.0 / this.mPriorVariance) : -java.lang.Math.sqrt(2.0 / this.mPriorVariance));
        }

        public double log2Prior(double beta, int dimension) {
            if (dimension == 0 && this.mNoninformativeIntercept) {
                return 0.0;
            }
            return log2Sqrt2Over2 - 0.5 * Math.log2(this.mPriorVariance) - sqrt2 * java.lang.Math.abs(beta) / java.lang.Math.sqrt(this.mPriorVariance);
        }

        public String toString() {
            return this.toString("LaplaceRegressionPrior", "Variance");
        }
    }

    static class VariableGaussianRegressionPrior
    extends VariableRegressionPrior {
        VariableGaussianRegressionPrior(double priorVariance, boolean noninformativeIntercept) {
            super(priorVariance, noninformativeIntercept);
        }

        public double gradient(double beta, int dimension) {
            return dimension == 0 && this.mNoninformativeIntercept ? 0.0 : beta / this.mPriorVariance;
        }

        public double log2Prior(double beta, int dimension) {
            if (dimension == 0 && this.mNoninformativeIntercept) {
                return 0.0;
            }
            return -log2Sqrt2Pi - 0.5 * Math.log2(this.mPriorVariance) - beta * beta / (2.0 * this.mPriorVariance);
        }

        public String toString() {
            return this.toString("GaussianRegressionPrior", "Variance");
        }
    }

    static abstract class VariableRegressionPrior
    extends RegressionPrior {
        final double mPriorVariance;
        final boolean mNoninformativeIntercept;

        VariableRegressionPrior(double priorVariance, boolean noninformativeIntercept) {
            this.mPriorVariance = priorVariance;
            this.mNoninformativeIntercept = noninformativeIntercept;
        }

        public String toString(String priorName, String paramName) {
            return priorName + "(" + paramName + "=" + this.mPriorVariance + ", noninformativeIntercept=" + this.mNoninformativeIntercept + ")";
        }
    }

    static class CauchyRegressionPrior
    extends ArrayRegressionPrior {
        CauchyRegressionPrior(double[] priorSquaredScales) {
            super(priorSquaredScales);
        }

        public double gradient(double beta, int dimension) {
            return 2.0 * beta / (beta * beta + this.mValues[dimension]);
        }

        public double log2Prior(double beta, int dimension) {
            return log21OverPi + 0.5 * Math.log2(this.mValues[dimension]) - Math.log2(beta * beta + this.mValues[dimension] * this.mValues[dimension]);
        }

        public String toString() {
            return this.toString("CauchyRegressionPrior", "Scale");
        }
    }

    static class LaplaceRegressionPrior
    extends ArrayRegressionPrior {
        LaplaceRegressionPrior(double[] priorVariances) {
            super(priorVariances);
        }

        public double gradient(double beta, int dimension) {
            if (beta == 0.0) {
                return 0.0;
            }
            if (beta > 0.0) {
                return java.lang.Math.sqrt(2.0 / this.mValues[dimension]);
            }
            return -java.lang.Math.sqrt(2.0 / this.mValues[dimension]);
        }

        public double log2Prior(double beta, int dimension) {
            return log2Sqrt2Over2 - 0.5 * Math.log2(this.mValues[dimension]) - sqrt2 * java.lang.Math.abs(beta) / java.lang.Math.sqrt(this.mValues[dimension]);
        }

        public String toString() {
            return this.toString("LaplaceRegressionPrior", "Variance");
        }
    }

    static class GaussianRegressionPrior
    extends ArrayRegressionPrior {
        GaussianRegressionPrior(double[] priorVariances) {
            super(priorVariances);
        }

        public double gradient(double beta, int dimension) {
            return beta / this.mValues[dimension];
        }

        public double log2Prior(double beta, int dimension) {
            return -log2Sqrt2Pi - 0.5 * Math.log2(this.mValues[dimension]) - beta * beta / (2.0 * this.mValues[dimension]);
        }

        public String toString() {
            return this.toString("GaussianRegressionPrior", "Variance");
        }
    }

    static abstract class ArrayRegressionPrior
    extends RegressionPrior {
        final double[] mValues;

        ArrayRegressionPrior(double[] values) {
            this.mValues = values;
        }

        void verifyNumberOfDimensions(int numDimensions) {
            if (this.mValues.length != numDimensions) {
                String msg = "Prior and instances must match in number of dimensions. Found prior numDimensions=" + this.mValues.length + " instance numDimensions=" + numDimensions;
                throw new IllegalArgumentException(msg);
            }
        }

        public String toString(String priorName, String paramName) {
            StringBuilder sb = new StringBuilder();
            sb.append(priorName + "\n");
            sb.append("     dimensionality=" + this.mValues.length);
            for (int i = 0; i < this.mValues.length; ++i) {
                sb.append("     " + paramName + "[" + i + "]=" + this.mValues[i] + "\n");
            }
            return sb.toString();
        }
    }

    static class NoninformativeRegressionPrior
    extends RegressionPrior {
        NoninformativeRegressionPrior() {
        }

        public double gradient(double beta, int dimension) {
            return 0.0;
        }

        public double log2Prior(double beta, int dimension) {
            return 0.0;
        }

        public double log2Prior(Vector beta) {
            return 0.0;
        }

        public double log2Prior(Vector[] betas) {
            return 0.0;
        }

        public String toString() {
            return "NoninformativeRegressionPrior";
        }
    }
}

