package com.rapidminer.operator.learner.functions.neuralnet;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeList;
import com.rapidminer.tools.RandomGenerator;
import java.util.List;

/* loaded from: input_file:com/rapidminer/operator/learner/functions/neuralnet/ImprovedNeuralNetLearner.class */
public class ImprovedNeuralNetLearner extends AbstractLearner {
    public static final String PARAMETER_HIDDEN_LAYERS = "hidden_layers";
    public static final String PARAMETER_TRAINING_CYCLES = "training_cycles";
    public static final String PARAMETER_ERROR_EPSILON = "error_epsilon";
    public static final String PARAMETER_LEARNING_RATE = "learning_rate";
    public static final String PARAMETER_MOMENTUM = "momentum";
    public static final String PARAMETER_DECAY = "decay";
    public static final String PARAMETER_SHUFFLE = "shuffle";
    public static final String PARAMETER_NORMALIZE = "normalize";
    public static final String PARAMETER_LOCAL_RANDOM_SEED = "local_random_seed";

    public ImprovedNeuralNetLearner(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        ImprovedNeuralNetModel improvedNeuralNetModel = new ImprovedNeuralNetModel(exampleSet);
        improvedNeuralNetModel.train(exampleSet, getParameterList(PARAMETER_HIDDEN_LAYERS), getParameterAsInt("training_cycles"), getParameterAsDouble("error_epsilon"), getParameterAsDouble("learning_rate"), getParameterAsDouble("momentum"), getParameterAsBoolean(PARAMETER_DECAY), getParameterAsBoolean(PARAMETER_SHUFFLE), getParameterAsBoolean("normalize"), RandomGenerator.getRandomGenerator(getParameterAsInt("local_random_seed")));
        return improvedNeuralNetModel;
    }

    @Override // com.rapidminer.operator.learner.Learner
    public boolean supportsCapability(LearnerCapability learnerCapability) {
        return learnerCapability == LearnerCapability.NUMERICAL_ATTRIBUTES || learnerCapability == LearnerCapability.POLYNOMINAL_CLASS || learnerCapability == LearnerCapability.BINOMINAL_CLASS || learnerCapability == LearnerCapability.NUMERICAL_CLASS || learnerCapability == LearnerCapability.WEIGHTED_EXAMPLES;
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        ParameterTypeList parameterTypeList = new ParameterTypeList(PARAMETER_HIDDEN_LAYERS, "Describes the name and the size of all hidden layers.", new ParameterTypeInt(SimpleNeuralNetLearner.PARAMETER_HIDDEN_LAYER_SIZES, "The name and the size of the hidden layers, e.g. 'First Hidden Layer' and 5. A size of < 0 leads to a layer size of (number_of_attributes + number of classes) / 2 + 1.", -1, Integer.MAX_VALUE, -1));
        parameterTypeList.setExpert(false);
        parameterTypes.add(parameterTypeList);
        ParameterTypeInt parameterTypeInt = new ParameterTypeInt("training_cycles", "The number of training cycles used for the neural network training.", 1, Integer.MAX_VALUE, 500);
        parameterTypeInt.setExpert(false);
        parameterTypes.add(parameterTypeInt);
        ParameterTypeDouble parameterTypeDouble = new ParameterTypeDouble("learning_rate", "The learning rate determines by how much we change the weights at each step.", 0.0d, 1.0d, 0.3d);
        parameterTypeDouble.setExpert(false);
        parameterTypes.add(parameterTypeDouble);
        parameterTypes.add(new ParameterTypeDouble("momentum", "The momentum simply adds a fraction of the previous weight update to the current one (prevent local maxima and smoothes optimization directions).", 0.0d, 1.0d, 0.2d));
        parameterTypes.add(new ParameterTypeBoolean(PARAMETER_DECAY, "Indicates if the learning rate should be decreased during learningh", false));
        parameterTypes.add(new ParameterTypeBoolean(PARAMETER_SHUFFLE, "Indicates if the input data should be shuffled before learning (increases memory usage but is recommended if data is sorted before)", true));
        parameterTypes.add(new ParameterTypeBoolean("normalize", "Indicates if the input data should be normalized between -1 and +1 before learning (increases runtime but is in most cases necessary)", true));
        parameterTypes.add(new ParameterTypeDouble("error_epsilon", "The optimization is stopped if the training error gets below this epsilon value.", 0.0d, Double.POSITIVE_INFINITY, 1.0E-5d));
        parameterTypes.add(new ParameterTypeInt("local_random_seed", "Use the given random seed instead of global random numbers (-1: use global)", -1, Integer.MAX_VALUE, -1));
        return parameterTypes;
    }
}
