package com.rapidminer.operator.learner.subgroups;

import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.subgroups.hypothesis.Hypothesis;
import com.rapidminer.operator.learner.subgroups.hypothesis.Rule;
import com.rapidminer.operator.learner.subgroups.utility.UtilityFunction;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:com/rapidminer/operator/learner/subgroups/SubgroupDiscovery.class */
public class SubgroupDiscovery extends AbstractLearner {
    public static final String PARAMETER_DISCOVERY_MODE = "mode";
    public static final int DISCOVERY_MODE_ABOVE_MINIMUM_UTILITY = 0;
    public static final int DISCOVERY_MODE_K_BEST_RULES = 1;
    public static final String PARAMETER_UTILITY_FUNCTION = "utility_function";
    public static final String PARAMETER_RULE_GENERATION = "rule_generation";
    public static final String PARAMETER_MAX_DEPTH = "max_depth";
    public static final String PARAMETER_MIN_UTILITY = "min_utility";
    public static final String PARAMETER_K_BEST_RULES = "k_best_rules";
    public static final String PARAMETER_MIN_COVERAGE = "min_coverage";
    public static final String PARAMETER_MAX_CACHE = "max_cache";
    public static final String[] DISCOVERY_MODES = {"above minimum utility", "k best rules"};
    public static final String[] RULE_GENERATION_MODES = Hypothesis.RULE_GENERATION_MODES;

    /* loaded from: input_file:com/rapidminer/operator/learner/subgroups/SubgroupDiscovery$HypothesisComparator.class */
    private class HypothesisComparator implements Comparator<Hypothesis> {
        private HypothesisComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Hypothesis hypothesis, Hypothesis hypothesis2) {
            return Double.compare(hypothesis2.getCoveredWeight(), hypothesis.getCoveredWeight());
        }

