package org.metaqtl.algo;

import java.util.Random;
import org.metaqtl.EMResult;
import org.metaqtl.numrec.NumericalUtilities;

/* loaded from: input_file:org/metaqtl/algo/EMAlgorithm.class */
public final class EMAlgorithm {
    public static int EM_START = 50;
    public static int EM_ITER_MAX = 1000;
    public static double EM_ERR = 1.0E-8d;
    public static double EM_MIN_DISTANCE = 1.0E-25d;
    public static boolean DO_SEM = true;
    public static final double TINY_Z_PROBA = Double.MIN_VALUE;
    public static final int EM_OK = 0;
    public static final int EM_CONTINUE = 1;
    public static final int EM_FAILURE = 2;

    public static EMResult doEM(double[] dArr, double[] dArr2, int i, EMResult eMResult) {
        int iterate;
        int iterate2;
        if (i < 1 || dArr == null || dArr2 == null) {
            return null;
        }
        int i2 = (EM_START <= 0 || eMResult != null) ? 1 : (i <= 1 || i >= dArr.length) ? 1 : EM_START;
        EMResult[] eMResultArr = new EMResult[i2];
        EMResult[] eMResultArr2 = new EMResult[i2];
        Random random = new Random();
        int i3 = 0;
        while (i3 < i2) {
            EMResult eMResult2 = new EMResult(dArr.length, i);
            eMResultArr2[i3] = new EMResult(dArr.length, i);
            double[] dArr3 = new double[i];
            if (eMResult == null) {
                initEM(dArr, dArr2, eMResult2, random);
                EMResult.copy(eMResultArr2[i3], eMResult2);
            } else {
                EMResult.copy(eMResult2, eMResult);
            }
            int i4 = 0;
            do {
                i4++;
                for (int i5 = 0; i5 < eMResult2.k; i5++) {
                    dArr3[i5] = eMResult2.mu[i5];
                }
                iterate2 = iterate(dArr, dArr2, eMResult2);
                if (i4 > EM_ITER_MAX) {
                    break;
                }
                eMResultArr[i3] = eMResult2;
                i3++;
            } while (iterate2 == 1);
            eMResultArr[i3] = eMResult2;
            i3++;
        }
        double d = Double.NEGATIVE_INFINITY;
        EMResult eMResult3 = null;
        EMResult eMResult4 = null;
        for (int i6 = 0; i6 < i3; i6++) {
            if (eMResultArr[i6].olog > d) {
                boolean z = true;
                if (eMResult == null && i > 1 && i < dArr.length) {
                    double[] euclidean = eMResultArr[i6].getEuclidean();
                    int i7 = 0;
                    while (i7 < euclidean.length && euclidean[i7] >= EM_MIN_DISTANCE) {
                        i7++;
                    }
                    z = i7 == euclidean.length;
                }
                if (z) {
                    d = eMResultArr[i6].olog;
                    eMResult4 = eMResultArr[i6];
                    eMResult3 = eMResultArr2[i6];
                } else {
                    System.err.println(new StringBuffer("Unable to find valid mixture model for K = ").append(i).toString());
                }
            }
        }
        if (eMResult4 != null && DO_SEM) {
            int i8 = 0;
            double[] dArr4 = new double[i];
            do {
                i8++;
                for (int i9 = 0; i9 < eMResult3.k; i9++) {
                    dArr4[i9] = eMResult3.mu[i9];
                }
                iterate = iterate(dArr, dArr2, eMResult3);
                if (i8 > EM_ITER_MAX) {
                    break;
                }
                if (DO_SEM) {
                    computeDM(dArr, dArr2, dArr4, eMResult3);
                }
            } while (iterate == 1);
            eMResult4.dm = eMResult3.dm;
        }
        if (eMResult4 != null) {
            computeCOV(dArr2, eMResult4);
        }
        return eMResult4;
    }

