/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.cluster;

import com.aliasi.classify.PrecisionRecallEvaluation;
import com.aliasi.util.Collections;
import com.aliasi.util.Distance;
import com.aliasi.util.Tuple;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ClusterScore<E> {
    private final PrecisionRecallEvaluation mPrEval;
    private final Set<? extends Set<? extends E>> mReferencePartition;
    private final Set<? extends Set<? extends E>> mResponsePartition;

    public ClusterScore(Set<? extends Set<? extends E>> referencePartition, Set<? extends Set<? extends E>> responsePartition) {
        ClusterScore.assertPartitionSameSets(referencePartition, responsePartition);
        this.mReferencePartition = referencePartition;
        this.mResponsePartition = responsePartition;
        this.mPrEval = this.calculateConfusionMatrix();
    }

    public PrecisionRecallEvaluation equivalenceEvaluation() {
        return this.mPrEval;
    }

    public double mucPrecision() {
        return ClusterScore.mucRecall(this.mResponsePartition, this.mReferencePartition);
    }

    public double mucRecall() {
        return ClusterScore.mucRecall(this.mReferencePartition, this.mResponsePartition);
    }

    public double mucF() {
        return ClusterScore.f(this.mucPrecision(), this.mucRecall());
    }

    public double b3ClusterPrecision() {
        return ClusterScore.b3ClusterRecall(this.mResponsePartition, this.mReferencePartition);
    }

    public double b3ClusterRecall() {
        return ClusterScore.b3ClusterRecall(this.mReferencePartition, this.mResponsePartition);
    }

    public double b3ClusterF() {
        return ClusterScore.f(this.b3ClusterPrecision(), this.b3ClusterRecall());
    }

    public double b3ElementPrecision() {
        return ClusterScore.b3ElementRecall(this.mResponsePartition, this.mReferencePartition);
    }

    public double b3ElementRecall() {
        return ClusterScore.b3ElementRecall(this.mReferencePartition, this.mResponsePartition);
    }

    public double b3ElementF() {
        return ClusterScore.f(this.b3ElementPrecision(), this.b3ElementRecall());
    }

    public Set<Tuple<E>> truePositives() {
        Set<Tuple<E>> referenceEquivalences = this.toEquivalences(this.mReferencePartition);
        Set<Tuple<E>> responseEquivalences = this.toEquivalences(this.mResponsePartition);
        referenceEquivalences.retainAll(responseEquivalences);
        return referenceEquivalences;
    }

    public Set<Tuple<E>> falsePositives() {
        Set<Tuple<E>> referenceEquivalences = this.toEquivalences(this.mReferencePartition);
        Set<Tuple<E>> responseEquivalences = this.toEquivalences(this.mResponsePartition);
        responseEquivalences.removeAll(referenceEquivalences);
        return responseEquivalences;
    }

    public Set<Tuple<E>> falseNegatives() {
        Set<Tuple<E>> referenceEquivalences = this.toEquivalences(this.mReferencePartition);
        Set<Tuple<E>> responseEquivalences = this.toEquivalences(this.mResponsePartition);
        referenceEquivalences.removeAll(responseEquivalences);
        return referenceEquivalences;
    }

    private PrecisionRecallEvaluation calculateConfusionMatrix() {
        Set<Tuple<E>> referenceEquivalences = this.toEquivalences(this.mReferencePartition);
        Set<Tuple<E>> responseEquivalences = this.toEquivalences(this.mResponsePartition);
        Iterator<Tuple<E>> it = referenceEquivalences.iterator();
        long tp = 0L;
        long fn = 0L;
        while (it.hasNext()) {
            if (responseEquivalences.remove(it.next())) {
                ++tp;
                continue;
            }
            ++fn;
        }
        long numElements = ClusterScore.elementsOf(this.mReferencePartition).size();
        long totalCount = numElements * numElements;
        long fp = responseEquivalences.size();
        long tn = totalCount - tp - fn - fp;
        return new PrecisionRecallEvaluation(tp, fn, fp, tn);
    }

    public String toString() {
        StringBuffer sb = new StringBuffer();
        sb.append("CLUSTER SCORE");
        sb.append("\nEquivalence Evaluation\n");
        sb.append(this.mPrEval.toString());
        sb.append("\nMUC Evaluation");
        sb.append("\n  MUC Precision = " + this.mucPrecision());
        sb.append("\n  MUC Recall = " + this.mucRecall());
        sb.append("\n  MUC F(1) = " + this.mucF());
        sb.append("\nB-Cubed Evaluation");
        sb.append("\n  B3 Cluster Averaged Precision = " + this.b3ClusterPrecision());
        sb.append("\n  B3 Cluster Averaged Recall = " + this.b3ClusterRecall());
        sb.append("\n  B3 Cluster Averaged F(1) = " + this.b3ClusterF());
        sb.append("\n  B3 Element Averaged Precision = " + this.b3ElementPrecision());
        sb.append("\n  B3 Element Averaged Recall = " + this.b3ElementRecall());
        sb.append("\n  B3 Element Averaged F(1) = " + this.b3ElementF());
        return sb.toString();
    }

    public static <E> double withinClusterScatter(Set<? extends Set<? extends E>> clustering, Distance<? super E> distance) {
        double scatter = 0.0;
        for (Set<E> set : clustering) {
            scatter += ClusterScore.scatter(set, distance);
        }
        return scatter;
    }

    public static <E> double scatter(Set<? extends E> cluster, Distance<? super E> distance) {
        Object[] elements = cluster.toArray();
        double scatter = 0.0;
        for (int i = 0; i < elements.length; ++i) {
            for (int j = i + 1; j < elements.length; ++j) {
                scatter += distance.distance(elements[i], elements[j]);
            }
        }
        return scatter;
    }

    Set<Tuple<E>> toEquivalences(Set<? extends Set<? extends E>> partition) {
        HashSet<Tuple<Tuple<Object>>> equivalences = new HashSet<Tuple<Tuple<Object>>>();
        for (Set<E> equivalenceClass : partition) {
            Object[] xs = new Object[equivalenceClass.size()];
            equivalenceClass.toArray(xs);
            for (int i = 0; i < xs.length; ++i) {
                for (int j = 0; j < xs.length; ++j) {
                    equivalences.add(Tuple.create(xs[i], xs[j]));
                }
            }
        }
        return equivalences;
    }

    private static double b3ElementRecall(Set referencePartition, Set responsePartition) {
        double score = 0.0;
        Set elementsOfReference = ClusterScore.elementsOf(referencePartition);
        for (Set referenceEqClass : referencePartition) {
            for (Object referenceEqClassElt : referenceEqClass) {
                score += ClusterScore.uniformElementWeight(elementsOfReference) * ClusterScore.b3Recall(referenceEqClassElt, referenceEqClass, responsePartition);
            }
        }
        return score;
    }

    private static double uniformElementWeight(Set elements) {
        return 1.0 / (double)elements.size();
    }

    private static double uniformClusterWeight(Set eqClass, Set partition) {
        return 1.0 / (double)(eqClass.size() * partition.size());
    }

    private static double b3ClusterRecall(Set referencePartition, Set responsePartition) {
        double score = 0.0;
        for (Set referenceEqClass : referencePartition) {
            for (Object referenceEqClassElt : referenceEqClass) {
                score += ClusterScore.uniformClusterWeight(referenceEqClass, referencePartition) * ClusterScore.b3Recall(referenceEqClassElt, referenceEqClass, responsePartition);
            }
        }
        return score;
    }

    private static double b3Recall(Object element, Set referenceEqClass, Set responsePartition) {
        Set responseClass = ClusterScore.getEquivalenceClass(element, responsePartition);
        return ClusterScore.recallSets(referenceEqClass, responseClass);
    }

    private static double recallSets(Set referenceSet, Set responseSet) {
        if (referenceSet.size() == 0) {
            return 1.0;
        }
        return (double)ClusterScore.intersectionSize(referenceSet, responseSet) / (double)referenceSet.size();
    }

    private static long intersectionSize(Set set1, Set set2) {
        long count = 0L;
        for (Object x : set1) {
            if (!set2.contains(x)) continue;
            ++count;
        }
        return count;
    }

    private static void assertPartitionSameSets(Set set1, Set set2) {
        ClusterScore.assertValidPartition(set1);
        ClusterScore.assertValidPartition(set2);
        if (!((Object)ClusterScore.elementsOf(set1)).equals(ClusterScore.elementsOf(set2))) {
            String msg = "Partitions must be of same sets.";
            throw new IllegalArgumentException(msg);
        }
    }

    private static void assertValidPartition(Set partition) {
        Iterator eqClasses = partition.iterator();
        HashSet eltsSoFar = new HashSet();
        while (eqClasses.hasNext()) {
            Set eqClass = (Set)eqClasses.next();
            Iterator members = eqClass.iterator();
            while (members.hasNext()) {
                if (eltsSoFar.add(members.next())) continue;
                throw new IllegalArgumentException("Partitions must not contain overlapping members.");
            }
        }
    }

    private static Set toPartition(Set[] equivalences) {
        HashSet partition = new HashSet();
        Collections.addAll(partition, equivalences);
        return partition;
    }

    private static Set getEquivalenceClass(Object element, Set partition) {
        for (Set equivalenceClass : partition) {
            if (!equivalenceClass.contains(element)) continue;
            return equivalenceClass;
        }
        throw new IllegalArgumentException("Element must be in an equivalence class in partition.");
    }

    private static Set elementsOf(Set partition) {
        HashSet elements = new HashSet();
        Iterator it = partition.iterator();
        while (it.hasNext()) {
            elements.addAll((Set)it.next());
        }
        return elements;
    }

    private static double f(double precision, double recall) {
        return 2.0 * precision * recall / (precision + recall);
    }

    private static double mucRecall(Set referencePartition, Set responsePartition) {
        long numerator = 0L;
        long denominator = 0L;
        for (Set referenceEqClass : referencePartition) {
            long numPartitions = 0L;
            for (Set responseEqClass : responsePartition) {
                if (!Collections.intersects(referenceEqClass, responseEqClass)) continue;
                ++numPartitions;
            }
            numerator += (long)referenceEqClass.size() - numPartitions;
            denominator += (long)(referenceEqClass.size() - 1);
        }
        if (denominator == 0L) {
            return 1.0;
        }
        return (double)numerator / (double)denominator;
    }
}

