package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.gui.tools.ExtendedJTabbedPane;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.LogService;
import com.rapidminer.tools.Tools;
import java.awt.Component;
import java.util.Iterator;
import java.util.List;

/* JADX WARN: Classes with same name are omitted:
  input_file:builds/deps.jar:com/rapidminer/operator/learner/meta/BayBoostModel.class
  input_file:builds/deps.jar:rapidMiner.jar:com/rapidminer/operator/learner/meta/BayBoostModel.class
  input_file:com/rapidminer/operator/learner/meta/BayBoostModel.class
 */
/* loaded from: input_file:rapidMiner.jar:com/rapidminer/operator/learner/meta/BayBoostModel.class */
public class BayBoostModel extends PredictionModel {
    private static final long serialVersionUID = 5821921049035718838L;
    private final List<BayBoostBaseModelInfo> modelInfo;
    private final double[] priors;
    private int maxModelNumber;
    private static final String MAX_MODEL_NUMBER = "iteration";
    private static final String CONV_TO_CRISP = "crisp";
    private double threshold;

    public BayBoostModel(ExampleSet exampleSet, List<BayBoostBaseModelInfo> list, double[] dArr) {
        super(exampleSet);
        this.maxModelNumber = -1;
        this.threshold = 0.5d;
        this.modelInfo = list;
        this.priors = dArr;
    }

    public BayBoostBaseModelInfo getBayBoostBaseModelInfo(int i) {
        return this.modelInfo.get(i);
    }

    public void setParameter(String str, String str2) throws OperatorException {
        if (str.equalsIgnoreCase(MAX_MODEL_NUMBER)) {
            try {
                this.maxModelNumber = Integer.parseInt(str2);
                return;
            } catch (NumberFormatException e) {
            }
        } else if (str.equalsIgnoreCase(CONV_TO_CRISP)) {
            this.threshold = Double.parseDouble(str2.trim());
            return;
        }
        super.setParameter(str, (Object) str2);
    }

    public void setMaxModelNumber(int i) {
        this.maxModelNumber = i;
    }

    @Override // com.rapidminer.operator.ResultObjectAdapter, com.rapidminer.operator.ResultObject
    public Component getVisualizationComponent(IOContainer iOContainer) {
        ExtendedJTabbedPane extendedJTabbedPane = new ExtendedJTabbedPane();
        for (int i = 0; i < getNumberOfModels(); i++) {
            extendedJTabbedPane.add("Model " + (i + 1), getModel(i).getVisualizationComponent(iOContainer));
        }
        return extendedJTabbedPane;
    }

    @Override // com.rapidminer.operator.learner.PredictionModel, com.rapidminer.report.Readable
    public String toString() {
        StringBuffer stringBuffer = new StringBuffer(String.valueOf(super.toString()) + Tools.getLineSeparator() + "Number of inner models: " + getNumberOfModels() + Tools.getLineSeparators(2));
        int i = 0;
        while (i < getNumberOfModels()) {
            stringBuffer.append(String.valueOf(i > 0 ? Tools.getLineSeparator() : "") + "Embedded model #" + i + Example.SPARSE_SEPARATOR + Tools.getLineSeparator() + getModel(i).toResultString());
            i++;
        }
        return stringBuffer.toString();
    }

    public int getNumberOfModels() {
        return this.maxModelNumber >= 0 ? Math.min(this.maxModelNumber, this.modelInfo.size()) : this.modelInfo.size();
    }

    private double[] getFactorsForModel(int i, int i2) {
        return this.modelInfo.get(i).getContingencyMatrix().getLiftRatiosForPrediction(i2);
    }

    private double getPriorOfClass(int i) {
        return this.priors[i];
    }

    public double[] getPriors() {
        double[] dArr = new double[this.priors.length];
        System.arraycopy(this.priors, 0, dArr, 0, dArr.length);
        return dArr;
    }

    public Model getModel(int i) {
        return this.modelInfo.get(i).getModel();
    }

    public ContingencyMatrix getContingencyMatrix(int i) {
        return this.modelInfo.get(i).getContingencyMatrix();
    }

