/*
 * Decompiled with CFR 0.152.
 */
package edu.columbia.ob.gen.planner;

import edu.columbia.ob.gen.core.DiscourseRelation;
import edu.columbia.ob.gen.core.SemanticUnit;
import edu.columbia.ob.gen.core.TemplateParameter;
import edu.columbia.ob.gen.env.PreGenEnv;
import edu.columbia.ob.gen.env.PreGenRuntime;
import edu.columbia.ob.gen.env.StaticVectors;
import edu.columbia.ob.gen.planner.PathFinder;
import edu.columbia.ob.gen.planner.SemanticUnitGraph;
import edu.columbia.ob.gen.planner.SemanticUnitGraphPath;
import edu.columbia.ob.gen.planner.SemanticUnitGraphPathImpl;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import ob.util.Utils;
import ob.util.collections.MultiMap;
import ob.util.collections.SortedMap3;

public class RelationNgramPathFinder
implements PathFinder {
    private SortedMap3<Integer, String, Double> _model = this.loadModel();
    private boolean _useDiscourseModel = true;

    public RelationNgramPathFinder() {
        this(true);
    }

    public RelationNgramPathFinder(boolean useDiscourseModel) {
        this._useDiscourseModel = useDiscourseModel;
    }

    @Override
    public SemanticUnitGraphPath bestPath(SemanticUnitGraph graph) {
        SemanticUnitGraphPathImpl path = new SemanticUnitGraphPathImpl();
        ArrayList<SemanticUnit> nodes = new ArrayList<SemanticUnit>(graph.getNodes());
        int index = -1;
        DiscourseRelation relation = null;
        while (!nodes.isEmpty()) {
            if (index == -1) {
                index = this.findNextIndexByPreference(nodes);
            }
            SemanticUnit node = (SemanticUnit)nodes.get(index);
            path.addNode(relation, node);
            nodes.remove(index);
            List<Edge> edges = RelationNgramPathFinder.getRemainingEdges(graph.getEdges(node), nodes);
            if (edges.isEmpty()) {
                index = -1;
                relation = DiscourseRelation.norel;
                continue;
            }
            List<DiscourseRelation> pastRelations = path.getRelationSequence();
            int edgeIndex = this.chooseEdgeStochastically(node, pastRelations, edges, nodes);
            if (edgeIndex == -1) {
                index = -1;
                relation = DiscourseRelation.norel;
                continue;
            }
            index = nodes.indexOf(edges.get(edgeIndex).getNode2());
            relation = edges.get(edgeIndex).getRelation();
        }
        return path;
    }

    private int findNextIndexByPreference(List<SemanticUnit> nodes) {
        int best = -1;
        double preference = Double.NEGATIVE_INFINITY;
        int i = 0;
        while (i < nodes.size()) {
            if (nodes.get(i).getPreferenceScore() > preference) {
                preference = nodes.get(i).getPreferenceScore();
                best = i;
            }
            ++i;
        }
        if (best == -1) {
            throw new RuntimeException("WTF?!?");
        }
        return best;
    }

    private int chooseEdgeStochastically(SemanticUnit currentNode, List<DiscourseRelation> pastRelations, List<Edge> edges, List<SemanticUnit> nodes) {
        if (pastRelations.size() > this._model.lastKey() - 1) {
            pastRelations = pastRelations.subList(pastRelations.size() - (this._model.lastKey() - 1), pastRelations.size());
        }
        if (pastRelations.contains(null)) {
            pastRelations = pastRelations.subList(pastRelations.lastIndexOf(null) + 1, pastRelations.size());
        }
        if (pastRelations.contains("norel")) {
            pastRelations = pastRelations.subList(pastRelations.lastIndexOf("norel") + 1, pastRelations.size());
        }
        ArrayList<Integer> indeces = new ArrayList<Integer>();
        ArrayList<Double> scores = new ArrayList<Double>();
        double totalScore = 0.0;
        int i = 0;
        while (i < edges.size()) {
            Edge edge = edges.get(i);
            if (nodes.contains(edge.getNode2())) {
                Double score = 1.0;
                if (this._useDiscourseModel) {
                    String ngram = this.buildNgram(pastRelations, edge.getRelation());
                    int n = pastRelations.size() + 1;
                    score = (Double)this._model.get(n).get(ngram);
                    if (score == null) {
                        score = 1.0E-7;
                    }
                }
                score = score * this.getSemanticScoreModifier(currentNode, edge.getNode2());
                score = score * this.getDistributionalScoreModifier(currentNode, edge.getNode2());
                scores.add(totalScore += score.doubleValue());
                indeces.add(i);
            }
            ++i;
        }
        double diceRoll = new Random().nextDouble() * totalScore;
        int i2 = 0;
        while (i2 < scores.size()) {
            if ((Double)scores.get(i2) > diceRoll) {
                return (Integer)indeces.get(i2);
            }
            ++i2;
        }
        throw new RuntimeException("WTF?!?");
    }

    private Double getDistributionalScoreModifier(SemanticUnit currentNode, SemanticUnit node2) {
        return StaticVectors.getSimilarity(currentNode, node2);
    }

    private Double getSemanticScoreModifier(SemanticUnit node1, SemanticUnit node2) {
        double modifier = 2.0;
        for (TemplateParameter p1 : node1.getParameters()) {
            if (p1.getType().equals(PreGenRuntime.getCoreType()) || p1.getType().equals("type")) continue;
            for (TemplateParameter p2 : node2.getParameters()) {
                if (p2.getType().equals(PreGenRuntime.getCoreType()) || p2.getType().equals("type") || !p1.getType().equals(p2.getType())) continue;
                modifier = Math.pow(modifier, 3.0);
            }
        }
        return modifier;
    }

    private String buildNgram(List<DiscourseRelation> pastRelations, DiscourseRelation relation) {
        String ngram = "";
        for (DiscourseRelation rel : pastRelations) {
            ngram = String.valueOf(ngram) + (Object)((Object)rel) + ":::";
        }
        ngram = String.valueOf(ngram) + (Object)((Object)relation);
        return ngram;
    }

    private static List<Edge> getRemainingEdges(MultiMap<SemanticUnit, DiscourseRelation> multiMap, List<SemanticUnit> nodes) {
        if (multiMap == null) {
            return Utils.list(new Edge[0]);
        }
        ArrayList<Edge> edges = new ArrayList<Edge>();
        for (SemanticUnit su : multiMap.keySet()) {
            if (!nodes.contains(su)) continue;
            for (DiscourseRelation relation : multiMap.get(su)) {
                edges.add(new Edge(su, relation));
            }
        }
        return edges;
    }

    private SortedMap3<Integer, String, Double> loadModel() {
        SortedMap3<Integer, String, Double> map = new SortedMap3<Integer, String, Double>();
        for (String line : Utils.readLines(PreGenEnv.getRelationNgramModelFile())) {
            if (!line.contains(": ")) continue;
            int separator = line.lastIndexOf(": ");
            String ngramStr = line.substring(0, separator);
            Double weight = Double.parseDouble(line.substring(separator + ": ".length()));
            String[] ngram = ngramStr.split(":::");
            map.add(ngram.length, ngramStr, weight);
        }
        return map;
    }

    private static class Edge {
        private SemanticUnit _node2;
        private DiscourseRelation _relation;

        public Edge(SemanticUnit node2, DiscourseRelation relation) {
            this._node2 = node2;
            this._relation = relation;
        }

        public SemanticUnit getNode2() {
            return this._node2;
        }

        public DiscourseRelation getRelation() {
            return this._relation;
        }
    }
}

