package jml.classification;

import java.util.Iterator;
import java.util.TreeMap;
import jml.matlab.Matlab;
import jml.options.Options;
import org.apache.commons.math.linear.BlockRealMatrix;
import org.apache.commons.math.linear.OpenMapRealMatrix;
import org.apache.commons.math.linear.RealMatrix;

/* loaded from: input_file:jml/classification/Classifier.class */
public abstract class Classifier {
    public int nClass;
    public int nFeature;
    public int nSample;
    public RealMatrix X;
    public RealMatrix Y;
    int[] labelIDs;
    int[] labels;
    public RealMatrix W;
    public double epsilon;
    int[] IDLabelMap;

    public Classifier() {
        this.nClass = 0;
        this.nFeature = 0;
        this.nSample = 0;
        this.X = null;
        this.Y = null;
        this.W = null;
        this.epsilon = 1.0E-4d;
    }

    public Classifier(Options options) {
        this.nClass = 0;
        this.nFeature = 0;
        this.nSample = 0;
        this.X = null;
        this.Y = null;
        this.W = null;
        this.epsilon = options.epsilon;
    }

    public abstract void loadModel(String str);

    public abstract void saveModel(String str);

    public void feedData(RealMatrix realMatrix) {
        this.X = realMatrix;
        this.nFeature = realMatrix.getRowDimension();
        this.nSample = realMatrix.getColumnDimension();
    }

    public void feedData(double[][] dArr) {
        feedData(new BlockRealMatrix(dArr));
    }

    public int calcNumClass(int[] iArr) {
        TreeMap treeMap = new TreeMap();
        int i = 0;
        for (int i2 : iArr) {
            if (!treeMap.containsValue(Integer.valueOf(i2))) {
                int i3 = i;
                i++;
                treeMap.put(Integer.valueOf(i3), Integer.valueOf(i2));
            }
        }
        return treeMap.size();
    }

    public int[] getIDLabelMap(int[] iArr) {
        TreeMap treeMap = new TreeMap();
        int i = 0;
        for (int i2 : iArr) {
            if (!treeMap.containsValue(Integer.valueOf(i2))) {
                int i3 = i;
                i++;
                treeMap.put(Integer.valueOf(i3), Integer.valueOf(i2));
            }
        }
        int[] iArr2 = new int[treeMap.size()];
        Iterator it = treeMap.keySet().iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            iArr2[intValue] = ((Integer) treeMap.get(Integer.valueOf(intValue))).intValue();
        }
        return iArr2;
    }

    public TreeMap<Integer, Integer> getLabelIDMap(int[] iArr) {
        TreeMap<Integer, Integer> treeMap = new TreeMap<>();
        int i = 0;
        for (int i2 : iArr) {
            if (!treeMap.containsKey(Integer.valueOf(i2))) {
                int i3 = i;
                i++;
                treeMap.put(Integer.valueOf(i2), Integer.valueOf(i3));
            }
        }
        return treeMap;
    }

    public void feedLabels(int[] iArr) {
        this.nClass = calcNumClass(iArr);
        this.IDLabelMap = getIDLabelMap(iArr);
        TreeMap<Integer, Integer> labelIDMap = getLabelIDMap(iArr);
        int[] iArr2 = new int[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr2[i] = labelIDMap.get(Integer.valueOf(iArr[i])).intValue();
        }
        this.Y = labelIndexArray2LabelMatrix(iArr2, this.nClass);
        this.labels = iArr;
        this.labelIDs = iArr2;
    }

    public void feedLabels(RealMatrix realMatrix) {
        this.Y = realMatrix;
        this.nClass = realMatrix.getColumnDimension();
        if (this.nSample != realMatrix.getRowDimension()) {
            System.err.println("Number of labels error!");
            System.exit(1);
        }
        int[] labelScoreMatrix2LabelIndexArray = labelScoreMatrix2LabelIndexArray(realMatrix);
        this.labels = labelScoreMatrix2LabelIndexArray;
        this.IDLabelMap = getIDLabelMap(this.labels);
        this.labelIDs = labelScoreMatrix2LabelIndexArray;
    }

    public void feedLabels(double[][] dArr) {
        feedLabels(new BlockRealMatrix(dArr));
    }

    public abstract void train();

    public int[] predict(RealMatrix realMatrix) {
        int[] labelScoreMatrix2LabelIndexArray = labelScoreMatrix2LabelIndexArray(predictLabelScoreMatrix(realMatrix));
        int[] iArr = new int[labelScoreMatrix2LabelIndexArray.length];
        for (int i = 0; i < labelScoreMatrix2LabelIndexArray.length; i++) {
            iArr[i] = this.IDLabelMap[labelScoreMatrix2LabelIndexArray[i]];
        }
        return iArr;
    }

    public int[] predict(double[][] dArr) {
        return predict(new BlockRealMatrix(dArr));
    }

    public RealMatrix predictLabelMatrix(RealMatrix realMatrix) {
        return labelIndexArray2LabelMatrix(labelScoreMatrix2LabelIndexArray(predictLabelScoreMatrix(realMatrix)), this.nClass);
    }

    public RealMatrix predictLabelMatrix(double[][] dArr) {
        return predictLabelMatrix(new BlockRealMatrix(dArr));
    }

    public abstract RealMatrix predictLabelScoreMatrix(RealMatrix realMatrix);

    public RealMatrix predictLabelScoreMatrix(double[][] dArr) {
        return predictLabelScoreMatrix(new BlockRealMatrix(dArr));
    }

    public static double getAccuracy(int[] iArr, int[] iArr2) {
        if (iArr.length != iArr2.length) {
            System.err.println("Number of predicted labels and number of true labels mismatch.");
            System.exit(1);
        }
        int length = iArr2.length;
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            if (iArr[i2] == iArr2[i2]) {
                i++;
            }
        }
        double d = i / length;
        System.out.println(String.format("Accuracy: %.2f%%\n", Double.valueOf(d * 100.0d)));
        return d;
    }

    public RealMatrix getProjectionMatrix() {
        return this.W;
    }

    public RealMatrix getTrainingLabelMatrix() {
        return this.Y;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static int[] labelScoreMatrix2LabelIndexArray(RealMatrix realMatrix) {
        double[] column = Matlab.max(realMatrix, 2).get("idx").getColumn(0);
        int[] iArr = new int[column.length];
        for (int i = 0; i < column.length; i++) {
            iArr[i] = (int) column[i];
        }
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static RealMatrix labelIndexArray2LabelMatrix(int[] iArr, int i) {
        OpenMapRealMatrix openMapRealMatrix = new OpenMapRealMatrix(iArr.length, i);
        for (int i2 = 0; i2 < iArr.length; i2++) {
            openMapRealMatrix.setEntry(i2, iArr[i2], 1.0d);
        }
        return openMapRealMatrix;
    }
}
