/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.cluster;

import com.aliasi.corpus.ObjectHandler;
import com.aliasi.stats.Statistics;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.tokenizer.Tokenizer;
import com.aliasi.tokenizer.TokenizerFactory;
import com.aliasi.util.Math;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.Strings;
import java.util.ArrayList;
import java.util.Random;
import java.util.Set;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class LatentDirichletAllocation {
    private final double mDocTopicPrior;
    private final double[][] mTopicWordProbs;

    public LatentDirichletAllocation(double docTopicPrior, double[][] topicWordProbs) {
        int topic;
        if (docTopicPrior <= 0.0 || Double.isNaN(docTopicPrior) || Double.isInfinite(docTopicPrior)) {
            String msg = "Document-topic prior must be finite and positive. Found docTopicPrior=" + docTopicPrior;
            throw new IllegalArgumentException(msg);
        }
        int numTopics = topicWordProbs.length;
        if (numTopics < 1) {
            String msg = "Require non-empty topic-word probabilities.";
            throw new IllegalArgumentException(msg);
        }
        int numWords = topicWordProbs[0].length;
        for (topic = 1; topic < numTopics; ++topic) {
            if (topicWordProbs[topic].length == numWords) continue;
            String msg = "All topics must have the same number of words. topicWordProbs[0].length=" + topicWordProbs[0].length + " topicWordProbs[" + topic + "]=" + topicWordProbs[topic].length;
            throw new IllegalArgumentException(msg);
        }
        for (topic = 0; topic < numTopics; ++topic) {
            for (int word = 0; word < numWords; ++word) {
                if (!(topicWordProbs[topic][word] < 0.0) && !(topicWordProbs[topic][word] > 1.0)) continue;
                String msg = "All probabilities must be between 0.0 and 1.0 Found topicWordProbs[" + topic + "][" + word + "]=" + topicWordProbs[topic][word];
                throw new IllegalArgumentException(msg);
            }
        }
        this.mDocTopicPrior = docTopicPrior;
        this.mTopicWordProbs = topicWordProbs;
    }

    public int numTopics() {
        return this.mTopicWordProbs.length;
    }

    public int numWords() {
        return this.mTopicWordProbs[0].length;
    }

    public double documentTopicPrior() {
        return this.mDocTopicPrior;
    }

    public double wordProbability(int topic, int word) {
        return this.mTopicWordProbs[topic][word];
    }

    public double[] wordProbabilities(int topic) {
        double[] xs = new double[this.mTopicWordProbs[topic].length];
        for (int i = 0; i < xs.length; ++i) {
            xs[i] = this.mTopicWordProbs[topic][i];
        }
        return xs;
    }

    public short[][] sampleTopics(int[] tokens, int numSamples, int burnin, int sampleLag, Random random) {
        if (burnin < 0) {
            String msg = "Burnin period must be non-negative. Found burnin=" + burnin;
            throw new IllegalArgumentException(msg);
        }
        if (numSamples < 1) {
            String msg = "Number of samples must be at least 1. Found numSamples=" + numSamples;
            throw new IllegalArgumentException(msg);
        }
        if (sampleLag < 1) {
            String msg = "Sample lag must be at least 1. Found sampleLag=" + sampleLag;
            throw new IllegalArgumentException(msg);
        }
        double docTopicPrior = this.documentTopicPrior();
        int numTokens = tokens.length;
        int numTopics = this.numTopics();
        int[] topicCount = new int[numTopics];
        short[][] samples = new short[numSamples][numTokens];
        int sample = 0;
        short[] currentSample = samples[0];
        for (int token = 0; token < numTokens; ++token) {
            int randomTopic;
            int n = randomTopic = random.nextInt(numTopics);
            topicCount[n] = topicCount[n] + 1;
            currentSample[token] = (short)randomTopic;
        }
        double[] topicDistro = new double[numTopics];
        int numEpochs = burnin + sampleLag * (numSamples - 1);
        for (int epoch = 0; epoch < numEpochs; ++epoch) {
            for (int token = 0; token < numTokens; ++token) {
                int sampledTopic;
                short currentTopic;
                int word = tokens[token];
                short s = currentTopic = currentSample[token];
                topicCount[s] = topicCount[s] - 1;
                if (topicCount[currentTopic] < 0) {
                    throw new IllegalArgumentException("bomb");
                }
                for (int topic = 0; topic < numTopics; ++topic) {
                    topicDistro[topic] = ((double)topicCount[topic] + docTopicPrior) * this.wordProbability(topic, word) + (topic == 0 ? 0.0 : topicDistro[topic - 1]);
                }
                int n = sampledTopic = Statistics.sample(topicDistro, random);
                topicCount[n] = topicCount[n] + 1;
                currentSample[token] = (short)sampledTopic;
            }
            if (epoch < burnin || (epoch - burnin) % sampleLag != 0) continue;
            short[] pastSample = currentSample;
            currentSample = samples[++sample];
            for (int token = 0; token < numTokens; ++token) {
                currentSample[token] = pastSample[token];
            }
        }
        return samples;
    }

    public double[] mapTopicEstimate(int[] tokens, int numSamples, int burnin, int sampleLag, Random random) {
        short[][] sampleTopics = this.sampleTopics(tokens, numSamples, burnin, sampleLag, random);
        int numTopics = this.numTopics();
        int[] counts = new int[numTopics];
        for (short[] topics : sampleTopics) {
            for (int tok = 0; tok < topics.length; ++tok) {
                short s = topics[tok];
                counts[s] = counts[s] + 1;
            }
        }
        double totalCount = 0.0;
        for (int topic = 0; topic < numTopics; ++topic) {
            totalCount += (double)counts[topic];
        }
        double[] result = new double[numTopics];
        for (int topic = 0; topic < numTopics; ++topic) {
            result[topic] = (double)counts[topic] / totalCount;
        }
        return result;
    }

    public static GibbsSample gibbsSampler(int[][] docWords, short numTopics, double docTopicPrior, double topicWordPrior, int burninEpochs, int sampleLag, int numSamples, Random random, ObjectHandler<GibbsSample> handler) {
        LatentDirichletAllocation.validateInputs(docWords, numTopics, docTopicPrior, topicWordPrior, burninEpochs, sampleLag, numSamples);
        int numDocs = docWords.length;
        int numWords = LatentDirichletAllocation.max(docWords) + 1;
        int numTokens = 0;
        for (int doc = 0; doc < numDocs; ++doc) {
            numTokens += docWords[doc].length;
        }
        short[][] currentSample = new short[numDocs][];
        for (int doc = 0; doc < numDocs; ++doc) {
            currentSample[doc] = new short[docWords[doc].length];
        }
        int[][] docTopicCount = new int[numDocs][numTopics];
        int[][] wordTopicCount = new int[numWords][numTopics];
        int[] topicTotalCount = new int[numTopics];
        for (int doc = 0; doc < numDocs; ++doc) {
            for (int tok = 0; tok < docWords[doc].length; ++tok) {
                int word = docWords[doc][tok];
                int topic = random.nextInt(numTopics);
                currentSample[doc][tok] = (short)topic;
                int[] nArray = docTopicCount[doc];
                int n = topic;
                nArray[n] = nArray[n] + 1;
                int[] nArray2 = wordTopicCount[word];
                int n2 = topic;
                nArray2[n2] = nArray2[n2] + 1;
                int n3 = topic;
                topicTotalCount[n3] = topicTotalCount[n3] + 1;
            }
        }
        double numWordsTimesTopicWordPrior = (double)numWords * topicWordPrior;
        double[] topicDistro = new double[numTopics];
        long startTime = System.currentTimeMillis();
        int numEpochs = burninEpochs + sampleLag * (numSamples - 1);
        for (int epoch = 0; epoch <= numEpochs; ++epoch) {
            double corpusLog2Prob = 0.0;
            int numChangedTopics = 0;
            for (int doc = 0; doc < numDocs; ++doc) {
                int[] docWordsDoc = docWords[doc];
                short[] currentSampleDoc = currentSample[doc];
                int[] docTopicCountDoc = docTopicCount[doc];
                for (int tok = 0; tok < docWordsDoc.length; ++tok) {
                    int topic;
                    int word = docWordsDoc[tok];
                    int[] wordTopicCountWord = wordTopicCount[word];
                    int currentTopic = currentSampleDoc[tok];
                    if (currentTopic == 0) {
                        topicDistro[0] = ((double)docTopicCountDoc[0] - 1.0 + docTopicPrior) * ((double)wordTopicCountWord[0] - 1.0 + topicWordPrior) / ((double)topicTotalCount[0] - 1.0 + numWordsTimesTopicWordPrior);
                    } else {
                        topicDistro[0] = ((double)docTopicCountDoc[0] + docTopicPrior) * ((double)wordTopicCountWord[0] + topicWordPrior) / ((double)topicTotalCount[0] + numWordsTimesTopicWordPrior);
                        for (topic = 1; topic < currentTopic; ++topic) {
                            topicDistro[topic] = ((double)docTopicCountDoc[topic] + docTopicPrior) * ((double)wordTopicCountWord[topic] + topicWordPrior) / ((double)topicTotalCount[topic] + numWordsTimesTopicWordPrior) + topicDistro[topic - 1];
                        }
                        topicDistro[currentTopic] = ((double)docTopicCountDoc[currentTopic] - 1.0 + docTopicPrior) * ((double)wordTopicCountWord[currentTopic] - 1.0 + topicWordPrior) / ((double)topicTotalCount[currentTopic] - 1.0 + numWordsTimesTopicWordPrior) + topicDistro[currentTopic - 1];
                    }
                    for (topic = currentTopic + 1; topic < numTopics; ++topic) {
                        topicDistro[topic] = ((double)docTopicCountDoc[topic] + docTopicPrior) * ((double)wordTopicCountWord[topic] + topicWordPrior) / ((double)topicTotalCount[topic] + numWordsTimesTopicWordPrior) + topicDistro[topic - 1];
                    }
                    int sampledTopic = Statistics.sample(topicDistro, random);
                    if (sampledTopic != currentTopic) {
                        currentSampleDoc[tok] = (short)sampledTopic;
                        int n = currentTopic;
                        docTopicCountDoc[n] = docTopicCountDoc[n] - 1;
                        int n4 = currentTopic;
                        wordTopicCountWord[n4] = wordTopicCountWord[n4] - 1;
                        int n5 = currentTopic;
                        topicTotalCount[n5] = topicTotalCount[n5] - 1;
                        int n6 = sampledTopic;
                        docTopicCountDoc[n6] = docTopicCountDoc[n6] + 1;
                        int n7 = sampledTopic;
                        wordTopicCountWord[n7] = wordTopicCountWord[n7] + 1;
                        int n8 = sampledTopic;
                        topicTotalCount[n8] = topicTotalCount[n8] + 1;
                    }
                    if (sampledTopic != currentTopic) {
                        ++numChangedTopics;
                    }
                    double topicProbGivenDoc = (double)docTopicCountDoc[sampledTopic] / (double)docWordsDoc.length;
                    double wordProbGivenTopic = (double)wordTopicCountWord[sampledTopic] / (double)topicTotalCount[sampledTopic];
                    double tokenLog2Prob = Math.log2(topicProbGivenDoc * wordProbGivenTopic);
                    corpusLog2Prob += tokenLog2Prob;
                }
            }
            if (epoch < burninEpochs || (epoch - burninEpochs) % sampleLag != 0) continue;
            GibbsSample sample = new GibbsSample(epoch, currentSample, docWords, docTopicPrior, topicWordPrior, docTopicCount, wordTopicCount, topicTotalCount, numChangedTopics, numWords, numTokens);
            if (handler != null) {
                handler.handle(sample);
            }
            if (epoch != numEpochs) continue;
            return sample;
        }
        throw new IllegalStateException("unreachable in practice because of return if epoch==numEpochs");
    }

    public static int[][] tokenizeDocuments(CharSequence[] texts, TokenizerFactory tokenizerFactory, SymbolTable symbolTable, int minCount) {
        ObjectToCounterMap<String> tokenCounter = new ObjectToCounterMap<String>();
        for (CharSequence text : texts) {
            char[] cs = Strings.toCharArray(text);
            Tokenizer tokenizer = tokenizerFactory.tokenizer(cs, 0, cs.length);
            for (String token : tokenizer) {
                tokenCounter.increment(token);
            }
        }
        tokenCounter.prune(minCount);
        Set tokenSet = tokenCounter.keySet();
        for (String token : tokenSet) {
            symbolTable.getOrAddSymbol(token);
        }
        int[][] docTokenId = new int[texts.length][];
        for (int i = 0; i < docTokenId.length; ++i) {
            docTokenId[i] = LatentDirichletAllocation.tokenizeDocument(texts[i], tokenizerFactory, symbolTable);
        }
        return docTokenId;
    }

    public static int[] tokenizeDocument(CharSequence text, TokenizerFactory tokenizerFactory, SymbolTable symbolTable) {
        char[] cs = Strings.toCharArray(text);
        Tokenizer tokenizer = tokenizerFactory.tokenizer(cs, 0, cs.length);
        ArrayList<Integer> idList = new ArrayList<Integer>();
        for (String token : tokenizer) {
            int id = symbolTable.symbolToID(token);
            if (id < 0) continue;
            idList.add(id);
        }
        int[] tokenIds = new int[idList.size()];
        for (int i = 0; i < tokenIds.length; ++i) {
            tokenIds[i] = (Integer)idList.get(i);
        }
        return tokenIds;
    }

    static int max(int[][] xs) {
        int max = 0;
        for (int i = 0; i < xs.length; ++i) {
            int[] xsI = xs[i];
            for (int j = 0; j < xsI.length; ++j) {
                if (xsI[j] <= max) continue;
                max = xsI[j];
            }
        }
        return max;
    }

    static double relativeDifference(double x, double y) {
        return java.lang.Math.abs(x - y) / (java.lang.Math.abs(x) + java.lang.Math.abs(y));
    }

    static void validateInputs(int[][] docWords, short numTopics, double docTopicPrior, double topicWordPrior, int burninEpochs, int sampleLag, int numSamples) {
        for (int doc = 0; doc < docWords.length; ++doc) {
            for (int tok = 0; tok < docWords[doc].length; ++tok) {
                if (docWords[doc][tok] >= 0) continue;
                String msg = "All tokens must have IDs greater than 0. Found docWords[" + doc + "][" + tok + "]=" + docWords[doc][tok];
                throw new IllegalArgumentException(msg);
            }
        }
        if (numTopics < 1) {
            String msg = "Num topics must be positive. Found numTopics=" + numTopics;
            throw new IllegalArgumentException(msg);
        }
        if (Double.isInfinite(docTopicPrior) || Double.isNaN(docTopicPrior) || docTopicPrior < 0.0) {
            String msg = "Document-topic prior must be finite and positive. Found docTopicPrior=" + docTopicPrior;
            throw new IllegalArgumentException(msg);
        }
        if (Double.isInfinite(topicWordPrior) || Double.isNaN(topicWordPrior) || topicWordPrior < 0.0) {
            String msg = "Topic-word prior must be finite and positive. Found topicWordPrior=" + topicWordPrior;
            throw new IllegalArgumentException(msg);
        }
        if (burninEpochs < 0) {
            String msg = "Number of burnin epochs must be non-negative. Found burninEpochs=" + burninEpochs;
            throw new IllegalArgumentException(msg);
        }
        if (sampleLag < 1) {
            String msg = "Sample lag must be positive. Found sampleLag=" + sampleLag;
            throw new IllegalArgumentException(msg);
        }
        if (numSamples < 1) {
            String msg = "Number of samples must be positive. Found numSamples=" + numSamples;
            throw new IllegalArgumentException(msg);
        }
    }

    public static class GibbsSample {
        private final int mEpoch;
        private final short[][] mTopicSample;
        private final int[][] mDocWords;
        private final double mDocTopicPrior;
        private final double mTopicWordPrior;
        private final int[][] mDocTopicCount;
        private final int[][] mWordTopicCount;
        private final int[] mTopicCount;
        private final int mNumChangedTopics;
        private final int mNumWords;
        private final int mNumTokens;

        GibbsSample(int epoch, short[][] topicSample, int[][] docWords, double docTopicPrior, double topicWordPrior, int[][] docTopicCount, int[][] wordTopicCount, int[] topicCount, int numChangedTopics, int numWords, int numTokens) {
            this.mEpoch = epoch;
            this.mTopicSample = topicSample;
            this.mDocWords = docWords;
            this.mDocTopicPrior = docTopicPrior;
            this.mTopicWordPrior = topicWordPrior;
            this.mDocTopicCount = docTopicCount;
            this.mWordTopicCount = wordTopicCount;
            this.mTopicCount = topicCount;
            this.mNumChangedTopics = numChangedTopics;
            this.mNumWords = numWords;
            this.mNumTokens = numTokens;
        }

        public int epoch() {
            return this.mEpoch;
        }

        public int numDocuments() {
            return this.mDocWords.length;
        }

        public int numWords() {
            return this.mNumWords;
        }

        public int numTokens() {
            return this.mNumTokens;
        }

        public int numTopics() {
            return this.mTopicCount.length;
        }

        public short topicSample(int doc, int token) {
            return this.mTopicSample[doc][token];
        }

        public int word(int doc, int token) {
            return this.mDocWords[doc][token];
        }

        public double documentTopicPrior() {
            return this.mDocTopicPrior;
        }

        public double topicWordPrior() {
            return this.mTopicWordPrior;
        }

        public int documentTopicCount(int doc, int topic) {
            return this.mDocTopicCount[doc][topic];
        }

        public int documentLength(int doc) {
            return this.mDocWords[doc].length;
        }

        public int topicWordCount(int topic, int word) {
            return this.mWordTopicCount[word][topic];
        }

        public int topicCount(int topic) {
            return this.mTopicCount[topic];
        }

        public int numChangedTopics() {
            return this.mNumChangedTopics;
        }

        public double topicWordProb(int topic, int word) {
            return ((double)this.topicWordCount(topic, word) + this.topicWordPrior()) / ((double)this.topicCount(topic) + (double)this.numWords() * this.topicWordPrior());
        }

        public int wordCount(int word) {
            int count = 0;
            for (int topic = 0; topic < this.numTopics(); ++topic) {
                count += this.topicWordCount(topic, word);
            }
            return count;
        }

        public double documentTopicProb(int doc, int topic) {
            return ((double)this.documentTopicCount(doc, topic) + this.documentTopicPrior()) / ((double)this.documentLength(doc) + (double)this.numTopics() * this.documentTopicPrior());
        }

        public double corpusLog2Probability() {
            double corpusLog2Prob = 0.0;
            int numDocs = this.numDocuments();
            int numTopics = this.numTopics();
            for (int doc = 0; doc < numDocs; ++doc) {
                int docLength = this.documentLength(doc);
                for (int token = 0; token < docLength; ++token) {
                    int word = this.word(doc, token);
                    double wordProb = 0.0;
                    for (int topic = 0; topic < numTopics; ++topic) {
                        double wordTopicProbGivenDoc = this.topicWordProb(topic, word) * this.documentTopicProb(doc, topic);
                        wordProb += wordTopicProbGivenDoc;
                    }
                    corpusLog2Prob += Math.log2(wordProb);
                }
            }
            return corpusLog2Prob;
        }

        public LatentDirichletAllocation lda() {
            int numTopics = this.numTopics();
            int numWords = this.numWords();
            double topicWordPrior = this.topicWordPrior();
            double[][] topicWordProbs = new double[numTopics][numWords];
            for (int topic = 0; topic < numTopics; ++topic) {
                double topicCount = this.topicCount(topic);
                double denominator = topicCount + (double)numWords * topicWordPrior;
                for (int word = 0; word < numWords; ++word) {
                    topicWordProbs[topic][word] = ((double)this.topicWordCount(topic, word) + topicWordPrior) / denominator;
                }
            }
            return new LatentDirichletAllocation(this.mDocTopicPrior, topicWordProbs);
        }
    }
}

