package com.rapidminer.operator.learner.tree;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.tools.math.MathFunctions;
import java.util.HashSet;
import java.util.Iterator;

/* JADX WARN: Classes with same name are omitted:
  input_file:builds/deps.jar:com/rapidminer/operator/learner/tree/PessimisticPruner.class
  input_file:builds/deps.jar:rapidMiner.jar:com/rapidminer/operator/learner/tree/PessimisticPruner.class
  input_file:com/rapidminer/operator/learner/tree/PessimisticPruner.class
 */
/* loaded from: input_file:rapidMiner.jar:com/rapidminer/operator/learner/tree/PessimisticPruner.class */
public class PessimisticPruner implements Pruner {
    private static final double PRUNE_PREFERENCE = 0.001d;
    private double confidenceLevel;
    private LeafCreator leafCreator;

    public PessimisticPruner(double d, LeafCreator leafCreator) {
        this.confidenceLevel = d;
        this.leafCreator = leafCreator;
    }

    @Override // com.rapidminer.operator.learner.tree.Pruner
    public void prune(Tree tree) {
        Iterator<Edge> childIterator = tree.childIterator();
        while (childIterator.hasNext()) {
            pruneChild(childIterator.next().getChild(), tree);
        }
    }

    private void pruneChild(Tree tree, Tree tree2) {
        if (tree.isLeaf()) {
            return;
        }
        Iterator<Edge> childIterator = tree.childIterator();
        while (childIterator.hasNext()) {
            pruneChild(childIterator.next().getChild(), tree);
        }
        if (childrenHaveChildren(tree)) {
            return;
        }
        double d = 0.0d;
        Iterator<Edge> childIterator2 = tree.childIterator();
        HashSet hashSet = new HashSet();
        while (childIterator2.hasNext()) {
            Tree child = childIterator2.next().getChild();
            ExampleSet trainingSet = child.getTrainingSet();
            hashSet.add(child.getLabel());
            int size = trainingSet.size();
            d += pessimisticErrors(size, getErrorNumber(trainingSet, trainingSet.getAttributes().getLabel().getMapping().getIndex(child.getLabel())) / trainingSet.size(), this.confidenceLevel) * (size / tree.getTrainingSet().size());
        }
        ExampleSet trainingSet2 = tree.getTrainingSet();
        if (hashSet.size() <= 1) {
            tree.removeChildren();
            this.leafCreator.changeTreeToLeaf(tree, trainingSet2);
            return;
        }
        double prunedLabel = prunedLabel(trainingSet2);
        if (pessimisticErrors(trainingSet2.size(), getErrorNumber(trainingSet2, prunedLabel) / trainingSet2.size(), this.confidenceLevel) - 0.001d <= d) {
            tree.removeChildren();
            this.leafCreator.changeTreeToLeaf(tree, trainingSet2);
        }
    }

    private boolean childrenHaveChildren(Tree tree) {
        Iterator<Edge> childIterator = tree.childIterator();
        while (childIterator.hasNext()) {
            if (!childIterator.next().getChild().isLeaf()) {
                return true;
            }
        }
        return false;
    }

    private int getErrorNumber(ExampleSet exampleSet, double d) {
        int i = 0;
        Iterator<Example> it = exampleSet.iterator();
        while (it.hasNext()) {
            if (it.next().getLabel() != d) {
                i++;
            }
        }
        return i;
    }

    public double prunedLabel(ExampleSet exampleSet) {
        Attribute label = exampleSet.getAttributes().getLabel();
        exampleSet.recalculateAttributeStatistics(label);
        return exampleSet.getStatistics(label, "mode");
    }

    public double pessimisticErrors(double d, double d2, double d3) {
        if (d2 < 1.0E-6d) {
            return d2 + (d * (1.0d - Math.exp(Math.log(d3) / d)));
        }
        if (d2 + 0.5d >= d) {
            return d2 + (0.67d * (d - d2));
        }
        double normalInverse = MathFunctions.normalInverse(1.0d - d3);
        double d4 = normalInverse * normalInverse;
        return d * ((((d2 + 0.5d) + (d4 / 2.0d)) + Math.sqrt(d4 * (((d2 + 0.5d) * (1.0d - ((d2 + 0.5d) / d))) + (d4 / 4.0d)))) / (d + d4));
    }
}
