package jml.regression;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import jml.matlab.Matlab;
import jml.options.Options;
import org.apache.commons.math.linear.RealMatrix;

/* loaded from: input_file:jml/regression/LASSO.class */
public class LASSO extends Regression {
    private double lambda;
    private boolean calc_OV;
    private boolean verbose;

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    public static void main(String[] strArr) {
        ?? r0 = {new double[]{1.0d, 2.0d, 3.0d, 2.0d}, new double[]{4.0d, 2.0d, 3.0d, 6.0d}, new double[]{5.0d, 1.0d, 2.0d, 1.0d}};
        Options options = new Options();
        options.maxIter = 600;
        options.lambda = 0.05d;
        options.verbose = false;
        options.calc_OV = false;
        options.epsilon = 1.0E-5d;
        LASSO lasso = new LASSO(options);
        lasso.feedData((double[][]) r0);
        lasso.feedDependentVariables((double[][]) new double[]{new double[]{3.0d, 2.0d}, new double[]{2.0d, 3.0d}, new double[]{1.0d, 4.0d}});
        long currentTimeMillis = System.currentTimeMillis();
        lasso.train();
        Matlab.fprintf("Elapsed time: %.3f seconds\n\n", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
        Matlab.fprintf("Projection matrix:\n", new Object[0]);
        Matlab.display(lasso.W);
        RealMatrix predict = lasso.predict((double[][]) r0);
        Matlab.fprintf("Predicted dependent variables:\n", new Object[0]);
        Matlab.display(predict);
    }

    public LASSO() {
        this.lambda = 1.0d;
        this.calc_OV = false;
        this.verbose = false;
    }

    public LASSO(double d) {
        super(d);
        this.lambda = 1.0d;
        this.calc_OV = false;
        this.verbose = false;
    }

    public LASSO(int i, double d) {
        super(i, d);
        this.lambda = 1.0d;
        this.calc_OV = false;
        this.verbose = false;
    }

    public LASSO(double d, int i, double d2) {
        super(i, d2);
        this.lambda = d;
        this.calc_OV = false;
        this.verbose = false;
    }

    public LASSO(Options options) {
        super(options);
        this.lambda = options.lambda;
        this.calc_OV = options.calc_OV;
        this.verbose = options.verbose;
    }

    @Override // jml.regression.Regression
    public void train() {
        this.W = train(this.X, this.Y);
    }

    @Override // jml.regression.Regression
    public void loadModel(String str) {
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(str));
            this.W = (RealMatrix) objectInputStream.readObject();
            objectInputStream.close();
            System.out.println("LASSO model loaded.");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            System.exit(1);
        } catch (IOException e2) {
            e2.printStackTrace();
        } catch (ClassNotFoundException e3) {
            e3.printStackTrace();
        }
    }

    @Override // jml.regression.Regression
    public void saveModel(String str) {
        File parentFile = new File(str).getParentFile();
        if (parentFile != null && !parentFile.exists()) {
            parentFile.mkdirs();
        }
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
            objectOutputStream.writeObject(this.W);
            objectOutputStream.close();
            System.out.println("LASSO model saved.");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            System.exit(1);
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    public static RealMatrix train(RealMatrix realMatrix, RealMatrix realMatrix2, Options options) {
        int size = Matlab.size(realMatrix, 2);
        int size2 = Matlab.size(realMatrix2, 2);
        double d = options.epsilon;
        int i = options.maxIter;
        double d2 = options.lambda;
        boolean z = options.calc_OV;
        boolean z2 = options.verbose;
        RealMatrix horzcat = Matlab.horzcat(realMatrix, Matlab.uminus(realMatrix));
        RealMatrix multiply = horzcat.transpose().multiply(horzcat);
        RealMatrix repmat = Matlab.repmat(Matlab.diag(multiply), 1, size2);
        RealMatrix multiply2 = horzcat.transpose().multiply(realMatrix2);
        RealMatrix mldivide = Matlab.mldivide(Matlab.plus(realMatrix.transpose().multiply(realMatrix), Matlab.times(d2, Matlab.eye(size))), realMatrix.transpose().multiply(realMatrix2));
        RealMatrix vertcat = Matlab.vertcat(Matlab.subplus(mldivide), Matlab.subplus(Matlab.uminus(mldivide)));
        RealMatrix plus = Matlab.plus(Matlab.uminus(multiply2), d2);
        RealMatrix plus2 = Matlab.plus(plus, Matlab.mtimes(multiply, vertcat));
        double norm = d * Matlab.norm(plus2);
        Matlab.zeros(Matlab.size(plus2));
        ArrayList arrayList = new ArrayList();
        if (z) {
            arrayList.add(Double.valueOf((Matlab.sum(Matlab.sum(Matlab.power(Matlab.minus(realMatrix2, Matlab.mtimes(realMatrix, mldivide)), 2.0d))).getEntry(0, 0) / 2.0d) + (d2 * Matlab.sum(Matlab.sum(Matlab.abs(mldivide))).getEntry(0, 0))));
        }
        int i2 = 0;
        while (true) {
            RealMatrix not = Matlab.not(Matlab.or(Matlab.lt(plus2, 0.0d), Matlab.gt(vertcat, 0.0d)));
            RealMatrix copy = plus2.copy();
            Matlab.logicalIndexingAssignment(copy, not, 0.0d);
            double norm2 = Matlab.norm(copy);
            if (norm2 >= norm) {
                for (int i3 = 0; i3 < 2 * size; i3++) {
                    vertcat.setRowMatrix(i3, Matlab.max(Matlab.minus(vertcat.getRowMatrix(i3), Matlab.rdivide(Matlab.plus(plus.getRowMatrix(i3), Matlab.mtimes(multiply.getRowMatrix(i3), vertcat)), repmat.getRowMatrix(i3))), 0.0d));
                }
                plus2 = Matlab.plus(plus, Matlab.mtimes(multiply, vertcat));
                i2++;
                if (i2 <= i) {
                    if (z) {
                        arrayList.add(Double.valueOf((Matlab.sum(Matlab.sum(Matlab.power(Matlab.minus(realMatrix2, Matlab.mtimes(horzcat, vertcat)), 2.0d))).getEntry(0, 0) / 2.0d) + (d2 * Matlab.sum(Matlab.sum(Matlab.abs(vertcat))).getEntry(0, 0))));
                    }
                    if (i2 % 10 == 0 && z2) {
                        if (z) {
                            System.out.format("Iter %d - ||PGrad||: %f, ofv: %f\n", Integer.valueOf(i2), Double.valueOf(norm2), arrayList.get(arrayList.size() - 1));
                        } else {
                            System.out.format("Iter %d - ||PGrad||: %f\n", Integer.valueOf(i2), Double.valueOf(norm2));
                        }
                    }
                } else if (z2) {
                    System.out.println("Maximal iterations");
                }
            } else if (z2) {
                System.out.println("Converge successfully!");
            }
        }
        return Matlab.minus(vertcat.getSubMatrix(0, size - 1, 0, size2 - 1), vertcat.getSubMatrix(size, (2 * size) - 1, 0, size2 - 1));
    }

    @Override // jml.regression.Regression
    public RealMatrix train(RealMatrix realMatrix, RealMatrix realMatrix2) {
        RealMatrix horzcat = Matlab.horzcat(realMatrix, Matlab.uminus(realMatrix));
        RealMatrix multiply = horzcat.transpose().multiply(horzcat);
        RealMatrix repmat = Matlab.repmat(Matlab.diag(multiply), 1, this.ny);
        RealMatrix multiply2 = horzcat.transpose().multiply(realMatrix2);
        RealMatrix mldivide = Matlab.mldivide(Matlab.plus(realMatrix.transpose().multiply(realMatrix), Matlab.times(this.lambda, Matlab.eye(this.p))), realMatrix.transpose().multiply(realMatrix2));
        RealMatrix vertcat = Matlab.vertcat(Matlab.subplus(mldivide), Matlab.subplus(Matlab.uminus(mldivide)));
        RealMatrix plus = Matlab.plus(Matlab.uminus(multiply2), this.lambda);
        RealMatrix plus2 = Matlab.plus(plus, Matlab.mtimes(multiply, vertcat));
        double norm = this.epsilon * Matlab.norm(plus2);
        Matlab.zeros(Matlab.size(plus2));
        ArrayList arrayList = new ArrayList();
        if (this.calc_OV) {
            arrayList.add(Double.valueOf((Matlab.sum(Matlab.sum(Matlab.power(Matlab.minus(realMatrix2, Matlab.mtimes(realMatrix, mldivide)), 2.0d))).getEntry(0, 0) / 2.0d) + (this.lambda * Matlab.sum(Matlab.sum(Matlab.abs(mldivide))).getEntry(0, 0))));
        }
        int i = 0;
        while (true) {
            RealMatrix not = Matlab.not(Matlab.or(Matlab.lt(plus2, 0.0d), Matlab.gt(vertcat, 0.0d)));
            RealMatrix copy = plus2.copy();
            Matlab.logicalIndexingAssignment(copy, not, 0.0d);
            double norm2 = Matlab.norm(copy);
            if (norm2 >= norm) {
                for (int i2 = 0; i2 < 2 * this.p; i2++) {
                    vertcat.setRowMatrix(i2, Matlab.max(Matlab.minus(vertcat.getRowMatrix(i2), Matlab.rdivide(Matlab.plus(plus.getRowMatrix(i2), Matlab.mtimes(multiply.getRowMatrix(i2), vertcat)), repmat.getRowMatrix(i2))), 0.0d));
                }
                plus2 = Matlab.plus(plus, Matlab.mtimes(multiply, vertcat));
                i++;
                if (i <= this.maxIter) {
                    if (this.calc_OV) {
                        arrayList.add(Double.valueOf((Matlab.sum(Matlab.sum(Matlab.power(Matlab.minus(realMatrix2, Matlab.mtimes(horzcat, vertcat)), 2.0d))).getEntry(0, 0) / 2.0d) + (this.lambda * Matlab.sum(Matlab.sum(Matlab.abs(vertcat))).getEntry(0, 0))));
                    }
                    if (i % 10 == 0 && this.verbose) {
                        if (this.calc_OV) {
                            System.out.format("Iter %d - ||PGrad||: %f, ofv: %f\n", Integer.valueOf(i), Double.valueOf(norm2), arrayList.get(arrayList.size() - 1));
                        } else {
                            System.out.format("Iter %d - ||PGrad||: %f\n", Integer.valueOf(i), Double.valueOf(norm2));
                        }
                    }
                } else if (this.verbose) {
                    System.out.println("Maximal iterations");
                }
            } else if (this.verbose) {
                System.out.println("Converge successfully!");
            }
        }
        return Matlab.minus(vertcat.getSubMatrix(0, this.p - 1, 0, this.ny - 1), vertcat.getSubMatrix(this.p, (2 * this.p) - 1, 0, this.ny - 1));
    }
}
