package com.rapidminer.operator.learner.functions.neuralnet;

import com.rapidminer.gui.tools.SwingTools;
import com.rapidminer.report.Renderable;
import com.rapidminer.tools.Tools;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.Rectangle;
import java.awt.event.MouseEvent;
import java.awt.event.MouseListener;
import java.awt.geom.Ellipse2D;
import java.awt.geom.Rectangle2D;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.swing.JPanel;

/* JADX WARN: Classes with same name are omitted:
  input_file:builds/deps.jar:com/rapidminer/operator/learner/functions/neuralnet/ImprovedNeuralNetVisualizer.class
  input_file:builds/deps.jar:rapidMiner.jar:com/rapidminer/operator/learner/functions/neuralnet/ImprovedNeuralNetVisualizer.class
  input_file:com/rapidminer/operator/learner/functions/neuralnet/ImprovedNeuralNetVisualizer.class
 */
/* loaded from: input_file:rapidMiner.jar:com/rapidminer/operator/learner/functions/neuralnet/ImprovedNeuralNetVisualizer.class */
public class ImprovedNeuralNetVisualizer extends JPanel implements MouseListener, Renderable {
    private static final long serialVersionUID = -26826681541601736L;
    private static final int ROW_HEIGHT = 36;
    private static final int LAYER_WIDTH = 150;
    private static final int MARGIN = 30;
    private static final int NODE_RADIUS = 24;
    private static final Font LABEL_FONT = new Font("SansSerif", 0, 11);
    private ImprovedNeuralNetModel neuralNet;
    private double maxAbsoluteWeight;
    private String[] attributeNames;
    private int selectedLayerIndex = -1;
    private int selectedRowIndex = -1;
    private String key = null;
    private int keyX = -1;
    private int keyY = -1;
    private Map<Integer, List<Node>> layers = new LinkedHashMap();

    public ImprovedNeuralNetVisualizer(ImprovedNeuralNetModel improvedNeuralNetModel, String[] strArr) {
        this.maxAbsoluteWeight = Double.NEGATIVE_INFINITY;
        this.neuralNet = improvedNeuralNetModel;
        this.attributeNames = strArr;
        addMouseListener(this);
        this.maxAbsoluteWeight = Double.NEGATIVE_INFINITY;
        ArrayList arrayList = new ArrayList();
        for (InputNode inputNode : this.neuralNet.getInputNodes()) {
            arrayList.add(inputNode);
        }
        this.layers.put(0, arrayList);
        for (InnerNode innerNode : this.neuralNet.getInnerNodes()) {
            for (double d : innerNode.getWeights()) {
                this.maxAbsoluteWeight = Math.max(this.maxAbsoluteWeight, Math.abs(d));
            }
            int layerIndex = innerNode.getLayerIndex();
            if (layerIndex != -2) {
                int i = layerIndex + 1;
                List<Node> list = this.layers.get(Integer.valueOf(i));
                if (list == null) {
                    list = new ArrayList();
                    this.layers.put(Integer.valueOf(i), list);
                }
                list.add(innerNode);
            }
        }
        int size = this.layers.size();
        ArrayList arrayList2 = new ArrayList();
        for (InnerNode innerNode2 : this.neuralNet.getInnerNodes()) {
            if (innerNode2.getLayerIndex() == -2) {
                arrayList2.add(innerNode2);
            }
        }
        this.layers.put(Integer.valueOf(size), arrayList2);
    }

    public Dimension getPreferredSize() {
        int i = -1;
        for (Map.Entry<Integer, List<Node>> entry : this.layers.entrySet()) {
            int intValue = entry.getKey().intValue();
            int size = entry.getValue().size();
            if (intValue != -2) {
                size++;
            }
            i = Math.max(i, size);
        }
        return new Dimension((this.layers.size() * 150) + 60, (i * 36) + 60);
    }

