/*
 * Decompiled with CFR 0.152.
 */
package edu.columbia.ob.gen.paraphraseMining;

import edu.columbia.ob.gen.paraphraseMining.ParaphraseUtils;
import edu.columbia.ob.gen.paraphraseMining.Vectors;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import ob.util.Pair;
import ob.util.StopWatch;
import ob.util.Utils;

public class TransductiveClustering {
    private static final int MIN_NUM_CLUSTERS = 1;
    private static final double MAX_SSE_THRESHOLD = 9.5;
    private static final int MAX_TEMPLATES_SIZE = 100;
    private static String _groupsFile;
    private static Vectors _vectors;

    public static void main(String[] args) throws Exception {
        if (args.length != 2) {
            throw new RuntimeException("expected two arguments: groups (input) file and output file");
        }
        _groupsFile = args[0];
        String outfile = args[1];
        PrintWriter pw = new PrintWriter(new FileWriter(outfile));
        int c = 0;
        StopWatch sw = new StopWatch();
        for (String line : Utils.readLinesDynamically(_groupsFile)) {
            if (++c % 10 == 0) {
                System.out.println("clustering group #" + c + "  --  " + Utils.usedMemory() + "  --  " + sw.getTimeElapsedPretty());
            }
            String[] tokens = line.split("\\t");
            String groupId = tokens[0];
            List<String> templates = Utils.list(Arrays.copyOfRange(tokens, 1, tokens.length));
            if (templates.size() > 100) {
                System.out.println("skipping: " + groupId);
                continue;
            }
            final Map<Pair<String>, Double> distances = TransductiveClustering.calculateDistances(templates);
            if (distances == null) {
                System.out.println("null distances: " + groupId);
                continue;
            }
            ArrayList<Pair<String>> sortedPairs = new ArrayList<Pair<String>>(distances.keySet());
            Collections.sort(sortedPairs, new Comparator<Pair<String>>(){

                @Override
                public int compare(Pair<String> pair1, Pair<String> pair2) {
                    return ((Double)distances.get(pair1)).compareTo((Double)distances.get(pair2));
                }
            });
            HashMap<String, Set<String>> clusters = new HashMap<String, Set<String>>();
            HashMap<String, String> membership = new HashMap<String, String>();
            int tmpClusterId = 0;
            for (String template : templates) {
                System.out.println(template);
                String id = "C" + tmpClusterId;
                membership.put(template, id);
                clusters.put(id, new HashSet());
                ((Set)clusters.get(id)).add(template);
                ++tmpClusterId;
            }
            Map<String, Set<String>> bestClusters = TransductiveClustering.copy(clusters);
            double maxSse = 0.0;
            double logMaxSse = 0.0;
            System.out.println(String.valueOf(clusters.size()) + " clusters (" + TransductiveClustering.getClusterSpread(clusters) + "): " + logMaxSse + " (log max SSE)");
            for (Pair pair : sortedPairs) {
                String clusterId2;
                String clusterId1 = (String)membership.get(pair.getFirst());
                if (clusterId1 == (clusterId2 = (String)membership.get(pair.getSecond()))) continue;
                for (String template : membership.keySet()) {
                    if (!((String)membership.get(template)).equals(clusterId2)) continue;
                    membership.put(template, clusterId1);
                }
                ((Set)clusters.get(clusterId1)).addAll((Collection)clusters.get(clusterId2));
                clusters.remove(clusterId2);
                maxSse = TransductiveClustering.calculateMaxSSE(clusters);
                logMaxSse = maxSse == 0.0 ? 0.0 : Math.log(maxSse);
                System.out.println(String.valueOf(clusters.size()) + " clusters (" + TransductiveClustering.getClusterSpread(clusters) + "): " + logMaxSse + " (log max SSE)");
                if (!(logMaxSse < 9.5)) break;
                bestClusters = TransductiveClustering.copy(clusters);
                if (clusters.size() == 1) break;
            }
            System.out.println("# best clusters: " + bestClusters.size() + " (" + TransductiveClustering.getClusterSpread(bestClusters) + ") out of " + templates.size());
            System.out.println();
            for (String string : bestClusters.keySet()) {
                String id = String.valueOf(groupId) + "-" + string;
                if (bestClusters.get(string).size() <= 1) continue;
                pw.println(id);
                for (String template : bestClusters.get(string)) {
                    pw.println("  " + template);
                }
                pw.println();
            }
        }
        pw.close();
    }

