/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classifier;
import com.aliasi.classify.ConditionalClassification;
import com.aliasi.classify.ConfusionMatrix;
import com.aliasi.classify.JointClassification;
import com.aliasi.classify.PrecisionRecallEvaluation;
import com.aliasi.classify.RankedClassification;
import com.aliasi.classify.ScoredClassification;
import com.aliasi.classify.ScoredPrecisionRecallEvaluation;
import com.aliasi.corpus.ClassificationHandler;
import com.aliasi.util.Collections;
import com.aliasi.util.Scored;
import java.util.ArrayList;
import java.util.HashSet;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ClassifierEvaluator<E, C extends Classification>
implements ClassificationHandler<E, Classification> {
    boolean mDefectiveRanking = false;
    boolean mDefectiveScoring = false;
    boolean mDefectiveConditioning = false;
    final Classifier<E, C> mClassifier;
    private final ConfusionMatrix mConfusionMatrix;
    private int mNumCases = 0;
    final String[] mCategories;
    final HashSet mCategorySet;
    final ArrayList mReferenceCategories = new ArrayList();
    final ArrayList mClassifications = new ArrayList();
    private boolean mHasRanked = false;
    private final int[][] mRankCounts;
    private boolean mHasScored = false;
    private final ArrayList[] mScoreOutcomeLists;
    private boolean mHasConditional = false;
    private final ArrayList[] mConditionalOutcomeLists;
    private boolean mHasJoint = false;

    public ClassifierEvaluator(Classifier<E, C> classifier, String[] categories) {
        int i;
        this.mClassifier = classifier;
        this.mCategories = categories;
        this.mCategorySet = new HashSet();
        Collections.addAll(this.mCategorySet, categories);
        this.mConfusionMatrix = new ConfusionMatrix(categories);
        int len = categories.length;
        this.mRankCounts = new int[len][len];
        for (i = 0; i < len; ++i) {
            for (int j = 0; j < len; ++j) {
                this.mRankCounts[i][j] = 0;
            }
        }
        this.mScoreOutcomeLists = new ArrayList[this.numCategories()];
        for (i = 0; i < this.mScoreOutcomeLists.length; ++i) {
            this.mScoreOutcomeLists[i] = new ArrayList();
        }
        this.mConditionalOutcomeLists = new ArrayList[this.numCategories()];
        for (i = 0; i < this.mConditionalOutcomeLists.length; ++i) {
            this.mConditionalOutcomeLists[i] = new ArrayList();
        }
    }

    public Classifier<E, C> classifier() {
        return this.mClassifier;
    }

    public String[] categories() {
        return this.mCategories;
    }

    public void addCase(String referenceCategory, E input) {
        this.validateCategory(referenceCategory);
        C classification = this.mClassifier.classify(input);
        this.addClassification(referenceCategory, (Classification)classification);
    }

    @Override
    public void handle(E input, Classification classification) {
        this.addCase(classification.bestCategory(), input);
    }

    public int numCases() {
        return this.mNumCases;
    }

    public ConfusionMatrix confusionMatrix() {
        return this.mConfusionMatrix;
    }

    public boolean missingRankings() {
        return this.mDefectiveRanking;
    }

    public boolean missingScorings() {
        return this.mDefectiveScoring;
    }

    public boolean missingConditionals() {
        return this.mDefectiveScoring;
    }

    public int rankCount(String referenceCategory, int rank) {
        this.validateCategory(referenceCategory);
        int i = this.categoryToIndex(referenceCategory);
        return this.rankCount(i, rank);
    }

    public double averageRankReference() {
        double sum = 0.0;
        int count = 0;
        for (int i = 0; i < this.numCategories(); ++i) {
            for (int rank = 0; rank < this.numCategories(); ++rank) {
                int rankCount = this.mRankCounts[i][rank];
                if (rankCount == 0) continue;
                count += rankCount;
                sum += (double)(rank * rankCount);
            }
        }
        return sum / (double)count;
    }

    public double meanReciprocalRank() {
        double sum = 0.0;
        int numCases = 0;
        for (int i = 0; i < this.numCategories(); ++i) {
            for (int rank = 0; rank < this.numCategories(); ++rank) {
                int rankCount = this.mRankCounts[i][rank];
                if (rankCount == 0) continue;
                numCases += rankCount;
                sum += (double)rankCount / (1.0 + (double)rank);
            }
        }
        return sum / (double)numCases;
    }

    public double averageConditionalProbability(String refCategory, String responseCategory) {
        this.validateCategory(refCategory);
        this.validateCategory(responseCategory);
        double sum = 0.0;
        int count = 0;
        block0: for (int i = 0; i < this.mReferenceCategories.size(); ++i) {
            if (!this.mReferenceCategories.get(i).equals(refCategory)) continue;
            ConditionalClassification c = (ConditionalClassification)this.mClassifications.get(i);
            for (int rank = 0; rank < c.size(); ++rank) {
                if (!c.category(rank).equals(responseCategory)) continue;
                sum += c.conditionalProbability(rank);
                ++count;
                continue block0;
            }
        }
        return sum / (double)count;
    }

    public double averageLog2JointProbability(String refCategory, String responseCategory) {
        this.validateCategory(refCategory);
        this.validateCategory(responseCategory);
        double sum = 0.0;
        int count = 0;
        block0: for (int i = 0; i < this.mReferenceCategories.size(); ++i) {
            if (!this.mReferenceCategories.get(i).equals(refCategory)) continue;
            JointClassification c = (JointClassification)this.mClassifications.get(i);
            for (int rank = 0; rank < c.size(); ++rank) {
                if (!c.category(rank).equals(responseCategory)) continue;
                sum += c.jointLog2Probability(rank);
                ++count;
                continue block0;
            }
        }
        return sum / (double)count;
    }

    public double averageScoreReference() {
        double sum = 0.0;
        block0: for (int i = 0; i < this.mReferenceCategories.size(); ++i) {
            String refCategory = this.mReferenceCategories.get(i).toString();
            ScoredClassification c = (ScoredClassification)this.mClassifications.get(i);
            for (int rank = 0; rank < c.size(); ++rank) {
                if (!c.category(rank).equals(refCategory)) continue;
                sum += c.score(rank);
                continue block0;
            }
        }
        return sum / (double)this.mReferenceCategories.size();
    }

    public double averageConditionalProbabilityReference() {
        double sum = 0.0;
        block0: for (int i = 0; i < this.mReferenceCategories.size(); ++i) {
            String refCategory = this.mReferenceCategories.get(i).toString();
            ConditionalClassification c = (ConditionalClassification)this.mClassifications.get(i);
            for (int rank = 0; rank < c.size(); ++rank) {
                if (!c.category(rank).equals(refCategory)) continue;
                sum += c.conditionalProbability(rank);
                continue block0;
            }
        }
        return sum / (double)this.mReferenceCategories.size();
    }

    public double averageLog2JointProbabilityReference() {
        double sum = 0.0;
        block0: for (int i = 0; i < this.mReferenceCategories.size(); ++i) {
            String refCategory = this.mReferenceCategories.get(i).toString();
            JointClassification c = (JointClassification)this.mClassifications.get(i);
            for (int rank = 0; rank < c.size(); ++rank) {
                if (!c.category(rank).equals(refCategory)) continue;
                sum += c.jointLog2Probability(rank);
                continue block0;
            }
        }
        return sum / (double)this.mReferenceCategories.size();
    }

    public double averageScore(String refCategory, String responseCategory) {
        this.validateCategory(refCategory);
        this.validateCategory(responseCategory);
        double sum = 0.0;
        int count = 0;
        block0: for (int i = 0; i < this.mReferenceCategories.size(); ++i) {
            if (!this.mReferenceCategories.get(i).equals(refCategory)) continue;
            ScoredClassification c = (ScoredClassification)this.mClassifications.get(i);
            for (int rank = 0; rank < c.size(); ++rank) {
                if (!c.category(rank).equals(responseCategory)) continue;
                sum += c.score(rank);
                ++count;
                continue block0;
            }
        }
        return sum / (double)count;
    }

    public double averageRank(String refCategory, String responseCategory) {
        this.validateCategory(refCategory);
        this.validateCategory(responseCategory);
        double sum = 0.0;
        int count = 0;
        for (int i = 0; i < this.mReferenceCategories.size(); ++i) {
            if (!this.mReferenceCategories.get(i).equals(refCategory)) continue;
            RankedClassification rankedClassification = (RankedClassification)this.mClassifications.get(i);
            int rank = this.getRank(rankedClassification, responseCategory);
            sum += (double)rank;
            ++count;
        }
        return sum / (double)count;
    }

    int getRank(RankedClassification classification, String responseCategory) {
        for (int rank = 0; rank < classification.size(); ++rank) {
            if (!classification.category(rank).equals(responseCategory)) continue;
            return rank;
        }
        return this.mCategories.length - 1;
    }

    public ScoredPrecisionRecallEvaluation scoredOneVersusAll(String refCategory) {
        this.validateCategory(refCategory);
        return this.scoredOneVersusAll(this.mScoreOutcomeLists, this.categoryToIndex(refCategory));
    }

    public ScoredPrecisionRecallEvaluation conditionalOneVersusAll(String refCategory) {
        this.validateCategory(refCategory);
        return this.scoredOneVersusAll(this.mConditionalOutcomeLists, this.categoryToIndex(refCategory));
    }

    public PrecisionRecallEvaluation oneVersusAll(String refCategory) {
        this.validateCategory(refCategory);
        PrecisionRecallEvaluation prEval = new PrecisionRecallEvaluation();
        int numCases = this.mReferenceCategories.size();
        for (int i = 0; i < numCases; ++i) {
            Object caseRefCategory = this.mReferenceCategories.get(i);
            Classification response = (Classification)this.mClassifications.get(i);
            String caseResponseCategory = response.bestCategory();
            boolean inRef = caseRefCategory.equals(refCategory);
            boolean inResp = caseResponseCategory.equals(refCategory);
            prEval.addCase(inRef, inResp);
        }
        return prEval;
    }

    private ScoredPrecisionRecallEvaluation scoredOneVersusAll(ArrayList[] outcomeLists, int categoryIndex) {
        ScoredPrecisionRecallEvaluation eval = new ScoredPrecisionRecallEvaluation();
        ArrayList responseList = outcomeLists[categoryIndex];
        for (int i = 0; i < responseList.size(); ++i) {
            ScoreOutcome outcome = (ScoreOutcome)responseList.get(i);
            eval.addCase(outcome.mOutcome, outcome.mScore);
        }
        return eval;
    }

    public String toString() {
        StringBuffer sb = new StringBuffer();
        sb.append("CLASSIFIER EVALUATION\n");
        this.mConfusionMatrix.toStringGlobal(sb);
        if (this.mHasRanked) {
            sb.append("Average Reference Rank=" + this.averageRankReference() + "\n");
        }
        if (this.mHasScored) {
            sb.append("Average Score Reference=" + this.averageScoreReference() + "\n");
        }
        if (this.mHasConditional) {
            sb.append("Average Conditional Probability Reference=" + this.averageConditionalProbabilityReference() + "\n");
        }
        if (this.mHasJoint) {
            sb.append("Average Log2 Joint Probability Reference=" + this.averageLog2JointProbabilityReference() + "\n");
        }
        sb.append("ONE VERSUS ALL EVALUATIONS BY CATEGORY\n");
        for (int i = 0; i < this.categories().length; ++i) {
            int j;
            String category = this.categories()[i];
            sb.append("\nCATEGORY[" + i + "]=" + category + "\n");
            sb.append("First-Best Precision/Recall Evaluation\n");
            sb.append(this.oneVersusAll(category));
            sb.append("\n");
            if (this.mHasRanked) {
                sb.append("Rank Histogram=\n");
                this.appendCategoryLine(sb);
                for (int rank = 0; rank < this.numCategories(); ++rank) {
                    if (rank > 0) {
                        sb.append(',');
                    }
                    sb.append(this.mRankCounts[i][rank]);
                }
                sb.append("\n");
                sb.append("Average Rank Histogram=\n");
                this.appendCategoryLine(sb);
                for (j = 0; j < this.numCategories(); ++j) {
                    if (j > 0) {
                        sb.append(',');
                    }
                    sb.append(this.averageRank(category, this.categories()[j]));
                }
                sb.append("\n");
            }
            if (this.mHasScored) {
                sb.append("Scored One Versus All\n");
                sb.append(this.scoredOneVersusAll(category).toString() + "\n");
                sb.append("Average Score Histogram=\n");
                this.appendCategoryLine(sb);
                for (j = 0; j < this.numCategories(); ++j) {
                    if (j > 0) {
                        sb.append(',');
                    }
                    sb.append(this.averageScore(category, this.categories()[j]));
                }
                sb.append("\n");
            }
            if (this.mHasConditional) {
                sb.append("Conditional One Versus All\n");
                sb.append(this.conditionalOneVersusAll(category).toString() + "\n");
                sb.append("Average Conditional Probability Histogram=\n");
                this.appendCategoryLine(sb);
                for (j = 0; j < this.numCategories(); ++j) {
                    if (j > 0) {
                        sb.append(',');
                    }
                    sb.append(this.averageConditionalProbability(category, this.categories()[j]));
                }
                sb.append("\n");
            }
            if (!this.mHasJoint) continue;
            sb.append("Average Joint Probability Histogram=\n");
            this.appendCategoryLine(sb);
            for (j = 0; j < this.numCategories(); ++j) {
                if (j > 0) {
                    sb.append(',');
                }
                sb.append(this.averageLog2JointProbability(category, this.categories()[j]));
            }
            sb.append("\n");
        }
        return sb.toString();
    }

    void appendCategoryLine(StringBuffer sb) {
        sb.append("  ");
        for (int i = 0; i < this.numCategories(); ++i) {
            if (i > 0) {
                sb.append(',');
            }
            sb.append(this.categories()[i]);
        }
        sb.append("\n  ");
    }

    private void validateCategory(String category) {
        if (this.mCategorySet.contains(category)) {
            return;
        }
        String msg = "Unknown category=" + category;
        throw new IllegalArgumentException(msg);
    }

    void rankHistogramToCSV(StringBuffer sb) {
        for (int i = 0; i < this.numCategories(); ++i) {
            if (i > 0) {
                sb.append('\n');
            }
            for (int rank = 0; rank < this.numCategories(); ++rank) {
                if (rank > 0) {
                    sb.append(',');
                }
                sb.append(this.mRankCounts[i][rank]);
            }
        }
    }

    double averageRankReference(int i) {
        double sum = 0.0;
        int count = 0;
        for (int rank = 0; rank < this.numCategories(); ++rank) {
            int rankCount = this.mRankCounts[i][rank];
            if (rankCount == 0) continue;
            count += rankCount;
            sum += (double)(rank * rankCount);
        }
        return sum / (double)count;
    }

    int categoryToIndex(String category) {
        int result = this.mConfusionMatrix.getIndex(category);
        if (result < 0) {
            String msg = "Unknown category=" + category;
            throw new IllegalArgumentException(msg);
        }
        return result;
    }

    int rankCount(int categoryIndex, int rank) {
        return this.mRankCounts[categoryIndex][rank];
    }

    public void addClassification(String referenceCategory, Classification classification) {
        this.mConfusionMatrix.increment(referenceCategory, classification.bestCategory());
        this.mReferenceCategories.add(referenceCategory);
        this.mClassifications.add(classification);
        ++this.mNumCases;
        if (classification instanceof RankedClassification) {
            this.mHasRanked = true;
            this.addRanking(referenceCategory, (RankedClassification)classification);
        }
        if (classification instanceof ScoredClassification) {
            this.mHasScored = true;
            this.addScoring(referenceCategory, (ScoredClassification)classification);
        }
        if (classification instanceof ConditionalClassification) {
            this.mHasConditional = true;
            this.addConditioning(referenceCategory, (ConditionalClassification)classification);
        }
        if (classification instanceof JointClassification) {
            this.mHasJoint = true;
        }
    }

    final int numCategories() {
        return this.mConfusionMatrix.numCategories();
    }

    void addRanking(String refCategory, RankedClassification ranking) {
        this.updateRankHistogram(refCategory, ranking);
    }

    private void updateRankHistogram(String refCategory, RankedClassification ranking) {
        int refCategoryIndex = this.categoryToIndex(refCategory);
        if (ranking.size() < this.numCategories()) {
            this.mDefectiveRanking = true;
        }
        for (int rank = 0; rank < this.numCategories() && rank < ranking.size(); ++rank) {
            String category = ranking.category(rank);
            if (!category.equals(refCategory)) continue;
            int[] nArray = this.mRankCounts[refCategoryIndex];
            int n = rank;
            nArray[n] = nArray[n] + 1;
            return;
        }
        int[] nArray = this.mRankCounts[refCategoryIndex];
        int n = this.mCategories.length - 1;
        nArray[n] = nArray[n] + 1;
    }

    private void addScoring(String refCategory, ScoredClassification scoring) {
        if (scoring.size() < this.numCategories()) {
            this.mDefectiveScoring = true;
        }
        for (int rank = 0; rank < this.numCategories() && rank < scoring.size(); ++rank) {
            double score = scoring.score(rank);
            String category = scoring.category(rank);
            int categoryIndex = this.categoryToIndex(category);
            boolean match = category.equals(refCategory);
            ScoreOutcome outcome = new ScoreOutcome(score, match, rank == 0);
            this.mScoreOutcomeLists[categoryIndex].add(outcome);
        }
    }

    private void addConditioning(String refCategory, ConditionalClassification scoring) {
        if (scoring.size() < this.numCategories()) {
            this.mDefectiveConditioning = true;
        }
        for (int rank = 0; rank < this.numCategories() && rank < scoring.size(); ++rank) {
            double score = scoring.conditionalProbability(rank);
            String category = scoring.category(rank);
            int categoryIndex = this.categoryToIndex(category);
            boolean match = category.equals(refCategory);
            ScoreOutcome outcome = new ScoreOutcome(score, match, rank == 0);
            this.mConditionalOutcomeLists[categoryIndex].add(outcome);
        }
    }

    static class ScoreOutcome
    implements Scored {
        private final double mScore;
        private final boolean mOutcome;
        private final boolean mFirstBest;

        public ScoreOutcome(double score, boolean outcome, boolean firstBest) {
            this.mOutcome = outcome;
            this.mScore = score;
            this.mFirstBest = firstBest;
        }

        public double score() {
            return this.mScore;
        }

        public String toString() {
            return "(" + this.mScore + ": " + this.mOutcome + "firstBest=" + this.mFirstBest + ")";
        }
    }
}

