package jml.manifold;

import jml.data.Data;
import jml.kernel.Kernel;
import jml.matlab.Matlab;
import jml.matlab.utils.FindResult;
import jml.matlab.utils.SortResult;
import jml.options.GraphOptions;
import org.apache.commons.math.linear.OpenMapRealMatrix;
import org.apache.commons.math.linear.RealMatrix;

/* loaded from: input_file:jml/manifold/Manifold.class */
public class Manifold {
    public static void main(String[] strArr) {
        RealMatrix loadMatrixFromDocTermCountFile = Data.loadMatrixFromDocTermCountFile("CNN - DocTermCount.txt");
        System.out.println(String.format("%d samples loaded", Integer.valueOf(loadMatrixFromDocTermCountFile.getColumnDimension())));
        GraphOptions graphOptions = new GraphOptions();
        graphOptions.graphType = "nn";
        String str = graphOptions.graphType;
        double d = graphOptions.graphParam;
        System.out.println(String.format("Graph type: %s with NN: %d", str, Integer.valueOf((int) d)));
        graphOptions.kernelType = "cosine";
        graphOptions.graphDistanceFunction = "cosine";
        graphOptions.graphNormalize = true;
        graphOptions.graphWeightType = "heat";
        Data.saveMatrix("adjacency.txt", adjacency(loadMatrixFromDocTermCountFile, str, d, graphOptions.graphDistanceFunction));
        RealMatrix laplacian = laplacian(loadMatrixFromDocTermCountFile, str, graphOptions);
        Data.saveMatrix("Laplacian.txt", laplacian);
        Matlab.display(laplacian.getSubMatrix(0, 9, 0, 9));
        RealMatrix calcLLR = calcLLR(loadMatrixFromDocTermCountFile, graphOptions.graphParam, graphOptions.graphDistanceFunction, graphOptions.kernelType, graphOptions.kernelParam, 0.001d);
        Data.saveMatrix("localLearningRegularization.txt", calcLLR);
        Matlab.display(calcLLR);
    }

    public static RealMatrix laplacian(RealMatrix realMatrix, String str, GraphOptions graphOptions) {
        RealMatrix subtract;
        System.out.println("Computing Graph Laplacian...");
        double d = graphOptions.graphParam;
        String str2 = graphOptions.graphDistanceFunction;
        String str3 = graphOptions.graphWeightType;
        double d2 = graphOptions.graphWeightParam;
        boolean z = graphOptions.graphNormalize;
        if (str3.equals("inner") && !str2.equals("cosine")) {
            System.err.println("WEIGHTTYPE and DISTANCEFUNCTION mismatch.");
        }
        RealMatrix adjacency = adjacency(realMatrix, str, d, str2);
        RealMatrix copy = adjacency.copy();
        FindResult find = Matlab.find(adjacency);
        int[] iArr = find.rows;
        int[] iArr2 = find.cols;
        double[] dArr = find.vals;
        if (str3.equals("distance")) {
            for (int i = 0; i < iArr.length; i++) {
                copy.setEntry(iArr[i], iArr2[i], dArr[i]);
            }
        } else if (str3.equals("inner")) {
            for (int i2 = 0; i2 < iArr.length; i2++) {
                copy.setEntry(iArr[i2], iArr2[i2], 1.0d - (dArr[i2] / 2.0d));
            }
        } else if (str3.equals("binary")) {
            for (int i3 = 0; i3 < iArr.length; i3++) {
                copy.setEntry(iArr[i3], iArr2[i3], 1.0d);
            }
        } else if (str3.equals("heat")) {
            double d3 = (-2.0d) * d2 * d2;
            for (int i4 = 0; i4 < iArr.length; i4++) {
                copy.setEntry(iArr[i4], iArr2[i4], Math.exp((dArr[i4] * dArr[i4]) / d3));
            }
        } else {
            System.err.println("Unknown Weight Type.");
        }
        RealMatrix sum = Matlab.sum(copy, 2);
        if (z) {
            RealMatrix diag = Matlab.diag(Matlab.dotDivide(1.0d, Matlab.sqrt(sum)));
            subtract = Matlab.eye(Matlab.size(copy, 1)).subtract(diag.multiply(copy).multiply(diag));
        } else {
            subtract = Matlab.diag(sum).subtract(copy);
        }
        return subtract;
    }