    private static Map<String, Set<String>> copy(Map<String, Set<String>> clusters) {
        HashMap<String, Set<String>> copy = new HashMap<String, Set<String>>();
        for (String key : clusters.keySet()) {
            copy.put(key, new HashSet(clusters.get(key)));
        }
        return copy;
    }

    private static String getClusterSpread(Map<String, Set<String>> bestClusters) {
        StringBuilder sb2 = new StringBuilder();
        for (Set<String> cluster : bestClusters.values()) {
            sb2.append(String.valueOf(cluster.size()) + ", ");
        }
        String result = sb2.toString();
        return result.substring(0, result.length() - 2);
    }

    private static double calculateMaxSSE(Map<String, Set<String>> clusters) {
        double max = 0.0;
        for (String i : clusters.keySet()) {
            double ssei = TransductiveClustering.calculateClusterSSE(clusters.get(i));
            max = Math.max(max, ssei);
        }
        return max;
    }

    private static double calculateClusterSSE(Set<String> members) {
        double sum = 0.0;
        Map<String, Double> centroid = TransductiveClustering.findCentroid(members);
        for (String template : members) {
            sum += Math.pow(TransductiveClustering.getVectors().euclideanDistance(centroid, ParaphraseUtils.normalizeText(TransductiveClustering.removeEntities(template))), 2.0);
        }
        return sum;
    }

    private static Map<String, Double> findCentroid(Set<String> members) {
        HashMap<String, Double> centroid = new HashMap<String, Double>();
        for (String template : members) {
            Map<String, Double> vector = TransductiveClustering.getVectors().getCombinedVector(ParaphraseUtils.normalizeText(TransductiveClustering.removeEntities(template)));
            for (String feature : vector.keySet()) {
                Double currentValue = (Double)centroid.get(feature);
                if (currentValue == null) {
                    currentValue = 0.0;
                }
                currentValue = currentValue + vector.get(feature);
                centroid.put(feature, currentValue);
            }
        }
        for (String feature : centroid.keySet()) {
            centroid.put(feature, (Double)centroid.get(feature) / (double)members.size());
        }
        return centroid;
    }

    private static Map<Pair<String>, Double> calculateDistances(List<String> templates) {
        HashMap<Pair<String>, Double> result = new HashMap<Pair<String>, Double>();
        int i = 0;
        while (i < templates.size() - 1) {
            String t1 = ParaphraseUtils.normalizeText(TransductiveClustering.removeEntities(templates.get(i)));
            Map<String, Double> v1 = TransductiveClustering.getVectors().getCombinedVector(t1);
            if (v1 == null) {
                return null;
            }
            int j = i + 1;
            while (j < templates.size()) {
                String t2 = ParaphraseUtils.normalizeText(TransductiveClustering.removeEntities(templates.get(j)));
                Map<String, Double> v2 = TransductiveClustering.getVectors().getCombinedVector(t2);
                if (v2 == null) {
                    return null;
                }
                Pair<String> pair = templates.get(i).compareTo(templates.get(j)) > 0 ? new Pair<String>(templates.get(i), templates.get(j)) : new Pair<String>(templates.get(j), templates.get(i));
                Double distance = Vectors.euclideanDistance(v1, v2);
                result.put(pair, distance);
                ++j;
            }
            ++i;
        }
        return result;
    }

    private static String removeEntities(String string) {
        return string.replaceAll("\\[\\[[^\\]]+\\]\\]", "");
    }

    private static Vectors getVectors() {
        if (_vectors == null) {
            Set<String> corpusTerms = TransductiveClustering.getCorpusTerms(_groupsFile);
            _vectors = Vectors.readFromFile("C:/Users/Or/Desktop/wp/vectors.culled.real10.tfidf", corpusTerms);
            _vectors.setUseCache(false);
        }
        return _vectors;
    }

    private static Set<String> getCorpusTerms(String groupsFile) {
        HashSet<String> corpusTerms = new HashSet<String>();
        for (String line : Utils.readLinesDynamically(groupsFile)) {
            String[] tokens = line.split("\\t");
            int i = 1;
            while (i < tokens.length) {
                String[] stringArray = ParaphraseUtils.normalizeText(TransductiveClustering.removeEntities(tokens[i])).split("\\s+");
                int n = stringArray.length;
                int n2 = 0;
                while (n2 < n) {
                    String term = stringArray[n2];
                    corpusTerms.add(term);
                    ++n2;
                }
                ++i;
            }
        }
        return corpusTerms;
    }
}

