/*
 * Decompiled with CFR 0.152.
 */
package marytts.machinelearning;

import Jama.Matrix;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.PrintWriter;
import marytts.machinelearning.SoP;
import marytts.util.math.MathUtils;
import marytts.util.math.Regression;

public class SFFS {
    protected boolean interceptTerm = true;
    protected boolean logSolution = false;
    protected int solutionSize = 1;

    public SFFS(int solSize, boolean b0, boolean logSol) {
        this.interceptTerm = b0;
        this.logSolution = logSol;
        this.solutionSize = solSize;
    }

    public void trainModel(String[] lingFactors, String featuresFile, int numFeatures, double percentToTrain, SoP sop) throws Exception {
        int cols;
        int d = this.solutionSize;
        int D = 0;
        int indVariable = cols = lingFactors.length;
        int rows = numFeatures;
        int rowIniTrain = 0;
        int percentVal = (int)Math.floor((double)numFeatures * percentToTrain);
        int rowEndTrain = percentVal - 1;
        int rowIniTest = percentVal;
        int rowEndTest = percentVal + (numFeatures - percentVal - 1) - 1;
        System.out.println("Number of points: " + rows + "\nNumber of points used for training from " + rowIniTrain + " to " + rowEndTrain + "(Total train=" + (rowEndTrain - rowIniTrain) + ")\nNumber of points used for testing from " + rowIniTest + " to " + rowEndTest + "(Total test=" + (rowEndTest - rowIniTest) + ")");
        System.out.println("Number of linguistic factors: " + cols);
        System.out.println("Max number of selected features in SFFS: " + (d + D));
        if (this.interceptTerm) {
            System.out.println("Using intercept Term for regression");
        } else {
            System.out.println("No intercept Term for regression");
        }
        if (this.logSolution) {
            System.out.println("Using log(val) as independent variable\n");
        } else {
            System.out.println("Using independent variable without log()\n");
        }
        int[] Y = new int[lingFactors.length];
        int[] X = new int[]{};
        for (int j = 0; j < lingFactors.length; ++j) {
            Y[j] = j;
        }
        System.out.println("Checking and removing columns with mean=0.0");
        Y = this.checkMeanColumns(featuresFile, Y, lingFactors);
        int[] selectedCols = this.sequentialForwardFloatingSelection(featuresFile, indVariable, lingFactors, X, Y, d, D, rowIniTrain, rowEndTrain, sop);
        sop.printCoefficients();
        System.out.println("Correlation original val / predicted val = " + sop.getCorrelation() + "\nRMSE (root mean square error) = " + sop.getRMSE());
        Regression reg = new Regression();
        reg.setCoeffs(sop.getCoeffs());
        System.out.println("\nNumber points used for training=" + (rowEndTrain - rowIniTrain));
        reg.predictValues(featuresFile, cols, selectedCols, this.interceptTerm, rowIniTrain, rowEndTrain);
        System.out.println("\nNumber points used for testing=" + (rowEndTest - rowIniTest));
        reg.predictValues(featuresFile, cols, selectedCols, this.interceptTerm, rowIniTest, rowEndTest);
    }

    public int[] sequentialForwardFloatingSelection(String dataFile, int indVariable, String[] features, int[] X, int[] Y, int d, int D, int rowIni, int rowEnd, SoP sop) throws Exception {
        int indVarColNumber = features.length;
        double[] forwardJ = new double[]{0.0, 0.0, 0.0};
        double[] backwardJ = new double[]{0.0, 0.0, 0.0};
        int k = X.length;
        boolean condSFS = true;
        boolean condSBS = true;
        double corX = 0.0;
        while (k < d + D && condSFS) {
            if (Y.length > 1) {
                System.out.println("ForwardSelection k=" + k + " remaining features=" + Y.length);
                int ms = this.sequentialForwardSelection(dataFile, features, indVarColNumber, X, Y, forwardJ, rowIni, rowEnd);
                System.out.format("corXplusy=%.4f  corX=%.4f\n", forwardJ[2], forwardJ[1]);
                corX = forwardJ[2];
                System.out.println("Most significant new feature to add: " + features[ms]);
                X = MathUtils.addIndex(X, ms);
                Y = MathUtils.removeIndex(Y, ms);
                ++k;
                condSBS = true;
                while (condSBS && k <= d + D && k > 1) {
                    if (X.length > 1) {
                        System.out.println(" BackwardSelection k=" + k);
                        int ls = this.sequentialBackwardSelection(dataFile, features, indVarColNumber, X, backwardJ, rowIni, rowEnd);
                        corX = backwardJ[1];
                        double improvement = Math.abs(backwardJ[0] - backwardJ[1]);
                        System.out.format(" corXminusx=%.4f  corX=%.4f  difference=%.4f : ", backwardJ[0], backwardJ[1], improvement);
                        System.out.println("Least significant feature to remove: " + features[ls]);
                        if (backwardJ[0] > backwardJ[1] || improvement < 1.0E-4) {
                            System.out.println(" better without least significant feature or improvement < 0.0001 : (removing feature)");
                            X = MathUtils.removeIndex(X, ls);
                            --k;
                            corX = backwardJ[0];
                            condSBS = true;
                            continue;
                        }
                        System.out.println(" better with least significant feature (keeping feature)\n");
                        condSBS = false;
                        continue;
                    }
                    System.out.println("X has one feature, can not execute a SBS step");
                    condSBS = false;
                }
                System.out.format("k=%d corX=%.4f   ", k, corX);
                SFFS.printSelectedFeatures(X, features);
                System.out.println("-------------------------\n");
                continue;
            }
            System.out.println("No more elements in Y for selection");
            condSFS = false;
        }
        Regression reg = new Regression();
        reg.multipleLinearRegression(dataFile, indVariable, X, features, this.interceptTerm, rowIni, rowEnd);
        sop.setCoeffsAndFactors(reg.getCoeffs(), X, features, this.interceptTerm);
        sop.setCorrelation(reg.getCorrelation());
        sop.setRMSE(reg.getRMSE());
        return X;
    }

