/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.coref.statistical;

import edu.stanford.nlp.coref.data.Dictionaries;
import edu.stanford.nlp.coref.statistical.CompressedFeatureVector;
import edu.stanford.nlp.coref.statistical.Compressor;
import edu.stanford.nlp.coref.statistical.DocumentExamples;
import edu.stanford.nlp.coref.statistical.Example;
import edu.stanford.nlp.coref.statistical.MaxMarginMentionRanker;
import edu.stanford.nlp.coref.statistical.PairwiseModel;
import edu.stanford.nlp.coref.statistical.StatisticalCorefTrainer;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

public class PairwiseModelTrainer {
    public static void trainRanking(PairwiseModel model) throws Exception {
        Redwood.log("scoref-train", "Reading compression...");
        Compressor compressor = (Compressor)IOUtils.readObjectFromFile(StatisticalCorefTrainer.compressorFile);
        Redwood.log("scoref-train", "Reading train data...");
        List trainDocuments = (List)IOUtils.readObjectFromFile(StatisticalCorefTrainer.extractedFeaturesFile);
        Redwood.log("scoref-train", "Training...");
        for (int i = 0; i < model.getNumEpochs(); ++i) {
            Collections.shuffle(trainDocuments);
            int j = 0;
            for (DocumentExamples doc : trainDocuments) {
                Redwood.log("scoref-train", "On epoch: " + i + " / " + model.getNumEpochs() + ", document: " + ++j + " / " + trainDocuments.size());
                HashMap<Integer, ArrayList<Example>> mentionToPotentialAntecedents = new HashMap<Integer, ArrayList<Example>>();
                for (Example e2 : doc.examples) {
                    int mention = e2.mentionId2;
                    ArrayList<Example> potentialAntecedents = (ArrayList<Example>)mentionToPotentialAntecedents.get(mention);
                    if (potentialAntecedents == null) {
                        potentialAntecedents = new ArrayList<Example>();
                        mentionToPotentialAntecedents.put(mention, potentialAntecedents);
                    }
                    potentialAntecedents.add(e2);
                }
                ArrayList examples = new ArrayList(mentionToPotentialAntecedents.values());
                Collections.shuffle(examples);
                for (List es : examples) {
                    Example maxScoringPositive;
                    if (es.size() == 0) continue;
                    if (model instanceof MaxMarginMentionRanker) {
                        MaxMarginMentionRanker ranker = (MaxMarginMentionRanker)model;
                        boolean noAntecedent = es.stream().allMatch(e -> e.label == 0.0);
                        es.add(new Example((Example)es.get(0), noAntecedent));
                        double maxPositiveScore = -1.7976931348623157E308;
                        maxScoringPositive = null;
                        for (Example e3 : es) {
                            double score = model.predict(e3, doc.mentionFeatures, compressor);
                            if (e3.label != 1.0) continue;
                            assert (!noAntecedent ^ e3.isNewLink());
                            if (!(score > maxPositiveScore)) continue;
                            maxPositiveScore = score;
                            maxScoringPositive = e3;
                        }
                        assert (maxScoringPositive != null);
                        double maxNegativeScore = -1.7976931348623157E308;
                        Example maxScoringNegative = null;
                        MaxMarginMentionRanker.ErrorType maxScoringEt = null;
                        for (Example e4 : es) {
                            double score = model.predict(e4, doc.mentionFeatures, compressor);
                            if (e4.label == 1.0) continue;
                            assert (!noAntecedent || !e4.isNewLink());
                            MaxMarginMentionRanker.ErrorType et = MaxMarginMentionRanker.ErrorType.WL;
                            if (noAntecedent && !e4.isNewLink()) {
                                et = MaxMarginMentionRanker.ErrorType.FL;
                            } else if (!noAntecedent && e4.isNewLink()) {
                                et = e4.mentionType2 == Dictionaries.MentionType.PRONOMINAL ? MaxMarginMentionRanker.ErrorType.FN_PRON : MaxMarginMentionRanker.ErrorType.FN;
                            }
                            score = ranker.multiplicativeCost ? ranker.costs[et.id] * (1.0 - maxPositiveScore + score) : (score += ranker.costs[et.id]);
                            if (!(score > maxNegativeScore)) continue;
                            maxNegativeScore = score;
                            maxScoringNegative = e4;
                            maxScoringEt = et;
                        }
                        assert (maxScoringNegative != null);
                        ranker.learn(maxScoringPositive, maxScoringNegative, doc.mentionFeatures, (Compressor<String>)compressor, maxScoringEt);
                        continue;
                    }
                    double maxPositiveScore = -1.7976931348623157E308;
                    double maxNegativeScore = -1.7976931348623157E308;
                    maxScoringPositive = null;
                    Example maxScoringNegative = null;
                    for (Example e5 : es) {
                        double score = model.predict(e5, doc.mentionFeatures, compressor);
                        if (e5.label == 1.0) {
                            if (!(score > maxPositiveScore)) continue;
                            maxPositiveScore = score;
                            maxScoringPositive = e5;
                            continue;
                        }
                        if (!(score > maxNegativeScore)) continue;
                        maxNegativeScore = score;
                        maxScoringNegative = e5;
                    }
                    model.learn(maxScoringPositive, maxScoringNegative, doc.mentionFeatures, compressor, 1.0);
                }
            }
        }
        Redwood.log("scoref-train", "Writing models...");
        model.writeModel();
    }

