package marytts.machinelearning;

import Jama.Matrix;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.PrintWriter;
import marytts.util.math.MathUtils;
import marytts.util.math.Regression;
import opennlp.tools.parser.Parse;

/* loaded from: input_file:marytts/machinelearning/SFFS.class */
public class SFFS {
    protected boolean interceptTerm;
    protected boolean logSolution;
    protected int solutionSize;

    public SFFS(int i, boolean z, boolean z2) {
        this.interceptTerm = true;
        this.logSolution = false;
        this.solutionSize = 1;
        this.interceptTerm = z;
        this.logSolution = z2;
        this.solutionSize = i;
    }

    public void trainModel(String[] strArr, String str, int i, double d, SoP soP) throws Exception {
        int i2 = this.solutionSize;
        int length = strArr.length;
        int floor = (int) Math.floor(i * d);
        int i3 = floor - 1;
        int i4 = (floor + ((i - floor) - 1)) - 1;
        System.out.println("Number of points: " + i + "\nNumber of points used for training from 0 to " + i3 + "(Total train=" + (i3 - 0) + ")\nNumber of points used for testing from " + floor + " to " + i4 + "(Total test=" + (i4 - floor) + Parse.BRACKET_RRB);
        System.out.println("Number of linguistic factors: " + length);
        System.out.println("Max number of selected features in SFFS: " + (i2 + 0));
        if (this.interceptTerm) {
            System.out.println("Using intercept Term for regression");
        } else {
            System.out.println("No intercept Term for regression");
        }
        if (this.logSolution) {
            System.out.println("Using log(val) as independent variable\n");
        } else {
            System.out.println("Using independent variable without log()\n");
        }
        int[] iArr = new int[strArr.length];
        int[] iArr2 = new int[0];
        for (int i5 = 0; i5 < strArr.length; i5++) {
            iArr[i5] = i5;
        }
        System.out.println("Checking and removing columns with mean=0.0");
        int[] sequentialForwardFloatingSelection = sequentialForwardFloatingSelection(str, length, strArr, iArr2, checkMeanColumns(str, iArr, strArr), i2, 0, 0, i3, soP);
        soP.printCoefficients();
        System.out.println("Correlation original val / predicted val = " + soP.getCorrelation() + "\nRMSE (root mean square error) = " + soP.getRMSE());
        Regression regression = new Regression();
        regression.setCoeffs(soP.getCoeffs());
        System.out.println("\nNumber points used for training=" + (i3 - 0));
        regression.predictValues(str, length, sequentialForwardFloatingSelection, this.interceptTerm, 0, i3);
        System.out.println("\nNumber points used for testing=" + (i4 - floor));
        regression.predictValues(str, length, sequentialForwardFloatingSelection, this.interceptTerm, floor, i4);
    }

    public int[] sequentialForwardFloatingSelection(String str, int i, String[] strArr, int[] iArr, int[] iArr2, int i2, int i3, int i4, int i5, SoP soP) throws Exception {
        int length = strArr.length;
        double[] dArr = {0.0d, 0.0d, 0.0d};
        double[] dArr2 = {0.0d, 0.0d, 0.0d};
        int length2 = iArr.length;
        boolean z = true;
        while (length2 < i2 + i3 && z) {
            if (iArr2.length > 1) {
                System.out.println("ForwardSelection k=" + length2 + " remaining features=" + iArr2.length);
                int sequentialForwardSelection = sequentialForwardSelection(str, strArr, length, iArr, iArr2, dArr, i4, i5);
                System.out.format("corXplusy=%.4f  corX=%.4f\n", Double.valueOf(dArr[2]), Double.valueOf(dArr[1]));
                double d = dArr[2];
                System.out.println("Most significant new feature to add: " + strArr[sequentialForwardSelection]);
                iArr = MathUtils.addIndex(iArr, sequentialForwardSelection);
                iArr2 = MathUtils.removeIndex(iArr2, sequentialForwardSelection);
                length2++;
                boolean z2 = true;
                while (z2 && length2 <= i2 + i3 && length2 > 1) {
                    if (iArr.length > 1) {
                        System.out.println(" BackwardSelection k=" + length2);
                        int sequentialBackwardSelection = sequentialBackwardSelection(str, strArr, length, iArr, dArr2, i4, i5);
                        d = dArr2[1];
                        double abs = Math.abs(dArr2[0] - dArr2[1]);
                        System.out.format(" corXminusx=%.4f  corX=%.4f  difference=%.4f : ", Double.valueOf(dArr2[0]), Double.valueOf(dArr2[1]), Double.valueOf(abs));
                        System.out.println("Least significant feature to remove: " + strArr[sequentialBackwardSelection]);
                        if (dArr2[0] > dArr2[1] || abs < 1.0E-4d) {
                            System.out.println(" better without least significant feature or improvement < 0.0001 : (removing feature)");
                            iArr = MathUtils.removeIndex(iArr, sequentialBackwardSelection);
                            length2--;
                            d = dArr2[0];
                            z2 = true;
                        } else {
                            System.out.println(" better with least significant feature (keeping feature)\n");
                            z2 = false;
                        }
                    } else {
                        System.out.println("X has one feature, can not execute a SBS step");
                        z2 = false;
                    }
                }
                System.out.format("k=%d corX=%.4f   ", Integer.valueOf(length2), Double.valueOf(d));
                printSelectedFeatures(iArr, strArr);
                System.out.println("-------------------------\n");
            } else {
                System.out.println("No more elements in Y for selection");
                z = false;
            }
        }
        Regression regression = new Regression();
        regression.multipleLinearRegression(str, i, iArr, strArr, this.interceptTerm, i4, i5);
        soP.setCoeffsAndFactors(regression.getCoeffs(), iArr, strArr, this.interceptTerm);
        soP.setCorrelation(regression.getCorrelation());
        soP.setRMSE(regression.getRMSE());
        return iArr;
    }

