/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.functions.LogisticRegressionModel;
import com.rapidminer.operator.learner.functions.LogisticRegressionOptimization;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.math.optimization.ec.es.ESOptimization;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class LogisticRegression
extends AbstractLearner {
    public static final String PARAMETER_ADD_INTERCEPT = "add_intercept";
    public static final String PARAMETER_RETURN_PERFORMANCE = "return_model_performance";
    public static final String PARAMETER_START_POPULATION_TYPE = "start_population_type";
    public static final String PARAMETER_MAX_GENERATIONS = "max_generations";
    public static final String PARAMETER_GENERATIONS_WITHOUT_IMPROVAL = "generations_without_improval";
    public static final String PARAMETER_POPULATION_SIZE = "population_size";
    public static final String PARAMETER_TOURNAMENT_FRACTION = "tournament_fraction";
    public static final String PARAMETER_KEEP_BEST = "keep_best";
    public static final String PARAMETER_MUTATION_TYPE = "mutation_type";
    public static final String PARAMETER_SELECTION_TYPE = "selection_type";
    public static final String PARAMETER_CROSSOVER_PROB = "crossover_prob";
    public static final String PARAMETER_LOCAL_RANDOM_SEED = "local_random_seed";
    public static final String PARAMETER_SHOW_CONVERGENCE_PLOT = "show_convergence_plot";
    private PerformanceVector estimatedPerformance;

    public LogisticRegression(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        RandomGenerator random = RandomGenerator.getRandomGenerator(this.getParameterAsInt(PARAMETER_LOCAL_RANDOM_SEED));
        LogisticRegressionOptimization optimization = new LogisticRegressionOptimization(exampleSet, this.getParameterAsBoolean(PARAMETER_ADD_INTERCEPT), this.getParameterAsInt(PARAMETER_START_POPULATION_TYPE), this.getParameterAsInt(PARAMETER_MAX_GENERATIONS), this.getParameterAsInt(PARAMETER_GENERATIONS_WITHOUT_IMPROVAL), this.getParameterAsInt(PARAMETER_POPULATION_SIZE), this.getParameterAsInt(PARAMETER_SELECTION_TYPE), this.getParameterAsDouble(PARAMETER_TOURNAMENT_FRACTION), this.getParameterAsBoolean(PARAMETER_KEEP_BEST), this.getParameterAsInt(PARAMETER_MUTATION_TYPE), this.getParameterAsDouble(PARAMETER_CROSSOVER_PROB), this.getParameterAsBoolean(PARAMETER_SHOW_CONVERGENCE_PLOT), random, this);
        LogisticRegressionModel model = optimization.train();
        this.estimatedPerformance = optimization.getPerformance();
        return model;
    }

    @Override
    public boolean shouldEstimatePerformance() {
        return this.getParameterAsBoolean(PARAMETER_RETURN_PERFORMANCE);
    }

    @Override
    public PerformanceVector getEstimatedPerformance() throws OperatorException {
        if (this.getParameterAsBoolean(PARAMETER_RETURN_PERFORMANCE) && this.estimatedPerformance != null) {
            return this.estimatedPerformance;
        }
        throw new UserError((Operator)this, 912, this.getName(), "could not deliver optimization performance.");
    }

    @Override
    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.NUMERICAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.BINOMINAL_CLASS) {
            return true;
        }
        return lc == LearnerCapability.WEIGHTED_EXAMPLES;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeBoolean(PARAMETER_ADD_INTERCEPT, "Determines whether to include an intercept.", true));
        types.add(new ParameterTypeBoolean(PARAMETER_RETURN_PERFORMANCE, "Determines whether to return the performance.", false));
        types.add(new ParameterTypeCategory(PARAMETER_START_POPULATION_TYPE, "The type of start population initialization.", ESOptimization.POPULATION_INIT_TYPES, 0));
        types.add(new ParameterTypeInt(PARAMETER_MAX_GENERATIONS, "Stop after this many evaluations", 1, Integer.MAX_VALUE, 10000));
        types.add(new ParameterTypeInt(PARAMETER_GENERATIONS_WITHOUT_IMPROVAL, "Stop after this number of generations without improvement (-1: optimize until max_iterations).", -1, Integer.MAX_VALUE, 300));
        types.add(new ParameterTypeInt(PARAMETER_POPULATION_SIZE, "The population size (-1: number of examples)", -1, Integer.MAX_VALUE, 3));
        types.add(new ParameterTypeDouble(PARAMETER_TOURNAMENT_FRACTION, "The fraction of the population used for tournament selection.", 0.0, Double.POSITIVE_INFINITY, 0.75));
        types.add(new ParameterTypeBoolean(PARAMETER_KEEP_BEST, "Indicates if the best individual should survive (elititst selection).", true));
        types.add(new ParameterTypeCategory(PARAMETER_MUTATION_TYPE, "The type of the mutation operator.", ESOptimization.MUTATION_TYPES, 1));
        types.add(new ParameterTypeCategory(PARAMETER_SELECTION_TYPE, "The type of the selection operator.", ESOptimization.SELECTION_TYPES, 6));
        types.add(new ParameterTypeDouble(PARAMETER_CROSSOVER_PROB, "The probability for crossovers.", 0.0, 1.0, 1.0));
        types.add(new ParameterTypeInt(PARAMETER_LOCAL_RANDOM_SEED, "Use the given random seed instead of global random numbers (-1: use global).", -1, Integer.MAX_VALUE, -1));
        types.add(new ParameterTypeBoolean(PARAMETER_SHOW_CONVERGENCE_PLOT, "Indicates if a dialog with a convergence plot should be drawn.", false));
        return types;
    }
}