    private int sequentialForwardSelection(String dataFile, String[] features, int indVarColNumber, int[] X, int[] Y, double[] J, int rowIni, int rowEnd) {
        double corX;
        double[] sig = new double[Y.length];
        int[] sigIndex = new int[Y.length];
        double[] corXplusy = new double[Y.length];
        if (X.length > 0) {
            Regression reg = new Regression();
            reg.multipleLinearRegression(dataFile, indVarColNumber, X, features, this.interceptTerm, rowIni, rowEnd);
            corX = reg.getCorrelation();
        } else {
            corX = 0.0;
        }
        for (int i = 0; i < Y.length; ++i) {
            corXplusy[i] = this.correlationOfNewFeature(dataFile, features, indVarColNumber, X, Y[i], rowIni, rowEnd);
            sig[i] = corXplusy[i] - corX;
            sigIndex[i] = Y[i];
        }
        int minSig = MathUtils.getMinIndex(sig);
        J[0] = corXplusy[minSig];
        J[1] = corX;
        int maxSig = MathUtils.getMaxIndex(sig);
        J[2] = corXplusy[maxSig];
        return sigIndex[maxSig];
    }

    private int sequentialBackwardSelection(String dataFile, String[] features, int indVarColNumber, int[] X, double[] J, int rowIni, int rowEnd) {
        double corX;
        double[] sig = new double[X.length];
        double[] corXminusx = new double[X.length];
        int[] sigIndex = new int[X.length];
        if (X.length > 0) {
            Regression reg = new Regression();
            reg.multipleLinearRegression(dataFile, indVarColNumber, X, features, this.interceptTerm, rowIni, rowEnd);
            corX = reg.getCorrelation();
        } else {
            corX = 0.0;
        }
        for (int i = 0; i < X.length; ++i) {
            corXminusx[i] = this.correlationOfFeature(dataFile, features, indVarColNumber, X, X[i], rowIni, rowEnd);
            sig[i] = corX - corXminusx[i];
            sigIndex[i] = X[i];
        }
        int minSig = MathUtils.getMinIndex(sig);
        J[0] = corXminusx[minSig];
        J[1] = corX;
        int maxSig = MathUtils.getMaxIndex(sig);
        J[2] = corXminusx[maxSig];
        return sigIndex[minSig];
    }

    private double correlationOfFeature(String dataFile, String[] features, int indVarColNumber, int[] X, int x, int rowIni, int rowEnd) {
        Regression reg = new Regression();
        int j = 0;
        int[] Xminusx = new int[X.length - 1];
        for (int i = 0; i < X.length; ++i) {
            if (X[i] == x) continue;
            Xminusx[j++] = X[i];
        }
        reg.multipleLinearRegression(dataFile, indVarColNumber, Xminusx, features, this.interceptTerm, rowIni, rowEnd);
        double corXminusx = reg.getCorrelation();
        return corXminusx;
    }

    private double correlationOfNewFeature(String dataFile, String[] features, int indVarColNumber, int[] X, int y, int rowIni, int rowEnd) {
        Regression reg = new Regression();
        boolean j = false;
        int[] Xplusf = new int[X.length + 1];
        for (int i = 0; i < X.length; ++i) {
            Xplusf[i] = X[i];
        }
        Xplusf[X.length] = y;
        reg.multipleLinearRegression(dataFile, indVarColNumber, Xplusf, features, this.interceptTerm, rowIni, rowEnd);
        double corXplusy = reg.getCorrelation();
        return corXplusy;
    }

    private static void printSelectedFeatures(int[] X, String[] features) {
        System.out.print("Features: ");
        for (int i = 0; i < X.length; ++i) {
            System.out.print(features[X[i]] + "  ");
        }
        System.out.println();
    }

    private static void printSelectedFeatures(int[] X, String[] features, PrintWriter file) {
        file.print("Features: ");
        for (int i = 0; i < X.length; ++i) {
            file.print(features[X[i]] + "  ");
        }
        file.println();
    }

    private int[] checkMeanColumns(String dataFile, int[] Y, String[] features) {
        try {
            BufferedReader reader = new BufferedReader(new FileReader(dataFile));
            Matrix data = Matrix.read(reader);
            reader.close();
            data = data.transpose();
            int rows = data.getRowDimension() - 1;
            int cols = data.getColumnDimension() - 1;
            data = data.getMatrix(0, rows, 1, cols);
            int M = data.getRowDimension();
            for (int i = 0; i < M; ++i) {
                double mn = MathUtils.mean(data.getArray()[i]);
                if (mn != 0.0) continue;
                System.out.println("Removing feature: " + features[i] + " from list of features because it has mean=0.0");
                Y = MathUtils.removeIndex(Y, i);
            }
        }
        catch (Exception e) {
            throw new RuntimeException("Problem reading file " + dataFile, e);
        }
        System.out.println();
        return Y;
    }
}

