package com.rapidminer.operator.learner.tree;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.SimplePredictionModel;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.Tools;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

/* JADX WARN: Classes with same name are omitted:
  input_file:builds/deps.jar:com/rapidminer/operator/learner/tree/MultiCriterionDecisionStumps.class
  input_file:builds/deps.jar:rapidMiner.jar:com/rapidminer/operator/learner/tree/MultiCriterionDecisionStumps.class
  input_file:com/rapidminer/operator/learner/tree/MultiCriterionDecisionStumps.class
 */
/* loaded from: input_file:rapidMiner.jar:com/rapidminer/operator/learner/tree/MultiCriterionDecisionStumps.class */
public class MultiCriterionDecisionStumps extends AbstractLearner {
    private static final String PARAMETER_UTILITY_FUNCTION = "utility_function";
    private int posIndex;
    private double globalP;
    private double globalN;
    private Model bestModel;
    private double bestScore;
    private String utilityFunction;
    private static final String ENTROPY = "entropy";
    private static final String ACC = "accuracy";
    private static final String SQRT_PN = "sqrt(TP*FP) + sqrt(FN*TN)";
    private static final String GINI = "gini index";
    private static final String CHI_SQUARE = "chi square test";
    private static final String[] UTILITY_FUNCTION_LIST = {ENTROPY, ACC, SQRT_PN, GINI, CHI_SQUARE};

    /* JADX WARN: Classes with same name are omitted:
      input_file:builds/deps.jar:com/rapidminer/operator/learner/tree/MultiCriterionDecisionStumps$DecisionStumpModel.class
      input_file:builds/deps.jar:rapidMiner.jar:com/rapidminer/operator/learner/tree/MultiCriterionDecisionStumps$DecisionStumpModel.class
      input_file:com/rapidminer/operator/learner/tree/MultiCriterionDecisionStumps$DecisionStumpModel.class
     */
    /* loaded from: input_file:rapidMiner.jar:com/rapidminer/operator/learner/tree/MultiCriterionDecisionStumps$DecisionStumpModel.class */
    public static class DecisionStumpModel extends SimplePredictionModel {
        private static final long serialVersionUID = -261158567126510415L;
        private final Attribute testAttribute;
        private final double testValue;
        private final boolean prediction;
        private boolean includeNaNs;
        private final boolean numerical;

        public DecisionStumpModel(Attribute attribute, double d, ExampleSet exampleSet, boolean z, boolean z2) {
            super(exampleSet);
            this.prediction = z;
            this.includeNaNs = z2;
            this.testAttribute = attribute;
            this.testValue = d;
            if (this.testAttribute == null || !this.testAttribute.isNominal()) {
                this.numerical = true;
            } else {
                this.numerical = false;
            }
        }

        @Override // com.rapidminer.operator.learner.SimplePredictionModel
        public double predict(Example example) {
            boolean z;
            if (this.testAttribute == null) {
                z = true;
            } else if (Double.isNaN(example.getValue(this.testAttribute))) {
                z = this.includeNaNs;
            } else if (this.numerical) {
                z = example.getValue(this.testAttribute) <= this.testValue;
            } else {
                z = example.getValue(this.testAttribute) == this.testValue;
            }
            return z == this.prediction ? getLabel().getMapping().getPositiveIndex() : getLabel().getMapping().getNegativeIndex();
        }

