package com.rapidminer.operator.learner.functions.kernel;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.FastExample2SparseTransform;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.FormulaProvider;
import com.rapidminer.tools.Tools;
import java.util.Iterator;
import libsvm.Svm;
import libsvm.svm_model;
import libsvm.svm_node;
import net.didion.jwnl.dictionary.file.DictionaryFile;
import opennlp.tools.parser.Parse;

/* loaded from: input_file:WEB-INF/lib/rapidMiner-1.0.0.jar:com/rapidminer/operator/learner/functions/kernel/LibSVMModel.class */
public class LibSVMModel extends KernelModel implements FormulaProvider {
    private static final long serialVersionUID = -2654603017217487365L;
    private svm_model model;
    private int numberOfAttributes;
    private boolean confidenceForMultiClass;

    public LibSVMModel(ExampleSet exampleSet, svm_model svm_modelVar, int i, boolean z) {
        super(exampleSet);
        this.confidenceForMultiClass = true;
        this.model = svm_modelVar;
        this.numberOfAttributes = i;
        this.confidenceForMultiClass = z;
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.KernelModel
    public boolean isClassificationModel() {
        return getLabel().isNominal();
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.KernelModel
    public double getAlpha(int i) {
        return this.model.sv_coef[0][i];
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.KernelModel
    public String getId(int i) {
        return null;
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.KernelModel
    public int getNumberOfSupportVectors() {
        return this.model.SV.length;
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.KernelModel
    public int getNumberOfAttributes() {
        return this.numberOfAttributes;
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.KernelModel
    public double getBias() {
        if (this.model.rho.length > 0) {
            return this.model.rho[0];
        }
        return 0.0d;
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.KernelModel
    public SupportVector getSupportVector(int i) {
        svm_node[] svm_nodeVarArr = this.model.SV[i];
        double[] dArr = new double[getNumberOfAttributes()];
        for (int i2 = 0; i2 < svm_nodeVarArr.length; i2++) {
            dArr[svm_nodeVarArr[i2].index] = svm_nodeVarArr[i2].value;
        }
        return new SupportVector(dArr, getRegressionLabel(i), Math.abs(getAlpha(i)));
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.KernelModel
    public double getAttributeValue(int i, int i2) {
        double[] dArr = new double[this.numberOfAttributes];
        svm_node[] svm_nodeVarArr = this.model.SV[i];
        for (int i3 = 0; i3 < svm_nodeVarArr.length; i3++) {
            dArr[svm_nodeVarArr[i3].index] = svm_nodeVarArr[i3].value;
        }
        return dArr[i2];
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.KernelModel
    public String getClassificationLabel(int i) {
        double regressionLabel = getRegressionLabel(i);
        return !Double.isNaN(regressionLabel) ? getLabel().getMapping().mapIndex((int) regressionLabel) : "?";
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.KernelModel
    public double getRegressionLabel(int i) {
        if (this.model.labelValues != null) {
            return this.model.labelValues[i];
        }
        return Double.NaN;
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.KernelModel
    public double getFunctionValue(int i) {
        if (!getLabel().isNominal()) {
            return Svm.svm_predict(this.model, this.model.SV[i]);
        }
        double[] dArr = new double[getLabel().getMapping().size()];
        Svm.svm_predict_probability(this.model, this.model.SV[i], dArr);
        return dArr[0];
    }

    @Override // com.rapidminer.operator.learner.PredictionModel
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute attribute) throws UserError {
        FastExample2SparseTransform fastExample2SparseTransform = new FastExample2SparseTransform(exampleSet);
        Attribute label = getLabel();
        Attribute[] attributeArr = (Attribute[]) null;
        if (label.isNominal() && label.getMapping().size() >= 2) {
            attributeArr = new Attribute[this.model.label.length];
            for (int i = 0; i < this.model.label.length; i++) {
                attributeArr[i] = exampleSet.getAttributes().getSpecial("confidence_" + label.getMapping().mapIndex(this.model.label[i]));
            }
        }
        if (label.isNominal() && label.getMapping().size() == 1) {
            double[] dArr = new double[exampleSet.size()];
            int i2 = 0;
            double d = Double.NEGATIVE_INFINITY;
            double d2 = Double.POSITIVE_INFINITY;
            Iterator<Example> it = exampleSet.iterator();
            while (it.hasNext()) {
                svm_node[] makeNodes = LibSVMLearner.makeNodes(it.next(), fastExample2SparseTransform);
                double[] dArr2 = new double[1];
                Svm.svm_predict_values(this.model, makeNodes, dArr2);
                int i3 = i2;
                i2++;
                dArr[i3] = dArr2[0];
                d2 = Math.min(d2, dArr2[0]);
                d = Math.max(d, dArr2[0]);
            }
            int i4 = 0;
            String mapIndex = attribute.getMapping().mapIndex(0);
            for (Example example : exampleSet) {
                example.setValue(attribute, 0.0d);
                int i5 = i4;
                i4++;
                example.setConfidence(mapIndex, (dArr[i5] - d2) / (d - d2));
            }
        } else {
            for (Example example2 : exampleSet) {
                if (label.isNominal()) {
                    svm_node[] makeNodes2 = LibSVMLearner.makeNodes(example2, fastExample2SparseTransform);
                    if (this.model.probA == null || this.model.probB == null) {
                        double svm_predict = Svm.svm_predict(this.model, makeNodes2);
                        example2.setValue(attribute, svm_predict);
                        if (label.getMapping().size() == 2) {
                            double[] dArr3 = new double[this.model.nr_class];
                            Svm.svm_predict_values(this.model, makeNodes2, dArr3);
                            double d3 = dArr3[0];
                            if (attributeArr != null && attributeArr.length > 0) {
                                example2.setValue(attributeArr[0], 1.0d / (1.0d + Math.exp(-d3)));
                                if (attributeArr.length > 1) {
                                    example2.setValue(attributeArr[1], 1.0d / (1.0d + Math.exp(d3)));
                                }
                            }
                        } else {
                            example2.setConfidence(getLabel().getMapping().mapIndex((int) svm_predict), 1.0d);
                        }
                    } else {
                        double[] dArr4 = new double[this.model.nr_class];
                        int i6 = this.model.nr_class;
                        double[] dArr5 = new double[(i6 * (i6 - 1)) / 2];
                        Svm.svm_predict_values(this.model, makeNodes2, dArr5);
                        double[][] dArr6 = new double[i6][i6];
                        int i7 = 0;
                        for (int i8 = 0; i8 < i6; i8++) {
                            for (int i9 = i8 + 1; i9 < i6; i9++) {
                                dArr6[i8][i9] = Math.min(Math.max(Svm.sigmoid_predict(dArr5[i7], this.model.probA[i7], this.model.probB[i7]), 1.0E-7d), 1.0d - 1.0E-7d);
                                dArr6[i9][i8] = 1.0d - dArr6[i8][i9];
                                i7++;
                            }
                        }
                        Svm.multiclass_probability(i6, dArr6, dArr4);
                        for (int i10 = 0; i10 < i6; i10++) {
                            example2.setValue(attributeArr[i10], dArr4[i10]);
                        }
                        if (this.confidenceForMultiClass) {
                            example2.setValue(attribute, Svm.svm_predict_probability(this.model, makeNodes2, dArr4));
                        } else {
                            example2.setValue(attribute, Svm.svm_predict(this.model, makeNodes2));
                        }
                    }
                } else {
                    example2.setValue(attribute, Svm.svm_predict(this.model, LibSVMLearner.makeNodes(example2, fastExample2SparseTransform)));
                }
            }
        }
        return exampleSet;
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.KernelModel, com.rapidminer.operator.learner.PredictionModel, com.rapidminer.report.Readable
    public String toString() {
        StringBuffer stringBuffer = new StringBuffer(String.valueOf(super.toString()) + Tools.getLineSeparator());
        stringBuffer.append("number of classes: " + this.model.nr_class + Tools.getLineSeparator());
        if (!getLabel().isNominal() || getLabel().getMapping().size() < 2 || this.model.nSV == null) {
            stringBuffer.append("number of support vectors: " + this.model.l + Tools.getLineSeparator());
        } else {
            for (int i = 0; i < this.model.nSV.length; i++) {
                stringBuffer.append("number of support vectors for class " + getLabel().getMapping().mapIndex(this.model.label[i]) + ": " + this.model.nSV[i] + Tools.getLineSeparator());
            }
        }
        return stringBuffer.toString();
    }

    @Override // com.rapidminer.operator.learner.FormulaProvider
    public String getFormula() {
        StringBuffer stringBuffer = new StringBuffer();
        int i = this.model.param.kernel_type;
        if (i == 4) {
            return "Precomputed kernel, no formula possible.";
        }
        if (i == 2) {
            return "RBF kernel, no formula possible.";
        }
        boolean z = true;
        for (int i2 = 0; i2 < getNumberOfSupportVectors(); i2++) {
            SupportVector supportVector = getSupportVector(i2);
            if (supportVector != null) {
                double alpha = supportVector.getAlpha();
                if (!Tools.isZero(alpha)) {
                    stringBuffer.append(Tools.getLineSeparator());
                    double[] x = supportVector.getX();
                    double y = supportVector.getY() * alpha;
                    if (y < 0.0d) {
                        if (z) {
                            stringBuffer.append("- " + Math.abs(y));
                        } else {
                            stringBuffer.append("- " + Math.abs(y));
                        }
                    } else if (z) {
                        stringBuffer.append(DictionaryFile.COMMENT_HEADER + y);
                    } else {
                        stringBuffer.append("+ " + y);
                    }
                    stringBuffer.append(" * (" + getDistanceFormula(x, getAttributeConstructions()) + Parse.BRACKET_RRB);
                    z = false;
                }
            }
        }
        double bias = getBias();
        if (!Tools.isZero(bias)) {
            stringBuffer.append(Tools.getLineSeparator());
            if (bias < 0.0d) {
                if (z) {
                    stringBuffer.append("- " + Math.abs(bias));
                } else {
                    stringBuffer.append("- " + Math.abs(bias));
                }
            } else if (z) {
                stringBuffer.append(bias);
            } else {
                stringBuffer.append("+ " + bias);
            }
        }
        return stringBuffer.toString();
    }

    private String getDistanceFormula(double[] dArr, String[] strArr) {
        switch (this.model.param.kernel_type) {
            case 0:
                StringBuffer stringBuffer = new StringBuffer();
                boolean z = true;
                for (int i = 0; i < dArr.length; i++) {
                    double d = dArr[i];
                    if (!Tools.isZero(d)) {
                        if (d < 0.0d) {
                            if (z) {
                                stringBuffer.append("-" + Math.abs(d) + " * " + strArr[i]);
                            } else {
                                stringBuffer.append(" - " + Math.abs(d) + " * " + strArr[i]);
                            }
                        } else if (z) {
                            stringBuffer.append(String.valueOf(d) + " * " + strArr[i]);
                        } else {
                            stringBuffer.append(" + " + d + " * " + strArr[i]);
                        }
                        z = false;
                    }
                }
                return stringBuffer.toString();
            case 1:
                StringBuffer stringBuffer2 = new StringBuffer();
                boolean z2 = true;
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    double d2 = dArr[i2];
                    if (!Tools.isZero(d2)) {
                        if (d2 < 0.0d) {
                            if (z2) {
                                stringBuffer2.append("-" + Math.abs(d2) + " * " + strArr[i2]);
                            } else {
                                stringBuffer2.append(" - " + Math.abs(d2) + " * " + strArr[i2]);
                            }
                        } else if (z2) {
                            stringBuffer2.append(String.valueOf(d2) + " * " + strArr[i2]);
                        } else {
                            stringBuffer2.append(" + " + d2 + " * " + strArr[i2]);
                        }
                        z2 = false;
                    }
                }
                return "pow((" + this.model.param.gamma + " * (" + stringBuffer2.toString() + ") + " + this.model.param.coef0 + "), " + this.model.param.degree + Parse.BRACKET_RRB;
            case 2:
            default:
                return "";
            case 3:
                StringBuffer stringBuffer3 = new StringBuffer();
                boolean z3 = true;
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    double d3 = dArr[i3];
                    if (!Tools.isZero(d3)) {
                        if (d3 < 0.0d) {
                            if (z3) {
                                stringBuffer3.append("-" + Math.abs(d3) + " * " + strArr[i3]);
                            } else {
                                stringBuffer3.append(" - " + Math.abs(d3) + " * " + strArr[i3]);
                            }
                        } else if (z3) {
                            stringBuffer3.append(String.valueOf(d3) + " * " + strArr[i3]);
                        } else {
                            stringBuffer3.append(" + " + d3 + " * " + strArr[i3]);
                        }
                        z3 = false;
                    }
                }
                return "tanh(" + this.model.param.gamma + " * (" + stringBuffer3.toString() + ") + " + this.model.param.coef0 + Parse.BRACKET_RRB;
        }
    }
}