    public static RealMatrix adjacency(RealMatrix realMatrix, String str, double d, String str2) {
        RealMatrix adjacencyDirected = adjacencyDirected(realMatrix, str, d, str2);
        return Matlab.max(adjacencyDirected, adjacencyDirected.transpose());
    }

    public static RealMatrix adjacencyDirected(RealMatrix realMatrix, String str, double d, String str2) {
        System.out.println("Computing directed adjacency graph...");
        int size = Matlab.size(realMatrix, 2);
        if (str.equals("nn")) {
            System.out.println(String.format("Creating the adjacency matrix. Nearest neighbors, N = %d.", Integer.valueOf((int) d)));
        } else if (str.equals("epsballs") || str.equals("eps")) {
            System.out.println(String.format("Creating the adjacency matrix. Epsilon balls, eps = %f.", Double.valueOf(d)));
        } else {
            System.err.println("type should be either \"nn\" or \"epsballs\" (\"eps\")");
            System.exit(1);
        }
        OpenMapRealMatrix openMapRealMatrix = new OpenMapRealMatrix(size, size);
        RealMatrix realMatrix2 = null;
        for (int i = 0; i < size; i++) {
            if (str2.equals("euclidean")) {
                realMatrix2 = euclidean(realMatrix.getColumnMatrix(i), realMatrix);
            } else if (str2.equals("cosine")) {
                realMatrix2 = cosine(realMatrix.getColumnMatrix(i), realMatrix);
            }
            SortResult sort = Matlab.sort(realMatrix2, 2);
            RealMatrix realMatrix3 = sort.B;
            int[][] iArr = sort.IX;
            if (str.equals("nn")) {
                for (int i2 = 2; i2 <= d + 1.0d; i2++) {
                    openMapRealMatrix.setEntry(i, iArr[0][i2 - 1], realMatrix3.getEntry(0, i2 - 1) + Matlab.eps);
                }
            } else if (str.equals("epsballs") || str.equals("eps")) {
                for (int i3 = 2; realMatrix3.getEntry(0, i3 - 1) <= d; i3++) {
                    openMapRealMatrix.setEntry(i, iArr[0][i3 - 1], realMatrix3.getEntry(0, i3 - 1) + Matlab.eps);
                }
            }
        }
        return openMapRealMatrix;
    }

    public static RealMatrix cosine(RealMatrix realMatrix, RealMatrix realMatrix2) {
        RealMatrix sum = Matlab.sum(Matlab.times(realMatrix, realMatrix));
        RealMatrix sum2 = Matlab.sum(Matlab.times(realMatrix2, realMatrix2));
        return Matlab.times(Matlab.scalarDivide(1.0d, Matlab.sqrt(Matlab.kron(sum.transpose(), sum2))), realMatrix.transpose().multiply(realMatrix2)).scalarMultiply(-1.0d).scalarAdd(1.0d);
    }

    public static RealMatrix euclidean(RealMatrix realMatrix, RealMatrix realMatrix2) {
        return Matlab.l2Distance(realMatrix, realMatrix2);
    }

    public static RealMatrix calcLLR(RealMatrix realMatrix, double d, String str, String str2, double d2, double d3) {
        RealMatrix adjacencyDirected = adjacencyDirected(realMatrix, "nn", d, str);
        RealMatrix calcKernel = Kernel.calcKernel(str2, d2, realMatrix);
        int size = Matlab.size(realMatrix, 2);
        int size2 = Matlab.size(realMatrix, 1);
        int i = (int) d;
        RealMatrix eye = Matlab.eye(i);
        RealMatrix eye2 = Matlab.eye(size);
        RealMatrix copy = adjacencyDirected.copy();
        int[] colon = Matlab.colon(0, size2 - 1);
        for (int i2 = 0; i2 < size; i2++) {
            int[] find = Matlab.find(adjacencyDirected.getRowVector(i2));
            int[] iArr = {i2};
            Matlab.setSubMatrix(copy, iArr, find, Matlab.mldivide(eye.scalarMultiply(i * d3).add(calcKernel.getSubMatrix(find, find)), Kernel.calcKernel(str2, d2, realMatrix.getSubMatrix(colon, find), realMatrix.getColumnMatrix(i2))).transpose());
        }
        RealMatrix subtract = copy.subtract(eye2);
        return subtract.transpose().multiply(subtract);
    }
}