    public static List<Pair<Example, Map<Integer, CompressedFeatureVector>>> getAnaphoricityExamples(List<DocumentExamples> documents) {
        int p = 0;
        int t = 0;
        ArrayList<Pair<Example, Map<Integer, CompressedFeatureVector>>> examples = new ArrayList<Pair<Example, Map<Integer, CompressedFeatureVector>>>();
        while (!documents.isEmpty()) {
            Boolean isAnaphoric;
            DocumentExamples doc = documents.remove(documents.size() - 1);
            HashMap<Integer, Boolean> areAnaphoric = new HashMap<Integer, Boolean>();
            for (Example example : doc.examples) {
                isAnaphoric = (Boolean)areAnaphoric.get(example.mentionId2);
                if (isAnaphoric == null) {
                    areAnaphoric.put(example.mentionId2, false);
                }
                if (example.label != 1.0) continue;
                areAnaphoric.put(example.mentionId2, true);
            }
            for (Map.Entry entry : areAnaphoric.entrySet()) {
                if (((Boolean)entry.getValue()).booleanValue()) {
                    ++p;
                }
                ++t;
            }
            for (Example example : doc.examples) {
                isAnaphoric = (Boolean)areAnaphoric.get(example.mentionId2);
                if (isAnaphoric == null) continue;
                areAnaphoric.remove(example.mentionId2);
                examples.add(new Pair<Example, Map<Integer, CompressedFeatureVector>>(new Example(example, isAnaphoric), doc.mentionFeatures));
            }
        }
        Redwood.log("scoref-train", "Num anaphoricity examples " + p + " positive, " + t + " total");
        return examples;
    }

    public static List<Pair<Example, Map<Integer, CompressedFeatureVector>>> getExamples(List<DocumentExamples> documents) {
        ArrayList<Pair<Example, Map<Integer, CompressedFeatureVector>>> examples = new ArrayList<Pair<Example, Map<Integer, CompressedFeatureVector>>>();
        while (!documents.isEmpty()) {
            DocumentExamples doc = documents.remove(documents.size() - 1);
            Map<Integer, CompressedFeatureVector> mentionFeatures = doc.mentionFeatures;
            for (Example e : doc.examples) {
                examples.add(new Pair<Example, Map<Integer, CompressedFeatureVector>>(e, mentionFeatures));
            }
        }
        return examples;
    }