    private int sequentialForwardSelection(String str, String[] strArr, int i, int[] iArr, int[] iArr2, double[] dArr, int i2, int i3) {
        double d;
        double[] dArr2 = new double[iArr2.length];
        int[] iArr3 = new int[iArr2.length];
        double[] dArr3 = new double[iArr2.length];
        if (iArr.length > 0) {
            Regression regression = new Regression();
            regression.multipleLinearRegression(str, i, iArr, strArr, this.interceptTerm, i2, i3);
            d = regression.getCorrelation();
        } else {
            d = 0.0d;
        }
        for (int i4 = 0; i4 < iArr2.length; i4++) {
            dArr3[i4] = correlationOfNewFeature(str, strArr, i, iArr, iArr2[i4], i2, i3);
            dArr2[i4] = dArr3[i4] - d;
            iArr3[i4] = iArr2[i4];
        }
        dArr[0] = dArr3[MathUtils.getMinIndex(dArr2)];
        dArr[1] = d;
        int maxIndex = MathUtils.getMaxIndex(dArr2);
        dArr[2] = dArr3[maxIndex];
        return iArr3[maxIndex];
    }

    private int sequentialBackwardSelection(String str, String[] strArr, int i, int[] iArr, double[] dArr, int i2, int i3) {
        double d;
        double[] dArr2 = new double[iArr.length];
        double[] dArr3 = new double[iArr.length];
        int[] iArr2 = new int[iArr.length];
        if (iArr.length > 0) {
            Regression regression = new Regression();
            regression.multipleLinearRegression(str, i, iArr, strArr, this.interceptTerm, i2, i3);
            d = regression.getCorrelation();
        } else {
            d = 0.0d;
        }
        for (int i4 = 0; i4 < iArr.length; i4++) {
            dArr3[i4] = correlationOfFeature(str, strArr, i, iArr, iArr[i4], i2, i3);
            dArr2[i4] = d - dArr3[i4];
            iArr2[i4] = iArr[i4];
        }
        int minIndex = MathUtils.getMinIndex(dArr2);
        dArr[0] = dArr3[minIndex];
        dArr[1] = d;
        dArr[2] = dArr3[MathUtils.getMaxIndex(dArr2)];
        return iArr2[minIndex];
    }

    private double correlationOfFeature(String str, String[] strArr, int i, int[] iArr, int i2, int i3, int i4) {
        Regression regression = new Regression();
        int i5 = 0;
        int[] iArr2 = new int[iArr.length - 1];
        for (int i6 = 0; i6 < iArr.length; i6++) {
            if (iArr[i6] != i2) {
                int i7 = i5;
                i5++;
                iArr2[i7] = iArr[i6];
            }
        }
        regression.multipleLinearRegression(str, i, iArr2, strArr, this.interceptTerm, i3, i4);
        return regression.getCorrelation();
    }

    private double correlationOfNewFeature(String str, String[] strArr, int i, int[] iArr, int i2, int i3, int i4) {
        Regression regression = new Regression();
        int[] iArr2 = new int[iArr.length + 1];
        for (int i5 = 0; i5 < iArr.length; i5++) {
            iArr2[i5] = iArr[i5];
        }
        iArr2[iArr.length] = i2;
        regression.multipleLinearRegression(str, i, iArr2, strArr, this.interceptTerm, i3, i4);
        return regression.getCorrelation();
    }

    private static void printSelectedFeatures(int[] iArr, String[] strArr) {
        System.out.print("Features: ");
        for (int i : iArr) {
            System.out.print(strArr[i] + "  ");
        }
        System.out.println();
    }

    private static void printSelectedFeatures(int[] iArr, String[] strArr, PrintWriter printWriter) {
        printWriter.print("Features: ");
        for (int i : iArr) {
            printWriter.print(strArr[i] + "  ");
        }
        printWriter.println();
    }

    private int[] checkMeanColumns(String str, int[] iArr, String[] strArr) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            Matrix read = Matrix.read(bufferedReader);
            bufferedReader.close();
            Matrix transpose = read.transpose();
            Matrix matrix = transpose.getMatrix(0, transpose.getRowDimension() - 1, 1, transpose.getColumnDimension() - 1);
            int rowDimension = matrix.getRowDimension();
            for (int i = 0; i < rowDimension; i++) {
                if (MathUtils.mean(matrix.getArray()[i]) == 0.0d) {
                    System.out.println("Removing feature: " + strArr[i] + " from list of features because it has mean=0.0");
                    iArr = MathUtils.removeIndex(iArr, i);
                }
            }
            System.out.println();
            return iArr;
        } catch (Exception e) {
            throw new RuntimeException("Problem reading file " + str, e);
        }
    }
}
