/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.hmm;

import com.aliasi.symbol.SymbolTable;
import com.aliasi.util.Math;
import com.aliasi.util.ScoredObject;
import java.util.Arrays;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class TagWordLattice {
    final double[][][] mTransitions;
    final double[][] mForwards;
    final double[] mForwardExps;
    final double[][] mBacks;
    final double[] mBackExps;
    final double[] mStarts;
    final double[] mEnds;
    final String[] mTokens;
    final SymbolTable mTagSymbolTable;
    double mTotal = Double.NaN;
    double mLog2Total = Double.NaN;

    public TagWordLattice(String[] tokens, SymbolTable tagSymbolTable, double[] startProbs, double[] endProbs, double[][][] transitProbs) {
        int i;
        for (i = 0; i < startProbs.length; ++i) {
            if (!(startProbs[i] < 0.0) && !(startProbs[i] > 1.0)) continue;
            String msg = "startProbs[" + i + "]=" + startProbs[i];
            throw new IllegalArgumentException(msg);
        }
        for (i = 0; i < endProbs.length; ++i) {
            if (!(endProbs[i] < 0.0) && !(endProbs[i] > 1.0)) continue;
            String msg = "endProbs[" + i + "]=" + endProbs[i];
            throw new IllegalArgumentException(msg);
        }
        for (i = 1; i < transitProbs.length; ++i) {
            for (int j = 0; j < transitProbs[i].length; ++j) {
                for (int k = 0; k < transitProbs[i][j].length; ++k) {
                    if (!(transitProbs[i][j][k] < 0.0) && !(transitProbs[i][j][k] > 1.0)) continue;
                    String msg = "transitProbs[" + i + "][" + j + "][" + k + "]=" + transitProbs[i][j][k];
                    throw new IllegalArgumentException(msg);
                }
            }
        }
        int numTags = tagSymbolTable.numSymbols();
        int numTokens = tokens.length;
        this.mStarts = startProbs;
        this.mEnds = endProbs;
        this.mTransitions = transitProbs;
        this.mTokens = tokens;
        this.mTagSymbolTable = tagSymbolTable;
        this.mForwards = new double[numTokens][numTags];
        this.mForwardExps = new double[numTokens];
        this.mBacks = new double[numTokens][numTags];
        this.mBackExps = new double[numTokens];
        this.computeAll();
    }

    public String[] tokens() {
        return this.mTokens;
    }

    public SymbolTable tagSymbolTable() {
        return this.mTagSymbolTable;
    }

    public ScoredObject<String>[] log2ConditionalTags(int tokenIndex) {
        double log2Total = this.log2Total();
        SymbolTable st = this.mTagSymbolTable;
        int numTags = st.numSymbols();
        ScoredObject[] scoredTags = new ScoredObject[numTags];
        for (int tagId = 0; tagId < numTags; ++tagId) {
            String tag = st.idToSymbol(tagId);
            double log2P = this.log2ForwardBackward(tokenIndex, tagId);
            double condLog2P = log2P - log2Total;
            if (condLog2P > 0.0) {
                condLog2P = 0.0;
            } else if (Double.isNaN(condLog2P) || Double.isInfinite(condLog2P)) {
                condLog2P = Math.log2(Double.MIN_VALUE);
            }
            scoredTags[tagId] = new ScoredObject<String>(tag, condLog2P);
        }
        Arrays.sort(scoredTags, ScoredObject.REVERSE_SCORE_COMPARATOR);
        return scoredTags;
    }

    public String[] bestForwardBackward() {
        String[] bestTags = new String[this.mTokens.length];
        int numTags = this.mTagSymbolTable.numSymbols();
        for (int i = 0; i < bestTags.length; ++i) {
            int bestTagId = 0;
            double bestFB = this.forwardBackward(i, 0);
            for (int tagId = 1; tagId < numTags; ++tagId) {
                double fb = this.forwardBackward(i, tagId);
                if (!(fb > bestFB)) continue;
                bestFB = fb;
                bestTagId = tagId;
            }
            bestTags[i] = this.mTagSymbolTable.idToSymbol(bestTagId);
        }
        return bestTags;
    }

    public double start(int tagId) {
        return this.mStarts[tagId];
    }

    public double log2Start(int tagId) {
        return Math.log2(this.start(tagId));
    }

    public double end(int tagId) {
        return this.mEnds[tagId];
    }

    public double log2End(int tagId) {
        return Math.log2(this.end(tagId));
    }

    public double transition(int tokenIndex, int sourceTagId, int targetTagId) {
        if (tokenIndex == 0) {
            String msg = "Token index must be > 0.";
            throw new IndexOutOfBoundsException(msg);
        }
        return this.mTransitions[tokenIndex][sourceTagId][targetTagId];
    }

    public double log2Transitions(int tokenIndex, int sourceTagId, int targetTagId) {
        return Math.log2(this.transition(tokenIndex, sourceTagId, targetTagId));
    }

    public double forward(int tokenIndex, int tagId) {
        return this.mForwards[tokenIndex][tagId] * java.lang.Math.pow(2.0, this.mForwardExps[tokenIndex]);
    }

    public double log2Forward(int tokenIndex, int tagId) {
        return Math.log2(this.mForwards[tokenIndex][tagId]) + this.mForwardExps[tokenIndex];
    }

    public double backward(int tokenIndex, int tagId) {
        return this.mBacks[tokenIndex][tagId] * java.lang.Math.pow(2.0, this.mBackExps[tokenIndex]);
    }

    public double log2Backward(int tokenIndex, int tagId) {
        return Math.log2(this.mBacks[tokenIndex][tagId]) + this.mBackExps[tokenIndex];
    }

    public double forwardBackward(int tokenIndex, int tagId) {
        return this.forward(tokenIndex, tagId) * this.backward(tokenIndex, tagId);
    }

    public double log2ForwardBackward(int tokenIndex, int tagId) {
        return this.log2Forward(tokenIndex, tagId) + this.log2Backward(tokenIndex, tagId);
    }

    public double total() {
        return this.mTotal;
    }

    public double log2Total() {
        return this.mLog2Total;
    }

    final void computeAll() {
        this.computeForward();
        this.computeBackward();
        this.computeTotal();
    }

    private void computeTotal() {
        if (this.mForwards.length == 0) {
            this.mTotal = 1.0;
            this.mLog2Total = 0.0;
            return;
        }
        double total = 0.0;
        int numSymbols = this.tagSymbolTable().numSymbols();
        for (int tagId = 0; tagId < numSymbols; ++tagId) {
            total += this.mForwards[0][tagId] * this.mBacks[0][tagId];
        }
        double exp = this.mForwardExps[0] + this.mBackExps[0];
        this.mLog2Total = Math.log2(total) + exp;
        this.mTotal = total * java.lang.Math.pow(2.0, exp);
    }

    private void computeForward() {
        if (this.mForwards.length == 0) {
            return;
        }
        int numSymbols = this.tagSymbolTable().numSymbols();
        double[] forwards = this.mForwards[0];
        for (int tagId = 0; tagId < numSymbols; ++tagId) {
            if (this.mStarts[tagId] < 0.0) {
                this.mStarts[tagId] = 0.0;
            }
            forwards[tagId] = this.mStarts[tagId];
        }
        this.mForwardExps[0] = TagWordLattice.log2ScaleExp(forwards);
        int numToks = this.mTokens.length;
        for (int tokenId = 1; tokenId < numToks; ++tokenId) {
            forwards = this.mForwards[tokenId - 1];
            double[][] transits = this.mTransitions[tokenId];
            for (int tagId = 0; tagId < numSymbols; ++tagId) {
                double f = 0.0;
                for (int prevTagId = 0; prevTagId < numSymbols; ++prevTagId) {
                    f += forwards[prevTagId] * transits[prevTagId][tagId];
                }
                this.mForwards[tokenId][tagId] = f;
            }
            this.mForwardExps[tokenId] = TagWordLattice.log2ScaleExp(this.mForwards[tokenId]) + this.mForwardExps[tokenId - 1];
        }
    }

    private void computeBackward() {
        if (this.mBacks.length == 0) {
            return;
        }
        int numSymbols = this.tagSymbolTable().numSymbols();
        int lastTok = this.mTokens.length - 1;
        double[] backs = this.mBacks[lastTok];
        for (int tagId = 0; tagId < numSymbols; ++tagId) {
            backs[tagId] = this.mEnds[tagId];
        }
        this.mBackExps[lastTok] = TagWordLattice.log2ScaleExp(backs);
        int tokenId = lastTok;
        while (--tokenId >= 0) {
            backs = this.mBacks[tokenId + 1];
            double[][] transits = this.mTransitions[tokenId + 1];
            for (int tagId = 0; tagId < numSymbols; ++tagId) {
                double b = 0.0;
                for (int nextTagId = 0; nextTagId < numSymbols; ++nextTagId) {
                    b += backs[nextTagId] * transits[tagId][nextTagId];
                }
                this.mBacks[tokenId][tagId] = b;
            }
            this.mBackExps[tokenId] = TagWordLattice.log2ScaleExp(this.mBacks[tokenId]) + this.mBackExps[tokenId + 1];
        }
    }

    static double log2ScaleExp(double[] xs) {
        if (xs.length == 0) {
            return 0.0;
        }
        double max = xs[0];
        for (int i = 1; i < xs.length; ++i) {
            if (!(max < xs[i])) continue;
            max = xs[i];
        }
        if (max < 0.0 || max > 1.0) {
            String msg = "Max must be >= 0 and <= 1. max=" + max;
            throw new IllegalArgumentException(msg);
        }
        double exp = 0.0;
        double mult = 1.0;
        while (max != 0.0 && max < 0.5) {
            exp -= 1.0;
            mult *= 2.0;
            max *= 2.0;
        }
        for (int j = 0; j < xs.length; ++j) {
            xs[j] = xs[j] * mult;
        }
        if (exp > 0.0) {
            throw new RuntimeException("exp=" + exp);
        }
        return exp;
    }
}

