package jml.classification;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import jml.matlab.Matlab;
import jml.optimization.LBFGS;
import org.apache.commons.math.linear.BlockRealMatrix;
import org.apache.commons.math.linear.RealMatrix;

/* loaded from: input_file:jml/classification/MaxEnt.class */
public class MaxEnt extends Classifier {
    private RealMatrix[] F;

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v2, types: [double[][], double[][][]] */
    public static void main(String[] strArr) {
        long currentTimeMillis = System.currentTimeMillis();
        ?? r0 = {new double[]{new double[]{1.0d, 0.0d, 0.0d}, new double[]{2.0d, 1.0d, -1.0d}, new double[]{0.0d, 1.0d, 2.0d}, new double[]{-1.0d, 2.0d, 1.0d}}, new double[]{new double[]{0.0d, 2.0d, 0.0d}, new double[]{1.0d, 0.0d, -1.0d}, new double[]{0.0d, 1.0d, 1.0d}, new double[]{-1.0d, 3.0d, 0.5d}}, new double[]{new double[]{0.0d, 0.0d, 0.8d}, new double[]{2.0d, 1.0d, -1.0d}, new double[]{1.0d, 3.0d, 0.0d}, new double[]{-0.5d, -1.0d, 2.0d}}, new double[]{new double[]{0.5d, 0.0d, 0.0d}, new double[]{1.0d, 1.0d, -1.0d}, new double[]{0.0d, 0.5d, 1.5d}, new double[]{-2.0d, 1.5d, 1.0d}}};
        MaxEnt maxEnt = new MaxEnt();
        maxEnt.feedData((double[][][]) r0);
        maxEnt.feedLabels(new int[]{1, 2, 3, 1});
        maxEnt.train();
        System.out.format("Elapsed time: %.3f seconds\n", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
        Matlab.fprintf("MaxEnt parameters:\n", new Object[0]);
        Matlab.display(maxEnt.W);
        maxEnt.saveModel("MaxEnt-Model.dat");
        MaxEnt maxEnt2 = new MaxEnt();
        maxEnt2.loadModel("MaxEnt-Model.dat");
        Matlab.fprintf("Predicted probability matrix:\n", new Object[0]);
        Matlab.display(maxEnt2.predictLabelScoreMatrix((double[][][]) r0));
        Matlab.fprintf("Predicted label matrix:\n", new Object[0]);
        Matlab.display(Matlab.full(maxEnt2.predictLabelMatrix((double[][][]) r0)));
        Matlab.fprintf("Predicted labels:\n", new Object[0]);
        Matlab.display(maxEnt2.predict((double[][][]) r0));
    }

    public void feedData(double[][][] dArr) {
        this.F = new RealMatrix[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            this.F[i] = new BlockRealMatrix(dArr[i]);
        }
        this.nSample = dArr.length;
        this.nFeature = dArr[0].length;
        this.nClass = dArr[0][0].length;
    }

    public void feedData(RealMatrix[] realMatrixArr) {
        this.F = realMatrixArr;
        this.nSample = realMatrixArr.length;
        this.nFeature = realMatrixArr[0].getRowDimension();
        this.nClass = realMatrixArr[0].getColumnDimension();
    }

    @Override // jml.classification.Classifier
    public void train() {
        RealMatrix realMatrix = null;
        BlockRealMatrix blockRealMatrix = new BlockRealMatrix(this.nSample, this.nClass);
        this.W = Matlab.zeros(this.nFeature, 1);
        for (int i = 0; i < this.nSample; i++) {
            blockRealMatrix.setRowMatrix(i, this.W.transpose().multiply(this.F[i]));
        }
        RealMatrix sigmoid = Matlab.sigmoid(blockRealMatrix);
        int i2 = 0;
        while (i2 < this.nSample) {
            RealMatrix rdivide = Matlab.rdivide(Matlab.minus(Matlab.mtimes(this.F[i2], sigmoid.getRowMatrix(i2).transpose()), this.F[i2].getColumnMatrix(this.labelIDs[i2])), this.nSample);
            realMatrix = i2 == 0 ? rdivide : Matlab.plus(realMatrix, rdivide);
            i2++;
        }
        double d = (-Matlab.sum(Matlab.log(Matlab.logicalIndexing(sigmoid, this.Y)), 1).getEntry(0, 0)) / this.nSample;
        while (true) {
            boolean[] run = LBFGS.run(realMatrix, d, this.epsilon, this.W);
            if (run[0]) {
                return;
            }
            for (int i3 = 0; i3 < this.nSample; i3++) {
                blockRealMatrix.setRowMatrix(i3, this.W.transpose().multiply(this.F[i3]));
            }
            RealMatrix sigmoid2 = Matlab.sigmoid(blockRealMatrix);
            d = (-Matlab.sum(Matlab.log(Matlab.logicalIndexing(sigmoid2, this.Y)), 1).getEntry(0, 0)) / this.nSample;
            if (run[1]) {
                int i4 = 0;
                while (i4 < this.nSample) {
                    RealMatrix rdivide2 = Matlab.rdivide(Matlab.minus(Matlab.mtimes(this.F[i4], sigmoid2.getRowMatrix(i4).transpose()), this.F[i4].getColumnMatrix(this.labelIDs[i4])), this.nSample);
                    realMatrix = i4 == 0 ? rdivide2 : Matlab.plus(realMatrix, rdivide2);
                    i4++;
                }
            }
        }
    }

    public int[] predict(double[][][] dArr) {
        int length = dArr.length;
        RealMatrix[] realMatrixArr = new RealMatrix[length];
        for (int i = 0; i < length; i++) {
            realMatrixArr[i] = new BlockRealMatrix(dArr[i]);
        }
        return predict(realMatrixArr);
    }

    public int[] predict(RealMatrix[] realMatrixArr) {
        int[] labelScoreMatrix2LabelIndexArray = labelScoreMatrix2LabelIndexArray(predictLabelScoreMatrix(realMatrixArr));
        int[] iArr = new int[labelScoreMatrix2LabelIndexArray.length];
        for (int i = 0; i < labelScoreMatrix2LabelIndexArray.length; i++) {
            iArr[i] = this.IDLabelMap[labelScoreMatrix2LabelIndexArray[i]];
        }
        return iArr;
    }

    @Override // jml.classification.Classifier
    public void loadModel(String str) {
        System.out.println("Loading model...");
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(str));
            MaxEntModel maxEntModel = (MaxEntModel) objectInputStream.readObject();
            this.nClass = maxEntModel.nClass;
            this.W = maxEntModel.W;
            this.IDLabelMap = maxEntModel.IDLabelMap;
            this.nFeature = maxEntModel.nFeature;
            objectInputStream.close();
            System.out.println("Model loaded.");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            System.exit(1);
        } catch (IOException e2) {
            e2.printStackTrace();
        } catch (ClassNotFoundException e3) {
            e3.printStackTrace();
        }
    }

    @Override // jml.classification.Classifier
    public void saveModel(String str) {
        File parentFile = new File(str).getParentFile();
        if (parentFile != null && !parentFile.exists()) {
            parentFile.mkdirs();
        }
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
            objectOutputStream.writeObject(new MaxEntModel(this.nClass, this.W, this.IDLabelMap));
            objectOutputStream.close();
            System.out.println("Model saved.");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            System.exit(1);
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    @Override // jml.classification.Classifier
    public RealMatrix predictLabelScoreMatrix(RealMatrix realMatrix) {
        return null;
    }

    public RealMatrix predictLabelScoreMatrix(RealMatrix[] realMatrixArr) {
        int length = realMatrixArr.length;
        BlockRealMatrix blockRealMatrix = new BlockRealMatrix(length, realMatrixArr[0].getColumnDimension());
        for (int i = 0; i < length; i++) {
            blockRealMatrix.setRowMatrix(i, this.W.transpose().multiply(realMatrixArr[i]));
        }
        return Matlab.sigmoid(blockRealMatrix);
    }

    public RealMatrix predictLabelScoreMatrix(double[][][] dArr) {
        int length = dArr.length;
        RealMatrix[] realMatrixArr = new RealMatrix[length];
        for (int i = 0; i < length; i++) {
            realMatrixArr[i] = new BlockRealMatrix(dArr[i]);
        }
        return predictLabelScoreMatrix(realMatrixArr);
    }

    public RealMatrix predictLabelMatrix(RealMatrix[] realMatrixArr) {
        return labelIndexArray2LabelMatrix(labelScoreMatrix2LabelIndexArray(predictLabelScoreMatrix(realMatrixArr)), this.nClass);
    }

    public RealMatrix predictLabelMatrix(double[][][] dArr) {
        int length = dArr.length;
        RealMatrix[] realMatrixArr = new RealMatrix[length];
        for (int i = 0; i < length; i++) {
            realMatrixArr[i] = new BlockRealMatrix(dArr[i]);
        }
        return predictLabelMatrix(realMatrixArr);
    }
}
