/*
 * Decompiled with CFR 0.152.
 */
package ob.ml.classifier;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeSet;
import ob.core.Feature;
import ob.ml.classifier.DataPoint;
import ob.ml.classifier.DataPointImpl;
import ob.ml.classifier.DataSet;
import ob.ml.classifier.DataSetPair;

public class DataSetImpl
implements DataSet {
    private Collection<DataPoint> _dataPoints = new ArrayList<DataPoint>();
    private Collection<Feature> _features;
    private Feature _classFeature;
    private boolean _sparse;

    @Override
    public Collection<DataPoint> getDataPoints() {
        return this._dataPoints;
    }

    @Override
    public void add(DataPoint dataPoint) {
        if (this._features == null) {
            this._features = new TreeSet<Feature>(dataPoint.getFeatures());
            this._classFeature = dataPoint.getClassFeature();
            this._sparse = dataPoint.isSparse();
        } else if (this._sparse) {
            this._features.addAll(dataPoint.getFeatures());
        }
        this._dataPoints.add(dataPoint);
    }

    @Override
    public int size() {
        return this._dataPoints.size();
    }

    @Override
    public Feature getClassFeature() {
        return this._classFeature;
    }

    @Override
    public Collection<Feature> getFeatures() {
        return this._features;
    }

    @Override
    public Collection<DataSetPair> getCrossValidationDataSets(int folds, boolean stratified) {
        int n;
        if (stratified && !this._classFeature.isNominal()) {
            throw new RuntimeException("cannot stratify data sets with non-nominal class features");
        }
        Collection[] testSets = new Collection[folds];
        int i = 0;
        while (i < testSets.length) {
            testSets[i] = new ArrayList();
            ++i;
        }
        Map<String, Integer> classDistribution = this.getClassDistribution();
        block1: for (DataPoint point : this._dataPoints) {
            Collection[] collectionArray = testSets;
            int n2 = testSets.length;
            n = 0;
            while (n < n2) {
                Collection testSet = collectionArray[n];
                if (this.isGoodSet(testSet, point, classDistribution, stratified, folds)) {
                    testSet.add(point);
                    continue block1;
                }
                ++n;
            }
        }
        ArrayList<DataSetPair> pairs = new ArrayList<DataSetPair>();
        Collection[] collectionArray = testSets;
        n = testSets.length;
        int n3 = 0;
        while (n3 < n) {
            Collection testSet = collectionArray[n3];
            DataSetImpl testData = new DataSetImpl();
            for (DataPoint point : testSet) {
                testData.add(point);
            }
            DataSetImpl trainData = new DataSetImpl();
            for (DataPoint point : this._dataPoints) {
                if (testSet.contains(point)) continue;
                trainData.add(point);
            }
            pairs.add(new DataSetPair(trainData, testData));
            ++n3;
        }
        return pairs;
    }

    private boolean isGoodSet(Collection<DataPoint> testSet, DataPoint point, Map<String, Integer> classDistribution, boolean stratified, int folds) {
        if (testSet.size() > this.getMaxCountForFold(this._dataPoints.size(), folds)) {
            return false;
        }
        if (classDistribution == null) {
            return true;
        }
        if (!stratified) {
            return true;
        }
        String value = point.getClassValue().getString();
        Integer maxCount = this.getMaxCountForFold(classDistribution.get(value), folds);
        Integer currentCount = this.getCount(value, testSet);
        return currentCount < maxCount;
    }

    private int getMaxCountForFold(int original, int folds) {
        return (int)Math.ceil((double)original / (double)folds);
    }

    private Integer getCount(String value, Collection<DataPoint> points) {
        int count = 0;
        for (DataPoint point : points) {
            String currValue = point.getClassValue().getString();
            if (!value.equals(currValue)) continue;
            ++count;
        }
        return count;
    }

    private Map<String, Integer> getClassDistribution() {
        if (!this._classFeature.isNominal()) {
            return null;
        }
        HashMap<String, Integer> distribution = new HashMap<String, Integer>();
        for (DataPoint point : this._dataPoints) {
            String nominal = point.getClassValue().getString();
            Integer count = (Integer)distribution.get(nominal);
            if (count == null) {
                count = 0;
            }
            count = count + 1;
            distribution.put(nominal, count);
        }
        return distribution;
    }

    @Override
    public void pseudoBalance() {
        if (this._classFeature.isNominal()) {
            Map<String, Integer> distribution = this.getClassDistribution();
            HashMap<String, Double> ratios = new HashMap<String, Double>();
            for (String nominal : distribution.keySet()) {
                double ratio = (double)distribution.get(nominal).intValue() / (double)this.size();
                ratios.put(nominal, ratio);
            }
            for (DataPoint dataPoint : this.getDataPoints()) {
                String nominal = dataPoint.getClassValue().getString();
                double weight = 1.0 / (Double)ratios.get(nominal);
                ((DataPointImpl)dataPoint).setWeight(weight);
            }
        } else if (this._classFeature.isNumeric()) {
            Double mean = this.getClassMean();
            for (DataPoint dataPoint : this.getDataPoints()) {
                Double value = dataPoint.getClassValue().getNumber();
                double weight = Math.abs(value - mean);
                ((DataPointImpl)dataPoint).setWeight(weight);
            }
        } else {
            throw new UnsupportedOperationException("Cannot pseudo-balance data sets with non-nominal and non-numeric class features");
        }
    }

    private Double getClassMean() {
        double mean = 0.0;
        for (DataPoint dataPoint : this.getDataPoints()) {
            mean += dataPoint.getClassValue().getNumber().doubleValue();
        }
        return mean / (double)this.getDataPoints().size();
    }
}