    @Override // com.rapidminer.operator.learner.PredictionModel
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute attribute) throws OperatorException {
        Attribute[] createSpecialAttributes = createSpecialAttributes(exampleSet);
        initIntermediateResultAttributes(exampleSet, createSpecialAttributes);
        for (int i = 0; i < getNumberOfModels(); i++) {
            ExampleSet apply = getModel(i).apply((ExampleSet) exampleSet.clone());
            updateEstimates(apply, getContingencyMatrix(i), createSpecialAttributes);
            PredictionModel.removePredictedLabel(apply);
        }
        Iterator<Example> it = exampleSet.iterator();
        while (it.hasNext()) {
            translateOddsIntoPredictions(it.next(), createSpecialAttributes, getTrainingHeader().getAttributes().getLabel());
        }
        cleanUpSpecialAttributes(exampleSet, createSpecialAttributes);
        return exampleSet;
    }

    private Attribute[] createSpecialAttributes(ExampleSet exampleSet) throws OperatorException {
        Attribute[] attributeArr = new Attribute[getLabel().getMapping().size()];
        for (int i = 0; i < attributeArr.length; i++) {
            attributeArr[i] = com.rapidminer.example.Tools.createSpecialAttribute(exampleSet, "BayBoostModelPrediction" + i, 2);
        }
        return attributeArr;
    }

    private void cleanUpSpecialAttributes(ExampleSet exampleSet, Attribute[] attributeArr) throws OperatorException {
        for (int i = 0; i < attributeArr.length; i++) {
            exampleSet.getAttributes().remove(attributeArr[i]);
            exampleSet.getExampleTable().removeAttribute(attributeArr[i]);
        }
    }

    private void initIntermediateResultAttributes(ExampleSet exampleSet, Attribute[] attributeArr) {
        double[] dArr = new double[this.priors.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.priors[i] == 1.0d ? Double.POSITIVE_INFINITY : this.priors[i] / (1.0d - this.priors[i]);
        }
        for (Example example : exampleSet) {
            for (int i2 = 0; i2 < attributeArr.length; i2++) {
                example.setValue(attributeArr[i2], dArr[i2]);
            }
        }
    }

    private void translateOddsIntoPredictions(Example example, Attribute[] attributeArr, Attribute attribute) {
        String mapIndex;
        double d = 0.0d;
        double[] dArr = new double[attributeArr.length];
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double value = example.getValue(attributeArr[i2]);
            if (Double.isNaN(value)) {
                logWarning("Found NaN odd ratio estimate.");
                dArr[i2] = 1.0d;
            } else {
                dArr[i2] = Double.isInfinite(value) ? 1.0d : value / (1.0d + value);
            }
            d += dArr[i2];
            if (dArr[i2] > dArr[i]) {
                i = i2;
            }
        }
        if (d != 1.0d) {
            for (int i3 = 0; i3 < dArr.length; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] / d;
            }
        }
        if (getLabel().isNominal() && getLabel().getMapping().size() == 2 && this.threshold != 0.5d) {
            int positiveIndex = getLabel().getMapping().getPositiveIndex();
            int negativeIndex = getLabel().getMapping().getNegativeIndex();
            this.threshold = (this.threshold < 0.0d || this.threshold > 1.0d) ? 0.5d : this.threshold;
            mapIndex = getLabel().getMapping().mapIndex(dArr[positiveIndex] >= this.threshold ? positiveIndex : negativeIndex);
        } else {
            mapIndex = getLabel().getMapping().mapIndex(i);
        }
        example.setValue(example.getAttributes().getPredictedLabel(), attribute.getMapping().mapString(mapIndex));
        for (int i5 = 0; i5 < dArr.length; i5++) {
            if (Double.isNaN(dArr[i5]) || dArr[i5] < 0.0d || dArr[i5] > 1.0d) {
                logWarning("Found illegal confidence value: " + dArr[i5]);
            }
            example.setConfidence(getLabel().getMapping().mapIndex(i5), dArr[i5]);
        }
    }

    private void updateEstimates(ExampleSet exampleSet, ContingencyMatrix contingencyMatrix, Attribute[] attributeArr) {
        for (Example example : exampleSet) {
            int predictedLabel = (int) example.getPredictedLabel();
            for (int i = 0; i < contingencyMatrix.getNumberOfClasses(); i++) {
                double liftRatio = contingencyMatrix.getLiftRatio(i, predictedLabel);
                if (Double.isNaN(liftRatio)) {
                    logWarning("Ignoring non-applicable model.");
                } else if (!Double.isInfinite(liftRatio)) {
                    double value = example.getValue(attributeArr[i]);
                    if (Double.isNaN(value)) {
                        logWarning("Found NaN value in intermediate odds ratio estimates!");
                    }
                    if (!Double.isInfinite(value)) {
                        example.setValue(attributeArr[i], value * liftRatio);
                    }
                } else if (example.getValue(attributeArr[i]) != 0.0d) {
                    for (Attribute attribute : attributeArr) {
                        example.setValue(attribute, 0.0d);
                    }
                    example.setValue(attributeArr[i], liftRatio);
                }
            }
        }
    }

    public static boolean adjustIntermediateProducts(double[] dArr, double[] dArr2) {
        for (int i = 0; i < dArr2.length; i++) {
            if (Double.isNaN(dArr2[i])) {
                LogService.getGlobal().log("Ignoring non-applicable model.", 5);
            } else if (!Double.isInfinite(dArr2[i])) {
                int i2 = i;
                dArr[i2] = dArr[i2] * dArr2[i];
                if (Double.isNaN(dArr[i])) {
                    LogService.getGlobal().log("Found NaN value in intermediate odds ratio estimates!", 5);
                }
            } else if (dArr[i] != 0.0d) {
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    dArr[i3] = 0.0d;
                }
                dArr[i] = dArr2[i];
                return true;
            }
        }
        return false;
    }

    public double[] getModelWeights() throws OperatorException {
        if (getLabel().getMapping().size() != 2) {
            throw new UserError((Operator) null, 114, "BayBoostModel", getLabel());
        }
        int positiveIndex = getLabel().getMapping().getPositiveIndex();
        int negativeIndex = getLabel().getMapping().getNegativeIndex();
        double[] dArr = new double[getNumberOfModels() + 1];
        dArr[0] = Math.log(getPriorOfClass(positiveIndex) / getPriorOfClass(negativeIndex));
        for (int i = 1; i < dArr.length; i++) {
            double min = Math.min(10, Math.max(-10, Math.log(getFactorsForModel(i - 1, positiveIndex)[positiveIndex])));
            double min2 = (min + Math.min(10, Math.max(-10, Math.log(getFactorsForModel(i - 1, negativeIndex)[positiveIndex])))) / 2.0d;
            if (Tools.isEqual(min2, 10) || Tools.isEqual(min2, -10)) {
                min = 10.0d * min2;
                min2 = 0.0d;
            }
            dArr[0] = dArr[0] + min2;
            dArr[i] = min - min2;
        }
        return dArr;
    }
}
