package com.rapidminer.operator.learner.rules;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;
import org.apache.log4j.Priority;

/* loaded from: input_file:WEB-INF/lib/rapidMiner-1.0.0.jar:com/rapidminer/operator/learner/rules/BestRuleInduction.class */
public class BestRuleInduction extends AbstractLearner {
    private static final String PARAMETER_MAX_DEPTH = "max_depth";
    private static final String PARAMETER_UTILITY_FUNCTION = "utility_function";
    private static final String PARAMETER_MAX_CACHE = "max_cache";
    private static final String PARAMETER_RELATIVE_TO_PREDICTIONS = "relative_to_predictions";
    private static final String WRACC = "weighted relative accuracy";
    private static final String BINOMIAL = "binomial test function";
    private static final String[] UTILITY_FUNCTION_LIST = {WRACC, BINOMIAL};
    private double globalP;
    private double globalN;
    protected ConjunctiveRuleModel bestRule;
    private double bestScore;
    private final Vector<RuleWithScoreUpperBound> openNodes;
    private final Vector<ConjunctiveRuleModel> prunedNodes;

    /* loaded from: input_file:WEB-INF/lib/rapidMiner-1.0.0.jar:com/rapidminer/operator/learner/rules/BestRuleInduction$RuleWithScoreUpperBound.class */
    public static class RuleWithScoreUpperBound implements Comparable {
        private final ConjunctiveRuleModel rule;
        private final double scoreUpperBound;

        public RuleWithScoreUpperBound(ConjunctiveRuleModel conjunctiveRuleModel, double d) {
            this.rule = conjunctiveRuleModel;
            this.scoreUpperBound = d;
        }

        public ConjunctiveRuleModel getRule() {
            return this.rule;
        }

        public double getScoreBound() {
            return this.scoreUpperBound;
        }

        @Override // java.lang.Comparable
        public int compareTo(Object obj) {
            if (!(obj instanceof RuleWithScoreUpperBound)) {
                return getClass().getName().compareTo(obj.getClass().getName());
            }
            double scoreBound = ((RuleWithScoreUpperBound) obj).getScoreBound();
            if (getScoreBound() < scoreBound) {
                return -1;
            }
            return getScoreBound() > scoreBound ? 1 : 0;
        }

        public boolean equals(Object obj) {
            if (obj instanceof RuleWithScoreUpperBound) {
                return this.rule.equals(((RuleWithScoreUpperBound) obj).rule);
            }
            return false;
        }

        public int hashCode() {
            return this.rule.hashCode();
        }
    }

    public BestRuleInduction(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.openNodes = new Vector<>();
        this.prunedNodes = new Vector<>();
    }

    @Override // com.rapidminer.operator.learner.Learner
    public boolean supportsCapability(LearnerCapability learnerCapability) {
        return learnerCapability == LearnerCapability.POLYNOMINAL_ATTRIBUTES || learnerCapability == LearnerCapability.BINOMINAL_ATTRIBUTES || learnerCapability == LearnerCapability.BINOMINAL_CLASS || learnerCapability == LearnerCapability.WEIGHTED_EXAMPLES;
    }

    protected void initHighscore() {
        this.bestRule = null;
        this.bestScore = Double.NEGATIVE_INFINITY;
    }

    protected boolean communicateToHighscore(ConjunctiveRuleModel conjunctiveRuleModel, double[] dArr) throws UndefinedParameterError {
        if (getOptimisticScore(dArr) <= getPruningScore()) {
            return true;
        }
        double score = getScore(dArr, true);
        double score2 = getScore(dArr, false);
        if (score > this.bestScore) {
            this.bestRule = conjunctiveRuleModel;
            this.bestScore = score;
        }
        if (score2 <= this.bestScore) {
            return false;
        }
        this.bestRule = new ConjunctiveRuleModel(conjunctiveRuleModel, conjunctiveRuleModel.getLabel().getMapping().getNegativeIndex());
        this.bestScore = score2;
        return false;
    }

    protected ConjunctiveRuleModel getBestRule() {
        return this.bestRule;
    }

