package jml.optimization;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import jml.matlab.Matlab;
import org.apache.commons.math.linear.RealMatrix;

/* loaded from: input_file:jml/optimization/LBFGS.class */
public class LBFGS {
    private static double rou_k;
    private static RealMatrix G = null;
    private static RealMatrix G_pre = null;
    private static RealMatrix X = null;
    private static RealMatrix X_pre = null;
    private static RealMatrix p = null;
    private static double fval = 0.0d;
    private static boolean gradientRequired = false;
    private static boolean converge = false;
    private static int state = 0;
    private static double t = 1.0d;
    private static double z = 0.0d;
    private static int k = 0;
    private static double alpha = 0.2d;
    private static double beta = 0.75d;
    private static int m = 30;
    private static double H = 0.0d;
    private static RealMatrix s_k = null;
    private static RealMatrix y_k = null;
    private static LinkedList<RealMatrix> s_ks = new LinkedList<>();
    private static LinkedList<RealMatrix> y_ks = new LinkedList<>();
    private static LinkedList<Double> rou_ks = new LinkedList<>();
    private static ArrayList<Double> J = new ArrayList<>();

    public static boolean[] run(RealMatrix realMatrix, double d, double d2, RealMatrix realMatrix2) {
        if (state == 4) {
            s_ks.clear();
            y_ks.clear();
            rou_ks.clear();
            J.clear();
            state = 0;
        }
        if (state == 0) {
            X = realMatrix2.copy();
            if (realMatrix == null) {
                System.err.println("Gradient is required on the first call!");
                System.exit(1);
            }
            G = realMatrix.copy();
            fval = d;
            if (Double.isNaN(fval)) {
                System.err.println("Object function value is nan!");
                System.exit(1);
            }
            System.out.format("Initial ofv: %g\n", Double.valueOf(fval));
            k = 0;
            state = 1;
        }
        if (state == 1) {
            double norm = Matlab.norm(G);
            if (norm < d2) {
                converge = true;
                gradientRequired = false;
                state = 4;
                System.out.printf("L-BFGS converges with norm(Grad) %f\n", Double.valueOf(norm));
                return new boolean[]{converge, gradientRequired};
            }
            if (k == 0) {
                H = 1.0d;
            } else {
                H = Matlab.innerProduct(s_k, y_k) / Matlab.innerProduct(y_k, y_k);
            }
            double[] dArr = new double[m];
            RealMatrix realMatrix3 = G;
            Iterator<RealMatrix> descendingIterator = s_ks.descendingIterator();
            Iterator<RealMatrix> descendingIterator2 = y_ks.descendingIterator();
            Iterator<Double> descendingIterator3 = rou_ks.descendingIterator();
            for (int size = s_ks.size() - 1; size >= 0; size--) {
                RealMatrix next = descendingIterator.next();
                RealMatrix next2 = descendingIterator2.next();
                dArr[size] = descendingIterator3.next().doubleValue() * Matlab.innerProduct(next, realMatrix3);
                realMatrix3 = realMatrix3.subtract(Matlab.times(dArr[size], next2));
            }
            RealMatrix times = Matlab.times(H, realMatrix3);
            Iterator<RealMatrix> it = s_ks.iterator();
            Iterator<RealMatrix> it2 = y_ks.iterator();
            Iterator<Double> it3 = rou_ks.iterator();
            for (int i = 0; i < s_ks.size(); i++) {
                times = times.add(Matlab.times(dArr[i] - (it3.next().doubleValue() * Matlab.innerProduct(it2.next(), times)), it.next()));
            }
            p = Matlab.uminus(times);
            t = 1.0d;
            z = Matlab.innerProduct(G, p);
            state = 2;
            Matlab.setMatrix(realMatrix2, Matlab.plus(X, Matlab.times(t, p)));
            converge = false;
            gradientRequired = false;
            return new boolean[]{converge, gradientRequired};
        }
        if (state == 2) {
            converge = false;
            if (d <= fval + (alpha * t * z)) {
                gradientRequired = true;
                state = 3;
            } else {
                t = beta * t;
                gradientRequired = false;
                Matlab.setMatrix(realMatrix2, Matlab.plus(X, Matlab.times(t, p)));
            }
            return new boolean[]{converge, gradientRequired};
        }
        if (state == 3) {
            X_pre = X.copy();
            G_pre = G.copy();
            if (Math.abs(d - fval) < 1.0E-32d) {
                converge = true;
                gradientRequired = false;
                System.out.printf("Objective function value doesn't decrease, iteration stopped!\n", new Object[0]);
                System.out.format("Iter %d, ofv: %g, norm(Grad): %g\n", Integer.valueOf(k + 1), Double.valueOf(fval), Double.valueOf(Matlab.norm(G)));
                return new boolean[]{converge, gradientRequired};
            }
            fval = d;
            J.add(Double.valueOf(fval));
            System.out.format("Iter %d, ofv: %g, norm(Grad): %g\n", Integer.valueOf(k + 1), Double.valueOf(fval), Double.valueOf(Matlab.norm(G)));
            X = realMatrix2.copy();
            G = realMatrix.copy();
            s_k = X.subtract(X_pre);
            y_k = Matlab.minus(G, G_pre);
            rou_k = 1.0d / Matlab.innerProduct(y_k, s_k);
            if (k >= m) {
                s_ks.removeFirst();
                y_ks.removeFirst();
                rou_ks.removeFirst();
            }
            s_ks.add(s_k);
            y_ks.add(y_k);
            rou_ks.add(Double.valueOf(rou_k));
            k++;
            state = 1;
        }
        converge = false;
        gradientRequired = false;
        return new boolean[]{converge, gradientRequired};
    }
}