        @Override // com.rapidminer.operator.learner.PredictionModel, com.rapidminer.report.Readable
        public String toString() {
            String str;
            String positiveString = getLabel().getMapping().getPositiveString();
            String negativeString = getLabel().getMapping().getNegativeString();
            StringBuffer stringBuffer = new StringBuffer(super.toString());
            stringBuffer.append(String.valueOf(Tools.getLineSeparator()) + " (" + getLabel().getName() + "=");
            stringBuffer.append(String.valueOf(this.prediction ? positiveString : negativeString) + ") <-- ");
            if (this.testAttribute != null) {
                str = String.valueOf(this.testAttribute.getName()) + (this.numerical ? " <= " + this.testValue : " = " + this.testAttribute.getMapping().mapIndex((int) this.testValue));
            } else {
                str = "";
            }
            stringBuffer.append(str);
            stringBuffer.append(String.valueOf(Tools.getLineSeparator()) + " unknown: predict '" + (this.includeNaNs ? positiveString : negativeString) + "'.");
            return stringBuffer.toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Classes with same name are omitted:
      input_file:builds/deps.jar:com/rapidminer/operator/learner/tree/MultiCriterionDecisionStumps$ScoreNaNInfo.class
      input_file:builds/deps.jar:rapidMiner.jar:com/rapidminer/operator/learner/tree/MultiCriterionDecisionStumps$ScoreNaNInfo.class
      input_file:com/rapidminer/operator/learner/tree/MultiCriterionDecisionStumps$ScoreNaNInfo.class
     */
    /* loaded from: input_file:rapidMiner.jar:com/rapidminer/operator/learner/tree/MultiCriterionDecisionStumps$ScoreNaNInfo.class */
    public static class ScoreNaNInfo {
        public double score;
        public boolean includeNaNs;
        public boolean predicted;

        ScoreNaNInfo(double d, boolean z, boolean z2) {
            this.score = d;
            this.includeNaNs = z;
            this.predicted = z2;
        }

        public ScoreNaNInfo max(ScoreNaNInfo scoreNaNInfo) {
            return this.score >= scoreNaNInfo.score ? this : scoreNaNInfo;
        }
    }

    public MultiCriterionDecisionStumps(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    @Override // com.rapidminer.operator.learner.Learner
    public boolean supportsCapability(LearnerCapability learnerCapability) {
        return learnerCapability == LearnerCapability.NUMERICAL_ATTRIBUTES || learnerCapability == LearnerCapability.POLYNOMINAL_ATTRIBUTES || learnerCapability == LearnerCapability.BINOMINAL_ATTRIBUTES || learnerCapability == LearnerCapability.BINOMINAL_CLASS || learnerCapability == LearnerCapability.WEIGHTED_EXAMPLES;
    }

    protected void initHighscore() {
        this.bestModel = null;
        this.bestScore = Double.NEGATIVE_INFINITY;
    }

    protected Model getBestModel() {
        return this.bestModel;
    }

    private void setBestModel(DecisionStumpModel decisionStumpModel, double d) {
        this.bestModel = decisionStumpModel;
        this.bestScore = d;
    }

    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        this.utilityFunction = UTILITY_FUNCTION_LIST[getParameterAsInt("utility_function")];
        initHighscore();
        this.posIndex = exampleSet.getAttributes().getLabel().getMapping().getPositiveIndex();
        double[] computePriors = computePriors(exampleSet);
        this.globalP = computePriors[0];
        this.globalN = computePriors[1];
        boolean z = getScore(computePriors, true) >= getScore(computePriors, false);
        setBestModel(new DecisionStumpModel(null, 0.0d, exampleSet, z, true), getScore(computePriors, z));
        evaluateNominalAttributes(exampleSet);
        evaluateNumericalAttributes(exampleSet);
        return getBestModel();
    }

    private void evaluateNumericalAttributes(ExampleSet exampleSet) throws OperatorException {
        int size = exampleSet.getAttributes().size();
        int[] iArr = new int[size];
        Attribute[] attributeArr = new Attribute[size];
        int i = 0;
        int i2 = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            if (attribute.isNominal()) {
                iArr[i2] = -1;
            } else {
                attributeArr[i] = attribute;
                int i3 = i;
                i++;
                iArr[i2] = i3;
            }
            i2++;
        }
        if (i == 0) {
            return;
        }
        boolean z = exampleSet.getAttributes().getWeight() != null;
        double[][] dArr = new double[exampleSet.size()][2];
        double[][][] dArr2 = new double[i][exampleSet.size()];
        int i4 = 0;
        double[] dArr3 = new double[2];
        for (Example example : exampleSet) {
            int i5 = example.getLabel() == ((double) this.posIndex) ? 0 : 1;
            double weight = z ? example.getWeight() : 1.0d;
            dArr3[i5] = dArr3[i5] + weight;
            double[] dArr4 = new double[2];
            dArr4[0] = i5;
            dArr4[1] = weight;
            dArr[i4] = dArr4;
            for (int i6 = 0; i6 < i; i6++) {
                double[] dArr5 = new double[2];
                dArr5[0] = example.getValue(attributeArr[i6]);
                dArr5[1] = i4;
                dArr2[i6][i4] = dArr5;
            }
            i4++;
        }
        boolean z2 = dArr3[0] >= dArr3[1];
        Comparator<double[]> comparator = new Comparator<double[]>() { // from class: com.rapidminer.operator.learner.tree.MultiCriterionDecisionStumps.1
            @Override // java.util.Comparator
            public int compare(double[] dArr6, double[] dArr7) {
                return Double.compare(dArr6[0], dArr7[0]);
            }
        };
        for (int i7 = 0; i7 < i; i7++) {
            Attribute attribute2 = attributeArr[i7];
            double[][] dArr6 = dArr2[i7];
            Arrays.sort(dArr6, comparator);
            double[] dArr7 = new double[exampleSet.getAttributes().getLabel().getMapping().size()];
            double d = Double.NEGATIVE_INFINITY;
            double d2 = Double.NEGATIVE_INFINITY;
            boolean z3 = false;
            for (int i8 = 0; i8 < dArr6.length; i8++) {
                double d3 = dArr6[i8][0];
                if (!Double.isNaN(d3) && d3 != Double.POSITIVE_INFINITY) {
                    int i9 = (int) dArr6[i8][1];
                    int i10 = (int) dArr[i9][0];
                    double d4 = dArr[i9][1];
                    if (d3 != d && d2 > this.bestScore) {
                        setBestModel(new DecisionStumpModel(attribute2, (d3 + d) / 2.0d, exampleSet, z3, z2 == z3), d2);
                    }
                    dArr7[i10] = dArr7[i10] + d4;
                    double score = getScore(dArr7, true);
                    double score2 = getScore(dArr7, false);
                    d2 = Math.max(score, score2);
                    z3 = score >= score2;
                    d = d3;
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void evaluateNominalAttributes(ExampleSet exampleSet) throws OperatorException {
        int size = exampleSet.getAttributes().size();
        int[] iArr = new int[size];
        Attribute[] attributeArr = new Attribute[size];
        int i = 0;
        int i2 = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            if (attribute.isNominal()) {
                attributeArr[i] = attribute;
                int i3 = i;
                i++;
                iArr[i2] = i3;
            } else {
                iArr[i2] = -1;
            }
            i2++;
        }
        if (i == 0) {
            return;
        }
        double[][] dArr = new double[i];
        double[][] dArr2 = new double[i][exampleSet.getAttributes().getLabel().getMapping().size()];
        for (int i4 = 0; i4 < i; i4++) {
            dArr[i4] = new double[attributeArr[i4].getMapping().size()][exampleSet.getAttributes().getLabel().getMapping().size()];
        }
        Attribute weight = exampleSet.getAttributes().getWeight();
        for (Example example : exampleSet) {
            double weight2 = weight == null ? 1.0d : example.getWeight();
            boolean z = example.getLabel() != ((double) this.posIndex);
            for (int i5 = 0; i5 < i; i5++) {
                double value = example.getValue(attributeArr[i5]);
                if (Double.isNaN(value)) {
                    double[] dArr3 = dArr2[i5];
                    dArr3[z ? 1 : 0] = dArr3[z ? 1 : 0] + weight2;
                } else {
                    double[] dArr4 = dArr[i5][(int) value];
                    dArr4[z ? 1 : 0] = dArr4[z ? 1 : 0] + weight2;
                }
            }
        }
        for (int i6 = 0; i6 < i; i6++) {
            Object[] objArr = dArr[i6];
            for (int i7 = 0; i7 < objArr.length; i7++) {
                ScoreNaNInfo score = getScore(objArr[i7], dArr2[i6]);
                if (score.score > this.bestScore) {
                    setBestModel(new DecisionStumpModel(attributeArr[i6], i7, exampleSet, score.predicted, score.includeNaNs), score.score);
                }
            }
        }
    }

    private ScoreNaNInfo getScore(double[] dArr, double[] dArr2) throws UndefinedParameterError {
        ScoreNaNInfo max = new ScoreNaNInfo(getScore(dArr, true), false, true).max(new ScoreNaNInfo(getScore(dArr, false), false, false));
        if (dArr2[0] > 0.0d || dArr2[1] > 0.0d) {
            dArr[0] = dArr[0] + dArr2[0];
            dArr[1] = dArr[1] + dArr2[1];
            max = max.max(new ScoreNaNInfo(getScore(dArr, true), true, true)).max(new ScoreNaNInfo(getScore(dArr, false), true, false));
        }
        return max;
    }

    protected double getScore(double[] dArr, boolean z) {
        double d;
        double d2 = dArr[0];
        double d3 = dArr[1];
        if (this.utilityFunction.equals(ACC)) {
            d = z ? d2 - d3 : d3 - d2;
        } else if (this.utilityFunction.equals(ENTROPY)) {
            if ((d2 - d3 >= 0.0d) ^ z) {
                return Double.NEGATIVE_INFINITY;
            }
            double d4 = d2 + d3;
            double d5 = (this.globalP + this.globalN) - d4;
            d = -(((d4 * (d4 == 0.0d ? 0.0d : entropyLog2(d2 / d4) + entropyLog2(d3 / d4))) + (d5 * (d5 == 0.0d ? 0.0d : entropyLog2((this.globalP - d2) / d5) + entropyLog2((this.globalN - d3) / d5)))) / (d4 + d5));
        } else if (this.utilityFunction.equals(SQRT_PN)) {
            if ((d2 - d3 >= 0.0d) ^ z) {
                return Double.NEGATIVE_INFINITY;
            }
            d = -(Math.sqrt(d2 * d3) + Math.sqrt((this.globalP - d2) * (this.globalN - d3)));
        } else if (this.utilityFunction.equals(GINI)) {
            if ((d2 - d3 >= 0.0d) ^ z) {
                return Double.NEGATIVE_INFINITY;
            }
            double d6 = d2 + d3;
            double d7 = (this.globalP + this.globalN) - d6;
            d = -(((d6 * (d6 == 0.0d ? 0.0d : (d2 / d6) * (d3 / d6))) + (d7 * (d7 == 0.0d ? 0.0d : ((this.globalP - d2) / d7) * ((this.globalN - d3) / d7)))) / (d6 + d7));
        } else if (this.utilityFunction.equals(CHI_SQUARE)) {
            double d8 = this.globalP - d2;
            double d9 = this.globalN - d3;
            double d10 = d2 + d3;
            double d11 = d8 + d9;
            double d12 = d10 + d11;
            double d13 = (d10 * this.globalP) / d12;
            double d14 = (d10 * this.globalN) / d12;
            double d15 = (d11 * this.globalP) / d12;
            double d16 = (d11 * this.globalN) / d12;
            d = (d10 <= 0.0d || d11 <= 0.0d) ? 0.0d : (((d2 - d13) * (d2 - d13)) / d13) + (((d3 - d14) * (d3 - d14)) / d14) + (((d8 - d15) * (d8 - d15)) / d15) + (((d9 - d16) * (d9 - d16)) / d16);
        } else {
            d = Double.NaN;
            logWarning("Found unknown utility function: " + this.utilityFunction);
        }
        return d;
    }

    private double entropyLog2(double d) {
        if (Double.isNaN(d) || d == 0.0d) {
            return 0.0d;
        }
        return ((-d) * Math.log(d)) / Math.log(2.0d);
    }

    protected double[] computePriors(ExampleSet exampleSet) {
        Attribute weight = exampleSet.getAttributes().getWeight();
        double d = 0.0d;
        double d2 = 0.0d;
        for (Example example : exampleSet) {
            double value = weight == null ? 1.0d : example.getValue(weight);
            if (example.getLabel() == this.posIndex) {
                d += value;
            } else {
                d2 += value;
            }
        }
        return new double[]{d, d2};
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeCategory("utility_function", "The function to be optimized by the rule.", UTILITY_FUNCTION_LIST, 0));
        return parameterTypes;
    }
}
