package edu.columbia.preju.generator;

import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.jdom2.Element;

import edu.columbia.preju.core.NarrativeRole;
import edu.columbia.preju.generator.htc.ChartGenerator;
import edu.columbia.preju.generator.htc.CurrentPredictionProperties;
import edu.columbia.preju.generator.htc.JFreeChartChartGenerator;
import edu.columbia.preju.prediction.Prediction;
import edu.columbia.preju.xml.XmlObjectFactory;
import ob.core.Feature;
import ob.util.Utils;

/**
 * @author Or
 *
 */
public class HtcJustificationGenerator implements JustificationGenerator, FileOutputter, XmlObjectFactory<JustificationGenerator> {

	private static final String CHART_FILE_NAME = "justification.jpg";
	private static final String HTML_FILE_NAME = "justification.html";
	
	private String _outputDir;
	
	private static final String FEATURE_COLOR = "52699E";
	
	private ChartGenerator _chartGenerator = new JFreeChartChartGenerator();
	private ParagraphGenerator _paragraphGenerator;
	
	public HtcJustificationGenerator() {
	}
	
	public HtcJustificationGenerator(String outputDir) {
		_outputDir = outputDir;
		_paragraphGenerator = new CoreSttParagraphGenerator();
	}

	@Override
	public void generate(JustificationNarrative narrative) {
		String chartFile = makeChartFile();
		makeChart(narrative, chartFile);

		String htmlFile = makeHtmlFile();
		makeHtml(chartFile, narrative, htmlFile);
	}

	
	private String makeHtmlFile() {
		return new File(_outputDir, HTML_FILE_NAME).getAbsolutePath();
	}

	private String makeChartFile() {
		return new File(_outputDir, CHART_FILE_NAME).getAbsolutePath();
	}


	private void makeChart(JustificationNarrative narrative, String chartFile) {

		Map<String, Double> featureEffectMap = getNamesAndEffects(narrative);

		_chartGenerator.generateBarChart(featureEffectMap, "Feature", "Effect", chartFile);
	}


	private void makeHtml(String chartFile, JustificationNarrative narrative, String htmlFile) {
		String body = htmlParagraph(getPredictionLine(narrative));
		
		body += line("<hr>");
		
		Collection<String> textParagraphs = _paragraphGenerator.getParagraphs(narrative);
		for (String paragraph : textParagraphs) {
			body += htmlParagraph(paragraph);
		}
		
		body += line("<hr>");
		
		body += htmlParagraph(getKeyFeatureList(narrative));

		String labels = "<center><table><tr><td>" +
						"<font color=black>&#x25CF; Total effect of all features on predicting the class</font><br>" +
						"<font color=blue>&#x25CF; Positive effect of a feature towards predicting the class</font><br>" +
						"<font color=red>&#x25CF; Negative effect of a feature against predicting the class</font><br>" +
						"</td></tr></table></center>";
		
		String html = 
				"<html>" +
				"<head><title>Prediction and justification for " + CurrentPredictionProperties.getPredictedObjectName() + "</title></head>" + 
				"<font face=arial size=+1 color=AAAAAA><table width=100% height=100%>" +
				"<tr>" +
				"<td>" + body + "</td>" +
				"<td><img src=" + new File(chartFile).getName() + "><br><hr>" + labels + "<hr></td>" +
				"</tr>" +
				"</table></font>" +
				"</html>";

		Utils.writeToFile(htmlFile, html);
	}

	private String getPredictionLine(JustificationNarrative narrative) {
		String predictedObjectName = CurrentPredictionProperties.getPredictedObjectName();
		String classColor = CurrentPredictionProperties.getClassColor();
		if (classColor == null) classColor = "green";
		return "The prediction" + (predictedObjectName == null ? "" : " for <b><font color=" + FEATURE_COLOR + ">" + predictedObjectName + "</font></b>") + ", given by " + narrative.getPrediction().getModelName() + ", is <b><font color=" + classColor + ">" + narrative.getPrediction().getPredictedClass() + "</font></b>.";
	}

