package jml.classification;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import jml.matlab.Matlab;
import jml.options.Options;
import org.apache.commons.math.linear.BlockRealMatrix;
import org.apache.commons.math.linear.RealMatrix;

/* loaded from: input_file:jml/classification/LogisticRegressionMCLBFGS_Ori.class */
public class LogisticRegressionMCLBFGS_Ori extends Classifier {
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    public static void main(String[] strArr) {
        ?? r0 = {new double[]{3.5d, 4.4d, 1.3d}, new double[]{5.3d, 2.2d, 0.5d}, new double[]{0.2d, 0.3d, 4.1d}, new double[]{-1.2d, 0.4d, 3.2d}};
        ?? r02 = {new double[]{1.0d, 0.0d, 0.0d}, new double[]{0.0d, 1.0d, 0.0d}, new double[]{0.0d, 0.0d, 1.0d}};
        BlockRealMatrix blockRealMatrix = new BlockRealMatrix(r02);
        Options options = new Options();
        options.epsilon = 1.0E-5d;
        LogisticRegressionMCLBFGS_Ori logisticRegressionMCLBFGS_Ori = new LogisticRegressionMCLBFGS_Ori(options);
        logisticRegressionMCLBFGS_Ori.feedData((double[][]) r0);
        logisticRegressionMCLBFGS_Ori.feedLabels((double[][]) r02);
        long currentTimeMillis = System.currentTimeMillis();
        logisticRegressionMCLBFGS_Ori.train();
        System.out.println("Projection matrix:");
        Matlab.printMatrix(logisticRegressionMCLBFGS_Ori.getProjectionMatrix());
        System.out.println("Ground truth:");
        Matlab.printMatrix(blockRealMatrix);
        RealMatrix predictLabelScoreMatrix = logisticRegressionMCLBFGS_Ori.predictLabelScoreMatrix((double[][]) r0);
        System.out.println("Predicted probability matrix:");
        Matlab.printMatrix(predictLabelScoreMatrix);
        RealMatrix predictLabelMatrix = logisticRegressionMCLBFGS_Ori.predictLabelMatrix((double[][]) r0);
        System.out.println("Predicted label matrix:");
        Matlab.printMatrix(predictLabelMatrix);
        System.out.format("Elapsed time: %.3f seconds\n", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
    }

    public LogisticRegressionMCLBFGS_Ori(Options options) {
        super(options);
    }

    @Override // jml.classification.Classifier
    public void train() {
        RealMatrix plus;
        RealMatrix sigmoid;
        double d;
        this.W = Matlab.repmat(Matlab.zeros(this.nFeature, 1), new int[]{1, this.nClass});
        RealMatrix sigmoid2 = Matlab.sigmoid(this.X.transpose().multiply(this.W));
        RealMatrix scalarMultiply = this.X.multiply(sigmoid2.subtract(this.Y)).scalarMultiply(1.0d / this.nSample);
        ArrayList arrayList = new ArrayList();
        double d2 = (-Matlab.sum(Matlab.sum(Matlab.times(this.Y, Matlab.log(sigmoid2.scalarAdd(Double.MIN_VALUE))))).getEntry(0, 0)) / this.nSample;
        arrayList.add(Double.valueOf(d2));
        System.out.format("Initial ofv: %g\n", Double.valueOf(d2));
        RealMatrix realMatrix = null;
        RealMatrix realMatrix2 = null;
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        double[] dArr = new double[10];
        int i = 0;
        while (Matlab.norm(scalarMultiply) >= this.epsilon) {
            double innerProduct = i == 0 ? 1.0d : Matlab.innerProduct(realMatrix, realMatrix2) / Matlab.innerProduct(realMatrix2, realMatrix2);
            RealMatrix realMatrix3 = scalarMultiply;
            Iterator descendingIterator = linkedList.descendingIterator();
            Iterator descendingIterator2 = linkedList2.descendingIterator();
            Iterator descendingIterator3 = linkedList3.descendingIterator();
            for (int size = linkedList.size() - 1; size >= 0; size--) {
                RealMatrix realMatrix4 = (RealMatrix) descendingIterator.next();
                RealMatrix realMatrix5 = (RealMatrix) descendingIterator2.next();
                dArr[size] = ((Double) descendingIterator3.next()).doubleValue() * Matlab.innerProduct(realMatrix4, realMatrix3);
                realMatrix3 = realMatrix3.subtract(Matlab.times(dArr[size], realMatrix5));
            }
            RealMatrix times = Matlab.times(innerProduct, realMatrix3);
            Iterator it = linkedList.iterator();
            Iterator it2 = linkedList2.iterator();
            Iterator it3 = linkedList3.iterator();
            for (int i2 = 0; i2 < linkedList.size(); i2++) {
                times = times.add(Matlab.times(dArr[i2] - (((Double) it3.next()).doubleValue() * Matlab.innerProduct((RealMatrix) it2.next(), times)), (RealMatrix) it.next()));
            }
            RealMatrix uminus = Matlab.uminus(times);
            double d3 = 1.0d;
            double innerProduct2 = Matlab.innerProduct(scalarMultiply, uminus);
            while (true) {
                plus = Matlab.plus(this.W, Matlab.times(d3, uminus));
                sigmoid = Matlab.sigmoid(this.X.transpose().multiply(plus));
                d = (-Matlab.sum(Matlab.sum(Matlab.times(this.Y, Matlab.log(Matlab.plus(sigmoid, Matlab.eps))))).getEntry(0, 0)) / this.nSample;
                if (d <= d2 + (0.2d * d3 * innerProduct2)) {
                    break;
                } else {
                    d3 = 0.5d * d3;
                }
            }
            RealMatrix realMatrix6 = this.W;
            RealMatrix realMatrix7 = scalarMultiply;
            if (Math.abs(d - d2) < Matlab.eps) {
                return;
            }
            d2 = d;
            arrayList.add(Double.valueOf(d2));
            System.out.format("Iter %d, ofv: %g\n", Integer.valueOf(i + 1), Double.valueOf(d2));
            this.W = plus;
            scalarMultiply = Matlab.rdivide(this.X.multiply(sigmoid.subtract(this.Y)), this.nSample);
            realMatrix = this.W.subtract(realMatrix6);
            realMatrix2 = Matlab.minus(scalarMultiply, realMatrix7);
            double innerProduct3 = 1.0d / Matlab.innerProduct(realMatrix2, realMatrix);
            if (i >= 10) {
                linkedList.removeFirst();
                linkedList2.removeFirst();
                linkedList3.removeFirst();
            }
            linkedList.add(realMatrix);
            linkedList2.add(realMatrix2);
            linkedList3.add(Double.valueOf(innerProduct3));
            i++;
        }
        System.out.println("Converge.");
    }

    @Override // jml.classification.Classifier
    public RealMatrix predictLabelScoreMatrix(RealMatrix realMatrix) {
        return Matlab.sigmoid(realMatrix.transpose().multiply(this.W));
    }

    @Override // jml.classification.Classifier
    public void loadModel(String str) {
        System.out.println("Loading model...");
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(str));
            this.W = (RealMatrix) objectInputStream.readObject();
            objectInputStream.close();
            System.out.println("Model loaded.");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            System.exit(1);
        } catch (IOException e2) {
            e2.printStackTrace();
        } catch (ClassNotFoundException e3) {
            e3.printStackTrace();
        }
    }

    @Override // jml.classification.Classifier
    public void saveModel(String str) {
        File parentFile = new File(str).getParentFile();
        if (parentFile != null && !parentFile.exists()) {
            parentFile.mkdirs();
        }
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
            objectOutputStream.writeObject(this.W);
            objectOutputStream.close();
            System.out.println("Model saved.");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            System.exit(1);
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }
}
