/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classifier;
import com.aliasi.classify.ConditionalClassification;
import com.aliasi.classify.PerceptronClassifier;
import com.aliasi.corpus.ClassificationHandler;
import com.aliasi.corpus.Corpus;
import com.aliasi.matrix.SparseFloatVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.LogisticRegression;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.symbol.MapSymbolTable;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.ObjectToDoubleMap;
import com.aliasi.util.Scored;
import com.aliasi.util.ScoredObject;
import java.io.CharArrayWriter;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class LogisticRegressionClassifier<E>
implements Classifier<E, ConditionalClassification>,
Compilable,
Serializable {
    private final LogisticRegression mModel;
    private final FeatureExtractor<? super E> mFeatureExtractor;
    private final boolean mAddInterceptFeature;
    private final SymbolTable mFeatureSymbolTable;
    private final String[] mCategorySymbols;
    static final String INTERCEPT_FEATURE_NAME = "*&^INTERCEPT%$^&**";

    LogisticRegressionClassifier(LogisticRegression model, FeatureExtractor<? super E> featureExtractor, boolean addInterceptFeature, SymbolTable featureSymbolTable, String[] categorySymbols) {
        if (model.numOutcomes() != categorySymbols.length) {
            String msg = "Number of model outcomes must match category symbols length. Found model.numOutcomes()=" + model.numOutcomes() + " categorySymbols.length=" + categorySymbols.length;
            throw new IllegalArgumentException(msg);
        }
        HashSet<String> categorySymbolSet = new HashSet<String>();
        for (int i = 0; i < categorySymbols.length; ++i) {
            if (categorySymbolSet.add(categorySymbols[i])) continue;
            String msg = "Categories must be unique. Found duplicate category categorySymbols[" + i + "]=" + categorySymbols[i];
            throw new IllegalArgumentException(msg);
        }
        this.mModel = model;
        this.mFeatureExtractor = featureExtractor;
        this.mAddInterceptFeature = addInterceptFeature;
        this.mFeatureSymbolTable = featureSymbolTable;
        this.mCategorySymbols = categorySymbols;
    }

    public SymbolTable featureSymbolTable() {
        return MapSymbolTable.unmodifiableView(this.mFeatureSymbolTable);
    }

    public List<String> categorySymbols() {
        return Arrays.asList(this.mCategorySymbols);
    }

    @Override
    public ConditionalClassification classify(E in) {
        Map<String, Number> featureMap = this.mFeatureExtractor.features(in);
        SparseFloatVector vector = PerceptronClassifier.toVector(featureMap, this.mFeatureSymbolTable, this.mFeatureSymbolTable.numSymbols(), this.mAddInterceptFeature);
        double[] conditionalProbs = this.mModel.classify(vector);
        ScoredObject[] sos = new ScoredObject[conditionalProbs.length];
        for (int i = 0; i < conditionalProbs.length; ++i) {
            sos[i] = new ScoredObject<String>(this.mCategorySymbols[i], conditionalProbs[i]);
        }
        Arrays.sort(sos, Scored.REVERSE_SCORE_COMPARATOR);
        String[] categories = new String[conditionalProbs.length];
        for (int i = 0; i < conditionalProbs.length; ++i) {
            categories[i] = sos[i].getObject().toString();
            conditionalProbs[i] = sos[i].score();
        }
        return new ConditionalClassification(categories, conditionalProbs);
    }

    @Override
    public void compileTo(ObjectOutput objOut) throws IOException {
        objOut.writeObject(new Externalizer(this));
    }

    private int categoryToId(String category) {
        for (int i = 0; i < this.mCategorySymbols.length; ++i) {
            if (!this.mCategorySymbols[i].equals(category)) continue;
            return i;
        }
        return -1;
    }

    public ObjectToDoubleMap<String> featureValues(String category) {
        int categoryId = this.categoryToId(category);
        if (categoryId < 0) {
            String msg = "Unknown category=" + category;
            throw new IllegalArgumentException(msg);
        }
        ObjectToDoubleMap<String> result = new ObjectToDoubleMap<String>();
        if (categoryId == this.mCategorySymbols.length - 1) {
            return result;
        }
        int numSymbols = this.mFeatureSymbolTable.numSymbols();
        Vector[] weightVectors = this.mModel.weightVectors();
        Vector weightVector = weightVectors[categoryId];
        for (int i = 0; i < numSymbols; ++i) {
            String symbol = this.mFeatureSymbolTable.idToSymbol(i);
            result.set(symbol, weightVector.value(i));
        }
        return result;
    }

    public String toString() {
        CharArrayWriter writer = new CharArrayWriter();
        PrintWriter printWriter = new PrintWriter(writer);
        List<String> categorySymbols = this.categorySymbols();
        printWriter.println("NUMBER OF CATEGORIES=" + categorySymbols.size());
        printWriter.println("NUMBER OF FEATURES=" + this.mFeatureSymbolTable.numSymbols());
        for (int i = 0; i < categorySymbols.size() - 1; ++i) {
            String category = categorySymbols.get(i);
            printWriter.println("\n  CATEGORY=" + category);
            ObjectToDoubleMap<String> parameterVector = this.featureValues(category);
            for (String feature : parameterVector.keysOrderedByValueList()) {
                printWriter.printf("%20s %15.6f\n", feature, parameterVector.get(feature));
            }
        }
        printWriter.write(10);
        return writer.toString();
    }

    private Object writeReplace() {
        return new Externalizer(this);
    }

    public static <F> LogisticRegressionClassifier<F> train(FeatureExtractor<? super F> featureExtractor, Corpus<ClassificationHandler<F, Classification>> corpus, int minFeatureCount, boolean addInterceptFeature, RegressionPrior prior, AnnealingSchedule annealingSchedule, double minImprovement, int minEpochs, int maxEpochs, PrintWriter progressWriter) throws IOException {
        MapSymbolTable featureSymbolTable = new MapSymbolTable();
        MapSymbolTable categorySymbolTable = new MapSymbolTable();
        if (addInterceptFeature) {
            featureSymbolTable.getOrAddSymbol(INTERCEPT_FEATURE_NAME);
        }
        ObjectToCounterMap<String> featureCounter = new ObjectToCounterMap<String>();
        corpus.visitTrain(new FeatureCounter<F>(featureExtractor, featureCounter));
        featureCounter.prune(minFeatureCount);
        for (String feature : featureCounter.keySet()) {
            featureSymbolTable.getOrAddSymbol(feature);
        }
        DataExtractor<F> dataExtractor = new DataExtractor<F>(featureExtractor, featureSymbolTable, categorySymbolTable, addInterceptFeature, featureSymbolTable.numSymbols());
        corpus.visitTrain(dataExtractor);
        Vector[] inputs = dataExtractor.inputs();
        int[] categories = dataExtractor.categories();
        LogisticRegression model = LogisticRegression.estimate(inputs, categories, prior, annealingSchedule, minImprovement, minEpochs, maxEpochs, progressWriter);
        String[] categorySymbols = new String[categorySymbolTable.numSymbols()];
        for (int i = 0; i < categorySymbols.length; ++i) {
            categorySymbols[i] = categorySymbolTable.idToSymbol(i);
        }
        return new LogisticRegressionClassifier<F>(model, featureExtractor, addInterceptFeature, featureSymbolTable, categorySymbols);
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class DataExtractor<F>
    implements ClassificationHandler<F, Classification> {
        final FeatureExtractor<? super F> mFeatureExtractor;
        final SymbolTable mFeatureSymbolTable;
        final SymbolTable mCategorySymbolTable;
        final boolean mAddInterceptFeature;
        final int mNumSymbols;
        final List<Vector> mInputVectorList = new ArrayList<Vector>();
        final List<Integer> mOutputCategoryList = new ArrayList<Integer>();

        DataExtractor(FeatureExtractor<? super F> featureExtractor, SymbolTable featureSymbolTable, SymbolTable categorySymbolTable, boolean addInterceptFeature, int numSymbols) {
            this.mFeatureExtractor = featureExtractor;
            this.mFeatureSymbolTable = featureSymbolTable;
            this.mCategorySymbolTable = categorySymbolTable;
            this.mAddInterceptFeature = addInterceptFeature;
            this.mNumSymbols = numSymbols;
        }

        @Override
        public void handle(F input, Classification output) {
            String outputCategoryName = output.bestCategory();
            Integer outputCategoryId = this.mCategorySymbolTable.getOrAddSymbol(outputCategoryName);
            Map<String, Number> featureMap = this.mFeatureExtractor.features(input);
            SparseFloatVector vector = PerceptronClassifier.toVector(featureMap, this.mFeatureSymbolTable, this.mNumSymbols, this.mAddInterceptFeature);
            this.mInputVectorList.add(vector);
            this.mOutputCategoryList.add(outputCategoryId);
        }

        int[] categories() {
            int[] inputs = new int[this.mOutputCategoryList.size()];
            for (int i = 0; i < inputs.length; ++i) {
                inputs[i] = this.mOutputCategoryList.get(i);
            }
            return inputs;
        }

        Vector[] inputs() {
            return this.mInputVectorList.toArray(new Vector[this.mInputVectorList.size()]);
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class Externalizer<G>
    extends AbstractExternalizable {
        static final long serialVersionUID = -2003123148721825458L;
        final LogisticRegressionClassifier mClassifier;

        public Externalizer() {
            this(null);
        }

        public Externalizer(LogisticRegressionClassifier classifier) {
            this.mClassifier = classifier;
        }

        @Override
        public void writeExternal(ObjectOutput objOut) throws IOException {
            objOut.writeObject(this.mClassifier.mModel);
            objOut.writeObject(this.mClassifier.mFeatureExtractor);
            objOut.writeBoolean(this.mClassifier.mAddInterceptFeature);
            objOut.writeObject(this.mClassifier.mFeatureSymbolTable);
            objOut.writeInt(this.mClassifier.mCategorySymbols.length);
            for (int i = 0; i < this.mClassifier.mCategorySymbols.length; ++i) {
                objOut.writeUTF(this.mClassifier.mCategorySymbols[i]);
            }
        }

        @Override
        public Object read(ObjectInput objIn) throws IOException, ClassNotFoundException {
            LogisticRegression model = (LogisticRegression)objIn.readObject();
            FeatureExtractor featureExtractor = (FeatureExtractor)objIn.readObject();
            boolean addInterceptFeature = objIn.readBoolean();
            SymbolTable featureSymbolTable = (SymbolTable)objIn.readObject();
            int numSymbols = objIn.readInt();
            String[] categorySymbols = new String[numSymbols];
            for (int i = 0; i < categorySymbols.length; ++i) {
                categorySymbols[i] = objIn.readUTF();
            }
            return new LogisticRegressionClassifier(model, featureExtractor, addInterceptFeature, featureSymbolTable, categorySymbols);
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class FeatureCounter<H>
    implements ClassificationHandler<H, Classification> {
        private final FeatureExtractor<? super H> mFeatureExtractor;
        private final ObjectToCounterMap<String> mFeatureCounter;

        FeatureCounter(FeatureExtractor<? super H> featureExtractor, ObjectToCounterMap<String> featureCounter) {
            this.mFeatureExtractor = featureExtractor;
            this.mFeatureCounter = featureCounter;
        }

        @Override
        public void handle(H h, Classification c) {
            Map<String, Number> featureMap = this.mFeatureExtractor.features(h);
            for (String feature : featureMap.keySet()) {
                this.mFeatureCounter.increment(feature);
            }
        }
    }
}