    public static void initEM(double[] dArr, double[] dArr2, EMResult eMResult, Random random) {
        if (eMResult.k == 1) {
            for (int i = 0; i < eMResult.n; i++) {
                eMResult.z[0][i] = 1.0d;
            }
            eMResult.pi[0] = 1.0d;
            mStep(dArr, dArr2, eMResult, EM_ERR);
        } else if (eMResult.k == eMResult.n) {
            for (int i2 = 0; i2 < eMResult.k; i2++) {
                eMResult.mu[i2] = dArr[i2];
                eMResult.pi[i2] = 1.0d / eMResult.k;
                for (int i3 = 0; i3 < eMResult.k; i3++) {
                    if (i3 != i2) {
                        eMResult.z[i2][i3] = 0.0d;
                    } else {
                        eMResult.z[i2][i3] = 1.0d;
                    }
                }
            }
            eMResult.sortCluster();
        } else {
            int i4 = eMResult.k;
            int[] iArr = new int[i4];
            for (int i5 = 0; i5 < eMResult.k; i5++) {
                iArr[i5] = i5;
            }
            for (int i6 = 0; i6 < eMResult.k; i6++) {
                int nextInt = random.nextInt(i4);
                for (int i7 = 0; i7 < eMResult.k; i7++) {
                    if (i7 == iArr[nextInt]) {
                        eMResult.z[i7][i6] = 1.0d;
                    } else {
                        eMResult.z[i7][i6] = 0.0d;
                    }
                }
                for (int i8 = nextInt; i8 < i4 - 1; i8++) {
                    iArr[i8] = iArr[i8 + 1];
                }
                i4--;
            }
            for (int i9 = eMResult.k; i9 < eMResult.n; i9++) {
                int nextInt2 = random.nextInt(eMResult.k);
                for (int i10 = 0; i10 < eMResult.k; i10++) {
                    if (i10 == nextInt2) {
                        eMResult.z[i10][i9] = 1.0d;
                    } else {
                        eMResult.z[i10][i9] = 0.0d;
                    }
                }
            }
            mStep(dArr, dArr2, eMResult, EM_ERR);
        }
        updateLoglikelihood(dArr, dArr2, eMResult);
    }

    public static int iterate(double[] dArr, double[] dArr2, EMResult eMResult) {
        if (eMResult.k <= 1 || eMResult.k >= eMResult.n) {
            return eMResult.k == 1 ? 0 : 0;
        }
        eStep(dArr, dArr2, eMResult);
        return mStep(dArr, dArr2, eMResult, EM_ERR);
    }

    public static void eStep(double[] dArr, double[] dArr2, EMResult eMResult) {
        updateZMatrix(dArr, dArr2, eMResult);
    }

    public static int mStep(double[] dArr, double[] dArr2, EMResult eMResult, double d) {
        double[] updateMuVector = updateMuVector(dArr, dArr2, eMResult);
        double[] updatePiVector = updatePiVector(eMResult);
        double[] dArr3 = eMResult.mu;
        eMResult.mu = updateMuVector;
        double[] dArr4 = eMResult.pi;
        eMResult.pi = updatePiVector;
        double d2 = -eMResult.olog;
        eMResult.sortCluster();
        updateLoglikelihood(dArr, dArr2, eMResult);
        double d3 = d2 + eMResult.olog;
        if (d3 >= 0.0d) {
            if (d3 < d) {
                emRate(eMResult, dArr3, dArr4);
                return 0;
            }
            emRate(eMResult, dArr3, dArr4);
            return 1;
        }
        double[] dArr5 = eMResult.mu;
        eMResult.mu = dArr3;
        double[] dArr6 = eMResult.pi;
        eMResult.pi = dArr4;
        eMResult.sortCluster();
        updateLoglikelihood(dArr, dArr2, eMResult);
        return 2;
    }

    public static void emRate(EMResult eMResult, double[] dArr, double[] dArr2) {
        double d = 0.0d;
        if (eMResult.k <= 1) {
            eMResult.edist = 0.0d;
            eMResult.rate = 0.0d;
            return;
        }
        for (int i = 0; i < eMResult.k; i++) {
            d = d + ((dArr[i] - eMResult.mu[i]) * (dArr[i] - eMResult.mu[i])) + ((dArr2[i] - eMResult.pi[i]) * (dArr2[i] - eMResult.pi[i]));
        }
        if (eMResult.edist > Double.MIN_VALUE) {
            eMResult.rate = Math.sqrt(d) / eMResult.edist;
        } else {
            eMResult.rate = 0.0d;
        }
        eMResult.edist = Math.sqrt(d);
    }

    public static double[] updateMuVector(double[] dArr, double[] dArr2, EMResult eMResult) {
        double[] dArr3 = new double[eMResult.k];
        for (int i = 0; i < eMResult.k; i++) {
            double d = 0.0d;
            dArr3[i] = 0.0d;
            for (int i2 = 0; i2 < dArr.length; i2++) {
                int i3 = i;
                dArr3[i3] = dArr3[i3] + ((dArr[i2] * eMResult.z[i][i2]) / (dArr2[i2] * dArr2[i2]));
                d += eMResult.z[i][i2] / (dArr2[i2] * dArr2[i2]);
            }
            int i4 = i;
            dArr3[i4] = dArr3[i4] / d;
        }
        return dArr3;
    }