        /* synthetic */ HypothesisComparator(SubgroupDiscovery subgroupDiscovery, HypothesisComparator hypothesisComparator) {
            this();
        }
    }

    /* loaded from: input_file:com/rapidminer/operator/learner/subgroups/SubgroupDiscovery$RuleComparator.class */
    private class RuleComparator implements Comparator<Rule> {
        Class<? extends UtilityFunction> functionClass;

        public RuleComparator(Class<? extends UtilityFunction> cls) {
            this.functionClass = cls;
        }

        @Override // java.util.Comparator
        public int compare(Rule rule, Rule rule2) {
            return Double.compare(rule2.getUtility(this.functionClass), rule.getUtility(this.functionClass));
        }
    }

    public SubgroupDiscovery(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        int parameterAsInt = getParameterAsInt("mode");
        int parameterAsInt2 = getParameterAsInt("max_depth");
        double parameterAsDouble = getParameterAsDouble(PARAMETER_MIN_UTILITY);
        int parameterAsInt3 = getParameterAsInt(PARAMETER_K_BEST_RULES);
        int parameterAsInt4 = getParameterAsInt(PARAMETER_RULE_GENERATION);
        double parameterAsDouble2 = getParameterAsDouble(PARAMETER_MIN_COVERAGE);
        int parameterAsInt5 = getParameterAsInt(PARAMETER_MAX_CACHE);
        int size = exampleSet.getAttributes().size();
        double d = 0.0d;
        double d2 = 0.0d;
        for (Example example : exampleSet) {
            double weight = exampleSet.getAttributes().getWeight() != null ? example.getWeight() : 1.0d;
            d += weight;
            if (example.getLabel() == example.getAttributes().getLabel().getMapping().getPositiveIndex()) {
                d2 += weight;
            }
        }
        UtilityFunction[] utilityFunctions = UtilityFunction.getUtilityFunctions(d, d2);
        UtilityFunction utilityFunction = utilityFunctions[getParameterAsInt("utility_function")];
        RuleComparator ruleComparator = new RuleComparator(utilityFunction.getClass());
        LinkedList linkedList = new LinkedList();
        ArrayList arrayList = new ArrayList(parameterAsInt3);
        LinkedList linkedList2 = new LinkedList();
        linkedList2.addAll(new Hypothesis().restrictedRefine(exampleSet.getAttributes()));
        int i = 0;
        while (true) {
            if (i < (parameterAsInt2 > size ? size : parameterAsInt2) && linkedList2.size() != 0) {
                log("evaluating " + linkedList2.size() + " hypotheses with " + (i + 1) + " literals");
                for (Example example2 : exampleSet) {
                    Iterator it = linkedList2.iterator();
                    while (it.hasNext()) {
                        ((Hypothesis) it.next()).apply(example2);
                    }
                }
                int i2 = 0;
                Iterator it2 = linkedList2.iterator();
                while (it2.hasNext()) {
                    if (((Hypothesis) it2.next()).getCoveredWeight() / d <= parameterAsDouble2) {
                        it2.remove();
                        i2++;
                    }
                }
                if (i2 > 0) {
                    log("removed " + i2 + " hypotheses not exceeding min coverage");
                }
                if (parameterAsInt5 != -1) {
                    Collections.sort(linkedList2, new HypothesisComparator(this, null));
                    int size2 = linkedList2.size() - parameterAsInt5;
                    for (int i3 = 0; i3 < size2; i3++) {
                        linkedList2.removeLast();
                    }
                    if (size2 > 0) {
                        log("removed " + size2 + " hypotheses with the lowest coverage");
                    }
                }
                log("generating rules from " + linkedList2.size() + " hypotheses");
                LinkedList linkedList3 = new LinkedList();
                Iterator it3 = linkedList2.iterator();
                while (it3.hasNext()) {
                    Hypothesis hypothesis = (Hypothesis) it3.next();
                    Iterator<Rule> it4 = hypothesis.generateRules(parameterAsInt4, exampleSet.getAttributes().getLabel()).iterator();
                    while (it4.hasNext()) {
                        Rule next = it4.next();
                        for (int i4 = 0; i4 < utilityFunctions.length; i4++) {
                            next.setUtility(utilityFunctions[i4], utilityFunctions[i4].utility(next));
                        }
                        double utility = utilityFunction.utility(next);
                        switch (parameterAsInt) {
                            case 0:
                                if (utility >= parameterAsDouble) {
                                    linkedList.add(next);
                                    log("scored: " + next);
                                    break;
                                } else {
                                    break;
                                }
                            case 1:
                                if (arrayList.size() < parameterAsInt3) {
                                    arrayList.add(next);
                                    log("scored: " + next + " [q(h)=" + utility + "]");
                                    Collections.sort(arrayList, ruleComparator);
                                    break;
                                } else if (utility > ((Rule) arrayList.get(parameterAsInt3 - 1)).getUtility(utilityFunction.getClass())) {
                                    arrayList.set(parameterAsInt3 - 1, next);
                                    parameterAsDouble = utility;
                                    log("scored: " + next + " [q(h)=" + utility + "]");
                                    Collections.sort(arrayList, ruleComparator);
                                    break;
                                } else {
                                    break;
                                }
                        }
                    }
                    if (utilityFunction.optimisticEstimate(hypothesis) >= parameterAsDouble) {
                        Iterator<Hypothesis> it5 = hypothesis.restrictedRefine().iterator();
                        while (it5.hasNext()) {
                            linkedList3.add(it5.next());
                        }
                    }
                }
                linkedList2 = linkedList3;
                i++;
            }
        }
        RuleSet ruleSet = new RuleSet(exampleSet);
        switch (parameterAsInt) {
            case 0:
                Collections.sort(linkedList, ruleComparator);
                Iterator it6 = linkedList.iterator();
                while (it6.hasNext()) {
                    ruleSet.addRule((Rule) it6.next());
                }
                break;
            case 1:
                Collections.sort(arrayList, ruleComparator);
                Iterator it7 = arrayList.iterator();
                while (it7.hasNext()) {
                    ruleSet.addRule((Rule) it7.next());
                }
                break;
        }
        return ruleSet;
    }

    @Override // com.rapidminer.operator.learner.Learner
    public boolean supportsCapability(LearnerCapability learnerCapability) {
        return learnerCapability == LearnerCapability.POLYNOMINAL_ATTRIBUTES || learnerCapability == LearnerCapability.BINOMINAL_ATTRIBUTES || learnerCapability == LearnerCapability.BINOMINAL_CLASS || learnerCapability == LearnerCapability.WEIGHTED_EXAMPLES;
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeCategory("mode", "Discovery mode.", DISCOVERY_MODES, 1));
        parameterTypes.add(new ParameterTypeCategory("utility_function", "Utility function.", UtilityFunction.FUNCTIONS, 6));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_MIN_UTILITY, "Minimum quality which has to be reached.", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.0d));
        parameterTypes.add(new ParameterTypeInt(PARAMETER_K_BEST_RULES, "Report the k best rules.", 1, Integer.MAX_VALUE, 10));
        parameterTypes.add(new ParameterTypeCategory(PARAMETER_RULE_GENERATION, "Determines which rules are generated.", RULE_GENERATION_MODES, 3));
        parameterTypes.add(new ParameterTypeInt("max_depth", "Maximum depth of BFS.", 0, Integer.MAX_VALUE, 5));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_MIN_COVERAGE, "Only consider rules which exceed the given coverage threshold.", 0.0d, 1.0d, 0.0d));
        parameterTypes.add(new ParameterTypeInt(PARAMETER_MAX_CACHE, "Bounds the number of rules which are evaluated (only the most supported rules are used).", -1, Integer.MAX_VALUE, -1));
        return parameterTypes;
    }
}
