package jml.sequence;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Random;
import java.util.TreeMap;
import jml.matlab.Matlab;
import jml.operation.ArrayOperation;
import jml.optimization.LBFGS;
import org.apache.commons.math.linear.ArrayRealVector;
import org.apache.commons.math.linear.BlockRealMatrix;
import org.apache.commons.math.linear.OpenMapRealMatrix;
import org.apache.commons.math.linear.RealMatrix;

/* loaded from: input_file:jml/sequence/CRF.class */
public class CRF {
    RealMatrix[][][] Fs;
    int D;
    int d;
    int numStates;
    int startIdx;
    int[][] Ys;
    RealMatrix W;
    double epsilon;
    double sigma;
    int maxIter;

    public static void main(String[] strArr) {
        Object[] generateDataSequences = generateDataSequences(50, 5, 10, 5, 3, 0.2d);
        RealMatrix[][][] realMatrixArr = (RealMatrix[][][]) generateDataSequences[0];
        int[][] iArr = (int[][]) generateDataSequences[1];
        CRF crf = new CRF(1.0E-4d);
        crf.feedData(realMatrixArr);
        crf.feedLabels(iArr);
        crf.train();
        crf.saveModel("CRF-Model.dat");
        Matlab.fprintf("CRF Parameters:\n", new Object[0]);
        Matlab.display(crf.W);
        CRF crf2 = new CRF();
        crf2.loadModel("CRF-Model.dat");
        int nextInt = new Random().nextInt(50);
        int[] iArr2 = iArr[nextInt];
        RealMatrix[][] realMatrixArr2 = realMatrixArr[nextInt];
        Matlab.fprintf("True label sequence:\n", new Object[0]);
        Matlab.display(iArr2);
        Matlab.fprintf("Predicted label sequence:\n", new Object[0]);
        Matlab.display(crf2.predict(realMatrixArr2));
    }

    public CRF() {
        this.sigma = 1.0d;
        this.maxIter = 50;
        this.epsilon = 1.0E-4d;
        this.startIdx = 0;
    }

    public CRF(double d) {
        this.sigma = 1.0d;
        this.maxIter = 50;
        this.epsilon = d;
        this.startIdx = 0;
    }

    public CRF(int i) {
        this.sigma = 1.0d;
        this.maxIter = 50;
        this.d = i;
        this.epsilon = 1.0E-4d;
        this.startIdx = 0;
    }

    public CRF(int i, double d) {
        this.sigma = 1.0d;
        this.maxIter = 50;
        this.d = i;
        this.epsilon = d;
        this.startIdx = 0;
    }

    public void feedData(RealMatrix[][][] realMatrixArr) {
        this.Fs = realMatrixArr;
        this.D = realMatrixArr.length;
        this.d = realMatrixArr[0][0].length;
        this.numStates = realMatrixArr[0][0][0].getRowDimension();
    }

