package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.performance.EstimatedPerformance;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeList;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.Tools;
import com.rapidminer.tools.math.optimization.ec.es.ESOptimization;
import com.rapidminer.tools.math.optimization.ec.es.Individual;
import java.util.LinkedList;
import java.util.List;

/* JADX WARN: Classes with same name are omitted:
  input_file:builds/deps.jar:com/rapidminer/operator/learner/meta/CostBasedThresholdLearner.class
  input_file:builds/deps.jar:rapidMiner.jar:com/rapidminer/operator/learner/meta/CostBasedThresholdLearner.class
  input_file:com/rapidminer/operator/learner/meta/CostBasedThresholdLearner.class
 */
/* loaded from: input_file:rapidMiner.jar:com/rapidminer/operator/learner/meta/CostBasedThresholdLearner.class */
public class CostBasedThresholdLearner extends AbstractMetaLearner {
    public static final String PARAMETER_CLASS_WEIGHTS = "class_weights";
    public static final String PARAMETER_PREDICT_UNKNOWN_COSTS = "predict_unknown_costs";
    public static final String PARAMETER_TRAINING_RATIO = "training_ratio";
    public static final String PARAMETER_NUMBER_OF_ITERATIONS = "number_of_iterations";
    public static final String PARAMETER_LOCAL_RANDOM_SEED = "local_random_seed";

    public CostBasedThresholdLearner(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        Attribute label = exampleSet.getAttributes().getLabel();
        List<String[]> parameterList = getParameterList("class_weights");
        if (!exampleSet.getAttributes().getLabel().isNominal()) {
            throw new UserError(this, 101, getName(), label.getName());
        }
        if (parameterList.size() == 0) {
            throw new UserError(this, 205, "class_weights");
        }
        double parameterAsDouble = getParameterAsDouble(PARAMETER_PREDICT_UNKNOWN_COSTS);
        double[] dArr = new double[label.getMapping().size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = 1.0d;
        }
        for (String[] strArr : parameterList) {
            dArr[label.getMapping().getIndex(strArr[0])] = Double.valueOf(strArr[1]).doubleValue();
        }
        LinkedList linkedList = new LinkedList();
        for (double d : dArr) {
            linkedList.add(Tools.formatIntegerIfPossible(d));
        }
        log("Used class weights --> " + linkedList + ", unknown weight: " + Tools.formatIntegerIfPossible(parameterAsDouble));
        return calculateThresholdModel(exampleSet, dArr, parameterAsDouble);
    }

    private Model calculateThresholdModel(ExampleSet exampleSet, final double[] dArr, final double d) throws OperatorException {
        SplittedExampleSet splittedExampleSet = new SplittedExampleSet(exampleSet, getParameterAsDouble("training_ratio"), 2, getParameterAsInt("local_random_seed"));
        splittedExampleSet.selectSingleSubset(0);
        Model applyInnerLearner = applyInnerLearner(splittedExampleSet);
        splittedExampleSet.selectSingleSubset(1);
        final ExampleSet apply = applyInnerLearner.apply(splittedExampleSet);
        final Attribute label = apply.getAttributes().getLabel();
        int parameterAsInt = getParameterAsInt("number_of_iterations");
        ESOptimization eSOptimization = new ESOptimization(0.0d, 1.0d, 5, dArr.length, 0, parameterAsInt, Math.max(1, parameterAsInt / 10), 6, 0.4d, true, 1, 0.9d, false, false, RandomGenerator.getRandomGenerator(getParameterAsInt("local_random_seed")), this) { // from class: com.rapidminer.operator.learner.meta.CostBasedThresholdLearner.1
            @Override // com.rapidminer.tools.math.optimization.ec.es.ESOptimization
            public PerformanceVector evaluateIndividual(Individual individual) throws OperatorException {
                double d2 = 0.0d;
                double[] values = individual.getValues();
                for (Example example : apply) {
                    int predictedLabel = (int) example.getPredictedLabel();
                    if (example.getConfidence(label.getMapping().mapIndex(predictedLabel)) <= values[predictedLabel]) {
                        double d3 = d;
                        if (d < 0.0d) {
                            d3 = dArr[(int) example.getLabel()];
                        }
                        if (example.getLabel() == example.getPredictedLabel()) {
                            d2 += d3;
                        }
                    } else if (example.getLabel() != example.getPredictedLabel()) {
                        d2 += dArr[(int) example.getLabel()];
                    }
                }
                PerformanceVector performanceVector = new PerformanceVector();
                performanceVector.addCriterion(new EstimatedPerformance("Costs", d2, 1, true));
                return performanceVector;
            }
        };
        eSOptimization.optimize();
        PredictionModel.removePredictedLabel(apply);
        return new ThresholdModel(apply, applyInnerLearner, eSOptimization.getBestValuesEver());
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        ParameterTypeList parameterTypeList = new ParameterTypeList("class_weights", "The weights for all classes (first column: class names, second column: weight), empty: using 1 for all classes. The costs for not classifying at all are defined with class name '?'.", new ParameterTypeDouble("weight", "The weight for the specified class.", 0.0d, Double.POSITIVE_INFINITY, 1.0d));
        parameterTypeList.setExpert(false);
        parameterTypes.add(parameterTypeList);
        ParameterTypeDouble parameterTypeDouble = new ParameterTypeDouble(PARAMETER_PREDICT_UNKNOWN_COSTS, "Use this cost value for predicting an example as unknown (-1: use same costs as for correct class).", -1.0d, Double.POSITIVE_INFINITY, -1.0d);
        parameterTypeDouble.setExpert(false);
        parameterTypes.add(parameterTypeDouble);
        parameterTypes.add(new ParameterTypeDouble("training_ratio", "Use this amount of input data for model learning and the rest for threshold optimization.", 0.0d, 1.0d, 0.7d));
        parameterTypes.add(new ParameterTypeInt("number_of_iterations", "Defines the number of optimization iterations.", 1, Integer.MAX_VALUE, 200));
        parameterTypes.add(new ParameterTypeInt("local_random_seed", "Use the given random seed instead of global random numbers (-1: use global)", -1, Integer.MAX_VALUE, -1));
        return parameterTypes;
    }
}
