package com.rapidminer.gui.viewer;

import com.rapidminer.gui.plotter.ColorProvider;
import com.rapidminer.gui.plotter.PlotterAdapter;
import com.rapidminer.report.Renderable;
import com.rapidminer.tools.math.ROCData;
import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.geom.Rectangle2D;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import javax.swing.JPanel;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.ValueAxis;
import org.jfree.chart.block.BlockBorder;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.DeviationRenderer;
import org.jfree.chart.title.LegendTitle;
import org.jfree.data.xy.XYDataset;
import org.jfree.data.xy.YIntervalSeries;
import org.jfree.data.xy.YIntervalSeriesCollection;
import org.jfree.ui.HorizontalAlignment;
import org.jfree.ui.RectangleEdge;
import org.jfree.ui.RectangleInsets;

/* JADX WARN: Classes with same name are omitted:
  input_file:builds/deps.jar:com/rapidminer/gui/viewer/ROCChartPlotter.class
  input_file:builds/deps.jar:rapidMiner.jar:com/rapidminer/gui/viewer/ROCChartPlotter.class
  input_file:com/rapidminer/gui/viewer/ROCChartPlotter.class
 */
/* loaded from: input_file:rapidMiner.jar:com/rapidminer/gui/viewer/ROCChartPlotter.class */
public class ROCChartPlotter extends JPanel implements Renderable {
    private static final long serialVersionUID = -5819082000307077237L;
    private static final int NUMBER_OF_POINTS = 500;
    private YIntervalSeriesCollection dataset = null;
    private final Map<String, List<ROCData>> rocDataLists = new HashMap();
    private final ColorProvider colorProvider = new ColorProvider();

    public ROCChartPlotter() {
        setBackground(Color.white);
    }

    public void addROCData(String str, ROCData rOCData) {
        LinkedList linkedList = new LinkedList();
        linkedList.add(rOCData);
        addROCData(str, linkedList);
    }

    public void addROCData(String str, List<ROCData> list) {
        this.rocDataLists.put(str, list);
    }

    private JFreeChart createChart(XYDataset xYDataset) {
        JFreeChart createXYLineChart = ChartFactory.createXYLineChart((String) null, (String) null, (String) null, xYDataset, PlotOrientation.VERTICAL, true, true, false);
        createXYLineChart.setBackgroundPaint(Color.white);
        XYPlot plot = createXYLineChart.getPlot();
        plot.setBackgroundPaint(Color.WHITE);
        plot.setAxisOffset(new RectangleInsets(5.0d, 5.0d, 5.0d, 5.0d));
        plot.setDomainGridlinePaint(Color.LIGHT_GRAY);
        plot.setRangeGridlinePaint(Color.LIGHT_GRAY);
        ValueAxis rangeAxis = plot.getRangeAxis();
        rangeAxis.setLabelFont(PlotterAdapter.LABEL_FONT_BOLD);
        rangeAxis.setTickLabelFont(PlotterAdapter.LABEL_FONT);
        ValueAxis domainAxis = plot.getDomainAxis();
        domainAxis.setLabelFont(PlotterAdapter.LABEL_FONT_BOLD);
        domainAxis.setTickLabelFont(PlotterAdapter.LABEL_FONT);
        DeviationRenderer deviationRenderer = new DeviationRenderer(true, false);
        BasicStroke basicStroke = new BasicStroke(2.0f, 1, 1);
        if (xYDataset.getSeriesCount() == 1) {
            deviationRenderer.setSeriesStroke(0, basicStroke);
            deviationRenderer.setSeriesPaint(0, Color.RED);
            deviationRenderer.setSeriesFillPaint(0, Color.RED);
        } else if (xYDataset.getSeriesCount() == 2) {
            deviationRenderer.setSeriesStroke(0, basicStroke);
            deviationRenderer.setSeriesPaint(0, Color.RED);
            deviationRenderer.setSeriesFillPaint(0, Color.RED);
            deviationRenderer.setSeriesStroke(1, basicStroke);
            deviationRenderer.setSeriesPaint(1, Color.BLUE);
            deviationRenderer.setSeriesFillPaint(1, Color.BLUE);
        } else {
            for (int i = 0; i < xYDataset.getSeriesCount(); i++) {
                deviationRenderer.setSeriesStroke(i, basicStroke);
                Color pointColor = this.colorProvider.getPointColor(i / (xYDataset.getSeriesCount() - 1));
                deviationRenderer.setSeriesPaint(i, pointColor);
                deviationRenderer.setSeriesFillPaint(i, pointColor);
            }
        }
        deviationRenderer.setAlpha(0.12f);
        plot.setRenderer(deviationRenderer);
        LegendTitle legend = createXYLineChart.getLegend();
        if (legend != null) {
            legend.setPosition(RectangleEdge.TOP);
            legend.setFrame(BlockBorder.NONE);
            legend.setHorizontalAlignment(HorizontalAlignment.LEFT);
            legend.setItemFont(PlotterAdapter.LABEL_FONT);
        }
        return createXYLineChart;
    }

