package jml.optimization;

import java.io.File;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import jml.data.Data;
import jml.matlab.Matlab;
import org.apache.commons.math.linear.RealMatrix;

/* loaded from: input_file:jml/optimization/LBFGSOnSimplex.class */
public class LBFGSOnSimplex {
    private static double rou_k;
    private static RealMatrix G = null;
    private static RealMatrix G_pre = null;
    private static RealMatrix X = null;
    private static RealMatrix X_pre = null;
    private static double fval = 0.0d;
    private static boolean gradientRequired = false;
    private static boolean converge = false;
    private static int state = 0;
    private static double t = 1.0d;
    private static int k = 0;
    private static double alpha = 0.2d;
    private static double beta = 0.75d;
    private static int m = 30;
    private static double H = 0.0d;
    private static RealMatrix s_k = null;
    private static RealMatrix y_k = null;
    private static LinkedList<RealMatrix> s_ks = new LinkedList<>();
    private static LinkedList<RealMatrix> y_ks = new LinkedList<>();
    private static LinkedList<Double> rou_ks = new LinkedList<>();
    private static RealMatrix z = null;
    private static RealMatrix z_t = null;
    private static RealMatrix p_z = null;
    private static RealMatrix I_z = null;
    private static RealMatrix G_z = null;
    private static RealMatrix PG_z = null;
    private static int i = -1;
    private static double tol = 1.0d;
    private static ArrayList<Double> J = new ArrayList<>();

