package com.rapidminer.operator.learner.functions;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.FastExample2SparseTransform;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeList;
import java.util.LinkedList;
import java.util.List;
import liblinear.FeatureNode;
import liblinear.Linear;
import liblinear.Parameter;
import liblinear.Problem;
import liblinear.SolverType;

/* JADX WARN: Classes with same name are omitted:
  input_file:builds/deps.jar:com/rapidminer/operator/learner/functions/FastLargeMargin.class
  input_file:builds/deps.jar:rapidMiner.jar:com/rapidminer/operator/learner/functions/FastLargeMargin.class
  input_file:com/rapidminer/operator/learner/functions/FastLargeMargin.class
 */
/* loaded from: input_file:rapidMiner.jar:com/rapidminer/operator/learner/functions/FastLargeMargin.class */
public class FastLargeMargin extends AbstractLearner {
    public static final String PARAMETER_SOLVER = "solver";
    public static final String PARAMETER_C = "C";
    public static final String PARAMETER_EPSILON = "epsilon";
    public static final String PARAMETER_CLASS_WEIGHTS = "class_weights";
    public static final String PARAMETER_USE_BIAS = "use_bias";
    public static final String[] SOLVER = {"L2 SVM Dual", "L2 SVM Primal", "L2 Logistic Regression", "L1 SVM Dual"};
    public static final int SOLVER_L2_SVM_DUAL = 0;
    public static final int SOLVER_L2_SVM_PRIMAL = 1;
    public static final int SOLVER_L2_LR = 2;
    public static final int SOLVER_L1_SVM_DUAL = 3;

    public FastLargeMargin(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    @Override // com.rapidminer.operator.learner.Learner
    public boolean supportsCapability(LearnerCapability learnerCapability) {
        return learnerCapability == LearnerCapability.NUMERICAL_ATTRIBUTES || learnerCapability == LearnerCapability.BINOMINAL_CLASS;
    }

    public static FeatureNode[] makeNodes(Example example, FastExample2SparseTransform fastExample2SparseTransform, boolean z) {
        int[] nonDefaultAttributeIndices = fastExample2SparseTransform.getNonDefaultAttributeIndices(example);
        double[] nonDefaultAttributeValues = fastExample2SparseTransform.getNonDefaultAttributeValues(example, nonDefaultAttributeIndices);
        FeatureNode[] featureNodeArr = new FeatureNode[nonDefaultAttributeIndices.length + (z ? 1 : 0)];
        for (int i = 0; i < nonDefaultAttributeIndices.length; i++) {
            featureNodeArr[i] = new FeatureNode(nonDefaultAttributeIndices[i] + 1, nonDefaultAttributeValues[i]);
        }
        if (z) {
            featureNodeArr[featureNodeArr.length - 1] = new FeatureNode(featureNodeArr.length, 1.0d);
        }
        return featureNodeArr;
    }

    /* JADX WARN: Type inference failed for: r1v14, types: [liblinear.FeatureNode[], liblinear.FeatureNode[][]] */
    private Problem getProblem(ExampleSet exampleSet) throws UserError {
        log("Creating LibLinear problem.");
        FastExample2SparseTransform fastExample2SparseTransform = new FastExample2SparseTransform(exampleSet);
        int i = 0;
        Problem problem = new Problem();
        problem.l = exampleSet.size();
        boolean parameterAsBoolean = getParameterAsBoolean("use_bias");
        if (parameterAsBoolean) {
            problem.n = exampleSet.getAttributes().size() + 1;
        } else {
            problem.n = exampleSet.getAttributes().size();
        }
        problem.y = new int[exampleSet.size()];
        problem.x = new FeatureNode[exampleSet.size()];
        Attribute label = exampleSet.getAttributes().getLabel();
        int i2 = 0;
        int negativeIndex = label.getMapping().getNegativeIndex();
        for (Example example : exampleSet) {
            problem.x[i2] = makeNodes(example, fastExample2SparseTransform, parameterAsBoolean);
            problem.y[i2] = ((int) example.getValue(label)) == negativeIndex ? 0 : 1;
            i += problem.x[i2].length;
            i2++;
        }
        log("Created " + i + " nodes for " + i2 + " examples.");
        return problem;
    }

    private Parameter getParameters(ExampleSet exampleSet) throws OperatorException {
        SolverType solverType;
        switch (getParameterAsInt(PARAMETER_SOLVER)) {
            case 0:
                solverType = SolverType.L2LOSS_SVM_DUAL;
                break;
            case 1:
                solverType = SolverType.L2LOSS_SVM;
                break;
            case 2:
                solverType = SolverType.L2_LR;
                break;
            case 3:
                solverType = SolverType.L1LOSS_SVM_DUAL;
                break;
            default:
                solverType = SolverType.L2LOSS_SVM_DUAL;
                break;
        }
        Parameter parameter = new Parameter(solverType, getParameterAsDouble("C"), getParameterAsDouble("epsilon"));
        if (isParameterSet("class_weights")) {
            double[] dArr = new double[2];
            int[] iArr = new int[2];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = 1.0d;
                iArr[i] = i;
            }
            Attribute label = exampleSet.getAttributes().getLabel();
            for (String[] strArr : getParameterList("class_weights")) {
                String str = strArr[0];
                double doubleValue = Double.valueOf(strArr[1]).doubleValue();
                int index = label.getMapping().getIndex(str);
                if (index >= 0 && index < dArr.length) {
                    dArr[index] = doubleValue;
                }
            }
            LinkedList linkedList = new LinkedList();
            for (double d : dArr) {
                linkedList.add(Double.valueOf(d));
            }
            log(String.valueOf(getName()) + ": used class weights --> " + linkedList);
            parameter.setWeights(dArr, iArr);
        }
        return parameter;
    }

    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        Parameter parameters = getParameters(exampleSet);
        if (exampleSet.size() < 2) {
            throw new UserError(this, 110, 2);
        }
        Linear.resetRandom();
        Linear.disableDebugOutput();
        return new FastMarginModel(exampleSet, Linear.train(getProblem(exampleSet), parameters), getParameterAsBoolean("use_bias"));
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        ParameterTypeCategory parameterTypeCategory = new ParameterTypeCategory(PARAMETER_SOLVER, "The solver type for this fast margin method.", SOLVER, 0);
        parameterTypeCategory.setExpert(false);
        parameterTypes.add(parameterTypeCategory);
        ParameterTypeDouble parameterTypeDouble = new ParameterTypeDouble("C", "The cost parameter C for c_svc, epsilon_svr, and nu_svr.", 0.0d, Double.POSITIVE_INFINITY, 1.0d);
        parameterTypeDouble.setExpert(false);
        parameterTypes.add(parameterTypeDouble);
        parameterTypes.add(new ParameterTypeDouble("epsilon", "Tolerance of termination criterion.", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.01d));
        parameterTypes.add(new ParameterTypeList("class_weights", "The weights w for all classes (first column: class name, second column: weight), i.e. set the parameters C of each class w * C (empty: using 1 for all classes where the weight was not defined).", new ParameterTypeDouble("weight", "The weight for the specified class.", 0.0d, Double.POSITIVE_INFINITY, 1.0d)));
        parameterTypes.add(new ParameterTypeBoolean("use_bias", "Indicates if an intercept value should be calculated.", true));
        return parameterTypes;
    }
}
