/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.JointClassification;
import com.aliasi.classify.LMClassifier;
import com.aliasi.corpus.ClassificationHandler;
import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.lm.LanguageModel;
import com.aliasi.lm.NGramBoundaryLM;
import com.aliasi.lm.NGramProcessLM;
import com.aliasi.lm.TokenizedLM;
import com.aliasi.stats.MultivariateDistribution;
import com.aliasi.stats.MultivariateEstimator;
import com.aliasi.tokenizer.TokenizerFactory;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.Factory;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class DynamicLMClassifier<L extends LanguageModel.Dynamic>
extends LMClassifier<L, MultivariateEstimator>
implements ClassificationHandler<CharSequence, Classification>,
Compilable {
    public DynamicLMClassifier(String[] categories, L[] languageModels) {
        super(categories, languageModels, (MultivariateDistribution)DynamicLMClassifier.createCategoryEstimator(categories));
    }

    public void train(String category, char[] cs, int start, int end) {
        this.train(category, new String(cs, start, end - start));
    }

    public void train(String category, CharSequence sampleCSeq) {
        this.train(category, sampleCSeq, 1);
    }

    public void train(String category, CharSequence sampleCSeq, int count) {
        if (count < 0) {
            String msg = "Counts must be non-negative. Found count=" + count;
            throw new IllegalArgumentException(msg);
        }
        if (count == 0) {
            return;
        }
        this.lmForCategory(category).train(sampleCSeq, count);
        this.categoryEstimator().train(category, count);
    }

    public static <L extends LanguageModel.Dynamic> DynamicLMClassifier<L> trainEm(Factory<DynamicLMClassifier<L>> classifierFactory, Corpus<ClassificationHandler<CharSequence, Classification>> labeledData, Corpus<ObjectHandler<CharSequence>> unlabeledData, int numEpochs, double trainingInstanceMultiple) throws IOException {
        DynamicLMClassifier<L> lastClassifier = classifierFactory.create();
        labeledData.visitCorpus(lastClassifier);
        for (int epoch = 0; epoch < numEpochs; ++epoch) {
            DynamicLMClassifier<L> classifier = classifierFactory.create();
            labeledData.visitCorpus(classifier);
            EmHandler emHandler = new EmHandler(classifier, lastClassifier, trainingInstanceMultiple);
            unlabeledData.visitCorpus(emHandler);
            lastClassifier = classifier;
        }
        return lastClassifier;
    }

    @Override
    public void handle(CharSequence charSequence, Classification classification) {
        this.train(classification.bestCategory(), charSequence);
    }

    public MultivariateEstimator categoryEstimator() {
        return (MultivariateEstimator)this.mCategoryDistribution;
    }

    public L lmForCategory(String category) {
        LanguageModel.Dynamic result = (LanguageModel.Dynamic)this.mCategoryToModel.get(category);
        if (result == null) {
            String msg = "Unknown category=" + category;
            throw new IllegalArgumentException(msg);
        }
        return (L)result;
    }

    @Override
    public void compileTo(ObjectOutput objOut) throws IOException {
        objOut.writeObject(new Externalizer(this));
    }

    public void resetCategory(String category, L lm, int newCount) {
        if (newCount < 0) {
            String msg = "Count must be non-negative. Found new count=" + newCount;
            throw new IllegalArgumentException(msg);
        }
        this.categoryEstimator().resetCount(category);
        this.categoryEstimator().train(category, newCount);
        L currentLM = this.lmForCategory(category);
        for (int i = 0; i < ((LanguageModel.Dynamic[])this.mLanguageModels).length; ++i) {
            if (currentLM != ((LanguageModel.Dynamic[])this.mLanguageModels)[i]) continue;
            ((LanguageModel.Dynamic[])this.mLanguageModels)[i] = lm;
            break;
        }
        this.mCategoryToModel.put(category, lm);
    }

    public static DynamicLMClassifier<NGramProcessLM> createNGramProcess(String[] categories, int maxCharNGram) {
        LanguageModel.Dynamic[] lms = new NGramProcessLM[categories.length];
        for (int i = 0; i < lms.length; ++i) {
            lms[i] = new NGramProcessLM(maxCharNGram);
        }
        return new DynamicLMClassifier(categories, lms);
    }

    public static DynamicLMClassifier<NGramBoundaryLM> createNGramBoundary(String[] categories, int maxCharNGram) {
        LanguageModel.Dynamic[] lms = new NGramBoundaryLM[categories.length];
        for (int i = 0; i < lms.length; ++i) {
            lms[i] = new NGramBoundaryLM(maxCharNGram);
        }
        return new DynamicLMClassifier(categories, lms);
    }

    public static DynamicLMClassifier<TokenizedLM> createTokenized(String[] categories, TokenizerFactory tokenizerFactory, int maxTokenNGram) {
        LanguageModel.Dynamic[] lms = new TokenizedLM[categories.length];
        for (int i = 0; i < lms.length; ++i) {
            lms[i] = new TokenizedLM(tokenizerFactory, maxTokenNGram);
        }
        return new DynamicLMClassifier(categories, lms);
    }

    static MultivariateEstimator createCategoryEstimator(String[] categories) {
        MultivariateEstimator estimator = new MultivariateEstimator();
        for (int i = 0; i < categories.length; ++i) {
            estimator.train(categories[i], 1L);
        }
        return estimator;
    }

    private static class Externalizer
    extends AbstractExternalizable {
        static final long serialVersionUID = -5411956637253735953L;
        final DynamicLMClassifier mClassifier;

        public Externalizer() {
            this.mClassifier = null;
        }

        public Externalizer(DynamicLMClassifier classifier) {
            this.mClassifier = classifier;
        }

        public void writeExternal(ObjectOutput objOut) throws IOException {
            objOut.writeObject(this.mClassifier.categories());
            this.mClassifier.categoryEstimator().compileTo(objOut);
            int numCategories = this.mClassifier.mCategories.length;
            for (int i = 0; i < numCategories; ++i) {
                ((LanguageModel.Dynamic)this.mClassifier.mLanguageModels[i]).compileTo(objOut);
            }
        }

        public Object read(ObjectInput objIn) throws ClassNotFoundException, IOException {
            String[] categories = (String[])objIn.readObject();
            MultivariateDistribution categoryEstimator = (MultivariateDistribution)objIn.readObject();
            LanguageModel[] models = new LanguageModel[categories.length];
            for (int i = 0; i < models.length; ++i) {
                models[i] = (LanguageModel)objIn.readObject();
            }
            return new LMClassifier(categories, models, categoryEstimator);
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private static class EmHandler
    implements ObjectHandler<CharSequence> {
        private final DynamicLMClassifier mClassifier;
        private final DynamicLMClassifier mLastClassifier;
        private final double mMultiple;

        EmHandler(DynamicLMClassifier classifier, DynamicLMClassifier lastClassifier, double multiple) {
            this.mClassifier = classifier;
            this.mLastClassifier = lastClassifier;
            this.mMultiple = multiple;
        }

        @Override
        public void handle(CharSequence cs) {
            JointClassification classification = this.mLastClassifier.classify(cs);
            for (int rank = 0; rank < classification.size(); ++rank) {
                String category = classification.category(rank);
                double pCatGivenCs = classification.conditionalProbability(rank);
                int count = (int)(pCatGivenCs * this.mMultiple);
                this.mClassifier.train(category, cs, count);
            }
        }
    }
}

