package jml.clustering;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import java.util.TreeMap;
import jml.data.Data;
import jml.matlab.Matlab;
import jml.options.KMeansOptions;
import org.apache.commons.math.linear.BlockRealMatrix;
import org.apache.commons.math.linear.OpenMapRealMatrix;
import org.apache.commons.math.linear.RealMatrix;

/* loaded from: input_file:jml/clustering/KMeans.class */
public class KMeans extends Clustering {
    KMeansOptions options;

    public KMeans(int i) {
        super(i);
        this.options.maxIter = 100;
        this.options.verbose = false;
    }

    public KMeans(int i, int i2) {
        super(i);
        this.options.maxIter = i2;
        this.options.verbose = false;
    }

    public KMeans(int i, int i2, boolean z) {
        super(i);
        this.options.maxIter = i2;
        this.options.verbose = z;
    }

    public KMeans(KMeansOptions kMeansOptions) {
        super(kMeansOptions.nClus);
        this.options = kMeansOptions;
    }

    @Override // jml.clustering.Clustering
    public void initialize(RealMatrix realMatrix) {
        if (realMatrix != null) {
            this.indicatorMatrix = realMatrix;
            return;
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.nSample; i++) {
            arrayList.add(Integer.valueOf(i));
        }
        Collections.shuffle(arrayList, new Random(System.currentTimeMillis()));
        this.indicatorMatrix = new OpenMapRealMatrix(this.nSample, this.nClus);
        for (int i2 = 0; i2 < this.nClus; i2++) {
            this.indicatorMatrix.setEntry(((Integer) arrayList.get(i2)).intValue(), i2, 1.0d);
        }
    }

    @Override // jml.clustering.Clustering
    public void clustering() {
        int i = 0;
        while (i < this.options.maxIter) {
            RealMatrix realMatrix = this.indicatorMatrix;
            long currentTimeMillis = System.currentTimeMillis();
            this.centers = this.dataMatrix.multiply(this.indicatorMatrix).multiply(Matlab.diag(Matlab.ones(this.options.nClus, 1).getColumnVector(0).ebeDivide(Matlab.diag(this.indicatorMatrix.transpose().multiply(this.indicatorMatrix)).getColumnVector(0))));
            TreeMap<String, RealMatrix> min = Matlab.min(Matlab.l2Distance(this.dataMatrix, this.centers), 2);
            RealMatrix realMatrix2 = min.get("val");
            RealMatrix realMatrix3 = min.get("idx");
            this.indicatorMatrix = new OpenMapRealMatrix(this.nSample, this.nClus);
            for (int i2 = 0; i2 < this.nSample; i2++) {
                this.indicatorMatrix.setEntry(i2, (int) realMatrix3.getEntry(i2, 0), 1.0d);
            }
            double trace = Matlab.sum(realMatrix2, 1).getTrace();
            if (realMatrix.subtract(this.indicatorMatrix).getFrobeniusNorm() == 0.0d) {
                System.out.println("KMeans complete.");
                return;
            }
            double currentTimeMillis2 = (System.currentTimeMillis() - currentTimeMillis) / 1000.0d;
            i++;
            if (this.options.verbose) {
                System.out.format("Iter %d: sse = %.3f (%.3f secs)\n", Integer.valueOf(i), Double.valueOf(trace), Double.valueOf(currentTimeMillis2));
            }
        }
    }

    /* JADX WARN: Type inference failed for: r0v6, types: [double[], double[][]] */
    public static void main(String[] strArr) {
        KMeans kMeans = new KMeans(new KMeansOptions(3, 100, true));
        Matlab.printMatrix(new BlockRealMatrix(new double[]{new double[]{1.0d, 0.0d, 3.0d, 2.0d, 0.0d}, new double[]{2.0d, 5.0d, 3.0d, 1.0d, 0.0d}, new double[]{4.0d, 1.0d, 0.0d, 0.0d, 1.0d}, new double[]{3.0d, 0.0d, 1.0d, 0.0d, 2.0d}, new double[]{2.0d, 5.0d, 3.0d, 1.0d, 6.0d}}));
        kMeans.feedData(Matlab.normalizeByColumns(Matlab.getTFIDF(Data.loadMatrix("CNNTest-TrainingData.txt"))));
        kMeans.initialize(null);
        kMeans.clustering();
        System.out.println("Indicator Matrix:");
        Matlab.printMatrix(Matlab.full(kMeans.getIndicatorMatrix()));
    }
}
