package edu.columbia.preju.prediction.classifier.weka;

import edu.columbia.preju.prediction.classifier.AbstractLinearClassifierModelInfo;
import edu.columbia.preju.prediction.classifier.ClassInfo;
import ob.core.Feature;
import ob.core.NumericFeature;
import weka.classifiers.functions.Logistic;
import weka.core.Instances;

public class WekaLogisticModelInfo extends AbstractLinearClassifierModelInfo {
	
	public WekaLogisticModelInfo(Logistic logistic, Instances train) {
		for (int j=0; j<train.numClasses(); j++) {
			ClassInfo classInfo = new ClassInfo();
			
			int coefficientIndex = 0;
			for (int i=0; i<train.numAttributes(); i++) {
				if (i == train.classIndex()) continue;

				Feature feature = new NumericFeature(train.attribute(i).name());
				
				double weight = getWeight(logistic, coefficientIndex, j);

				classInfo.setWeight(feature, weight);

				double mean = computeMean(train, i, j);
				classInfo.setMean(feature, mean);
				
				coefficientIndex++;
			}
			
			double intercept = getIntercept(logistic, j);
			classInfo.setPrior(intercept);

			super.addClassInfo(j, classInfo);
		}
	}

	private double getIntercept(Logistic logistic, int classIndex) {
		if (classIndex == logistic.coefficients()[0].length) {
			double intercept = 0;
			for (int i=0; i<classIndex; i++) {
				intercept -= logistic.coefficients()[0][i];
			}
			return intercept;
		}
		return logistic.coefficients()[0][classIndex];
	}

	private double getWeight(Logistic logistic, int coefficientIndex, int classIndex) {
		if (classIndex == logistic.coefficients()[coefficientIndex+1].length) {
			double weight = 0;
			for (int i=0; i<classIndex; i++) {
				weight -= logistic.coefficients()[coefficientIndex+1][i];
			}
			return weight;
		}
		return logistic.coefficients()[coefficientIndex+1][classIndex];
	}

	private double computeMean(Instances train, int featureIndex, int classIndex) {
		double mean = 0;
		double count = 0;
		
		for (int i=0; i<train.numInstances(); i++) {
			double value = train.instance(i).value(featureIndex);
			
			if (! Double.isNaN(value) && train.instance(i).classValue() == classIndex) {
				mean += value;
				count++;
			}
		}
		
		return mean / count;
	}

	@Override
	public String getModelName() {
		return "Logistic Regression";
	}

}
