package jml.feature.selection;

import jml.matlab.Matlab;
import org.apache.commons.math.linear.RealMatrix;

/* loaded from: input_file:jml/feature/selection/JointL21NormsMinimization.class */
public class JointL21NormsMinimization extends SupervisedFeatureSelection {
    public double gamma;

    public JointL21NormsMinimization(double d) {
        this.gamma = d;
    }

    @Override // jml.feature.selection.SupervisedFeatureSelection, jml.feature.selection.FeatureSelection
    public void run() {
        int columnDimension = this.X.getColumnDimension();
        int rowDimension = this.X.getRowDimension();
        RealMatrix horzcat = Matlab.horzcat(this.X.transpose(), Matlab.times(this.gamma, Matlab.eye(columnDimension)));
        RealMatrix eye = Matlab.eye(columnDimension + rowDimension);
        RealMatrix realMatrix = null;
        for (int i = 0; i <= 100; i++) {
            RealMatrix times = Matlab.times(Matlab.repmat(Matlab.rdivide(1.0d, Matlab.diag(eye)), 1, columnDimension), horzcat.transpose());
            RealMatrix mtimes = Matlab.mtimes(times, Matlab.mldivide(Matlab.mtimes(horzcat, times), this.Y));
            RealMatrix realMatrix2 = eye;
            eye = Matlab.diag(Matlab.rdivide(0.5d, Matlab.plus(Matlab.l2NormByRows(mtimes), Matlab.eps)));
            if (0 != 0) {
                Matlab.fprintf("||D_{k+1} - D_{k}||: %f\n", Double.valueOf(Matlab.norm(Matlab.minus(realMatrix2, eye))));
            }
            this.W = mtimes.getSubMatrix(Matlab.colon(0, rowDimension - 1), Matlab.colon(0, mtimes.getColumnDimension() - 1));
            if (0 != 0) {
                Matlab.fprintf("ofv: %f\n", Double.valueOf((Matlab.sum(Matlab.l2NormByRows(this.X.transpose().multiply(this.W).subtract(this.Y))).getEntry(0, 0) / this.gamma) + Matlab.sum(Matlab.l2NormByRows(this.W)).getEntry(0, 0)));
            }
            if (i > 0) {
                Matlab.fprintf("Iter %d: ||U_{k+1} - U_{k}||: %f\n", Integer.valueOf(i), Double.valueOf(Matlab.norm(Matlab.minus(realMatrix, mtimes))));
            }
            realMatrix = mtimes;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    public static void main(String[] strArr) {
        JointL21NormsMinimization jointL21NormsMinimization = new JointL21NormsMinimization(2.0d);
        jointL21NormsMinimization.feedData((double[][]) new double[]{new double[]{3.5d, 4.4d, 1.3d}, new double[]{5.3d, 2.2d, 0.5d}, new double[]{0.2d, 0.3d, 4.1d}, new double[]{-1.2d, 0.4d, 3.2d}});
        jointL21NormsMinimization.feedLabels((double[][]) new double[]{new double[]{1.0d, 0.0d, 0.0d}, new double[]{0.0d, 1.0d, 0.0d}, new double[]{0.0d, 0.0d, 1.0d}});
        long currentTimeMillis = System.currentTimeMillis();
        jointL21NormsMinimization.run();
        System.out.format("Elapsed time: %.3f seconds\n", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
        System.out.println("Projection matrix:");
        Matlab.display(jointL21NormsMinimization.getW());
    }
}