    public void paint(Graphics graphics) {
        graphics.clearRect(0, 0, getWidth(), getHeight());
        graphics.setColor(Color.WHITE);
        graphics.fillRect(0, 0, getWidth(), getHeight());
        int i = getPreferredSize().height;
        Graphics2D graphics2D = (Graphics2D) graphics;
        Graphics2D create = graphics2D.create();
        create.translate(30, 30);
        create.setFont(LABEL_FONT);
        Graphics2D graphics2D2 = (Graphics2D) create.create();
        paintSynapses(graphics2D2, i);
        graphics2D2.dispose();
        Graphics2D graphics2D3 = (Graphics2D) create.create();
        paintNodes(graphics2D3, i);
        graphics2D3.dispose();
        create.dispose();
        if (this.key != null) {
            this.key = Tools.transformAllLineSeparators(this.key);
            String[] split = this.key.split("\n");
            double d = Double.NEGATIVE_INFINITY;
            double d2 = 0.0d;
            for (String str : split) {
                Rectangle2D stringBounds = graphics2D.getFontMetrics().getStringBounds(str, graphics2D);
                d = Math.max(d, stringBounds.getWidth());
                d2 += stringBounds.getHeight();
            }
            Rectangle rectangle = new Rectangle(this.keyX - 4, this.keyY, ((int) d) + 8, ((int) (d2 + ((split.length - 1) * 3))) + 6);
            graphics2D.setColor(SwingTools.LIGHTEST_YELLOW);
            graphics2D.fill(rectangle);
            graphics2D.setColor(SwingTools.DARK_YELLOW);
            graphics2D.draw(rectangle);
            graphics2D.setColor(Color.BLACK);
            int i2 = this.keyX;
            int i3 = this.keyY;
            for (String str2 : split) {
                int height = i3 + ((int) graphics2D.getFontMetrics().getStringBounds(str2, graphics2D).getHeight());
                graphics2D.drawString(str2, i2, height);
                i3 = height + 3;
            }
        }
    }

    private void paintSynapses(Graphics2D graphics2D, int i) {
        for (int i2 = 1; i2 < this.layers.size(); i2++) {
            int i3 = i2;
            List<Node> list = this.layers.get(Integer.valueOf(i3));
            int size = (i / 2) - (((list.size() + (i3 == this.layers.size() - 1 ? 0 : 1)) * 36) / 2);
            for (Node node : list) {
                if (node instanceof InnerNode) {
                    Node[] inputNodes = node.getInputNodes();
                    double[] weights = ((InnerNode) node).getWeights();
                    int length = (i / 2) - (((inputNodes.length + 1) * 36) / 2);
                    for (int i4 = 0; i4 < inputNodes.length; i4++) {
                        float abs = 1.0f - ((float) (Math.abs(weights[i4 + 1]) / this.maxAbsoluteWeight));
                        graphics2D.setColor(new Color(abs, abs, abs));
                        graphics2D.drawLine(12, length + 12, 162, size + 12);
                        length += 36;
                    }
                    float abs2 = 1.0f - ((float) (Math.abs(weights[0]) / this.maxAbsoluteWeight));
                    graphics2D.setColor(new Color(abs2, abs2, abs2));
                    graphics2D.drawLine(12, length + 12, 162, size + 12);
                }
                size += 36;
            }
            graphics2D.translate(150, 0);
            int i5 = i3 + 1;
        }
    }

    private void paintNodes(Graphics2D graphics2D, int i) {
        for (Map.Entry<Integer, List<Node>> entry : this.layers.entrySet()) {
            int intValue = entry.getKey().intValue();
            int size = entry.getValue().size();
            if (intValue < this.layers.size() - 1) {
                size++;
            }
            String str = intValue == 0 ? "Input" : intValue == this.layers.size() - 1 ? "Output" : "Hidden " + intValue;
            Rectangle2D stringBounds = LABEL_FONT.getStringBounds(str, graphics2D.getFontRenderContext());
            graphics2D.setColor(Color.BLACK);
            graphics2D.drawString(str, (int) ((((-1.0d) * stringBounds.getWidth()) / 2.0d) + 12.0d), 0);
            int i2 = (i / 2) - ((size * 36) / 2);
            for (int i3 = 0; i3 < size; i3++) {
                Ellipse2D.Double r0 = new Ellipse2D.Double(0.0d, i2, 24.0d, 24.0d);
                if (intValue == 0 || intValue == this.layers.size() - 1) {
                    if (i3 < size - 1 || intValue == this.layers.size() - 1) {
                        graphics2D.setPaint(SwingTools.makeYellowPaint(24.0d, 24.0d));
                    } else {
                        graphics2D.setPaint(new Color(233, 233, 233));
                    }
                } else if (i3 < size - 1) {
                    graphics2D.setPaint(SwingTools.makeBluePaint(24.0d, 24.0d));
                } else {
                    graphics2D.setPaint(new Color(233, 233, 233));
                }
                graphics2D.fill(r0);
                if (intValue == this.selectedLayerIndex && i3 == this.selectedRowIndex) {
                    graphics2D.setColor(Color.RED);
                } else {
                    graphics2D.setColor(Color.BLACK);
                }
                graphics2D.draw(r0);
                i2 += 36;
            }
            graphics2D.translate(150, 0);
            int i4 = intValue + 1;
        }
    }

