/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions.kernel;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.FastExample2SparseTransform;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.FormulaProvider;
import com.rapidminer.operator.learner.functions.kernel.KernelModel;
import com.rapidminer.operator.learner.functions.kernel.LibSVMLearner;
import com.rapidminer.operator.learner.functions.kernel.SupportVector;
import com.rapidminer.tools.Tools;
import libsvm.Svm;
import libsvm.svm_model;
import libsvm.svm_node;

public class LibSVMModel
extends KernelModel
implements FormulaProvider {
    private static final long serialVersionUID = -2654603017217487365L;
    private svm_model model;
    private int numberOfAttributes;
    private boolean confidenceForMultiClass = true;

    public LibSVMModel(ExampleSet exampleSet, svm_model model, int numberOfAttributes, boolean confidenceForMultiClass) {
        super(exampleSet);
        this.model = model;
        this.numberOfAttributes = numberOfAttributes;
        this.confidenceForMultiClass = confidenceForMultiClass;
    }

    public boolean isClassificationModel() {
        return this.getLabel().isNominal();
    }

    public double getAlpha(int index) {
        return this.model.sv_coef[0][index];
    }

    public String getId(int index) {
        return null;
    }

    public int getNumberOfSupportVectors() {
        return this.model.SV.length;
    }

    public int getNumberOfAttributes() {
        return this.numberOfAttributes;
    }

    public double getBias() {
        if (this.model.rho.length > 0) {
            return this.model.rho[0];
        }
        return 0.0;
    }

    public SupportVector getSupportVector(int index) {
        svm_node[] nodes = this.model.SV[index];
        double[] x = new double[this.getNumberOfAttributes()];
        int i = 0;
        while (i < nodes.length) {
            x[nodes[i].index] = nodes[i].value;
            ++i;
        }
        return new SupportVector(x, this.getRegressionLabel(index), Math.abs(this.getAlpha(index)));
    }

    public double getAttributeValue(int exampleIndex, int attributeIndex) {
        double[] dense = new double[this.numberOfAttributes];
        svm_node[] node = this.model.SV[exampleIndex];
        int i = 0;
        while (i < node.length) {
            dense[node[i].index] = node[i].value;
            ++i;
        }
        return dense[attributeIndex];
    }

    public String getClassificationLabel(int index) {
        double functionValue = this.getRegressionLabel(index);
        if (!Double.isNaN(functionValue)) {
            return this.getLabel().getMapping().mapIndex((int)functionValue);
        }
        return "?";
    }

    public double getRegressionLabel(int index) {
        if (this.model.labelValues != null) {
            return this.model.labelValues[index];
        }
        return Double.NaN;
    }

    public double getFunctionValue(int index) {
        if (this.getLabel().isNominal()) {
            double[] classProbs = new double[this.getLabel().getMapping().size()];
            Svm.svm_predict_probability(this.model, this.model.SV[index], classProbs);
            return classProbs[0];
        }
        return Svm.svm_predict(this.model, this.model.SV[index]);
    }

    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws UserError {
        FastExample2SparseTransform ripper = new FastExample2SparseTransform(exampleSet);
        Attribute label = this.getLabel();
        Attribute[] confidenceAttributes = null;
        if (label.isNominal() && label.getMapping().size() >= 2) {
            confidenceAttributes = new Attribute[this.model.label.length];
            int j = 0;
            while (j < this.model.label.length) {
                String labelName = label.getMapping().mapIndex(this.model.label[j]);
                confidenceAttributes[j] = exampleSet.getAttributes().getSpecial("confidence_" + labelName);
                ++j;
            }
        }
        if (label.isNominal() && label.getMapping().size() == 1) {
            double[] allConfidences = new double[exampleSet.size()];
            int counter = 0;
            double maxConfidence = Double.NEGATIVE_INFINITY;
            double minConfidence = Double.POSITIVE_INFINITY;
            for (Example e : exampleSet) {
                svm_node[] currentNodes = LibSVMLearner.makeNodes(e, ripper);
                double[] prob = new double[1];
                Svm.svm_predict_values(this.model, currentNodes, prob);
                allConfidences[counter++] = prob[0];
                minConfidence = Math.min(minConfidence, prob[0]);
                maxConfidence = Math.max(maxConfidence, prob[0]);
            }
            counter = 0;
            String className = predictedLabel.getMapping().mapIndex(0);
            for (Example e : exampleSet) {
                e.setValue(predictedLabel, 0.0);
                e.setConfidence(className, (allConfidences[counter++] - minConfidence) / (maxConfidence - minConfidence));
            }
        } else {
            for (Example e : exampleSet) {
                if (label.isNominal()) {
                    svm_node[] currentNodes = LibSVMLearner.makeNodes(e, ripper);
                    if (this.model.probA != null && this.model.probB != null) {
                        double predictedClass;
                        double[] classProbs = new double[this.model.nr_class];
                        int nr_class = this.model.nr_class;
                        double[] dec_values = new double[nr_class * (nr_class - 1) / 2];
                        Svm.svm_predict_values(this.model, currentNodes, dec_values);
                        double min_prob = 1.0E-7;
                        double[][] pairwise_prob = new double[nr_class][nr_class];
                        int k = 0;
                        int a = 0;
                        while (a < nr_class) {
                            int j = a + 1;
                            while (j < nr_class) {
                                pairwise_prob[a][j] = Math.min(Math.max(Svm.sigmoid_predict(dec_values[k], this.model.probA[k], this.model.probB[k]), min_prob), 1.0 - min_prob);
                                pairwise_prob[j][a] = 1.0 - pairwise_prob[a][j];
                                ++k;
                                ++j;
                            }
                            ++a;
                        }
                        Svm.multiclass_probability(nr_class, pairwise_prob, classProbs);
                        k = 0;
                        while (k < nr_class) {
                            e.setValue(confidenceAttributes[k], classProbs[k]);
                            ++k;
                        }
                        if (this.confidenceForMultiClass) {
                            predictedClass = Svm.svm_predict_probability(this.model, currentNodes, classProbs);
                            e.setValue(predictedLabel, predictedClass);
                            continue;
                        }
                        predictedClass = Svm.svm_predict(this.model, currentNodes);
                        e.setValue(predictedLabel, predictedClass);
                        continue;
                    }
                    double predictedClass = Svm.svm_predict(this.model, currentNodes);
                    e.setValue(predictedLabel, predictedClass);
                    if (label.getMapping().size() == 2) {
                        double[] functionValues = new double[this.model.nr_class];
                        Svm.svm_predict_values(this.model, currentNodes, functionValues);
                        double prediction = functionValues[0];
                        if (confidenceAttributes == null || confidenceAttributes.length <= 0) continue;
                        e.setValue(confidenceAttributes[0], 1.0 / (1.0 + Math.exp(-prediction)));
                        if (confidenceAttributes.length <= 1) continue;
                        e.setValue(confidenceAttributes[1], 1.0 / (1.0 + Math.exp(prediction)));
                        continue;
                    }
                    e.setConfidence(this.getLabel().getMapping().mapIndex((int)predictedClass), 1.0);
                    continue;
                }
                e.setValue(predictedLabel, Svm.svm_predict(this.model, LibSVMLearner.makeNodes(e, ripper)));
            }
        }
        return exampleSet;
    }

    public String toString() {
        StringBuffer result = new StringBuffer(String.valueOf(super.toString()) + Tools.getLineSeparator());
        result.append("number of classes: " + this.model.nr_class + Tools.getLineSeparator());
        if (this.getLabel().isNominal() && this.getLabel().getMapping().size() >= 2 && this.model.nSV != null) {
            int i = 0;
            while (i < this.model.nSV.length) {
                result.append("number of support vectors for class " + this.getLabel().getMapping().mapIndex(this.model.label[i]) + ": " + this.model.nSV[i] + Tools.getLineSeparator());
                ++i;
            }
        } else {
            result.append("number of support vectors: " + this.model.l + Tools.getLineSeparator());
        }
        return result.toString();
    }

    public String getFormula() {
        StringBuffer result = new StringBuffer();
        int kernelType = this.model.param.kernel_type;
        if (kernelType == 4) {
            return "Precomputed kernel, no formula possible.";
        }
        if (kernelType == 2) {
            return "RBF kernel, no formula possible.";
        }
        boolean first = true;
        int i = 0;
        while (i < this.getNumberOfSupportVectors()) {
            double alpha;
            SupportVector sv = this.getSupportVector(i);
            if (sv != null && !Tools.isZero(alpha = sv.getAlpha())) {
                result.append(Tools.getLineSeparator());
                double[] x = sv.getX();
                double y = sv.getY();
                double factor = y * alpha;
                if (factor < 0.0) {
                    if (first) {
                        result.append("- " + Math.abs(factor));
                    } else {
                        result.append("- " + Math.abs(factor));
                    }
                } else if (first) {
                    result.append("  " + factor);
                } else {
                    result.append("+ " + factor);
                }
                result.append(" * (" + this.getDistanceFormula(x, this.getAttributeConstructions()) + ")");
                first = false;
            }
            ++i;
        }
        double bias = this.getBias();
        if (!Tools.isZero(bias)) {
            result.append(Tools.getLineSeparator());
            if (bias < 0.0) {
                if (first) {
                    result.append("- " + Math.abs(bias));
                } else {
                    result.append("- " + Math.abs(bias));
                }
            } else if (first) {
                result.append(bias);
            } else {
                result.append("+ " + bias);
            }
        }
        return result.toString();
    }

    private String getDistanceFormula(double[] x, String[] attributeConstructions) {
        int kernelType = this.model.param.kernel_type;
        switch (kernelType) {
            case 0: {
                StringBuffer result = new StringBuffer();
                boolean first = true;
                int i = 0;
                while (i < x.length) {
                    double value = x[i];
                    if (!Tools.isZero(value)) {
                        if (value < 0.0) {
                            if (first) {
                                result.append("-" + Math.abs(value) + " * " + attributeConstructions[i]);
                            } else {
                                result.append(" - " + Math.abs(value) + " * " + attributeConstructions[i]);
                            }
                        } else if (first) {
                            result.append(String.valueOf(value) + " * " + attributeConstructions[i]);
                        } else {
                            result.append(" + " + value + " * " + attributeConstructions[i]);
                        }
                        first = false;
                    }
                    ++i;
                }
                return result.toString();
            }
            case 1: {
                StringBuffer dotResult = new StringBuffer();
                boolean first = true;
                int i = 0;
                while (i < x.length) {
                    double value = x[i];
                    if (!Tools.isZero(value)) {
                        if (value < 0.0) {
                            if (first) {
                                dotResult.append("-" + Math.abs(value) + " * " + attributeConstructions[i]);
                            } else {
                                dotResult.append(" - " + Math.abs(value) + " * " + attributeConstructions[i]);
                            }
                        } else if (first) {
                            dotResult.append(String.valueOf(value) + " * " + attributeConstructions[i]);
                        } else {
                            dotResult.append(" + " + value + " * " + attributeConstructions[i]);
                        }
                        first = false;
                    }
                    ++i;
                }
                return "pow((" + this.model.param.gamma + " * (" + dotResult.toString() + ") + " + this.model.param.coef0 + "), " + this.model.param.degree + ")";
            }
            case 3: {
                StringBuffer dotResult = new StringBuffer();
                boolean first = true;
                int i = 0;
                while (i < x.length) {
                    double value = x[i];
                    if (!Tools.isZero(value)) {
                        if (value < 0.0) {
                            if (first) {
                                dotResult.append("-" + Math.abs(value) + " * " + attributeConstructions[i]);
                            } else {
                                dotResult.append(" - " + Math.abs(value) + " * " + attributeConstructions[i]);
                            }
                        } else if (first) {
                            dotResult.append(String.valueOf(value) + " * " + attributeConstructions[i]);
                        } else {
                            dotResult.append(" + " + value + " * " + attributeConstructions[i]);
                        }
                        first = false;
                    }
                    ++i;
                }
                return "tanh(" + this.model.param.gamma + " * (" + dotResult.toString() + ") + " + this.model.param.coef0 + ")";
            }
        }
        return "";
    }
}

