package weka.classifiers.lazy;

import com.ibm.icu.text.DateFormat;
import com.vividsolutions.jts.io.gml2.GMLConstants;
import java.util.Enumeration;
import java.util.Vector;
import opennlp.tools.tokenize.TokenizerME;
import org.hsqldb.Tokens;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.rules.ZeroR;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.neighboursearch.CoverTree;
import weka.core.neighboursearch.LinearNNSearch;
import weka.core.neighboursearch.NearestNeighbourSearch;

/* loaded from: input_file:WEB-INF/lib/weka-dev-3.7.6.jar:weka/classifiers/lazy/IBk.class */
public class IBk extends AbstractClassifier implements OptionHandler, UpdateableClassifier, WeightedInstancesHandler, TechnicalInformationHandler, AdditionalMeasureProducer {
    static final long serialVersionUID = -3080186098777067172L;
    protected Instances m_Train;
    protected int m_NumClasses;
    protected int m_ClassType;
    protected int m_kNN;
    protected int m_kNNUpper;
    protected boolean m_kNNValid;
    protected int m_WindowSize;
    protected int m_DistanceWeighting;
    protected boolean m_CrossValidate;
    protected boolean m_MeanSquared;
    protected ZeroR m_defaultModel;
    public static final int WEIGHT_NONE = 1;
    public static final int WEIGHT_INVERSE = 2;
    public static final int WEIGHT_SIMILARITY = 4;
    public static final Tag[] TAGS_WEIGHTING = {new Tag(1, "No distance weighting"), new Tag(2, "Weight by 1/distance"), new Tag(4, "Weight by 1-distance")};
    protected NearestNeighbourSearch m_NNSearch = new LinearNNSearch();
    protected double m_NumAttributesUsed;

    public IBk(int i) {
        init();
        setKNN(i);
    }

    public IBk() {
        init();
    }

    public String globalInfo() {
        return "K-nearest neighbours classifier. Can select appropriate value of K based on cross-validation. Can also do distance weighting.\n\nFor more information, see\n\n" + getTechnicalInformation().toString();
    }

