/*
 * Decompiled with CFR 0.152.
 */
package ob.ml.classifier.weka;

import java.util.Collection;
import ob.core.CentralFactory;
import ob.core.Feature;
import ob.core.FeatureType;
import ob.core.NumericFeature;
import ob.core.StringFeature;
import ob.core.Value;
import ob.ml.classifier.Classifier;
import ob.ml.classifier.DataPoint;
import ob.ml.classifier.DataSet;
import ob.ml.util.WekaUtils;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SerializationHelper;
import weka.core.SparseInstance;

public class WekaClassifier
implements Classifier {
    private weka.classifiers.Classifier _classifier;
    private Collection<Feature> _features;
    private Feature _classFeature;

    public WekaClassifier(weka.classifiers.Classifier classifier) {
        this._classifier = classifier;
    }

    public WekaClassifier(String file) {
        try {
            this._classifier = (weka.classifiers.Classifier)SerializationHelper.read((String)file);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public Value classify(DataPoint dataPoint) {
        Instances instances = this.makeInstances(1);
        Instance instance = this.convertToInstance(dataPoint, instances);
        double result = this.classify(instance);
        if (Instance.isMissingValue((double)result)) {
            return null;
        }
        Attribute classAttribute = instances.classAttribute();
        if (classAttribute.isNumeric()) {
            return CentralFactory.getNumericValue(result);
        }
        if (classAttribute.isString() || classAttribute.isNominal()) {
            String stringResult = classAttribute.value((int)result);
            return CentralFactory.getStringValue(stringResult);
        }
        throw new RuntimeException("unhandled class attribute type");
    }

    @Override
    public double[] distribution(DataPoint dataPoint) {
        Instances instances = this.makeInstances(1);
        Instance instance = this.convertToInstance(dataPoint, instances);
        return this.distribution(instance);
    }

    private double[] distribution(Instance instance) {
        try {
            return this._classifier.distributionForInstance(instance);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private double classify(Instance instance) {
        try {
            return this._classifier.classifyInstance(instance);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void train(DataSet dataSet) {
        this._classFeature = dataSet.getClassFeature();
        this._features = dataSet.getFeatures();
        Instances instances = this.makeInstances(dataSet.size());
        for (DataPoint dataPoint : dataSet.getDataPoints()) {
            Instance instance = this.convertToInstance(dataPoint, instances);
            instances.add(instance);
        }
        try {
            this._classifier.buildClassifier(instances);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        System.out.println("done building classifier");
    }

    private Instances makeInstances(int size) {
        FastVector attInfo = new FastVector();
        Attribute classAttribute = this.createAttribute(this._classFeature);
        attInfo.addElement((Object)classAttribute);
        for (Feature feature : this._features) {
            Attribute attribute = this.createAttribute(feature);
            attInfo.addElement((Object)attribute);
        }
        Instances instances = new Instances("dataset", attInfo, size);
        instances.setClassIndex(0);
        return instances;
    }

    private Attribute createAttribute(Feature feature) {
        if (feature.getType().equals((Object)FeatureType.NUMERIC)) {
            return new Attribute(feature.getName());
        }
        if (feature.getType().equals((Object)FeatureType.STRING)) {
            FastVector nominalValues = new FastVector();
            for (String nominal : ((StringFeature)feature).getNominals()) {
                nominalValues.addElement((Object)nominal);
            }
            return new Attribute(feature.getName(), nominalValues);
        }
        throw new RuntimeException("cannot handle feature type " + (Object)((Object)feature.getType()));
    }

    private Instance convertToInstance(DataPoint dataPoint, Instances instances) {
        if (dataPoint.isSparse()) {
            double[] attValues = new double[instances.numAttributes()];
            Value classValue = dataPoint.getClassValue();
            Attribute classAttribute = instances.classAttribute();
            if (classAttribute.isNumeric()) {
                attValues[0] = classValue.getNumber();
            }
            if (classAttribute.isString() || classAttribute.isNominal()) {
                attValues[0] = classAttribute.indexOfValue(classValue.getString());
            }
            int i = 1;
            while (i < instances.numAttributes()) {
                Value value;
                Attribute attribute = instances.attribute(i);
                if (attribute.isNumeric()) {
                    double number;
                    value = dataPoint.getValue(new NumericFeature(attribute.name()));
                    Double numberO = value.getNumber();
                    attValues[i] = number = numberO.doubleValue();
                }
                if (attribute.isString() || attribute.isNominal()) {
                    value = dataPoint.getValue(new StringFeature(attribute.name()));
                    attValues[i] = attribute.indexOfValue(value.getString());
                }
                ++i;
            }
            SparseInstance instance = new SparseInstance(dataPoint.getWeight(), attValues);
            instance.setDataset(instances);
            return instance;
        }
        Instance instance = WekaUtils.getNewInstance(instances, dataPoint.isSparse());
        Value classValue = dataPoint.getClassValue();
        Attribute classAttribute = instances.classAttribute();
        if (classAttribute.isNumeric()) {
            instance.setClassValue(classValue.getNumber().doubleValue());
        }
        if (classAttribute.isString() || classAttribute.isNominal()) {
            instance.setClassValue(classValue.getString());
        }
        for (Feature feature : dataPoint.getFeatures()) {
            Value value = dataPoint.getValue(feature);
            Attribute attribute = instances.attribute(feature.getName());
            if (attribute.isNumeric()) {
                Double numberO = value.getNumber();
                double number = numberO;
                instance.setValue(attribute, number);
            }
            if (!attribute.isString() && !attribute.isNominal()) continue;
            instance.setValue(attribute, value.getString());
        }
        instance.setWeight(dataPoint.getWeight());
        return instance;
    }

    @Override
    public void save(String file) {
        try {
            SerializationHelper.write((String)file, (Object)this._classifier);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public String getInfo() {
        return this._classifier.toString();
    }

    @Override
    public Classifier copy() {
        try {
            return new WekaClassifier(weka.classifiers.Classifier.makeCopy((weka.classifiers.Classifier)this._classifier));
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void learn(DataSet dataSet) {
        this.train(dataSet);
    }

    @Override
    public Value predict(DataPoint dataPoint) {
        return this.classify(dataPoint);
    }

    public void setFeatures(DataSet dataSet) {
        this._classFeature = dataSet.getClassFeature();
        this._features = dataSet.getFeatures();
    }

    public void setFeatures(DataPoint dataPoint) {
        this._classFeature = dataPoint.getClassFeature();
        this._features = dataPoint.getFeatures();
    }
}

