/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.stats;

import com.aliasi.stats.MultivariateConstant;
import com.aliasi.stats.MultivariateDistribution;
import com.aliasi.util.AbstractExternalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;

public class MultivariateEstimator
extends MultivariateDistribution
implements Serializable {
    static final long serialVersionUID = 1171641384366463097L;
    final HashMap mLabelToIndex;
    final ArrayList mIndexToLabel;
    final ArrayList mIndexToCount;
    long mTotalCount = 0L;
    int mNextIndex = 0;

    public MultivariateEstimator() {
        this(new HashMap(), new ArrayList(), new ArrayList());
    }

    private MultivariateEstimator(HashMap labelToIndex, ArrayList indexToLabel, ArrayList indexToCount) {
        this.mLabelToIndex = labelToIndex;
        this.mIndexToLabel = indexToLabel;
        this.mIndexToCount = indexToCount;
    }

    static void checkLongAddInRange(long a, long b) {
        if (Long.MAX_VALUE - b < a) {
            String msg = "Long addition overflow. a=" + a + " b=" + b;
            throw new IllegalArgumentException(msg);
        }
    }

    public void resetCount(String outcomeLabel) {
        Integer index = (Integer)this.mLabelToIndex.get(outcomeLabel);
        if (index == null) {
            String msg = "May only reset known outcomes. Found outcome=" + outcomeLabel;
            throw new IllegalArgumentException(msg);
        }
        long currentCount = (Long)this.mIndexToCount.get(index);
        this.mTotalCount -= currentCount;
        this.mIndexToCount.set(index, new Long(0L));
    }

    public void train(String outcomeLabel, long increment) {
        if (increment < 1L) {
            String msg = "Increment must be positive. Found increment=" + increment;
            throw new IllegalArgumentException(msg);
        }
        this.mTotalCount += increment;
        Integer indexInteger = (Integer)this.mLabelToIndex.get(outcomeLabel);
        if (indexInteger == null) {
            int index = this.mNextIndex++;
            this.mLabelToIndex.put(outcomeLabel, new Integer(index));
            this.mIndexToLabel.add(index, outcomeLabel);
            this.mIndexToCount.add(index, new Long(increment));
            return;
        }
        int index = indexInteger;
        long currentCount = ((Long)this.mIndexToCount.get(index)).intValue();
        MultivariateEstimator.checkLongAddInRange(currentCount, increment);
        this.mIndexToCount.set(index, new Long(currentCount + increment));
    }

    public long outcome(String outcomeLabel) {
        Integer outcome = (Integer)this.mLabelToIndex.get(outcomeLabel);
        return outcome == null ? -1L : outcome.longValue();
    }

    public String label(long outcome) {
        if (outcome < 0L || outcome >= (long)this.mNextIndex) {
            String msg = "Outcome must be between 0 and max. Max outcome=" + this.maxOutcome() + " Argument outcome=" + outcome;
            throw new IllegalArgumentException(msg);
        }
        return this.mIndexToLabel.get((int)outcome).toString();
    }

    public int numDimensions() {
        return this.mIndexToLabel.size();
    }

    public double probability(long outcome) {
        if (outcome < this.minOutcome() || outcome > this.maxOutcome()) {
            return 0.0;
        }
        return (double)this.getCount(outcome) / (double)this.trainingSampleCount();
    }

    public long getCount(long outcome) {
        this.checkOutcome(outcome);
        Long count = (Long)this.mIndexToCount.get((int)outcome);
        return count == null ? 0L : count;
    }

    public long getCount(String outcomeLabel) {
        Integer index = (Integer)this.mLabelToIndex.get(outcomeLabel);
        if (index == null) {
            String msg = "May only count known outcomes by label. Found outcome=" + outcomeLabel;
            throw new IllegalArgumentException(msg);
        }
        return this.getCount(index.longValue());
    }

    public long trainingSampleCount() {
        return this.mTotalCount;
    }

    public void compileTo(ObjectOutput objOut) throws IOException {
        objOut.writeObject(new Externalizer(this));
    }

    static class Externalizer
    extends AbstractExternalizable {
        private static final long serialVersionUID = 2913496935213914118L;
        final MultivariateEstimator mEstimator;

        public Externalizer() {
            this.mEstimator = null;
        }

        public Externalizer(MultivariateEstimator estimator) {
            this.mEstimator = estimator;
        }

        public void writeExternal(ObjectOutput out) throws IOException {
            String[] labels = new String[this.mEstimator.mIndexToLabel.size()];
            this.mEstimator.mIndexToLabel.toArray(labels);
            out.writeObject(labels);
            Long[] counts = new Long[this.mEstimator.mIndexToCount.size()];
            this.mEstimator.mIndexToCount.toArray(counts);
            double totalCount = this.mEstimator.mTotalCount;
            double[] ratios = new double[counts.length];
            for (int i = 0; i < ratios.length; ++i) {
                ratios[i] = counts[i].doubleValue() / totalCount;
            }
            out.writeObject(ratios);
        }

        public Object read(ObjectInput in) throws ClassNotFoundException, IOException {
            String[] labels = (String[])in.readObject();
            double[] ratios = (double[])in.readObject();
            return new MultivariateConstant(ratios, labels);
        }
    }
}