    public static double[] updatePiVector(EMResult eMResult) {
        double[] dArr = new double[eMResult.k];
        for (int i = 0; i < eMResult.k; i++) {
            dArr[i] = 0.0d;
            for (int i2 = 0; i2 < eMResult.n; i2++) {
                int i3 = i;
                dArr[i3] = dArr[i3] + eMResult.z[i][i2];
            }
            int i4 = i;
            dArr[i4] = dArr[i4] / eMResult.n;
        }
        return dArr;
    }

    public static void updateZMatrix(double[] dArr, double[] dArr2, EMResult eMResult) {
        double[] dArr3 = new double[eMResult.n];
        for (int i = 0; i < eMResult.n; i++) {
            dArr3[i] = 0.0d;
        }
        for (int i2 = 0; i2 < eMResult.k; i2++) {
            for (int i3 = 0; i3 < eMResult.n; i3++) {
                eMResult.z[i2][i3] = (eMResult.pi[i2] / dArr2[i3]) * NumericalUtilities.gauss((dArr[i3] - eMResult.mu[i2]) / dArr2[i3]);
                int i4 = i3;
                dArr3[i4] = dArr3[i4] + eMResult.z[i2][i3];
            }
        }
        for (int i5 = 0; i5 < eMResult.k; i5++) {
            for (int i6 = 0; i6 < eMResult.n; i6++) {
                if (dArr3[i6] > Double.MIN_VALUE) {
                    double[] dArr4 = eMResult.z[i5];
                    int i7 = i6;
                    dArr4[i7] = dArr4[i7] / dArr3[i6];
                    eMResult.z[i5][i6] = eMResult.z[i5][i6] < Double.MIN_VALUE ? 0.0d : eMResult.z[i5][i6];
                } else {
                    eMResult.z[i5][i6] = 0.0d;
                }
            }
        }
    }

