/*
 * Decompiled with CFR 0.152.
 */
package ca.pfv.spmf.algorithms.classifiers.adt;

import ca.pfv.spmf.algorithms.ArraysAlgos;
import ca.pfv.spmf.algorithms.classifiers.adt.ADNode;
import ca.pfv.spmf.algorithms.classifiers.adt.RuleADT;
import ca.pfv.spmf.algorithms.classifiers.data.Dataset;
import ca.pfv.spmf.algorithms.classifiers.data.Instance;
import ca.pfv.spmf.algorithms.classifiers.general.Rule;
import ca.pfv.spmf.algorithms.classifiers.general.RuleClassifier;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;

public class ClassifierADT
extends RuleClassifier
implements Serializable {
    private static final long serialVersionUID = 8240202223112688265L;
    private Dataset training;
    private double minMerit;

    public ClassifierADT(List<RuleADT> rules, double minMerit, Dataset training) {
        super("ADT");
        RuleADT rule;
        this.training = training;
        this.minMerit = minMerit;
        Collections.sort(rules, new Comparator<RuleADT>(){

            @Override
            public int compare(RuleADT arg0, RuleADT arg1) {
                if (Double.compare(arg0.getConfidence(), arg1.getConfidence()) != 0) {
                    return -Double.compare(arg0.getConfidence(), arg1.getConfidence());
                }
                if (Double.compare(arg0.getSupportRule(), arg1.getSupportRule()) != 0) {
                    return -Double.compare(arg0.getSupportRule(), arg1.getSupportRule());
                }
                if (Integer.compare(arg0.size(), arg1.size()) != 0) {
                    return Integer.compare(arg0.size(), arg1.size());
                }
                int i = 0;
                while (i < arg0.size()) {
                    short y;
                    short x = arg0.getAntecedent().get(i);
                    if (Integer.compare(x, y = arg1.getAntecedent().get(i).shortValue()) != 0) {
                        return Integer.compare(x, y);
                    }
                    ++i;
                }
                return Integer.compare(arg0.getKlass(), arg1.getKlass());
            }
        });
        rules = this.removeRedundant(rules);
        int indexInstance = 0;
        while (indexInstance < this.training.getInstances().size()) {
            Instance instance = this.training.getInstances().get(indexInstance);
            Short[] items = instance.getItems();
            boolean match = false;
            int i = 0;
            while (i < rules.size() && !match) {
                rule = rules.get(i);
                if (rule.matching(items)) {
                    match = true;
                    rule.addCoveredInstance(indexInstance);
                    if (rule.getKlass() == instance.getKlass().shortValue()) {
                        rule.incrementHits();
                    } else {
                        rule.incrementMisses();
                    }
                }
                ++i;
            }
            ++indexInstance;
        }
        RuleADT defaultRule = this.extractDefaultRule();
        ADNode parent = new ADNode(defaultRule);
        int m = rules.size() - 1;
        while (m >= 0) {
            ADNode auxNode;
            ADNode tmpParent = parent;
            rule = rules.get(m);
            while ((auxNode = tmpParent.isChild(rule)) != null) {
                tmpParent = auxNode;
            }
            ADNode newNode = new ADNode(rule);
            newNode.parent = tmpParent;
            tmpParent.childs.add(newNode);
            --m;
        }
        this.prune(parent);
        this.rules = this.transformTreeToRules(parent);
    }

    private List<Rule> transformTreeToRules(ADNode node) {
        ArrayList<Rule> rules = new ArrayList<Rule>();
        int i = node.childs.size() - 1;
        while (i >= 0) {
            rules.addAll(this.transformTreeToRules(node.childs.get(i)));
            --i;
        }
        if (node.rule.getMerit() >= this.minMerit) {
            rules.add(node.rule);
        }
        return rules;
    }

    private void prune(ADNode node) {
        if (node == null || node.childs.isEmpty()) {
            return;
        }
        for (ADNode child : node.childs) {
            this.prune(child);
        }
        ADNode leafNode = new ADNode(node);
        double leafErrors = this.calculatePessimisticErrorEstimate(leafNode);
        double treeErrors = node.rule.getPessimisticErrorEstimate();
        for (ADNode child : node.childs) {
            treeErrors += child.rule.getPessimisticErrorEstimate();
        }
        if (leafErrors < treeErrors) {
            node.childs.clear();
            node.rule = leafNode.rule;
        }
    }

    private double calculatePessimisticErrorEstimate(ADNode node) {
        for (ADNode child : node.childs) {
            List<Integer> instances = child.rule.getCoveredInstances();
            for (Integer tid : instances) {
                Instance instance = this.training.getInstances().get(tid);
                Short[] items = instance.getItems();
                if (!node.rule.matching(items)) continue;
                node.rule.addCoveredInstance(tid);
                if (node.rule.getKlass() == instance.getKlass().shortValue()) {
                    node.rule.incrementHits();
                    continue;
                }
                node.rule.incrementMisses();
            }
        }
        return node.rule.getPessimisticErrorEstimate();
    }

    private RuleADT extractDefaultRule() {
        short majorityKlass = (Short)Collections.max(this.training.getMapClassToFrequency().entrySet(), Comparator.comparingLong(Map.Entry::getValue)).getKey();
        return new RuleADT(majorityKlass);
    }

    private List<RuleADT> removeRedundant(List<RuleADT> rules) {
        ArrayList<RuleADT> finalRules = new ArrayList<RuleADT>();
        for (RuleADT ruleI : rules) {
            boolean isGeneral = true;
            int j = 0;
            while (j < finalRules.size() && isGeneral) {
                RuleADT ruleJ = (RuleADT)finalRules.get(j);
                if (ArraysAlgos.containsOrEquals(ruleI.getAntecedent(), ruleJ.getAntecedent())) {
                    isGeneral = false;
                }
                ++j;
            }
            if (!isGeneral) continue;
            finalRules.add(ruleI);
        }
        return finalRules;
    }
}

