package com.rapidminer.operator.visualization;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorChain;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.condition.AllInnerOperatorCondition;
import com.rapidminer.operator.condition.InnerOperatorCondition;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.math.ROCData;
import com.rapidminer.tools.math.ROCDataGenerator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import marytts.signalproc.adaptation.codebook.WeightedCodebookMapperParams;

/* JADX WARN: Classes with same name are omitted:
  input_file:builds/deps.jar:com/rapidminer/operator/visualization/ROCBasedComparisonOperator.class
  input_file:builds/deps.jar:rapidMiner.jar:com/rapidminer/operator/visualization/ROCBasedComparisonOperator.class
  input_file:builds/deps.jar:tmp-src.zip:rapidMiner.jar:com/rapidminer/operator/visualization/ROCBasedComparisonOperator.class
  input_file:com/rapidminer/operator/visualization/ROCBasedComparisonOperator.class
  input_file:rapidMiner.jar:com/rapidminer/operator/visualization/ROCBasedComparisonOperator.class
  input_file:rapidMiner.jar:com/rapidminer/operator/visualization/ROCBasedComparisonOperator.class
 */
/* loaded from: input_file:tmp-src.zip:rapidMiner.jar:com/rapidminer/operator/visualization/ROCBasedComparisonOperator.class */
public class ROCBasedComparisonOperator extends OperatorChain {
    public static final String PARAMETER_NUMBER_OF_FOLDS = "number_of_folds";
    public static final String PARAMETER_SPLIT_RATIO = "split_ratio";
    public static final String PARAMETER_SAMPLING_TYPE = "sampling_type";
    public static final String PARAMETER_LOCAL_RANDOM_SEED = "local_random_seed";
    public static final String PARAMETER_USE_EXAMPLE_WEIGHTS = "use_example_weights";

    public ROCBasedComparisonOperator(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    @Override // com.rapidminer.operator.OperatorChain, com.rapidminer.operator.Operator
    public IOObject[] apply() throws OperatorException {
        ExampleSet exampleSet = (ExampleSet) getInput(ExampleSet.class);
        if (exampleSet.getAttributes().getLabel() == null) {
            throw new UserError(this, 105);
        }
        if (!exampleSet.getAttributes().getLabel().isNominal()) {
            throw new UserError(this, 101, "ROC Comparison", exampleSet.getAttributes().getLabel());
        }
        if (exampleSet.getAttributes().getLabel().getMapping().getValues().size() != 2) {
            throw new UserError(this, 114, "ROC Comparison", exampleSet.getAttributes().getLabel());
        }
        HashMap hashMap = new HashMap();
        int parameterAsInt = getParameterAsInt(PARAMETER_NUMBER_OF_FOLDS);
        if (parameterAsInt < 0) {
            SplittedExampleSet splittedExampleSet = new SplittedExampleSet((ExampleSet) exampleSet.clone(), getParameterAsDouble("split_ratio"), getParameterAsInt("sampling_type"), getParameterAsInt("local_random_seed"));
            PredictionModel.removePredictedLabel(splittedExampleSet);
            for (int i = 0; i < getNumberOfOperators(); i++) {
                splittedExampleSet.selectSingleSubset(0);
                Operator operator = getOperator(i);
                Model model = (Model) operator.apply(new IOContainer(splittedExampleSet)).remove(Model.class);
                splittedExampleSet.selectSingleSubset(1);
                ExampleSet apply = model.apply(splittedExampleSet);
                if (apply.getAttributes().getPredictedLabel() == null) {
                    throw new UserError(this, 107);
                }
                ROCData createROCData = new ROCDataGenerator(1.0d, 1.0d).createROCData(apply, getParameterAsBoolean("use_example_weights"));
                LinkedList linkedList = new LinkedList();
                linkedList.add(createROCData);
                hashMap.put(operator.getName(), linkedList);
                PredictionModel.removePredictedLabel(apply);
            }
        } else {
            SplittedExampleSet splittedExampleSet2 = new SplittedExampleSet((ExampleSet) exampleSet.clone(), parameterAsInt, getParameterAsInt("sampling_type"), getParameterAsInt("local_random_seed"));
            PredictionModel.removePredictedLabel(splittedExampleSet2);
            for (int i2 = 0; i2 < getNumberOfOperators(); i2++) {
                Operator operator2 = getOperator(i2);
                LinkedList linkedList2 = new LinkedList();
                for (int i3 = 0; i3 < parameterAsInt; i3++) {
                    splittedExampleSet2.selectAllSubsetsBut(i3);
                    Model model2 = (Model) operator2.apply(new IOContainer(splittedExampleSet2)).remove(Model.class);
                    splittedExampleSet2.selectSingleSubset(i3);
                    ExampleSet apply2 = model2.apply(splittedExampleSet2);
                    if (apply2.getAttributes().getPredictedLabel() == null) {
                        throw new UserError(this, 107);
                    }
                    linkedList2.add(new ROCDataGenerator(1.0d, 1.0d).createROCData(apply2, getParameterAsBoolean("use_example_weights")));
                    PredictionModel.removePredictedLabel(apply2);
                    inApplyLoop();
                }
                hashMap.put(operator2.getName(), linkedList2);
            }
        }
        return new IOObject[]{exampleSet, new ROCComparison(hashMap)};
    }

    @Override // com.rapidminer.operator.Operator
    public Class<?>[] getInputClasses() {
        return new Class[]{ExampleSet.class};
    }

    @Override // com.rapidminer.operator.Operator
    public Class<?>[] getOutputClasses() {
        return new Class[]{ExampleSet.class, ROCComparison.class};
    }

    @Override // com.rapidminer.operator.OperatorChain
    public InnerOperatorCondition getInnerOperatorCondition() {
        return new AllInnerOperatorCondition(new Class[]{ExampleSet.class}, new Class[]{Model.class});
    }

    @Override // com.rapidminer.operator.OperatorChain
    public int getMinNumberOfInnerOperators() {
        return 1;
    }

    @Override // com.rapidminer.operator.OperatorChain
    public int getMaxNumberOfInnerOperators() {
        return Integer.MAX_VALUE;
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        ParameterTypeInt parameterTypeInt = new ParameterTypeInt(PARAMETER_NUMBER_OF_FOLDS, "The number of folds used for a cross validation evaluation (-1: use simple split ratio).", -1, Integer.MAX_VALUE, 10);
        parameterTypeInt.setExpert(false);
        parameterTypes.add(parameterTypeInt);
        parameterTypes.add(new ParameterTypeDouble("split_ratio", "Relative size of the training set", WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN, 1.0d, 0.7d));
        parameterTypes.add(new ParameterTypeCategory("sampling_type", "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 2));
        parameterTypes.add(new ParameterTypeInt("local_random_seed", "Use the given random seed instead of global random numbers (-1: use global)", -1, Integer.MAX_VALUE, -1));
        parameterTypes.add(new ParameterTypeBoolean("use_example_weights", "Indicates if example weights should be regarded (use weight 1 for each example otherwise).", true));
        return parameterTypes;
    }
}