    private void prepareData() {
        this.dataset = new YIntervalSeriesCollection();
        boolean z = this.rocDataLists.size() <= 1;
        for (Map.Entry<String, List<ROCData>> entry : this.rocDataLists.entrySet()) {
            YIntervalSeries yIntervalSeries = new YIntervalSeries(entry.getKey());
            YIntervalSeries yIntervalSeries2 = new YIntervalSeries(String.valueOf(entry.getKey()) + " (Thresholds)");
            List<ROCData> value = entry.getValue();
            for (int i = 0; i <= 500; i++) {
                double d = 0.0d;
                double d2 = 0.0d;
                double d3 = 0.0d;
                double d4 = 0.0d;
                for (ROCData rOCData : value) {
                    double interpolatedTruePositives = rOCData.getInterpolatedTruePositives(i / 500.0d) / rOCData.getTotalPositives();
                    d += interpolatedTruePositives;
                    d2 += interpolatedTruePositives * interpolatedTruePositives;
                    double interpolatedThreshold = rOCData.getInterpolatedThreshold(i / 500.0d);
                    d3 += interpolatedThreshold;
                    d4 += interpolatedThreshold * interpolatedThreshold;
                }
                double size = d / value.size();
                double sqrt = Math.sqrt((d2 / value.size()) - (size * size));
                yIntervalSeries.add(i / 500.0d, size, size - sqrt, size + sqrt);
                double size2 = d3 / value.size();
                double sqrt2 = Math.sqrt((d4 / value.size()) - (size2 * size2));
                yIntervalSeries2.add(i / 500.0d, size2, size2 - sqrt2, size2 + sqrt2);
            }
            this.dataset.addSeries(yIntervalSeries);
            if (z) {
                this.dataset.addSeries(yIntervalSeries2);
            }
        }
    }

    public void paintComponent(Graphics graphics) {
        super.paintComponent(graphics);
        paintDeviationChart(graphics, getWidth(), getHeight());
    }

    public void paintDeviationChart(Graphics graphics, int i, int i2) {
        prepareData();
        JFreeChart createChart = createChart(this.dataset);
        createChart.setBackgroundPaint(Color.white);
        LegendTitle legend = createChart.getLegend();
        if (legend != null) {
            legend.setPosition(RectangleEdge.TOP);
            legend.setFrame(BlockBorder.NONE);
            legend.setHorizontalAlignment(HorizontalAlignment.LEFT);
        }
        createChart.draw((Graphics2D) graphics, new Rectangle2D.Double(0.0d, 0.0d, i, i2));
    }

    @Override // com.rapidminer.report.Renderable
    public void prepareRendering() {
    }

    @Override // com.rapidminer.report.Renderable
    public void finishRendering() {
    }

    @Override // com.rapidminer.report.Renderable
    public int getRenderHeight(int i) {
        int height = getHeight();
        if (height < 1) {
            height = i;
        }
        return height;
    }

    @Override // com.rapidminer.report.Renderable
    public int getRenderWidth(int i) {
        int width = getWidth();
        if (width < 1) {
            width = i;
        }
        return width;
    }

    @Override // com.rapidminer.report.Renderable
    public void render(Graphics graphics, int i, int i2) {
        setSize(i, i2);
        paintDeviationChart(graphics, i, i2);
    }
}