	private String getKeyFeatureList(JustificationNarrative narrative) {
		String list = line("Key feature list:<br>");
		NarrativeRole lastRole = null;
		for (Feature feature : narrative.getKeyFeatures()) {
			NarrativeRole role = narrative.getNarrativeRole(feature);
			if (! role.equals(lastRole)) {
				list += line("<font color=white>-</font><br>");
			}
			list += line("- <font color=" + FEATURE_COLOR + ">" + Utils.ucFirst(feature.getName()) + "</font> (" + role + ")<br>");
			lastRole = role;
		}
		
		return list;
	}


	private String htmlParagraph(String paragraph) {
		return line("<p>" + paragraph + "</p>");
	}

	private String line(String string) {
		return string + "\n";
	}

	private static Map<String, Double> getNamesAndEffects(JustificationNarrative narrative) {
		final Map<String, Double> positives = new HashMap<String, Double>();
		double otherPositiveTotal = 0.0;
		int otherPositiveNum = 0;

		final Map<String, Double> negatives = new HashMap<String, Double>();
		double otherNegativeTotal = 0.0;
		int otherNegativeNum = 0;

		int zeroNum = 0;
		
		double totalEffect = 0.0;
		
		for (Feature feature : narrative.getFeatures()) {
			double effect = narrative.getEffect(feature).getDouble();
			
			if (Double.isNaN(effect)) continue;
			
			totalEffect += effect;
			
			if (effect == 0) {
				zeroNum++;
			}
			else if (narrative.isKeyFeature(feature)) {
				if (effect > 0) {
					positives.put(feature.getName(), effect);
				}
				else {
					negatives.put(feature.getName(), effect);
				}
			}
			else {
				if (effect > 0) {
					otherPositiveNum++;
					otherPositiveTotal += effect;
				}
				else {
					otherNegativeNum++;
					otherNegativeTotal += effect;
				}
			}
		}

		List<String> positiveNames = new ArrayList<String>(positives.keySet());
		Collections.sort(positiveNames, new Comparator<String>() {
			public int compare(String o1, String o2) {
				return - positives.get(o1).compareTo(positives.get(o2));
			}
		});
		List<Double> positiveEffects = new ArrayList<Double>();
		for (String name : positiveNames) positiveEffects.add(positives.get(name));

		List<String> negativeNames = new ArrayList<String>(negatives.keySet());
		Collections.sort(negativeNames, new Comparator<String>() {
			public int compare(String o1, String o2) {
				return negatives.get(o1).compareTo(negatives.get(o2));
			}
		});
		List<Double> negativeEffects = new ArrayList<Double>();
		for (String name : negativeNames) negativeEffects.add(negatives.get(name));

		
		// build the map of labels and values, in the order it is going to appear in the chart
		Map<String, Double> map = new LinkedHashMap<String, Double>();

		// first, the total effect
		map.put("total effect", totalEffect);
		
		// then, the positive key features
		for (int i=0; i<positiveNames.size(); i++) {
			map.put(positiveNames.get(i), positiveEffects.get(i));
		}

		// if there are positive non-key features, collapse them into one label and value
		if (otherPositiveNum > 0) {
			map.put(
					otherPositiveNum + (positiveNames.isEmpty() ? "" : " other") + " positive feature" + (otherPositiveNum==1 ? "" : "s"), 
					otherPositiveTotal
					);
		}


		// if there are zero-effect features, collapse them into one label and value
		if (zeroNum > 0) {
			map.put(
					zeroNum + " zero-effect feature" + (zeroNum==1 ? "" : "s"),
					0.0
					);
		}

		// then, the negative key features
		for (int i=0; i<negativeNames.size(); i++) {
			map.put(negativeNames.get(i), negativeEffects.get(i));
		}

		// finally, if there are negative non-key features, collapse them into one label and value
		if (otherNegativeNum > 0) {
			map.put(
					otherNegativeNum + (negativeNames.isEmpty() ? "" : " other") + " negative feature" + (otherNegativeNum==1 ? "" : "s"),
					otherNegativeTotal
					);
		}

		return map;
	}

	public String getOutputDir() {
		return _outputDir;
	}

	public void setOutputDir(String outputDir) {
		_outputDir = outputDir;
	}

	@Override
	public JustificationGenerator createFromXml(Element element, Prediction prediction) {
		String outputDir = element.getChildTextNormalize("outputDir");
		return new HtcJustificationGenerator(outputDir);
	}

}