    public static void trainClassification(PairwiseModel model, boolean anaphoricityModel) throws Exception {
        int numTrainingExamples = model.getNumTrainingExamples();
        Redwood.log("scoref-train", "Reading compression...");
        Compressor compressor = (Compressor)IOUtils.readObjectFromFile(StatisticalCorefTrainer.compressorFile);
        Redwood.log("scoref-train", "Reading train data...");
        List trainDocuments = (List)IOUtils.readObjectFromFile(StatisticalCorefTrainer.extractedFeaturesFile);
        Redwood.log("scoref-train", "Building train set...");
        List<Pair<Example, Map<Integer, CompressedFeatureVector>>> allExamples = anaphoricityModel ? PairwiseModelTrainer.getAnaphoricityExamples(trainDocuments) : PairwiseModelTrainer.getExamples(trainDocuments);
        Redwood.log("scoref-train", "Training...");
        Random random = new Random(0L);
        int i = 0;
        boolean stopTraining = false;
        block0: while (!stopTraining) {
            Collections.shuffle(allExamples, random);
            for (Pair<Example, Map<Integer, CompressedFeatureVector>> pair : allExamples) {
                if (i++ > numTrainingExamples) {
                    stopTraining = true;
                    continue block0;
                }
                if (i % 10000 == 0) {
                    Redwood.log("scoref-train", String.format("On train example %d/%d = %.2f%%", i, numTrainingExamples, 100.0 * (double)i / (double)numTrainingExamples));
                }
                model.learn((Example)pair.first, (Map)pair.second, compressor);
            }
        }
        Redwood.log("scoref-train", "Writing models...");
        model.writeModel();
    }

    public static void test(PairwiseModel model, String predictionsName, boolean anaphoricityModel) throws Exception {
        Redwood.log("scoref-train", "Reading compression...");
        Compressor compressor = (Compressor)IOUtils.readObjectFromFile(StatisticalCorefTrainer.compressorFile);
        Redwood.log("scoref-train", "Reading test data...");
        List testDocuments = (List)IOUtils.readObjectFromFile(StatisticalCorefTrainer.extractedFeaturesFile);
        Redwood.log("scoref-train", "Building test set...");
        List<Pair<Example, Map<Integer, CompressedFeatureVector>>> allExamples = anaphoricityModel ? PairwiseModelTrainer.getAnaphoricityExamples(testDocuments) : PairwiseModelTrainer.getExamples(testDocuments);
        Redwood.log("scoref-train", "Testing...");
        PrintWriter writer = new PrintWriter(model.getDefaultOutputPath() + predictionsName);
        HashMap<Integer, Counter<Pair<Integer, Integer>>> scores = new HashMap<Integer, Counter<Pair<Integer, Integer>>>();
        PairwiseModelTrainer.writeScores(allExamples, compressor, model, writer, scores);
        if (model instanceof MaxMarginMentionRanker) {
            writer.close();
            writer = new PrintWriter(model.getDefaultOutputPath() + predictionsName + "_anaphoricity");
            testDocuments = (List)IOUtils.readObjectFromFile(StatisticalCorefTrainer.extractedFeaturesFile);
            allExamples = PairwiseModelTrainer.getAnaphoricityExamples(testDocuments);
            PairwiseModelTrainer.writeScores(allExamples, compressor, model, writer, scores);
        }
        IOUtils.writeObjectToFile(scores, model.getDefaultOutputPath() + predictionsName + ".ser");
        writer.close();
    }

    public static void writeScores(List<Pair<Example, Map<Integer, CompressedFeatureVector>>> examples, Compressor<String> compressor, PairwiseModel model, PrintWriter writer, Map<Integer, Counter<Pair<Integer, Integer>>> scores) {
        int i = 0;
        for (Pair<Example, Map<Integer, CompressedFeatureVector>> pair : examples) {
            if (i++ % 10000 == 0) {
                Redwood.log("scoref-train", String.format("On test example %d/%d = %.2f%%", i, examples.size(), 100.0 * (double)i / (double)examples.size()));
            }
            Example example = (Example)pair.first;
            Map mentionFeatures = (Map)pair.second;
            double p = model.predict(example, mentionFeatures, compressor);
            writer.println(example.docId + " " + example.mentionId1 + "," + example.mentionId2 + " " + p + " " + example.label);
            Counter<Pair<Integer, Integer>> docScores = scores.get(example.docId);
            if (docScores == null) {
                docScores = new ClassicCounter<Pair<Integer, Integer>>();
                scores.put(example.docId, docScores);
            }
            docScores.incrementCount(new Pair<Integer, Integer>(example.mentionId1, example.mentionId2), p);
        }
    }
}