    @Override // weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "D. Aha and D. Kibler");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "1991");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Instance-based learning algorithms");
        technicalInformation.setValue(TechnicalInformation.Field.JOURNAL, "Machine Learning");
        technicalInformation.setValue(TechnicalInformation.Field.VOLUME, "6");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "37-66");
        return technicalInformation;
    }

    public String KNNTipText() {
        return "The number of neighbours to use.";
    }

    public void setKNN(int i) {
        this.m_kNN = i;
        this.m_kNNUpper = i;
        this.m_kNNValid = false;
    }

    public int getKNN() {
        return this.m_kNN;
    }

    public String windowSizeTipText() {
        return "Gets the maximum number of instances allowed in the training pool. The addition of new instances above this value will result in old instances being removed. A value of 0 signifies no limit to the number of training instances.";
    }

    public int getWindowSize() {
        return this.m_WindowSize;
    }

    public void setWindowSize(int i) {
        this.m_WindowSize = i;
    }

    public String distanceWeightingTipText() {
        return "Gets the distance weighting method used.";
    }

    public SelectedTag getDistanceWeighting() {
        return new SelectedTag(this.m_DistanceWeighting, TAGS_WEIGHTING);
    }

    public void setDistanceWeighting(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_WEIGHTING) {
            this.m_DistanceWeighting = selectedTag.getSelectedTag().getID();
        }
    }

    public String meanSquaredTipText() {
        return "Whether the mean squared error is used rather than mean absolute error when doing cross-validation for regression problems.";
    }

    public boolean getMeanSquared() {
        return this.m_MeanSquared;
    }

    public void setMeanSquared(boolean z) {
        this.m_MeanSquared = z;
    }

    public String crossValidateTipText() {
        return "Whether hold-one-out cross-validation will be used to select the best k value.";
    }

    public boolean getCrossValidate() {
        return this.m_CrossValidate;
    }

    public void setCrossValidate(boolean z) {
        this.m_CrossValidate = z;
    }

    public String nearestNeighbourSearchAlgorithmTipText() {
        return "The nearest neighbour search algorithm to use (Default: weka.core.neighboursearch.LinearNNSearch).";
    }

    public NearestNeighbourSearch getNearestNeighbourSearchAlgorithm() {
        return this.m_NNSearch;
    }

    public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch nearestNeighbourSearch) {
        this.m_NNSearch = nearestNeighbourSearch;
    }

    public int getNumTraining() {
        return this.m_Train.numInstances();
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.DATE_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.setMinimumNumberInstances(0);
        return capabilities;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        this.m_NumClasses = instances2.numClasses();
        this.m_ClassType = instances2.classAttribute().type();
        this.m_Train = new Instances(instances2, 0, instances2.numInstances());
        if (this.m_WindowSize > 0 && instances2.numInstances() > this.m_WindowSize) {
            this.m_Train = new Instances(this.m_Train, this.m_Train.numInstances() - this.m_WindowSize, this.m_WindowSize);
        }
        this.m_NumAttributesUsed = 0.0d;
        for (int i = 0; i < this.m_Train.numAttributes(); i++) {
            if (i != this.m_Train.classIndex() && (this.m_Train.attribute(i).isNominal() || this.m_Train.attribute(i).isNumeric())) {
                this.m_NumAttributesUsed += 1.0d;
            }
        }
        this.m_NNSearch.setInstances(this.m_Train);
        this.m_kNNValid = false;
        this.m_defaultModel = new ZeroR();
        this.m_defaultModel.buildClassifier(instances2);
    }

    @Override // weka.classifiers.UpdateableClassifier
    public void updateClassifier(Instance instance) throws Exception {
        boolean z;
        if (!this.m_Train.equalHeaders(instance.dataset())) {
            throw new Exception("Incompatible instance types\n" + this.m_Train.equalHeadersMsg(instance.dataset()));
        }
        if (instance.classIsMissing()) {
            return;
        }
        this.m_Train.add(instance);
        this.m_NNSearch.update(instance);
        this.m_kNNValid = false;
        if (this.m_WindowSize <= 0 || this.m_Train.numInstances() <= this.m_WindowSize) {
            return;
        }
        boolean z2 = false;
        while (true) {
            z = z2;
            if (this.m_Train.numInstances() <= this.m_WindowSize) {
                break;
            }
            this.m_Train.delete(0);
            z2 = true;
        }
        if (z) {
            this.m_NNSearch.setInstances(this.m_Train);
        }
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_Train.numInstances() == 0) {
            return this.m_defaultModel.distributionForInstance(instance);
        }
        if (this.m_WindowSize > 0 && this.m_Train.numInstances() > this.m_WindowSize) {
            this.m_kNNValid = false;
            while (this.m_Train.numInstances() > this.m_WindowSize) {
                this.m_Train.delete(0);
            }
            if (0 == 1) {
                this.m_NNSearch.setInstances(this.m_Train);
            }
        }
        if (!this.m_kNNValid && this.m_CrossValidate && this.m_kNNUpper >= 1) {
            crossValidate();
        }
        this.m_NNSearch.addInstanceInfo(instance);
        return makeDistribution(this.m_NNSearch.kNearestNeighbours(instance, this.m_kNN), this.m_NNSearch.getDistances());
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector(8);
        vector.addElement(new Option("\tWeight neighbours by the inverse of their distance\n\t(use when k > 1)", "I", 0, "-I"));
        vector.addElement(new Option("\tWeight neighbours by 1 - their distance\n\t(use when k > 1)", TokenizerME.NO_SPLIT, 0, "-F"));
        vector.addElement(new Option("\tNumber of nearest neighbours (k) used in classification.\n\t(Default = 1)", Tokens.T_K_FACTOR, 1, "-K <number of neighbors>"));
        vector.addElement(new Option("\tMinimise mean squared error rather than mean absolute\n\terror when using -X option with numeric prediction.", DateFormat.ABBR_WEEKDAY, 0, "-E"));
        vector.addElement(new Option("\tMaximum number of training instances maintained.\n\tTraining instances are dropped FIFO. (Default = no window)", "W", 1, "-W <window size>"));
        vector.addElement(new Option("\tSelect the number of nearest neighbours between 1\n\tand the k value specified using hold-one-out evaluation\n\ton the training data (use when k > 1)", GMLConstants.GML_COORD_X, 0, "-X"));
        vector.addElement(new Option("\tThe nearest neighbour search algorithm to use (default: weka.core.neighboursearch.LinearNNSearch).\n", "A", 0, "-A"));
        return vector.elements();
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('K', strArr);
        if (option.length() != 0) {
            setKNN(Integer.parseInt(option));
        } else {
            setKNN(1);
        }
        String option2 = Utils.getOption('W', strArr);
        if (option2.length() != 0) {
            setWindowSize(Integer.parseInt(option2));
        } else {
            setWindowSize(0);
        }
        if (Utils.getFlag('I', strArr)) {
            setDistanceWeighting(new SelectedTag(2, TAGS_WEIGHTING));
        } else if (Utils.getFlag('F', strArr)) {
            setDistanceWeighting(new SelectedTag(4, TAGS_WEIGHTING));
        } else {
            setDistanceWeighting(new SelectedTag(1, TAGS_WEIGHTING));
        }
        setCrossValidate(Utils.getFlag('X', strArr));
        setMeanSquared(Utils.getFlag('E', strArr));
        String option3 = Utils.getOption('A', strArr);
        if (option3.length() != 0) {
            String[] splitOptions = Utils.splitOptions(option3);
            if (splitOptions.length == 0) {
                throw new Exception("Invalid NearestNeighbourSearch algorithm specification string.");
            }
            String str = splitOptions[0];
            splitOptions[0] = "";
            setNearestNeighbourSearchAlgorithm((NearestNeighbourSearch) Utils.forName(NearestNeighbourSearch.class, str, splitOptions));
        } else {
            setNearestNeighbourSearchAlgorithm(new LinearNNSearch());
        }
        Utils.checkForRemainingOptions(strArr);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        String[] strArr = new String[11];
        int i = 0 + 1;
        strArr[0] = "-K";
        int i2 = i + 1;
        strArr[i] = "" + getKNN();
        int i3 = i2 + 1;
        strArr[i2] = "-W";
        int i4 = i3 + 1;
        strArr[i3] = "" + this.m_WindowSize;
        if (getCrossValidate()) {
            i4++;
            strArr[i4] = "-X";
        }
        if (getMeanSquared()) {
            int i5 = i4;
            i4++;
            strArr[i5] = "-E";
        }
        if (this.m_DistanceWeighting == 2) {
            int i6 = i4;
            i4++;
            strArr[i6] = "-I";
        } else if (this.m_DistanceWeighting == 4) {
            int i7 = i4;
            i4++;
            strArr[i7] = "-F";
        }
        int i8 = i4;
        int i9 = i4 + 1;
        strArr[i8] = "-A";
        int i10 = i9 + 1;
        strArr[i9] = this.m_NNSearch.getClass().getName() + " " + Utils.joinOptions(this.m_NNSearch.getOptions());
        while (i10 < strArr.length) {
            int i11 = i10;
            i10++;
            strArr[i11] = "";
        }
        return strArr;
    }

    @Override // weka.core.AdditionalMeasureProducer
    public Enumeration enumerateMeasures() {
        if (!this.m_CrossValidate) {
            return this.m_NNSearch.enumerateMeasures();
        }
        Enumeration enumerateMeasures = this.m_NNSearch.enumerateMeasures();
        Vector vector = new Vector();
        while (enumerateMeasures.hasMoreElements()) {
            vector.add(enumerateMeasures.nextElement());
        }
        vector.add("measureKNN");
        return vector.elements();
    }

    @Override // weka.core.AdditionalMeasureProducer
    public double getMeasure(String str) {
        return str.equals("measureKNN") ? this.m_kNN : this.m_NNSearch.getMeasure(str);
    }

    public String toString() {
        if (this.m_Train == null) {
            return "IBk: No model built yet.";
        }
        if (this.m_Train.numInstances() == 0) {
            return "Warning: no training instances - ZeroR model used.";
        }
        if (!this.m_kNNValid && this.m_CrossValidate) {
            crossValidate();
        }
        String str = "IB1 instance-based classifier\nusing " + this.m_kNN;
        switch (this.m_DistanceWeighting) {
            case 2:
                str = str + " inverse-distance-weighted";
                break;
            case 4:
                str = str + " similarity-weighted";
                break;
        }
        String str2 = str + " nearest neighbour(s) for classification\n";
        if (this.m_WindowSize != 0) {
            str2 = str2 + "using a maximum of " + this.m_WindowSize + " (windowed) training instances\n";
        }
        return str2;
    }

    protected void init() {
        setKNN(1);
        this.m_WindowSize = 0;
        this.m_DistanceWeighting = 1;
        this.m_CrossValidate = false;
        this.m_MeanSquared = false;
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Failed to find 'out' block for switch in B:18:0x00c5. Please report as an issue. */
    protected double[] makeDistribution(Instances instances, double[] dArr) throws Exception {
        double d;
        double d2 = 0.0d;
        double[] dArr2 = new double[this.m_NumClasses];
        if (this.m_ClassType == 1) {
            for (int i = 0; i < this.m_NumClasses; i++) {
                dArr2[i] = 1.0d / Math.max(1, this.m_Train.numInstances());
            }
            d2 = this.m_NumClasses / Math.max(1, this.m_Train.numInstances());
        }
        for (int i2 = 0; i2 < instances.numInstances(); i2++) {
            Instance instance = instances.instance(i2);
            dArr[i2] = dArr[i2] * dArr[i2];
            dArr[i2] = Math.sqrt(dArr[i2] / this.m_NumAttributesUsed);
            switch (this.m_DistanceWeighting) {
                case 2:
                    d = 1.0d / (dArr[i2] + 0.001d);
                    break;
                case 4:
                    d = 1.0d - dArr[i2];
                    break;
                default:
                    d = 1.0d;
                    break;
            }
            double weight = d * instance.weight();
            try {
                switch (this.m_ClassType) {
                    case 0:
                        dArr2[0] = dArr2[0] + (instance.classValue() * weight);
                        d2 += weight;
                    case 1:
                        int classValue = (int) instance.classValue();
                        dArr2[classValue] = dArr2[classValue] + weight;
                        d2 += weight;
                    default:
                        d2 += weight;
                }
            } catch (Exception e) {
                throw new Error("Data has no class attribute!");
            }
        }
        if (d2 > 0.0d) {
            Utils.normalize(dArr2, d2);
        }
        return dArr2;
    }

    protected void crossValidate() {
        try {
            if (this.m_NNSearch instanceof CoverTree) {
                throw new Exception("CoverTree doesn't support hold-one-out cross-validation. Use some other NN method.");
            }
            double[] dArr = new double[this.m_kNNUpper];
            double[] dArr2 = new double[this.m_kNNUpper];
            for (int i = 0; i < this.m_kNNUpper; i++) {
                dArr[i] = 0.0d;
                dArr2[i] = 0.0d;
            }
            this.m_kNN = this.m_kNNUpper;
            for (int i2 = 0; i2 < this.m_Train.numInstances(); i2++) {
                if (this.m_Debug && i2 % 50 == 0) {
                    System.err.print("Cross validating " + i2 + "/" + this.m_Train.numInstances() + "\r");
                }
                Instance instance = this.m_Train.instance(i2);
                Instances kNearestNeighbours = this.m_NNSearch.kNearestNeighbours(instance, this.m_kNN);
                double[] distances = this.m_NNSearch.getDistances();
                for (int i3 = this.m_kNNUpper - 1; i3 >= 0; i3--) {
                    double[] dArr3 = new double[distances.length];
                    System.arraycopy(distances, 0, dArr3, 0, distances.length);
                    double[] makeDistribution = makeDistribution(kNearestNeighbours, dArr3);
                    double maxIndex = Utils.maxIndex(makeDistribution);
                    if (this.m_Train.classAttribute().isNumeric()) {
                        double classValue = makeDistribution[0] - instance.classValue();
                        int i4 = i3;
                        dArr2[i4] = dArr2[i4] + (classValue * classValue);
                        int i5 = i3;
                        dArr[i5] = dArr[i5] + Math.abs(classValue);
                    } else if (maxIndex != instance.classValue()) {
                        int i6 = i3;
                        dArr[i6] = dArr[i6] + 1.0d;
                    }
                    if (i3 >= 1) {
                        kNearestNeighbours = pruneToK(kNearestNeighbours, dArr3, i3);
                    }
                }
            }
            for (int i7 = 0; i7 < this.m_kNNUpper; i7++) {
                if (this.m_Debug) {
                    System.err.print("Hold-one-out performance of " + (i7 + 1) + " neighbors ");
                }
                if (this.m_Train.classAttribute().isNumeric()) {
                    if (this.m_Debug) {
                        if (this.m_MeanSquared) {
                            System.err.println("(RMSE) = " + Math.sqrt(dArr2[i7] / this.m_Train.numInstances()));
                        } else {
                            System.err.println("(MAE) = " + (dArr[i7] / this.m_Train.numInstances()));
                        }
                    }
                } else if (this.m_Debug) {
                    System.err.println("(%ERR) = " + ((100.0d * dArr[i7]) / this.m_Train.numInstances()));
                }
            }
            double[] dArr4 = dArr;
            if (this.m_Train.classAttribute().isNumeric() && this.m_MeanSquared) {
                dArr4 = dArr2;
            }
            double d = Double.NaN;
            int i8 = 1;
            for (int i9 = 0; i9 < this.m_kNNUpper; i9++) {
                if (Double.isNaN(d) || d > dArr4[i9]) {
                    d = dArr4[i9];
                    i8 = i9 + 1;
                }
            }
            this.m_kNN = i8;
            if (this.m_Debug) {
                System.err.println("Selected k = " + i8);
            }
            this.m_kNNValid = true;
        } catch (Exception e) {
            throw new Error("Couldn't optimize by cross-validation: " + e.getMessage());
        }
    }

    public Instances pruneToK(Instances instances, double[] dArr, int i) {
        if (instances == null || dArr == null || instances.numInstances() == 0) {
            return null;
        }
        if (i < 1) {
            i = 1;
        }
        int i2 = 0;
        int i3 = 0;
        while (true) {
            if (i3 >= instances.numInstances()) {
                break;
            }
            i2++;
            double d = dArr[i3];
            if (i2 > i && d != dArr[i3 - 1]) {
                instances = new Instances(instances, 0, i2 - 1);
                break;
            }
            i3++;
        }
        return instances;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8034 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new IBk(), strArr);
    }
}