    private void setKey(String str, int i, int i2) {
        this.key = str;
        this.keyX = i;
        this.keyY = i2;
        repaint();
    }

    private void setSelectedNode(int i, int i2, int i3, int i4) {
        this.selectedLayerIndex = i;
        this.selectedRowIndex = i2;
        if (this.selectedLayerIndex < 0 || this.selectedRowIndex < 0) {
            setKey(null, -1, -1);
            return;
        }
        if (i != 0) {
            List<Node> list = this.layers.get(Integer.valueOf(this.selectedLayerIndex));
            if (i2 < 0 || i2 >= list.size()) {
                setKey("Threshold Node", i3, i4);
            } else {
                StringBuffer stringBuffer = new StringBuffer("Weights:" + Tools.getLineSeparator());
                Node node = list.get(this.selectedRowIndex);
                if (node instanceof InnerNode) {
                    double[] weights = ((InnerNode) node).getWeights();
                    for (int i5 = 1; i5 < weights.length; i5++) {
                        stringBuffer.append(String.valueOf(Tools.formatNumber(weights[i5])) + Tools.getLineSeparator());
                    }
                    stringBuffer.append(String.valueOf(Tools.formatNumber(weights[0])) + " (Threshold)");
                }
                setKey(stringBuffer.toString(), i3, i4);
            }
        } else if (i2 >= 0 && i2 < this.attributeNames.length) {
            setKey(this.attributeNames[i2], i3, i4);
        } else if (i2 == this.attributeNames.length) {
            setKey("Threshold Node", i3, i4);
        } else {
            setKey(null, -1, -1);
        }
        repaint();
    }

    private void checkMousePos(int i, int i2) {
        int i3 = i - 30;
        int i4 = i2 - 30;
        int i5 = i3 / 150;
        int i6 = i3 % 150;
        if ((i6 > 0 && i6 < 24) && i5 >= 0 && i5 < this.layers.size()) {
            int size = this.layers.get(Integer.valueOf(i5)).size();
            if (i5 < this.layers.size() - 1) {
                size++;
            }
            int i7 = (getPreferredSize().height / 2) - ((size * 36) / 2);
            if (i4 > i7) {
                for (int i8 = 0; i8 < size; i8++) {
                    if (i4 > i7 && i4 < i7 + 24) {
                        if (this.selectedLayerIndex == i5 && this.selectedRowIndex == i8) {
                            setSelectedNode(-1, -1, -1, -1);
                            return;
                        } else {
                            setSelectedNode(i5, i8, i, i2);
                            return;
                        }
                    }
                    i7 += 36;
                }
            }
        }
        setSelectedNode(-1, -1, -1, -1);
    }

    public void mouseClicked(MouseEvent mouseEvent) {
    }

    public void mouseEntered(MouseEvent mouseEvent) {
    }

    public void mouseExited(MouseEvent mouseEvent) {
    }

    public void mousePressed(MouseEvent mouseEvent) {
    }

    public void mouseReleased(MouseEvent mouseEvent) {
        checkMousePos(mouseEvent.getX(), mouseEvent.getY());
    }

    @Override // com.rapidminer.report.Renderable
    public void prepareRendering() {
    }

    @Override // com.rapidminer.report.Renderable
    public void finishRendering() {
    }

    @Override // com.rapidminer.report.Renderable
    public int getRenderHeight(int i) {
        int i2 = getPreferredSize().height;
        if (i2 < 1) {
            i2 = i;
        }
        if (i > i2) {
            i2 = i;
        }
        return i2;
    }

    @Override // com.rapidminer.report.Renderable
    public int getRenderWidth(int i) {
        int i2 = getPreferredSize().width;
        if (i2 < 1) {
            i2 = i;
        }
        if (i > i2) {
            i2 = i;
        }
        return i2;
    }

    @Override // com.rapidminer.report.Renderable
    public void render(Graphics graphics, int i, int i2) {
        setSize(i, i2);
        paint(graphics);
    }
}
