package com.rapidminer.operator.learner.bayes;

import Jama.Matrix;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.SimplePredictionModel;
import com.rapidminer.tools.Tools;

/* loaded from: input_file:com/rapidminer/operator/learner/bayes/DiscriminantModel.class */
public class DiscriminantModel extends SimplePredictionModel {
    private static final long serialVersionUID = 3793343069512113817L;
    private double alpha;
    private String[] labels;
    private Matrix[] meanVectors;
    private Matrix[] inverseCovariances;
    private double[] aprioriProbabilities;
    private double[] constClassValues;

    public DiscriminantModel(ExampleSet exampleSet, String[] strArr, Matrix[] matrixArr, Matrix[] matrixArr2, double[] dArr, double d) {
        super(exampleSet);
        this.alpha = d;
        this.labels = strArr;
        this.meanVectors = matrixArr;
        this.inverseCovariances = matrixArr2;
        this.aprioriProbabilities = dArr;
        this.constClassValues = new double[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            this.constClassValues[i] = ((-0.5d) * matrixArr[i].times(matrixArr2[i]).times(matrixArr[i].transpose()).get(0, 0)) + Math.log(dArr[i]);
        }
    }

    @Override // com.rapidminer.operator.learner.SimplePredictionModel
    public double predict(Example example) throws OperatorException {
        double[] dArr = new double[this.meanVectors[0].getColumnDimension()];
        int i = 0;
        for (Attribute attribute : example.getAttributes()) {
            if (attribute.isNumerical()) {
                dArr[i] = example.getValue(attribute);
                i++;
            }
        }
        Matrix matrix = new Matrix(dArr, 1);
        double[] dArr2 = new double[this.labels.length];
        for (int i2 = 0; i2 < this.labels.length; i2++) {
            dArr2[i2] = matrix.times(this.inverseCovariances[i2]).times(this.meanVectors[i2].transpose()).get(0, 0) + this.constClassValues[i2];
        }
        double d = Double.NEGATIVE_INFINITY;
        int i3 = 0;
        for (int i4 = 0; i4 < this.labels.length; i4++) {
            if (dArr2[i4] >= d) {
                i3 = i4;
                d = dArr2[i4];
            }
        }
        return i3;
    }

    @Override // com.rapidminer.operator.AbstractModel, com.rapidminer.operator.ResultObjectAdapter, com.rapidminer.operator.ResultObject
    public String getName() {
        return this.alpha == 0.0d ? "Quadratic Discriminant Model" : this.alpha == 1.0d ? "Linear Discriminant Model" : "Regularized Discriminant Model";
    }

    @Override // com.rapidminer.operator.learner.PredictionModel, com.rapidminer.report.Readable
    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Apriori probabilities:\n");
        for (int i = 0; i < this.labels.length; i++) {
            stringBuffer.append(String.valueOf(this.labels[i]) + "\t");
            stringBuffer.append(String.valueOf(Tools.formatNumber(this.aprioriProbabilities[i], 4)) + "\n");
        }
        return stringBuffer.toString();
    }
}