    protected double getPruningScore() {
        return this.bestScore;
    }

    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        initHighscore();
        ConjunctiveRuleModel conjunctiveRuleModel = new ConjunctiveRuleModel(exampleSet, exampleSet.getAttributes().getLabel().getMapping().getPositiveIndex());
        double[] counts = getCounts(conjunctiveRuleModel, exampleSet);
        this.globalP = counts[0];
        this.globalN = counts[1];
        communicateToHighscore(conjunctiveRuleModel, counts);
        double optimisticScore = getOptimisticScore(counts);
        this.openNodes.clear();
        this.prunedNodes.clear();
        addRulesToOpenNodes(conjunctiveRuleModel.getAllRefinedRules(exampleSet), optimisticScore);
        for (int i = 1; !this.openNodes.isEmpty() && i <= getParameterAsInt("max_depth"); i++) {
            int i2 = 0;
            log("Evaluating " + this.openNodes.size() + " rules of length " + i);
            if (this.openNodes.size() > getParameterAsInt("max_cache")) {
                log("Ignoring all but the " + getParameterAsInt("max_cache") + " rules with highest support.");
            }
            RuleWithScoreUpperBound[] ruleWithScoreUpperBoundArr = new RuleWithScoreUpperBound[this.openNodes.size()];
            this.openNodes.toArray(ruleWithScoreUpperBoundArr);
            Arrays.sort(ruleWithScoreUpperBoundArr);
            int max = Math.max(0, ruleWithScoreUpperBoundArr.length - getParameterAsInt("max_cache"));
            this.openNodes.clear();
            for (int length = ruleWithScoreUpperBoundArr.length - 1; length >= max; length--) {
                RuleWithScoreUpperBound ruleWithScoreUpperBound = ruleWithScoreUpperBoundArr[length];
                ConjunctiveRuleModel rule = ruleWithScoreUpperBound.getRule();
                if (isRefinementOfPrunedRule(rule)) {
                    i2++;
                } else if (ruleWithScoreUpperBound.getScoreBound() <= getPruningScore()) {
                    i2++;
                    this.prunedNodes.add(ruleWithScoreUpperBound.getRule());
                } else {
                    expandNode(rule, exampleSet);
                }
                checkForStop();
            }
            log("Could ignore " + i2 + " rules as refinements of pruned rules or by optimistic estimates.");
            log("Number of pruned rules in cache: " + this.prunedNodes.size());
            log("Best rule is " + getBestRule().toString());
            log("Score is " + getPruningScore());
        }
        this.openNodes.clear();
        this.prunedNodes.clear();
        return getBestRule();
    }

    private void addRulesToOpenNodes(Collection collection, double d) {
        if (d <= getPruningScore()) {
            return;
        }
        Iterator it = collection.iterator();
        while (it.hasNext()) {
            this.openNodes.add(new RuleWithScoreUpperBound((ConjunctiveRuleModel) it.next(), d));
        }
    }

    private void expandNode(ConjunctiveRuleModel conjunctiveRuleModel, ExampleSet exampleSet) throws OperatorException {
        double[] counts = getCounts(conjunctiveRuleModel, exampleSet);
        if (communicateToHighscore(conjunctiveRuleModel, counts)) {
            this.prunedNodes.add(conjunctiveRuleModel);
        } else if (conjunctiveRuleModel.getRuleLength() < getParameterAsInt("max_depth")) {
            addRulesToOpenNodes(conjunctiveRuleModel.getAllRefinedRules(exampleSet), getOptimisticScore(counts));
        }
    }

    public boolean isRefinementOfPrunedRule(ConjunctiveRuleModel conjunctiveRuleModel) {
        Iterator<ConjunctiveRuleModel> it = this.prunedNodes.iterator();
        while (it.hasNext()) {
            if (conjunctiveRuleModel.isRefinementOf(it.next())) {
                return true;
            }
        }
        return false;
    }

    protected double getScore(double[] dArr, boolean z) throws UndefinedParameterError {
        double sqrt;
        double d = dArr[0];
        double d2 = dArr[1];
        double d3 = (d + d2) / (this.globalP + this.globalN);
        double d4 = z ? d : d2;
        String str = UTILITY_FUNCTION_LIST[getParameterAsInt("utility_function")];
        UndefinedParameterError undefinedParameterError = new UndefinedParameterError("Missing parameter 'utility_function'!");
        if (getParameterAsBoolean(PARAMETER_RELATIVE_TO_PREDICTIONS) && dArr.length == 4) {
            double d5 = dArr[2];
            double d6 = dArr[3];
            double d7 = z ? d5 : d6;
            if (str.equals(WRACC)) {
                sqrt = d3 * ((d4 / (d + d2)) - (d7 / (d5 + d6)));
            } else {
                if (!str.equals(BINOMIAL)) {
                    throw undefinedParameterError;
                }
                sqrt = Math.sqrt(d3) * ((d4 / (d + d2)) - (d7 / (d5 + d6)));
            }
        } else {
            double d8 = z ? this.globalP : this.globalN;
            if (str.equals(WRACC)) {
                sqrt = d3 * ((d4 / (d + d2)) - (d8 / (this.globalP + this.globalN)));
            } else {
                if (!str.equals(BINOMIAL)) {
                    throw undefinedParameterError;
                }
                sqrt = Math.sqrt(d3) * ((d4 / (d + d2)) - (d8 / (this.globalP + this.globalN)));
            }
        }
        return sqrt;
    }

    protected double getOptimisticScore(double[] dArr) throws UndefinedParameterError {
        double d = dArr[0];
        double d2 = dArr[1];
        if (!getParameterAsBoolean(PARAMETER_RELATIVE_TO_PREDICTIONS) || dArr.length != 4) {
            return Math.max(getScore(new double[]{d, 0.0d}, true), getScore(new double[]{0.0d, d2}, false));
        }
        return Math.max(getScore(new double[]{d, 0.0d, 0.0d, dArr[3]}, true), getScore(new double[]{0.0d, d2, 0.0d, dArr[2]}, false));
    }

    protected double[] getCounts(ConjunctiveRuleModel conjunctiveRuleModel, ExampleSet exampleSet) throws OperatorException {
        Attribute weight = exampleSet.getAttributes().getWeight();
        Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
        boolean z = predictedLabel != null && getParameterAsBoolean(PARAMETER_RELATIVE_TO_PREDICTIONS);
        int conclusion = conjunctiveRuleModel.getConclusion();
        int positiveIndex = exampleSet.getAttributes().getLabel().getMapping().getPositiveIndex();
        int negativeIndex = exampleSet.getAttributes().getLabel().getMapping().getNegativeIndex();
        String str = null;
        String str2 = null;
        if (z) {
            str = predictedLabel.getMapping().mapIndex(positiveIndex);
            str2 = predictedLabel.getMapping().mapIndex(negativeIndex);
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (Example example : exampleSet) {
            double value = weight == null ? 1.0d : example.getValue(weight);
            if (conjunctiveRuleModel.predict(example) == conclusion) {
                if (example.getValue(example.getAttributes().getLabel()) == positiveIndex) {
                    d += value;
                } else {
                    d2 += value;
                }
                if (z) {
                    double confidence = example.getConfidence(str) + example.getConfidence(str2);
                    d3 += (value * example.getConfidence(str)) / confidence;
                    d4 += (value * example.getConfidence(str2)) / confidence;
                }
            }
        }
        return z ? new double[]{d, d2, d3, d4} : new double[]{d, d2};
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeInt("max_depth", "An upper bound for the number of literals.", 1, Integer.MAX_VALUE, 2));
        parameterTypes.add(new ParameterTypeCategory("utility_function", "The function to be optimized by the rule.", UTILITY_FUNCTION_LIST, 0));
        parameterTypes.add(new ParameterTypeInt("max_cache", "Bounds the number of rules considered per depth to avoid high memory consumption, but leads to incomplete search.", 1, Integer.MAX_VALUE, Priority.DEBUG_INT));
        parameterTypes.add(new ParameterTypeBoolean(PARAMETER_RELATIVE_TO_PREDICTIONS, "Searches for rules with a maximum difference to the predited label.", false));
        return parameterTypes;
    }
}