    public static void updateLoglikelihood(double[] dArr, double[] dArr2, EMResult eMResult) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < eMResult.n; i++) {
            double d3 = 0.0d;
            for (int i2 = 0; i2 < eMResult.k; i2++) {
                double gauss = NumericalUtilities.gauss((dArr[i] - eMResult.mu[i2]) / dArr2[i]);
                if (gauss > Double.MIN_VALUE) {
                    d3 += (eMResult.pi[i2] / dArr2[i]) * gauss;
                    d += eMResult.z[i2][i] * Math.log(gauss / dArr2[i]);
                }
            }
            if (d3 > 0.0d) {
                d2 += Math.log(d3);
            }
        }
        eMResult.olog = d2;
        eMResult.clog = d;
    }

    public static void computeDM(double[] dArr, double[] dArr2, double[] dArr3, EMResult eMResult) {
        for (int i = 0; i < eMResult.k; i++) {
            EMResult eMResult2 = new EMResult(eMResult.n, eMResult.k);
            double d = eMResult.mu[i] - dArr3[i];
            if (Math.abs(d) > Double.MIN_VALUE) {
                for (int i2 = 0; i2 < eMResult.k; i2++) {
                    if (i2 == i) {
                        eMResult2.mu[i] = eMResult.mu[i];
                    } else {
                        eMResult2.mu[i2] = dArr3[i2];
                    }
                    eMResult2.pi[i2] = eMResult.pi[i2];
                }
                eStep(dArr, dArr2, eMResult2);
                mStep(dArr, dArr2, eMResult2, EM_ERR);
                for (int i3 = 0; i3 < eMResult.k; i3++) {
                    double d2 = eMResult2.mu[i3] - eMResult.mu[i3];
                    if (Math.abs(d2) > Double.MIN_VALUE) {
                        eMResult.dm[i][i3] = d2 / d;
                    } else {
                        eMResult.dm[i][i3] = 0.0d;
                    }
                }
            } else {
                for (int i4 = 0; i4 < eMResult.k; i4++) {
                    eMResult.dm[i][i4] = 0.0d;
                }
            }
        }
    }

    public static void computeCOV(double[] dArr, EMResult eMResult) {
        if (eMResult.k < eMResult.n) {
            for (int i = 0; i < eMResult.k; i++) {
                eMResult.ccov[i] = 0.0d;
                for (int i2 = 0; i2 < eMResult.n; i2++) {
                    double[] dArr2 = eMResult.ccov;
                    int i3 = i;
                    dArr2[i3] = dArr2[i3] + (eMResult.z[i][i2] / (dArr[i2] * dArr[i2]));
                }
                if (eMResult.ccov[i] != 0.0d) {
                    eMResult.ccov[i] = 1.0d / eMResult.ccov[i];
                } else {
                    eMResult.ccov[i] = 0.0d;
                }
            }
        } else {
            for (int i4 = 0; i4 < eMResult.k; i4++) {
                eMResult.ccov[i4] = dArr[i4] * dArr[i4];
            }
        }
        for (int i5 = 0; i5 < eMResult.k; i5++) {
            eMResult.ocov[i5][i5] = eMResult.ccov[i5];
            for (int i6 = i5 + 1; i6 < eMResult.k; i6++) {
                eMResult.ocov[i5][i6] = 0.0d;
                eMResult.ocov[i6][i5] = 0.0d;
            }
        }
        if (eMResult.k <= 1 || eMResult.k >= eMResult.n || !DO_SEM) {
            return;
        }
        double[][] dArr3 = new double[eMResult.k][eMResult.k];
        double[][] dArr4 = new double[eMResult.k][eMResult.k];
        double[][] dArr5 = new double[eMResult.k][eMResult.k];
        double[] dArr6 = new double[eMResult.k];
        for (int i7 = 0; i7 < eMResult.k; i7++) {
            dArr3[i7][i7] = 1.0d - eMResult.dm[i7][i7];
            for (int i8 = i7 + 1; i8 < eMResult.k; i8++) {
                dArr3[i7][i8] = -eMResult.dm[i7][i8];
                dArr3[i8][i7] = -eMResult.dm[i8][i7];
            }
        }
        SVDAlgorithm.SVDecomposition(dArr3, dArr4, dArr6, eMResult.k, eMResult.k);
        boolean z = false;
        double d = 0.0d;
        for (int i9 = 0; i9 < eMResult.k; i9++) {
            if (dArr6[i9] > d) {
                d = dArr6[i9];
            }
        }
        for (int i10 = 0; i10 < eMResult.k; i10++) {
            if (dArr6[i10] < d * 1.0E-8d) {
                dArr6[i10] = 0.0d;
                z = true;
            }
        }
        if (z) {
            for (int i11 = 0; i11 < eMResult.k; i11++) {
                eMResult.ocov[i11][i11] = eMResult.ccov[i11];
                for (int i12 = i11 + 1; i12 < eMResult.k; i12++) {
                    eMResult.ocov[i11][i12] = 0.0d;
                    eMResult.ocov[i12][i11] = 0.0d;
                }
            }
            return;
        }
        for (int i13 = 0; i13 < eMResult.k; i13++) {
            for (int i14 = 0; i14 < eMResult.k; i14++) {
                dArr5[i13][i14] = 0.0d;
                for (int i15 = 0; i15 < eMResult.k; i15++) {
                    if (dArr6[i15] != 0.0d) {
                        double[] dArr7 = dArr5[i13];
                        int i16 = i14;
                        dArr7[i16] = dArr7[i16] + ((dArr4[i13][i15] * dArr3[i14][i15]) / dArr6[i15]);
                    }
                }
            }
        }
        for (int i17 = 0; i17 < eMResult.k; i17++) {
            for (int i18 = 0; i18 < eMResult.k; i18++) {
                dArr3[i17][i18] = 0.0d;
                for (int i19 = 0; i19 < eMResult.k; i19++) {
                    double[] dArr8 = dArr3[i17];
                    int i20 = i18;
                    dArr8[i20] = dArr8[i20] + (eMResult.dm[i17][i19] * dArr5[i19][i18]);
                }
            }
        }
        for (int i21 = 0; i21 < eMResult.k; i21++) {
            eMResult.ocov[i21][i21] = Math.max(eMResult.ccov[i21], eMResult.ccov[i21] + dArr3[i21][i21]);
            for (int i22 = i21 + 1; i22 < eMResult.k; i22++) {
                double d2 = 0.5d * (dArr3[i21][i22] + dArr3[i22][i21]);
                eMResult.ocov[i21][i22] = d2;
                eMResult.ocov[i22][i21] = d2;
            }
        }
    }
}
