package com.rapidminer.tools.math;

import com.rapidminer.datatable.DataTable;
import com.rapidminer.datatable.SimpleDataTable;
import com.rapidminer.datatable.SimpleDataTableRow;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeTypeException;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Statistics;
import com.rapidminer.gui.plotter.SimplePlotterDialog;
import com.rapidminer.gui.viewer.ROCChartPlotter;
import java.awt.Component;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import javax.swing.JDialog;
import marytts.signalproc.adaptation.codebook.WeightedCodebookMapperParams;

/* JADX WARN: Classes with same name are omitted:
  input_file:builds/deps.jar:com/rapidminer/tools/math/ROCDataGenerator.class
  input_file:builds/deps.jar:rapidMiner.jar:com/rapidminer/tools/math/ROCDataGenerator.class
  input_file:builds/deps.jar:tmp-src.zip:rapidMiner.jar:com/rapidminer/tools/math/ROCDataGenerator.class
  input_file:com/rapidminer/tools/math/ROCDataGenerator.class
  input_file:rapidMiner.jar:com/rapidminer/tools/math/ROCDataGenerator.class
  input_file:rapidMiner.jar:com/rapidminer/tools/math/ROCDataGenerator.class
 */
/* loaded from: input_file:tmp-src.zip:rapidMiner.jar:com/rapidminer/tools/math/ROCDataGenerator.class */
public class ROCDataGenerator implements Serializable {
    private static final long serialVersionUID = -4473681331604071436L;
    public static final int MAX_ROC_POINTS = 200;
    private double misclassificationCostsPositive;
    private double misclassificationCostsNegative;
    private double slope = 1.0d;
    private double bestThreshold = Double.NaN;

    public ROCDataGenerator(double d, double d2) {
        this.misclassificationCostsPositive = 1.0d;
        this.misclassificationCostsNegative = 1.0d;
        this.misclassificationCostsPositive = d;
        this.misclassificationCostsNegative = d2;
    }

    public double getBestThreshold() {
        return this.bestThreshold;
    }

    public ROCData createROCData(ExampleSet exampleSet, boolean z) {
        String mapIndex;
        Attribute label = exampleSet.getAttributes().getLabel();
        exampleSet.recalculateAttributeStatistics(label);
        Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
        WeightedConfidenceAndLabel[] weightedConfidenceAndLabelArr = new WeightedConfidenceAndLabel[exampleSet.size()];
        Attribute weight = z ? exampleSet.getAttributes().getWeight() : null;
        Attribute label2 = exampleSet.getAttributes().getLabel();
        int positiveIndex = label.getMapping().getPositiveIndex();
        if (label.isNominal() && label.getMapping().size() == 2) {
            mapIndex = label2.getMapping().mapIndex(positiveIndex);
        } else {
            if (!label.isNominal() || label.getMapping().size() != 1) {
                throw new AttributeTypeException("Cannot calculate ROC data for non-classification labels or for labels with more than 2 classes.");
            }
            mapIndex = label2.getMapping().mapIndex(0);
        }
        int i = 0;
        for (Example example : exampleSet) {
            int i2 = i;
            i++;
            weightedConfidenceAndLabelArr[i2] = weight == null ? new WeightedConfidenceAndLabel(example.getConfidence(mapIndex), example.getValue(label2), example.getValue(predictedLabel)) : new WeightedConfidenceAndLabel(example.getConfidence(mapIndex), example.getValue(label2), example.getValue(predictedLabel), example.getValue(weight));
        }
        Arrays.sort(weightedConfidenceAndLabelArr);
        double statistics = exampleSet.getStatistics(label, Statistics.COUNT, mapIndex) / exampleSet.getStatistics(label, Statistics.COUNT, label.getMapping().mapIndex(label.getMapping().getNegativeIndex()));
        this.slope = this.misclassificationCostsNegative / this.misclassificationCostsPositive;
        this.slope = statistics / this.slope;
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        this.bestThreshold = Double.POSITIVE_INFINITY;
        double d4 = 1.0d;
        ROCData rOCData = new ROCData();
        ROCPoint rOCPoint = new ROCPoint(WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN, WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN, 1.0d);
        double d5 = -1.0d;
        for (WeightedConfidenceAndLabel weightedConfidenceAndLabel : weightedConfidenceAndLabelArr) {
            double confidence = weightedConfidenceAndLabel.getConfidence();
            if (confidence != d4 || d5 != weightedConfidenceAndLabel.getLabel()) {
                rOCData.addPoint(rOCPoint);
                d4 = confidence;
                d5 = weightedConfidenceAndLabel.getLabel();
            }
            double weight2 = weightedConfidenceAndLabel.getWeight();
            double d6 = d2 - d;
            if (weightedConfidenceAndLabel.getLabel() == positiveIndex) {
                d += weight2;
            } else {
                double d7 = d - (d6 * this.slope);
                if (d7 > d3) {
                    d3 = d7;
                    this.bestThreshold = weightedConfidenceAndLabel.getConfidence();
                }
            }
            d2 += weight2;
            rOCPoint = new ROCPoint(d2 - d, d, confidence);
        }
        rOCData.addPoint(rOCPoint);
        double d8 = d - ((d2 - d) * this.slope);
        if (d8 > d3) {
            this.bestThreshold = Double.NEGATIVE_INFINITY;
            d3 = d8;
        }
        rOCData.setTotalPositives(d);
        rOCData.setTotalNegatives(d2 - d);
        rOCData.setBestIsometricsTPValue(d3 / d);
        return rOCData;
    }