    public static void main(String[] strArr) {
        RealMatrix rand = Matlab.rand(10);
        Matlab.minus(rand.multiply(rand.transpose()), Matlab.times(0.05d, Matlab.eye(10)));
        Matlab.times(3.0d, Matlab.rand(10, 1));
        RealMatrix loadMatrix = Data.loadMatrix(String.valueOf("C:/Aaron/My Codes/Matlab/Convex Optimization") + File.separator + "C.txt");
        RealMatrix loadMatrix2 = Data.loadMatrix(String.valueOf("C:/Aaron/My Codes/Matlab/Convex Optimization") + File.separator + "y.txt");
        long currentTimeMillis = System.currentTimeMillis();
        RealMatrix copy = Matlab.rdivide(Matlab.ones(10, 1), 10).copy();
        RealMatrix subtract = loadMatrix.multiply(copy).subtract(loadMatrix2);
        double norm = Matlab.norm(subtract);
        double norm2 = Matlab.norm(copy);
        double d = norm + (0.01d * norm2);
        RealMatrix plus = Matlab.plus(Matlab.rdivide(loadMatrix.transpose().multiply(subtract), norm), Matlab.times(0.01d, Matlab.rdivide(copy, norm2)));
        int i2 = 0;
        while (true) {
            boolean[] run = run(plus, d, 1.0E-6d, copy);
            if (run[0]) {
                break;
            }
            if (Matlab.sumAll(Matlab.isnan(copy)) > 0.0d) {
                int i3 = 1 + 1;
            }
            RealMatrix subtract2 = loadMatrix.multiply(copy).subtract(loadMatrix2);
            double norm3 = Matlab.norm(subtract2);
            double norm4 = Matlab.norm(copy);
            d = norm3 + (0.01d * norm4);
            if (run[1]) {
                i2++;
                if (i2 > 1000) {
                    break;
                } else {
                    plus = Matlab.plus(Matlab.rdivide(loadMatrix.transpose().multiply(subtract2), norm3), Matlab.times(0.01d, Matlab.rdivide(copy, norm4)));
                }
            }
        }
        Matlab.fprintf("fval_projected_LBFGS_Armijo: %g\n\n", Double.valueOf(d));
        Matlab.fprintf("x_projected_LBFGS_Armijo:\n", new Object[0]);
        Matlab.display(copy.transpose());
        Matlab.fprintf("Elapsed time: %.3f seconds\n", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
    }

    public static boolean[] run(RealMatrix realMatrix, double d, double d2, RealMatrix realMatrix2) {
        if (state == 4) {
            s_ks.clear();
            y_ks.clear();
            rou_ks.clear();
            J.clear();
            state = 0;
        }
        if (state == 0) {
            X = realMatrix2.copy();
            if (realMatrix == null) {
                System.err.println("Gradient is required on the first call!");
                System.exit(1);
            }
            G = realMatrix.copy();
            fval = d;
            if (Double.isNaN(fval)) {
                System.err.println("Object function value is nan!");
                System.exit(1);
            }
            System.out.format("Initial ofv: %g\n", Double.valueOf(fval));
            tol = d2 * Matlab.norm(G);
            k = 0;
            state = 1;
        }
        if (state != 1) {
            if (state == 2) {
                converge = false;
                if (d <= fval + (alpha * Matlab.innerProduct(G_z, Matlab.minus(z_t, z)))) {
                    gradientRequired = true;
                    state = 3;
                } else {
                    t = beta * t;
                    z_t = Matlab.subplus(Matlab.plus(z, Matlab.times(t, p_z)));
                    Matlab.logicalIndexingAssignment(realMatrix2, I_z, z_t);
                    realMatrix2.setEntry(i - 1, 0, 1.0d - Matlab.sumAll(z_t));
                    gradientRequired = false;
                }
                return new boolean[]{converge, gradientRequired};
            }
            if (state == 3) {
                X_pre = X.copy();
                G_pre = G.copy();
                fval = d;
                J.add(Double.valueOf(fval));
                System.out.format("Iter %d, ofv: %g, norm(PGrad_z): %g\n", Integer.valueOf(k + 1), Double.valueOf(fval), Double.valueOf(Matlab.norm(PG_z)));
                X = realMatrix2.copy();
                G = realMatrix.copy();
                s_k = X.subtract(X_pre);
                y_k = Matlab.minus(G, G_pre);
                rou_k = 1.0d / Matlab.innerProduct(y_k, s_k);
                if (k >= m) {
                    s_ks.removeFirst();
                    y_ks.removeFirst();
                    rou_ks.removeFirst();
                }
                s_ks.add(s_k);
                y_ks.add(y_k);
                rou_ks.add(Double.valueOf(rou_k));
                k++;
                state = 1;
            }
            converge = false;
            gradientRequired = false;
            return new boolean[]{converge, gradientRequired};
        }
        if (k == 0) {
            H = 1.0d;
        } else {
            H = Matlab.innerProduct(s_k, y_k) / Matlab.innerProduct(y_k, y_k);
        }
        double[] dArr = new double[m];
        RealMatrix realMatrix3 = G;
        Iterator<RealMatrix> descendingIterator = s_ks.descendingIterator();
        Iterator<RealMatrix> descendingIterator2 = y_ks.descendingIterator();
        Iterator<Double> descendingIterator3 = rou_ks.descendingIterator();
        for (int size = s_ks.size() - 1; size >= 0; size--) {
            RealMatrix next = descendingIterator.next();
            RealMatrix next2 = descendingIterator2.next();
            dArr[size] = descendingIterator3.next().doubleValue() * Matlab.innerProduct(next, realMatrix3);
            realMatrix3 = realMatrix3.subtract(Matlab.times(dArr[size], next2));
        }
        RealMatrix times = Matlab.times(H, realMatrix3);
        Iterator<RealMatrix> it = s_ks.iterator();
        Iterator<RealMatrix> it2 = y_ks.iterator();
        Iterator<Double> it3 = rou_ks.iterator();
        for (int i2 = 0; i2 < s_ks.size(); i2++) {
            times = times.add(Matlab.times(dArr[i2] - (it3.next().doubleValue() * Matlab.innerProduct(it2.next(), times)), it.next()));
        }
        i = ((int) Matlab.max(X, 1).get("idx").getEntry(0, 0)) + 1;
        int size2 = Matlab.size(X, 1);
        RealMatrix vertcat = Matlab.vertcat(Matlab.horzcat(Matlab.eye(i - 1), Matlab.zeros(i - 1, size2 - i)), Matlab.horzcat(Matlab.uminus(Matlab.ones(1, i - 1)), Matlab.uminus(Matlab.ones(1, size2 - i))), Matlab.horzcat(Matlab.zeros(size2 - i, i - 1), Matlab.eye(size2 - i)));
        I_z = Matlab.not(Matlab.vertcat(Matlab.zeros(i - 1, 1), Matlab.ones(1, 1), Matlab.zeros(size2 - i, 1)));
        z = Matlab.logicalIndexing(X, I_z);
        G_z = vertcat.transpose().multiply(G);
        RealMatrix not = Matlab.not(Matlab.or(Matlab.lt(G_z, 0.0d), Matlab.gt(z, 0.0d)));
        PG_z = G_z.copy();
        Matlab.logicalIndexingAssignment(PG_z, not, 0.0d);
        double norm = Matlab.norm(PG_z);
        if (norm < tol) {
            converge = true;
            gradientRequired = false;
            state = 4;
            System.out.printf("PLBFGS on simplex converges with norm(PGrad_z) %f\n", Double.valueOf(norm));
            return new boolean[]{converge, gradientRequired};
        }
        RealMatrix multiply = vertcat.transpose().multiply(times);
        RealMatrix not2 = Matlab.not(Matlab.or(Matlab.lt(multiply, 0.0d), Matlab.gt(z, 0.0d)));
        RealMatrix copy = multiply.copy();
        Matlab.logicalIndexingAssignment(copy, not2, 0.0d);
        if (Matlab.innerProduct(copy, G_z) <= 0.0d) {
            p_z = Matlab.uminus(PG_z);
        } else {
            p_z = Matlab.uminus(copy);
        }
        t = 1.0d;
        while (true) {
            z_t = Matlab.subplus(Matlab.plus(z, Matlab.times(t, p_z)));
            if (Matlab.sumAll(z_t) <= 1.0d) {
                state = 2;
                Matlab.logicalIndexingAssignment(realMatrix2, I_z, z_t);
                realMatrix2.setEntry(i - 1, 0, 1.0d - Matlab.sumAll(z_t));
                converge = false;
                gradientRequired = false;
                return new boolean[]{converge, gradientRequired};
            }
            t = beta * t;
        }
    }
}
