package com.rapidminer.operator.postprocessing;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.LogService;
import java.util.Iterator;

/* loaded from: input_file:WEB-INF/lib/rapidMiner-1.0.0.jar:com/rapidminer/operator/postprocessing/PlattScaling.class */
public class PlattScaling extends Operator {
    public PlattScaling(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    public boolean supportsCapability(LearnerCapability learnerCapability) {
        return (learnerCapability == LearnerCapability.NUMERICAL_CLASS || learnerCapability == LearnerCapability.POLYNOMINAL_CLASS) ? false : true;
    }

    @Override // com.rapidminer.operator.Operator
    public Class<?>[] getInputClasses() {
        return new Class[]{ExampleSet.class, Model.class};
    }

    @Override // com.rapidminer.operator.Operator
    public Class<?>[] getOutputClasses() {
        return new Class[]{Model.class};
    }

    @Override // com.rapidminer.operator.Operator
    public IOObject[] apply() throws OperatorException {
        ExampleSet exampleSet = (ExampleSet) getInput(ExampleSet.class);
        Model model = (Model) getInput(Model.class);
        if (exampleSet.getAttributes().getLabel() == null) {
            throw new UserError(this, 105, new Object[0]);
        }
        if (exampleSet.getAttributes().size() == 0) {
            throw new UserError(this, 106, new Object[0]);
        }
        Attribute extractLabel = extractLabel(model, exampleSet);
        ExampleSet apply = model.apply((ExampleSet) exampleSet.clone());
        PlattParameters computeParameters = computeParameters(apply, extractLabel);
        PredictionModel.removePredictedLabel(apply);
        return new IOObject[]{new PlattScalingModel(exampleSet, model, computeParameters)};
    }

    private Attribute extractLabel(Model model, ExampleSet exampleSet) {
        if (model instanceof PredictionModel) {
            return ((PredictionModel) model).getLabel();
        }
        logWarning("Could not find label in model for Platt's Scaling, using Label of provided ExampleSet instead.");
        return exampleSet.getAttributes().getLabel();
    }

    public static PlattParameters computeParameters(ExampleSet exampleSet, Attribute attribute) {
        double d;
        String positiveString = attribute.getMapping().getPositiveString();
        int mapString = exampleSet.getAttributes().getLabel().getMapping().mapString(positiveString);
        int mapString2 = exampleSet.getAttributes().getLabel().getMapping().mapString(attribute.getMapping().getNegativeString());
        Attribute weight = exampleSet.getAttributes().getWeight();
        double[] dArr = new double[2];
        for (Example example : exampleSet) {
            double weight2 = weight == null ? 1.0d : example.getWeight();
            int label = (int) example.getLabel();
            dArr[label] = dArr[label] + weight2;
        }
        double d2 = 0.0d;
        double log = Math.log((dArr[mapString2] + 1.0d) / (dArr[mapString] + 1.0d));
        double d3 = (dArr[mapString] + 1.0d) / (dArr[mapString] + 2.0d);
        double d4 = 1.0d / (dArr[mapString2] + 2.0d);
        double d5 = 0.001d;
        double d6 = 1.0E300d;
        double[] dArr2 = new double[exampleSet.size()];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = (dArr[mapString] + 1.0d) / ((dArr[mapString2] + dArr[mapString]) + 2.0d);
        }
        int i2 = 0;
        for (int i3 = 1; i3 <= 100; i3++) {
            double d7 = 0.0d;
            double d8 = 0.0d;
            double d9 = 0.0d;
            double d10 = 0.0d;
            double d11 = 0.0d;
            double d12 = 0.0d;
            Iterator<Example> it2 = exampleSet.iterator();
            int i4 = 0;
            while (it2.hasNext()) {
                Example next = it2.next();
                d12 = next.getLabel() == ((double) mapString) ? d3 : d4;
                double logOddsPosConfidence = getLogOddsPosConfidence(next.getConfidence(positiveString));
                double weight3 = weight == null ? 1.0d : next.getWeight();
                double d13 = weight3 * (dArr2[i4] - d12);
                double d14 = weight3 * dArr2[i4] * (1.0d - dArr2[i4]);
                d7 += logOddsPosConfidence * logOddsPosConfidence * d14;
                d8 += d14;
                d9 += logOddsPosConfidence * d14;
                d10 += logOddsPosConfidence * d13;
                d11 += d13;
                i4++;
            }
            if (Math.abs(d10) < 1.0E-9d && Math.abs(d11) < 1.0E-9d) {
                break;
            }
            double d15 = d2;
            double d16 = log;
            while (true) {
                double d17 = ((d7 + d5) * (d8 + d5)) - (d9 * d9);
                if (d17 != 0.0d) {
                    d2 = d15 + ((((d8 + d5) * d10) - (d9 * d11)) / d17);
                    log = d16 + ((((d7 + d5) * d11) - (d9 * d10)) / d17);
                    d = 0.0d;
                    int i5 = 0;
                    while (it2.hasNext()) {
                        Example next2 = it2.next();
                        double logOddsPosConfidence2 = getLogOddsPosConfidence(next2.getConfidence(positiveString));
                        double weight4 = weight == null ? 1.0d : next2.getWeight();
                        double min = Math.min(1.0d, 1.0d / (1.0d + Math.min(1.0E30d, Math.exp((logOddsPosConfidence2 * d2) + log))));
                        int i6 = i5;
                        i5++;
                        dArr2[i6] = min;
                        d -= weight4 * ((d12 * Math.log(min)) + ((d12 - 1.0d) * Math.log(1.0d - min)));
                    }
                    if (d < d6 * 1.0000001d) {
                        d5 *= 0.1d;
                        break;
                    }
                    d5 *= 10.0d;
                    if (d5 >= 1000000.0d) {
                        break;
                    }
                } else {
                    d5 *= 10.0d;
                }
            }
            double d18 = d - d6;
            double d19 = 0.5d * (d + d6 + 1.0d);
            i2 = (d18 <= (-0.001d) * d19 || d18 >= 1.0E-7d * d19) ? 0 : i2 + 1;
            d6 = d;
            if (i2 == 3) {
                break;
            }
        }
        if (Double.isNaN(d2) || Double.isNaN(log)) {
            d2 = 1.0d;
            log = 0.0d;
            exampleSet.getLog().logWarning("Discarding invalid result of Platt's scaling, using identity instead.");
        }
        return new PlattParameters(d2, log);
    }

    public static double getLogOddsPosConfidence(double d) {
        double min = Math.min(Math.max(1.0E-30d, d), 1.0d - 1.0E-30d);
        if (Double.isNaN(min)) {
            min = 0.5d;
            LogService.getGlobal().log("Found a NaN confidence during Platt's Scaling.", 5);
        }
        return Math.log((1.0d - min) / min);
    }
}