    private DataTable createDataTable(ROCData rOCData, boolean z, boolean z2) {
        SimpleDataTable simpleDataTable = new SimpleDataTable("ROC Plot", new String[]{"FP/N", "TP/P", "Slope", "Threshold"});
        Iterator<ROCPoint> it = rOCData.iterator();
        int i = 0;
        int max = Math.max(1, (int) Math.round(rOCData.getNumberOfPoints() / 200.0d));
        while (it.hasNext()) {
            ROCPoint next = it.next();
            if (i == 0 || i % max == 0 || !it.hasNext()) {
                double falsePositives = next.getFalsePositives() / rOCData.getTotalNegatives();
                simpleDataTable.add(new SimpleDataTableRow(new double[]{falsePositives, next.getTruePositives() / rOCData.getTotalPositives(), rOCData.getBestIsometricsTPValue() + (falsePositives * this.slope * (rOCData.getTotalNegatives() / rOCData.getTotalPositives())), next.getConfidence()}));
            }
            i++;
        }
        return simpleDataTable;
    }

    public void createROCPlotDialog(ROCData rOCData, boolean z, boolean z2) {
        SimplePlotterDialog simplePlotterDialog = new SimplePlotterDialog(createDataTable(rOCData, z, z2));
        simplePlotterDialog.setXAxis(0);
        simplePlotterDialog.plotColumn(1, true);
        if (z) {
            simplePlotterDialog.plotColumn(2, true);
        }
        if (z2) {
            simplePlotterDialog.plotColumn(3, true);
        }
        simplePlotterDialog.setDrawRange(WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN, 1.0d, WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN, 1.0d);
        simplePlotterDialog.setPointType(1);
        simplePlotterDialog.setSize(500, 500);
        simplePlotterDialog.setLocationRelativeTo(simplePlotterDialog.getOwner());
        simplePlotterDialog.setVisible(true);
    }

    public void createROCPlotDialog(ROCData rOCData) {
        ROCChartPlotter rOCChartPlotter = new ROCChartPlotter();
        rOCChartPlotter.addROCData("ROC", rOCData);
        JDialog jDialog = new JDialog();
        jDialog.setTitle("ROC Plot");
        jDialog.add(rOCChartPlotter);
        jDialog.setSize(500, 500);
        jDialog.setLocationRelativeTo((Component) null);
        jDialog.setVisible(true);
    }

    public double calculateAUC(ROCData rOCData) {
        if (rOCData.getNumberOfPoints() == 2) {
            return 0.5d;
        }
        double d = 0.0d;
        double[] dArr = (double[]) null;
        Iterator<ROCPoint> it = rOCData.iterator();
        while (it.hasNext()) {
            ROCPoint next = it.next();
            double falsePositives = next.getFalsePositives() / rOCData.getTotalNegatives();
            double truePositives = next.getTruePositives() / rOCData.getTotalPositives();
            if (dArr != null) {
                d += dArr[1] * (falsePositives - dArr[0]);
            }
            dArr = new double[]{falsePositives, truePositives};
        }
        return d;
    }
}