    public void feedLabels(int[][] iArr) {
        this.Ys = iArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static Object[] generateDataSequences(int i, int i2, int i3, int i4, int i5, double d) {
        Object[] objArr = new Object[2];
        RealMatrix[][] realMatrixArr = new RealMatrix[i];
        int[] iArr = new int[i];
        Random random = new Random();
        for (int i6 = 0; i6 < i; i6++) {
            int nextInt = random.nextInt((i3 - i2) + 1) + i2;
            realMatrixArr[i6] = new RealMatrix[nextInt];
            iArr[i6] = new int[nextInt];
            for (int i7 = 0; i7 < nextInt; i7++) {
                iArr[i6][i7] = random.nextInt(i5);
                realMatrixArr[i6][i7] = new RealMatrix[i4];
                for (int i8 = 0; i8 < i4; i8++) {
                    realMatrixArr[i6][i7][i8] = new OpenMapRealMatrix(i5, i5);
                    for (int i9 = 0; i9 < i5; i9++) {
                        for (int i10 = 0; i10 < i5; i10++) {
                            if (random.nextDouble() < d) {
                                realMatrixArr[i6][i7][i8].setEntry(i9, i10, 1.0d);
                            }
                        }
                    }
                }
            }
        }
        objArr[0] = realMatrixArr;
        objArr[1] = iArr;
        return objArr;
    }

    public void train() {
        this.W = Matlab.times(10.0d, Matlab.ones(this.d, 1));
        BlockRealMatrix blockRealMatrix = new BlockRealMatrix(this.d, 1);
        double computeObjectiveFunctionValue = computeObjectiveFunctionValue(true, blockRealMatrix, this.W);
        int i = 0;
        while (true) {
            boolean[] run = LBFGS.run(blockRealMatrix, computeObjectiveFunctionValue, this.epsilon, this.W);
            if (run[0]) {
                return;
            }
            if (Matlab.sumAll(Matlab.isnan(this.W)) > 0.0d) {
                int i2 = 1 + 1;
            }
            computeObjectiveFunctionValue = computeObjectiveFunctionValue(run[1], blockRealMatrix, this.W);
            if (run[1]) {
                i++;
                if (i > this.maxIter) {
                    return;
                }
            }
        }
    }

    public double computeObjectiveFunctionValue(boolean z, RealMatrix realMatrix, RealMatrix realMatrix2) {
        double d;
        double entry;
        double d2 = 0.0d;
        BlockRealMatrix blockRealMatrix = new BlockRealMatrix(this.d, 1);
        BlockRealMatrix blockRealMatrix2 = new BlockRealMatrix(this.d, 1);
        double[] allocateVector = ArrayOperation.allocateVector(this.d);
        for (int i = 0; i < this.d; i++) {
            double d3 = 0.0d;
            for (int i2 = 0; i2 < this.D; i2++) {
                int[] iArr = this.Ys[i2];
                int length = this.Fs[i2].length;
                for (int i3 = 0; i3 < length; i3++) {
                    if (i3 == 0) {
                        d = d3;
                        entry = this.Fs[i2][i3][i].getEntry(0, iArr[i3]);
                    } else {
                        d = d3;
                        entry = this.Fs[i2][i3][i].getEntry(iArr[i3 - 1], iArr[i3]);
                    }
                    d3 = d + entry;
                }
            }
            blockRealMatrix.setEntry(i, 0, d3);
        }
        RealMatrix realMatrix3 = null;
        for (int i4 = 0; i4 < this.D; i4++) {
            int length2 = this.Fs[i4].length;
            RealMatrix[] realMatrixArr = new RealMatrix[length2];
            for (int i5 = 0; i5 < length2; i5++) {
                for (int i6 = 0; i6 < this.d; i6++) {
                    realMatrix3 = this.Fs[i4][i5][i6];
                    if (i6 == 0) {
                        realMatrixArr[i5] = Matlab.times(realMatrix2.getEntry(i6, 0), realMatrix3);
                    } else {
                        realMatrixArr[i5] = Matlab.plus(realMatrixArr[i5], Matlab.times(realMatrix2.getEntry(i6, 0), realMatrix3));
                    }
                }
                realMatrixArr[i5] = Matlab.exp(realMatrixArr[i5]);
                for (int i7 = 0; i7 < this.numStates; i7++) {
                    if (ArrayOperation.sum(realMatrixArr[i5].getRow(i7)) == 0.0d) {
                        realMatrixArr[i5].setRow(i7, ArrayOperation.allocateVector(this.numStates, 1.0E-10d));
                    }
                }
            }
            BlockRealMatrix blockRealMatrix3 = new BlockRealMatrix(this.numStates, length2);
            BlockRealMatrix blockRealMatrix4 = new BlockRealMatrix(this.numStates, 1);
            ArrayRealVector arrayRealVector = new ArrayRealVector(this.numStates);
            arrayRealVector.setEntry(this.startIdx, 1.0d);
            double[] allocateVector2 = ArrayOperation.allocateVector(length2);
            blockRealMatrix4.setColumnVector(0, arrayRealVector);
            for (int i8 = 0; i8 < length2; i8++) {
                if (i8 == 0) {
                    blockRealMatrix3.setColumnMatrix(i8, realMatrixArr[i8].transpose().multiply(blockRealMatrix4));
                } else {
                    blockRealMatrix3.setColumnMatrix(i8, realMatrixArr[i8].transpose().multiply(blockRealMatrix3.getColumnMatrix(i8 - 1)));
                }
                allocateVector2[i8] = 1.0d / Matlab.sum(blockRealMatrix3.getColumnVector(i8));
                if (Double.isInfinite(allocateVector2[i8])) {
                    int i9 = 1 + 1;
                }
                blockRealMatrix3.setColumnMatrix(i8, Matlab.times(allocateVector2[i8], blockRealMatrix3.getColumnMatrix(i8)));
            }
            BlockRealMatrix blockRealMatrix5 = new BlockRealMatrix(this.numStates, length2);
            for (int i10 = length2 - 1; i10 >= 0; i10--) {
                if (i10 == length2 - 1) {
                    blockRealMatrix5.setColumnMatrix(i10, Matlab.ones(this.numStates, 1));
                } else {
                    blockRealMatrix5.setColumnMatrix(i10, Matlab.mtimes(realMatrixArr[i10 + 1], blockRealMatrix5.getColumnMatrix(i10 + 1)));
                }
                blockRealMatrix5.setColumnMatrix(i10, Matlab.times(allocateVector2[i10], blockRealMatrix5.getColumnMatrix(i10)));
            }
            for (int i11 = 0; i11 < length2; i11++) {
                d2 -= Math.log(allocateVector2[i11]);
            }
            if (z) {
                for (int i12 = 0; i12 < this.d; i12++) {
                    for (int i13 = 0; i13 < length2; i13++) {
                        if (i13 == 0) {
                            int i14 = i12;
                            allocateVector[i14] = allocateVector[i14] + blockRealMatrix4.transpose().multiply(Matlab.times(realMatrixArr[i13], realMatrix3)).multiply(blockRealMatrix5.getColumnMatrix(i13)).getEntry(0, 0);
                        } else {
                            int i15 = i12;
                            allocateVector[i15] = allocateVector[i15] + blockRealMatrix3.getColumnMatrix(i13 - 1).transpose().multiply(Matlab.times(realMatrixArr[i13], realMatrix3)).multiply(blockRealMatrix5.getColumnMatrix(i13)).getEntry(0, 0);
                        }
                    }
                }
            }
        }
        double innerProduct = ((d2 - Matlab.innerProduct(realMatrix2, blockRealMatrix)) + (this.sigma * Matlab.innerProduct(realMatrix2, realMatrix2))) / this.D;
        if (!z) {
            return innerProduct;
        }
        blockRealMatrix2.setColumn(0, allocateVector);
        Matlab.setMatrix(realMatrix, Matlab.rdivide(Matlab.plus(Matlab.minus(blockRealMatrix2, blockRealMatrix), Matlab.times(2.0d * this.sigma, realMatrix2)), this.D));
        return innerProduct;
    }

    public int[] predict2(RealMatrix[][] realMatrixArr) {
        int length = realMatrixArr.length;
        RealMatrix[] realMatrixArr2 = new RealMatrix[length];
        for (int i = 0; i < length; i++) {
            realMatrixArr2[i] = new OpenMapRealMatrix(this.numStates, this.numStates);
            for (int i2 = 0; i2 < this.d; i2++) {
                realMatrixArr2[i] = Matlab.plus(realMatrixArr2[i], Matlab.times(this.W.getEntry(i2, 0), realMatrixArr[i][i2]));
            }
        }
        BlockRealMatrix blockRealMatrix = new BlockRealMatrix(length, this.numStates);
        BlockRealMatrix blockRealMatrix2 = new BlockRealMatrix(length, this.numStates);
        for (int i3 = 0; i3 < length; i3++) {
            if (i3 == 0) {
                blockRealMatrix.setRowMatrix(i3, realMatrixArr2[i3].getRowMatrix(this.startIdx));
                blockRealMatrix2.setRowMatrix(i3, Matlab.uminus(Matlab.ones(1, this.numStates)));
            } else {
                TreeMap<String, RealMatrix> max = Matlab.max(Matlab.plus(Matlab.repmat(blockRealMatrix.getRowMatrix(i3 - 1).transpose(), 1, this.numStates), realMatrixArr2[i3]), 1);
                blockRealMatrix.setRowMatrix(i3, max.get("val"));
                blockRealMatrix2.setRowMatrix(i3, max.get("idx"));
            }
        }
        double[] row = blockRealMatrix.getRow(length - 1);
        int[] allocateIntegerVector = ArrayOperation.allocateIntegerVector(length);
        for (int i4 = length - 1; i4 >= 0; i4--) {
            if (i4 == length - 1) {
                allocateIntegerVector[i4] = ArrayOperation.argmax(row);
            } else {
                allocateIntegerVector[i4] = (int) blockRealMatrix2.getEntry(i4 + 1, allocateIntegerVector[i4 + 1]);
            }
        }
        Matlab.fprintf("P*(YPred|x) = %g\n", Double.valueOf(Math.exp(row[allocateIntegerVector[length - 1]])));
        return allocateIntegerVector;
    }

    public int[] predict(RealMatrix[][] realMatrixArr) {
        RealMatrix[] computeTransitionMatrix = computeTransitionMatrix(realMatrixArr);
        int length = realMatrixArr.length;
        double[] allocateVector = ArrayOperation.allocateVector(length);
        BlockRealMatrix blockRealMatrix = new BlockRealMatrix(this.numStates, length);
        for (int i = length - 1; i >= 0; i--) {
            if (i == length - 1) {
                blockRealMatrix.setColumnMatrix(i, Matlab.ones(this.numStates, 1));
            } else {
                blockRealMatrix.setColumnMatrix(i, Matlab.mtimes(computeTransitionMatrix[i + 1], blockRealMatrix.getColumnMatrix(i + 1)));
            }
            allocateVector[i] = 1.0d / Matlab.sum(blockRealMatrix.getColumnVector(i));
            blockRealMatrix.setColumnMatrix(i, Matlab.times(allocateVector[i], blockRealMatrix.getColumnMatrix(i)));
        }
        RealMatrix[] realMatrixArr2 = new RealMatrix[length];
        for (int i2 = 0; i2 < length; i2++) {
            realMatrixArr2[i2] = Matlab.times(computeTransitionMatrix[i2], Matlab.repmat(blockRealMatrix.getColumnMatrix(i2).transpose(), this.numStates, 1));
            realMatrixArr2[i2] = Matlab.rdivide(realMatrixArr2[i2], Matlab.repmat(Matlab.sum(realMatrixArr2[i2], 2), 1, this.numStates));
        }
        BlockRealMatrix blockRealMatrix2 = new BlockRealMatrix(length, this.numStates);
        BlockRealMatrix blockRealMatrix3 = new BlockRealMatrix(length, this.numStates);
        for (int i3 = 0; i3 < length; i3++) {
            if (i3 == 0) {
                blockRealMatrix2.setRowMatrix(i3, Matlab.log(realMatrixArr2[i3].getRowMatrix(this.startIdx)));
                blockRealMatrix3.setRowMatrix(i3, Matlab.uminus(Matlab.ones(1, this.numStates)));
            } else {
                TreeMap<String, RealMatrix> max = Matlab.max(Matlab.plus(Matlab.repmat(blockRealMatrix2.getRowMatrix(i3 - 1).transpose(), 1, this.numStates), Matlab.log(realMatrixArr2[i3])), 1);
                blockRealMatrix2.setRowMatrix(i3, max.get("val"));
                blockRealMatrix3.setRowMatrix(i3, max.get("idx"));
            }
        }
        double[] row = blockRealMatrix2.getRow(length - 1);
        int[] allocateIntegerVector = ArrayOperation.allocateIntegerVector(length);
        for (int i4 = length - 1; i4 >= 0; i4--) {
            if (i4 == length - 1) {
                allocateIntegerVector[i4] = ArrayOperation.argmax(row);
            } else {
                allocateIntegerVector[i4] = (int) blockRealMatrix3.getEntry(i4 + 1, allocateIntegerVector[i4 + 1]);
            }
        }
        Matlab.fprintf("P*(YPred|x) = %g\n", Double.valueOf(Math.exp(row[allocateIntegerVector[length - 1]])));
        return allocateIntegerVector;
    }

    public RealMatrix[] computeTransitionMatrix(RealMatrix[][] realMatrixArr) {
        int length = realMatrixArr.length;
        RealMatrix[] realMatrixArr2 = new RealMatrix[length];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < this.d; i2++) {
                RealMatrix realMatrix = realMatrixArr[i][i2];
                if (i2 == 0) {
                    realMatrixArr2[i] = Matlab.times(this.W.getEntry(i2, 0), realMatrix);
                } else {
                    realMatrixArr2[i] = Matlab.plus(realMatrixArr2[i], Matlab.times(this.W.getEntry(i2, 0), realMatrix));
                }
            }
            realMatrixArr2[i] = Matlab.exp(realMatrixArr2[i]);
            for (int i3 = 0; i3 < this.numStates; i3++) {
                if (ArrayOperation.sum(realMatrixArr2[i].getRow(i3)) == 0.0d) {
                    realMatrixArr2[i].setRow(i3, ArrayOperation.allocateVector(this.numStates, 1.0E-10d));
                }
            }
        }
        return realMatrixArr2;
    }

    public RealMatrix computeFeatureVector(RealMatrix[][] realMatrixArr, int[] iArr) {
        double d;
        double entry;
        BlockRealMatrix blockRealMatrix = new BlockRealMatrix(this.d, 1);
        for (int i = 0; i < this.d; i++) {
            double d2 = 0.0d;
            int length = realMatrixArr.length;
            for (int i2 = 0; i2 < length; i2++) {
                if (i2 == 0) {
                    d = d2;
                    entry = realMatrixArr[i2][i].getEntry(0, iArr[i2]);
                } else {
                    d = d2;
                    entry = realMatrixArr[i2][i].getEntry(iArr[i2 - 1], iArr[i2]);
                }
                d2 = d + entry;
            }
            blockRealMatrix.setEntry(i, 0, d2);
        }
        return blockRealMatrix;
    }

    public RealMatrix backwardRecursion4Viterbi(RealMatrix[] realMatrixArr) {
        int length = this.Fs.length;
        double[] allocateVector = ArrayOperation.allocateVector(length);
        BlockRealMatrix blockRealMatrix = new BlockRealMatrix(this.numStates, length);
        for (int i = length - 1; i >= 0; i--) {
            if (i == length - 1) {
                blockRealMatrix.setColumnMatrix(i, Matlab.ones(this.numStates, 1));
            } else {
                blockRealMatrix.setColumnMatrix(i, Matlab.mtimes(realMatrixArr[i + 1], blockRealMatrix.getColumnMatrix(i + 1)));
            }
            allocateVector[i] = 1.0d / Matlab.sum(blockRealMatrix.getColumnVector(i));
            blockRealMatrix.setColumnMatrix(i, Matlab.times(allocateVector[i], blockRealMatrix.getColumnMatrix(i)));
        }
        return blockRealMatrix;
    }

    public void saveModel(String str) {
        File parentFile = new File(str).getParentFile();
        if (parentFile != null && !parentFile.exists()) {
            parentFile.mkdirs();
        }
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
            objectOutputStream.writeObject(new CRFModel(this.numStates, this.startIdx, this.W));
            objectOutputStream.close();
            System.out.println("Model saved.");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            System.exit(1);
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    public void loadModel(String str) {
        System.out.println("Loading model...");
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(str));
            CRFModel cRFModel = (CRFModel) objectInputStream.readObject();
            this.W = cRFModel.W;
            this.d = cRFModel.d;
            this.startIdx = cRFModel.startIdx;
            this.numStates = cRFModel.numStates;
            objectInputStream.close();
            System.out.println("Model loaded.");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            System.exit(1);
        } catch (IOException e2) {
            e2.printStackTrace();
        } catch (ClassNotFoundException e3) {
            e3.printStackTrace();
        }
    }
}
