package edu.columbia.preju.weka;

import java.util.HashMap;
import java.util.Map;

import edu.columbia.preju.prediction.ClassifierPrediction;
import edu.columbia.preju.prediction.Prediction;
import edu.columbia.preju.prediction.classifier.ClassifierModelInfo;
import edu.columbia.preju.prediction.classifier.weka.WekaLogisticModelInfo;
import ob.core.CentralFactory;
import ob.core.Feature;
import ob.core.NumericFeature;
import ob.core.Value;
import weka.classifiers.functions.Logistic;
import weka.core.Instance;
import weka.core.Instances;

/**
 * @author Or
 *
 */
public class WekaPredictionFactory {

	public static Prediction getClassifierPrediction(String subjectDomainName, ClassifierModelInfo modelInfo, Instances trainingData, Instance classifiedInstance, double classification) {
		int classIndex = (int) classification;
		Value predictedClass = CentralFactory.getStringValue(trainingData.classAttribute().value(classIndex));

		Map<Feature, Value> featureValues = getFeatureValues(classifiedInstance, trainingData);

		return new ClassifierPrediction(subjectDomainName, classIndex, predictedClass, featureValues, modelInfo);
	}
	
	public static Prediction getLogisticPrediction(String subjectDomainName, Logistic logistic, Instances trainingData, Instance classifiedInstance, double classification) {
		ClassifierModelInfo modelInfo = new WekaLogisticModelInfo(logistic, trainingData);
		
		return getClassifierPrediction(subjectDomainName, modelInfo, trainingData, classifiedInstance, classification);
	}

	public static Prediction getLogisticPrediction(String subjectDomainName, Logistic logistic, Instances trainingData, Instance instanceToClassify) {
		double classification = classify(logistic, instanceToClassify);
		
		return getLogisticPrediction(subjectDomainName, logistic, trainingData, instanceToClassify, classification);
	}

	private static double classify(Logistic logistic, Instance toClassify) {
		try {
			return logistic.classifyInstance(toClassify);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

	private static Map<Feature, Value> getFeatureValues(Instance instance, Instances instances) {
		Map<Feature, Value> map = new HashMap<Feature, Value>();
		
		for (int i=0; i<instance.numAttributes(); i++) {
			if (i == instances.classIndex()) continue;
			
			String name = instances.attribute(i).name();
			double value = instance.value(i);
			map.put(new NumericFeature(name), CentralFactory.getNumericValue(value));
		}
		
		return map;
	}

}
