package edu.columbia.preju.xml;

import java.util.HashMap;
import java.util.Map;

import ob.core.CentralFactory;
import ob.core.Feature;
import ob.core.NumericFeature;
import ob.core.Value;

import org.jdom2.Document;
import org.jdom2.Element;
import org.jdom2.input.SAXBuilder;

import edu.columbia.preju.core.Effect;
import edu.columbia.preju.core.Importance;
import edu.columbia.preju.core.SimpleEffect;
import edu.columbia.preju.core.SimpleImportance;
import edu.columbia.preju.prediction.ClassifierPrediction;
import edu.columbia.preju.prediction.Prediction;
import edu.columbia.preju.prediction.classifier.ClassifierModelInfo;
import edu.columbia.preju.prediction.classifier.SimpleClassifierModelInfo;

public class XmlPredictionFactory {

	public static Prediction getClassifierPrediction(String predictionXmlFile) {
		Document doc = build(predictionXmlFile);
		
		Element root = doc.getRootElement();
		
		String subjectDomainName = root.getAttributeValue("domain");
		Element classification = root.getChild("classification");
		int classIndex = getClassIndex(classification);
		Value predictedClass = getPredictedClass(classification);
		Map<Feature, Value> featureValues = getFeatureValues(classification);
		
		Element model = root.getChild("model");
		ClassifierModelInfo modelInfo = getModelInfo(model);
		
		return new ClassifierPrediction(subjectDomainName, classIndex, predictedClass, featureValues, modelInfo);
	}

	private static ClassifierModelInfo getModelInfo(Element model) {
		String modelName = model.getAttributeValue("name");

		Map<Integer, Value> priorMap = getPriorMap(model);
		Map<Integer, Map<Feature, Importance>> importanceMap = getImportanceMap(model);
		Map<Integer, Map<Feature, Effect>> effectMap = getEffectMap(model);
		
		return new SimpleClassifierModelInfo(modelName, priorMap, importanceMap, effectMap);
	}

	private static Map<Integer, Map<Feature, Effect>> getEffectMap(Element model) {
		Element effects = model.getChild("effects");
		
		Map<Integer, Map<Feature, Effect>> effectMap = new HashMap<Integer, Map<Feature, Effect>>();

		for (Element effect : effects.getChildren("effect")) {
			Integer classIndex = Integer.parseInt(effect.getAttributeValue("classIndex"));
			String featureName = effect.getAttributeValue("feature");
			Double value = Double.valueOf(effect.getAttributeValue("value"));
			
			Map<Feature, Effect> classMap = effectMap.get(classIndex);
			if (classMap == null) {
				classMap = new HashMap<Feature, Effect>();
				effectMap.put(classIndex, classMap);
			}
			
			classMap.put(new NumericFeature(featureName), new SimpleEffect(value));
		}

		return effectMap;
	}

	private static Map<Integer, Map<Feature, Importance>> getImportanceMap(Element model) {
		Element importances = model.getChild("importances");
		
		Map<Integer, Map<Feature, Importance>> importanceMap = new HashMap<Integer, Map<Feature, Importance>>();

		for (Element importance : importances.getChildren("importance")) {
			Integer classIndex = Integer.parseInt(importance.getAttributeValue("classIndex"));
			String featureName = importance.getAttributeValue("feature");
			Double value = Double.valueOf(importance.getAttributeValue("value"));
			
			Map<Feature, Importance> classMap = importanceMap.get(classIndex);
			if (classMap == null) {
				classMap = new HashMap<Feature, Importance>();
				importanceMap.put(classIndex, classMap);
			}
			
			classMap.put(new NumericFeature(featureName), new SimpleImportance(value));
		}

		return importanceMap;
	}

	private static Map<Integer, Value> getPriorMap(Element model) {
		Element priors = model.getChild("priors");
		
		Map<Integer, Value> priorMap = new HashMap<Integer, Value>();
		
		for (Element prior : priors.getChildren("prior")) {
			Integer classIndex = Integer.parseInt(prior.getAttributeValue("classIndex"));
			Double value = Double.valueOf(prior.getAttributeValue("value"));
			priorMap.put(classIndex, CentralFactory.getNumericValue(value));
		}
		
		return priorMap;
	}

	private static Map<Feature, Value> getFeatureValues(Element classification) {
		Map<Feature, Value> featureValues = new HashMap<Feature, Value>();
		for (Element feature : classification.getChildren("feature")) {
			String name = feature.getAttributeValue("name");
			Double value = Double.valueOf(feature.getAttributeValue("value"));
			featureValues.put(new NumericFeature(name), CentralFactory.getNumericValue(value));
		}
		return featureValues;
	}

	private static Value getPredictedClass(Element classification) {
		String className = classification.getAttributeValue("className");
		return CentralFactory.getStringValue(className);
	}

	private static int getClassIndex(Element classification) {
		return Integer.parseInt(classification.getAttributeValue("classIndex"));
	}

	private static Document build(String predictionXmlFile) {
		try {
			return new SAXBuilder().build(predictionXmlFile);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

}
