/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.tree;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.learner.tree.Edge;
import com.rapidminer.operator.learner.tree.LeafCreator;
import com.rapidminer.operator.learner.tree.Pruner;
import com.rapidminer.operator.learner.tree.Tree;
import com.rapidminer.tools.math.MathFunctions;
import java.util.HashSet;
import java.util.Iterator;

public class PessimisticPruner
implements Pruner {
    private static final double PRUNE_PREFERENCE = 0.001;
    private double confidenceLevel;
    private LeafCreator leafCreator;

    public PessimisticPruner(double confidenceLevel, LeafCreator leafCreator) {
        this.confidenceLevel = confidenceLevel;
        this.leafCreator = leafCreator;
    }

    public void prune(Tree root) {
        Iterator<Edge> childIterator = root.childIterator();
        while (childIterator.hasNext()) {
            this.pruneChild(childIterator.next().getChild(), root);
        }
    }

    private void pruneChild(Tree currentNode, Tree father) {
        if (!currentNode.isLeaf()) {
            Iterator<Edge> childIterator = currentNode.childIterator();
            while (childIterator.hasNext()) {
                this.pruneChild(childIterator.next().getChild(), currentNode);
            }
            if (!this.childrenHaveChildren(currentNode)) {
                double leafsErrorEstimate = 0.0;
                childIterator = currentNode.childIterator();
                HashSet<String> classSet = new HashSet<String>();
                while (childIterator.hasNext()) {
                    Tree leafNode = childIterator.next().getChild();
                    ExampleSet leafExampleSet = leafNode.getTrainingSet();
                    classSet.add(leafNode.getLabel());
                    int examples = leafExampleSet.size();
                    double currentErrorRate = (double)this.getErrorNumber(leafExampleSet, leafExampleSet.getAttributes().getLabel().getMapping().getIndex(leafNode.getLabel())) / (double)leafExampleSet.size();
                    leafsErrorEstimate += this.pessimisticErrors(examples, currentErrorRate, this.confidenceLevel) * ((double)examples / (double)currentNode.getTrainingSet().size());
                }
                ExampleSet currentNodeExampleSet = currentNode.getTrainingSet();
                if (classSet.size() <= 1) {
                    currentNode.removeChildren();
                    this.leafCreator.changeTreeToLeaf(currentNode, currentNodeExampleSet);
                } else {
                    double currentErrorRate;
                    double currentNodeLabel = this.prunedLabel(currentNodeExampleSet);
                    int examples = currentNodeExampleSet.size();
                    double nodeErrorEstimate = this.pessimisticErrors(examples, currentErrorRate = (double)this.getErrorNumber(currentNodeExampleSet, currentNodeLabel) / (double)currentNodeExampleSet.size(), this.confidenceLevel);
                    if (nodeErrorEstimate - 0.001 <= leafsErrorEstimate) {
                        currentNode.removeChildren();
                        this.leafCreator.changeTreeToLeaf(currentNode, currentNodeExampleSet);
                    }
                }
            }
        }
    }

    private boolean childrenHaveChildren(Tree node) {
        Iterator<Edge> iterator = node.childIterator();
        while (iterator.hasNext()) {
            if (iterator.next().getChild().isLeaf()) continue;
            return true;
        }
        return false;
    }

    private int getErrorNumber(ExampleSet exampleSet, double label) {
        int errors = 0;
        Iterator iterator = exampleSet.iterator();
        while (iterator.hasNext()) {
            if (((Example)iterator.next()).getLabel() == label) continue;
            ++errors;
        }
        return errors;
    }

    public double prunedLabel(ExampleSet exampleSet) {
        Attribute labelAttribute = exampleSet.getAttributes().getLabel();
        exampleSet.recalculateAttributeStatistics(labelAttribute);
        double test = exampleSet.getStatistics(labelAttribute, "mode");
        return test;
    }

    public double pessimisticErrors(double numberOfExamples, double errorRate, double confidenceLevel) {
        if (errorRate < 1.0E-6) {
            return errorRate + numberOfExamples * (1.0 - Math.exp(Math.log(confidenceLevel) / numberOfExamples));
        }
        if (errorRate + 0.5 >= numberOfExamples) {
            return errorRate + 0.67 * (numberOfExamples - errorRate);
        }
        double coefficient = MathFunctions.normalInverse(1.0 - confidenceLevel);
        coefficient *= coefficient;
        double pessimisticRate = (errorRate + 0.5 + coefficient / 2.0 + Math.sqrt(coefficient * ((errorRate + 0.5) * (1.0 - (errorRate + 0.5) / numberOfExamples) + coefficient / 4.0))) / (numberOfExamples + coefficient);
        return numberOfExamples * pessimisticRate;
    }
}

