/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classifier;
import com.aliasi.classify.JointClassification;
import com.aliasi.corpus.ClassificationHandler;
import com.aliasi.stats.MultivariateEstimator;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.Math;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.ObjectToDoubleMap;
import com.aliasi.util.ScoredObject;
import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class BernoulliClassifier<E>
implements Classifier<E, JointClassification>,
ClassificationHandler<E, Classification>,
Serializable {
    static final long serialVersionUID = -7761909693358968780L;
    private final MultivariateEstimator mCategoryDistribution = new MultivariateEstimator();
    private final FeatureExtractor mFeatureExtractor;
    private final Map<String, ObjectToCounterMap<String>> mFeatureDistributionMap = new HashMap<String, ObjectToCounterMap<String>>();
    private final Set<String> mFeatureSet = new HashSet<String>();
    private final double mActivationThreshold;

    public BernoulliClassifier(FeatureExtractor featureExtractor) {
        this(featureExtractor, 0.0);
    }

    public BernoulliClassifier(FeatureExtractor featureExtractor, double featureActivationThreshold) {
        this.mFeatureExtractor = featureExtractor;
        this.mActivationThreshold = featureActivationThreshold;
    }

    public String[] categories() {
        String[] categories = new String[this.mCategoryDistribution.numDimensions()];
        for (int i = 0; i < this.mCategoryDistribution.numDimensions(); ++i) {
            categories[i] = this.mCategoryDistribution.label(i);
        }
        return categories;
    }

    @Override
    public void handle(E input, Classification classification) {
        String category = classification.bestCategory();
        this.mCategoryDistribution.train(category, 1L);
        ObjectToCounterMap<String> categoryCounter = this.mFeatureDistributionMap.get(category);
        if (categoryCounter == null) {
            categoryCounter = new ObjectToCounterMap();
            this.mFeatureDistributionMap.put(category, categoryCounter);
        }
        for (String feature : this.activeFeatureSet(input)) {
            categoryCounter.increment(feature);
            this.mFeatureSet.add(feature);
        }
    }

    @Override
    public JointClassification classify(E input) {
        Set<String> activeFeatureSet = this.activeFeatureSet(input);
        HashSet<String> inactiveFeatureSet = new HashSet<String>(this.mFeatureSet);
        inactiveFeatureSet.removeAll(activeFeatureSet);
        String[] activeFeatures = activeFeatureSet.toArray(new String[activeFeatureSet.size()]);
        String[] inactiveFeatures = inactiveFeatureSet.toArray(new String[inactiveFeatureSet.size()]);
        ObjectToDoubleMap<String> categoryToLog2P = new ObjectToDoubleMap<String>();
        int numCategories = this.mCategoryDistribution.numDimensions();
        for (long i = 0L; i < (long)numCategories; ++i) {
            String category = this.mCategoryDistribution.label(i);
            double log2P = Math.log2(this.mCategoryDistribution.probability(i));
            double categoryCount = this.mCategoryDistribution.getCount(i);
            ObjectToCounterMap<String> categoryFeatureCounts = this.mFeatureDistributionMap.get(category);
            for (String activeFeature : activeFeatures) {
                double featureCount = categoryFeatureCounts.getCount(activeFeature);
                if (featureCount == 0.0) continue;
                log2P += Math.log2((featureCount + 1.0) / (categoryCount + 2.0));
            }
            for (String inactiveFeature : inactiveFeatures) {
                double notFeatureCount = categoryCount - (double)categoryFeatureCounts.getCount(inactiveFeature);
                log2P += Math.log2((notFeatureCount + 1.0) / (categoryCount + 2.0));
            }
            categoryToLog2P.set(category, log2P);
        }
        String[] categories = new String[numCategories];
        double[] log2Ps = new double[numCategories];
        List scoredObjectList = categoryToLog2P.scoredObjectsOrderedByValueList();
        for (int i = 0; i < numCategories; ++i) {
            ScoredObject so = scoredObjectList.get(i);
            categories[i] = (String)so.getObject();
            log2Ps[i] = so.score();
        }
        return new JointClassification(categories, log2Ps);
    }

    private Set<String> activeFeatureSet(E input) {
        HashSet<String> activeFeatureSet = new HashSet<String>();
        Map<String, Number> featureMap = this.mFeatureExtractor.features(input);
        for (Map.Entry<String, Number> entry : featureMap.entrySet()) {
            String feature = entry.getKey();
            Number val = entry.getValue();
            if (!(val.doubleValue() > this.mActivationThreshold)) continue;
            activeFeatureSet.add(feature);
        }
        return activeFeatureSet;
    }
}

