package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCreationException;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.Learner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.lazy.DefaultLearner;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.OperatorService;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/rapidminer/operator/learner/meta/AdditiveRegression.class */
public class AdditiveRegression extends AbstractMetaLearner {
    public static final String PARAMETER_ITERATIONS = "iterations";
    public static final String PARAMETER_SHRINKAGE = "shrinkage";

    public AdditiveRegression(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        ExampleSet exampleSet2 = (ExampleSet) exampleSet.clone();
        Attribute label = exampleSet2.getAttributes().getLabel();
        Attribute createAttribute = AttributeFactory.createAttribute(label, "working_label");
        exampleSet2.getExampleTable().addAttribute(createAttribute);
        exampleSet2.getAttributes().addRegular(createAttribute);
        for (Example example : exampleSet2) {
            example.setValue(createAttribute, example.getValue(label));
        }
        exampleSet2.getAttributes().remove(createAttribute);
        exampleSet2.getAttributes().setLabel(createAttribute);
        try {
            Model learn = ((Learner) OperatorService.createOperator(DefaultLearner.class)).learn(exampleSet2);
            residualReplace(exampleSet2, learn, false);
            Model[] modelArr = new Model[getParameterAsInt("iterations")];
            for (int i = 0; i < modelArr.length; i++) {
                modelArr[i] = applyInnerLearner(exampleSet2);
                residualReplace(exampleSet2, modelArr[i], true);
            }
            exampleSet2.getAttributes().remove(createAttribute);
            exampleSet2.getExampleTable().removeAttribute(createAttribute);
            return new AdditiveRegressionModel(exampleSet, learn, modelArr, getParameterAsDouble(PARAMETER_SHRINKAGE));
        } catch (OperatorCreationException e) {
            throw new OperatorException(String.valueOf(getName()) + ": not able to create default classifier!", e);
        }
    }

    private void residualReplace(ExampleSet exampleSet, Model model, boolean z) throws OperatorException {
        ExampleSet apply = model.apply(exampleSet);
        Attribute label = exampleSet.getAttributes().getLabel();
        Iterator<Example> it = exampleSet.iterator();
        Iterator<Example> it2 = apply.iterator();
        while (it.hasNext() && it2.hasNext()) {
            Example next = it.next();
            double predictedLabel = it2.next().getPredictedLabel();
            if (z) {
                predictedLabel *= getParameterAsDouble(PARAMETER_SHRINKAGE);
            }
            next.setValue(label, next.getLabel() - predictedLabel);
        }
        PredictionModel.removePredictedLabel(apply);
    }

    @Override // com.rapidminer.operator.learner.meta.AbstractMetaLearner, com.rapidminer.operator.OperatorChain
    public int getMinNumberOfInnerOperators() {
        return 1;
    }

    @Override // com.rapidminer.operator.learner.meta.AbstractMetaLearner, com.rapidminer.operator.OperatorChain
    public int getMaxNumberOfInnerOperators() {
        return 1;
    }

    @Override // com.rapidminer.operator.learner.meta.AbstractMetaLearner, com.rapidminer.operator.learner.Learner
    public boolean supportsCapability(LearnerCapability learnerCapability) {
        if (learnerCapability.equals(LearnerCapability.BINOMINAL_CLASS) || learnerCapability.equals(LearnerCapability.POLYNOMINAL_CLASS)) {
            return false;
        }
        return super.supportsCapability(learnerCapability);
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeInt("iterations", "The number of iterations.", 1, Integer.MAX_VALUE, 10));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_SHRINKAGE, "Reducing this learning rate prevent overfitting but increases the learning time.", 0.0d, 1.0d, 1.0d));
        return parameterTypes;
    }
}
