package jml.classification;

import de.bwaldvogel.liblinear.Feature;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.InvalidInputDataException;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.Problem;
import de.bwaldvogel.liblinear.SolverType;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.StringTokenizer;
import jml.data.Data;
import jml.matlab.Matlab;
import org.apache.commons.math.linear.Array2DRowRealMatrix;
import org.apache.commons.math.linear.OpenMapRealMatrix;
import org.apache.commons.math.linear.RealMatrix;

/* loaded from: input_file:jml/classification/MultiClassSVM.class */
public class MultiClassSVM extends Classifier {
    Problem problem;
    Feature[][] features;
    double C;
    double eps;
    Parameter parameter;
    private double bias;
    de.bwaldvogel.liblinear.Model model;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !MultiClassSVM.class.desiredAssertionStatus();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    public static void main(String[] strArr) {
        ?? r0 = {new double[]{3.5d, 4.4d, 1.3d}, new double[]{5.3d, 2.2d, 0.5d}, new double[]{0.2d, 0.3d, 4.1d}, new double[]{-1.2d, 0.4d, 3.2d}};
        MultiClassSVM multiClassSVM = new MultiClassSVM(1.0d, 0.01d);
        multiClassSVM.feedData((double[][]) r0);
        multiClassSVM.feedLabels((double[][]) new double[]{new double[]{1.0d, 0.0d, 0.0d}, new double[]{0.0d, 1.0d, 0.0d}, new double[]{0.0d, 0.0d, 1.0d}});
        multiClassSVM.train();
        Matlab.display(multiClassSVM.predictLabelMatrix((double[][]) r0));
        System.out.println("Projection matrix:");
        Matlab.printMatrix(multiClassSVM.getProjectionMatrix());
        System.out.println("Predicted label score matrix:");
        Matlab.display(multiClassSVM.predictLabelScoreMatrix((double[][]) r0));
        long currentTimeMillis = System.currentTimeMillis();
        MultiClassSVM multiClassSVM2 = new MultiClassSVM(1.0d, 0.01d);
        multiClassSVM2.feedProblem("heart_scale");
        multiClassSVM2.train();
        System.out.println("Using Linear.predict method:");
        int[] predict2 = multiClassSVM2.predict2(multiClassSVM2.features, multiClassSVM2.labels);
        System.out.println("Using our predict method:");
        RealMatrix features2Matrix = features2Matrix(multiClassSVM2.features);
        getAccuracy(multiClassSVM2.predict(features2Matrix), multiClassSVM2.labels);
        System.out.println("There are " + predict2.length + " samples predicted.\n");
        System.out.println("Using predict(Feature[][]) method:");
        getAccuracy(multiClassSVM2.predict(matrix2Features(features2Matrix)), multiClassSVM2.labels);
        System.out.format("Elapsed time: %.1f seconds.\n", Float.valueOf(((float) (System.currentTimeMillis() - currentTimeMillis)) / 1000.0f));
        MultiClassSVM multiClassSVM3 = new MultiClassSVM(1.0d, 0.01d);
        RealMatrix loadMatrixFromDocTermCountFile = Data.loadMatrixFromDocTermCountFile("CNN - DocTermCount.txt");
        multiClassSVM3.feedData(loadMatrixFromDocTermCountFile);
        Data.saveSparseMatrix("GroundTruthSparse.txt", Data.loadMatrix("GroundTruth.txt"));
        RealMatrix loadMatrix = Data.loadMatrix("GroundTruthSparse.txt");
        multiClassSVM3.feedLabels(loadMatrix);
        multiClassSVM3.train();
        getAccuracy(multiClassSVM3.predict(loadMatrixFromDocTermCountFile), multiClassSVM3.labels);
        multiClassSVM3.saveModel("MCSVMModel");
        MultiClassSVM multiClassSVM4 = new MultiClassSVM();
        multiClassSVM4.loadModel("MCSVMModel");
        int[] predict = multiClassSVM4.predict(loadMatrixFromDocTermCountFile);
        int[] labelScoreMatrix2LabelIndexArray = labelScoreMatrix2LabelIndexArray(loadMatrix);
        getAccuracy(predict, labelScoreMatrix2LabelIndexArray);
        System.out.println("There are " + predict.length + " samples predicted.\n");
        System.out.println("Using predict(Feature[][]) method:");
        getAccuracy(multiClassSVM4.predict(matrix2Features(loadMatrixFromDocTermCountFile)), labelScoreMatrix2LabelIndexArray);
        System.out.println("Mission complete!");
    }

    public static void run(String[] strArr) {
        double d = 1.0d;
        double d2 = 0.001d;
        String str = "";
        String str2 = "";
        String str3 = "";
        boolean z = false;
        int i = 0;
        while (i < strArr.length) {
            if (!strArr[i].startsWith("--")) {
                if (strArr[i].charAt(0) != '-') {
                    break;
                }
                i++;
                if (i >= strArr.length) {
                    showUsage();
                    System.exit(1);
                }
                String str4 = strArr[i - 1];
                String str5 = strArr[i];
                if (str4.equals("-C")) {
                    d = Double.parseDouble(str5);
                } else if (str4.equals("-eps")) {
                    d2 = Double.parseDouble(str5);
                } else if (str4.equals("-model")) {
                    str = str5;
                } else if (str4.equals("-trainingData")) {
                    str2 = str5;
                } else if (str4.equals("-testData")) {
                    str3 = str5;
                } else if (str4.equals("-train")) {
                    z = Boolean.parseBoolean(str5);
                }
            }
            i++;
        }
        if (str2.isEmpty() && str3.isEmpty()) {
            showUsage();
            System.exit(1);
        }
        if (!z) {
            RealMatrix features2Matrix = features2Matrix(readProblem(str2).x);
            MultiClassSVM multiClassSVM = new MultiClassSVM();
            multiClassSVM.loadModel(str);
            int[] predict = multiClassSVM.predict(features2Matrix);
            int[] labels = multiClassSVM.model.getLabels();
            for (int i2 = 0; i2 < predict.length; i2++) {
                System.out.format("Doc %d: y_pred: %s\n", Integer.valueOf(i2), Integer.valueOf(labels[predict[i2]]));
            }
            System.out.println();
            return;
        }
        if (str.isEmpty()) {
            System.err.println("Model file path is empty.");
            showUsage();
            System.exit(1);
        }
        Problem readProblem = readProblem(str2);
        RealMatrix features2Matrix2 = features2Matrix(readProblem.x);
        MultiClassSVM multiClassSVM2 = new MultiClassSVM(d, d2);
        multiClassSVM2.feedData(features2Matrix2);
        multiClassSVM2.feedLabels(readProblem.y);
        multiClassSVM2.train();
        multiClassSVM2.saveModel(str);
        if (str3.isEmpty()) {
            return;
        }
        int[] predict2 = multiClassSVM2.predict(features2Matrix(readProblem(str2).x));
        for (int i3 = 0; i3 < predict2.length; i3++) {
            System.out.format("Doc %d: y_pred: %s\n", Integer.valueOf(i3), Integer.valueOf(predict2[i3]));
        }
        System.out.println();
    }

    private static void showUsage() {
    }

    public MultiClassSVM(double d, double d2) {
        this.bias = 1.0d;
        this.C = d;
        this.eps = d2;
        this.features = null;
        this.parameter = new Parameter(SolverType.MCSVM_CS, d, d2);
        this.model = null;
    }

    public MultiClassSVM() {
    }

    public void feedProblem(ArrayList<String> arrayList) {
        feedProblem(readProblemFromStringArray(arrayList));
    }

    public void feedProblem(String str) {
        feedProblem(readProblemFromFile(new File(str), this.bias));
        System.out.println("Problem loaded from file " + str);
    }

    public void feedProblem(Problem problem) {
        this.problem = problem;
        this.features = problem.x;
        super.feedData(features2Matrix(this.features));
        this.labels = problem.y;
        feedLabels(this.labels);
    }

    @Override // jml.classification.Classifier
    public void feedData(RealMatrix realMatrix) {
        this.features = matrix2Features(realMatrix, this.bias);
        super.feedData(realMatrix);
    }

    public static int getMaxRawFeatureIndex(Feature[][] featureArr, double d) {
        int i = 0;
        for (int i2 = 0; i2 < featureArr.length; i2++) {
            i = Math.max(i, featureArr[i2][(d >= 0.0d ? featureArr[i2].length - 1 : featureArr[i2].length) - 1].getIndex());
        }
        return i;
    }

    @Override // jml.classification.Classifier
    public void train() {
        if (this.problem == null) {
            this.problem = new Problem();
            this.problem.bias = this.bias;
            this.problem.l = this.features.length;
            this.problem.x = this.features;
            this.problem.y = this.labels;
            this.problem.n = getMaxRawFeatureIndex(this.features, this.bias);
            this.nFeature = this.problem.n;
            if (this.bias >= 0.0d) {
                this.problem.n++;
            }
        }
        this.model = Linear.train(this.problem, this.parameter);
        reconstructProjectionMatrix();
    }

    @Deprecated
    public static int[] predict2(de.bwaldvogel.liblinear.Model model, Feature[][] featureArr, int[] iArr) {
        int length = featureArr.length;
        int[] iArr2 = new int[length];
        for (int i = 0; i < length; i++) {
            iArr2[i] = Linear.predict(model, featureArr[i]);
        }
        if (iArr != null) {
            int i2 = 0;
            for (int i3 = 0; i3 < length; i3++) {
                if (iArr2[i3] == iArr[i3]) {
                    i2++;
                }
            }
            System.out.println(String.format("Accuracy: %.2f%%\n", Double.valueOf((i2 / length) * 100.0d)));
        }
        return iArr2;
    }

    @Deprecated
    public int[] predict2(Feature[][] featureArr, int[] iArr) {
        int length = featureArr.length;
        int[] iArr2 = new int[length];
        double[] dArr = new double[this.nClass];
        for (int i = 0; i < length; i++) {
            iArr2[i] = Linear.predictValues(this.model, featureArr[i], dArr);
        }
        if (iArr != null) {
            int i2 = 0;
            for (int i3 = 0; i3 < length; i3++) {
                if (iArr2[i3] == iArr[i3]) {
                    i2++;
                }
            }
            System.out.println(String.format("Accuracy: %.2f%%\n", Double.valueOf((i2 / length) * 100.0d)));
        }
        return iArr2;
    }

    @Deprecated
    public int[] predict2(RealMatrix realMatrix) {
        return predict2(realMatrix, (int[]) null);
    }

    @Deprecated
    public int[] predict2(RealMatrix realMatrix, int[] iArr) {
        return predict2(matrix2Features(realMatrix, this.bias), iArr);
    }

    public static Problem readProblemFromFile(File file, double d) {
        try {
            return Problem.readFromFile(file, d);
        } catch (InvalidInputDataException e) {
            e.printStackTrace();
            return null;
        } catch (IOException e2) {
            e2.printStackTrace();
            return null;
        }
    }

    public static Problem readProblem(String str) {
        return readProblemFromFile(new File(str), 1.0d);
    }

    public static Feature[][] matrix2Features(RealMatrix realMatrix) {
        return matrix2Features(realMatrix, 1.0d);
    }

    /* JADX WARN: Type inference failed for: r1v10, types: [de.bwaldvogel.liblinear.Feature[], de.bwaldvogel.liblinear.Feature[][]] */
    public static Feature[][] matrix2Features(RealMatrix realMatrix, double d) {
        int rowDimension = realMatrix.getRowDimension();
        int columnDimension = realMatrix.getColumnDimension();
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (int i2 = 0; i2 < columnDimension; i2++) {
            ArrayList arrayList2 = new ArrayList();
            for (int i3 = 0; i3 < rowDimension; i3++) {
                int i4 = i3 + 1;
                double entry = realMatrix.getEntry(i3, i2);
                if (entry != 0.0d) {
                    arrayList2.add(new FeatureNode(i4, entry));
                }
            }
            int size = arrayList2.size();
            Feature[] featureArr = d >= 0.0d ? new Feature[size + 1] : new Feature[size];
            Iterator it = arrayList2.iterator();
            int i5 = 0;
            while (it.hasNext()) {
                featureArr[i5] = (Feature) it.next();
                i5++;
            }
            arrayList.add(featureArr);
            if (size > 0) {
                i = Math.max(i, featureArr[size - 1].getIndex());
            }
        }
        Problem problem = new Problem();
        problem.bias = d;
        problem.l = arrayList.size();
        problem.n = i;
        if (d >= 0.0d) {
            problem.n++;
        }
        problem.x = new Feature[problem.l];
        for (int i6 = 0; i6 < problem.l; i6++) {
            problem.x[i6] = (Feature[]) arrayList.get(i6);
            if (d >= 0.0d) {
                if (!$assertionsDisabled && problem.x[i6][problem.x[i6].length - 1] != null) {
                    throw new AssertionError();
                }
                problem.x[i6][problem.x[i6].length - 1] = new FeatureNode(i + 1, d);
            }
        }
        return problem.x;
    }

    public static Problem readProblemFromStringArray(ArrayList<String> arrayList) {
        return readProblemFromStringArray(arrayList, 1.0d);
    }

    public static Problem readProblemFromStringArray(ArrayList<String> arrayList, double d) {
        String next;
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        int i = 0;
        int i2 = 0;
        Iterator<String> it = arrayList.iterator();
        while (it.hasNext() && (next = it.next()) != null) {
            i2++;
            StringTokenizer stringTokenizer = new StringTokenizer(next, " \t\n\r\f:");
            String nextToken = stringTokenizer.nextToken();
            try {
                arrayList2.add(Integer.valueOf(atoi(nextToken)));
            } catch (NumberFormatException e) {
                System.err.println(String.format("invalid label: %s at line %d", nextToken, Integer.valueOf(i2)));
            }
            int countTokens = stringTokenizer.countTokens() / 2;
            Feature[] featureArr = d >= 0.0d ? new Feature[countTokens + 1] : new Feature[countTokens];
            int i3 = 0;
            for (int i4 = 0; i4 < countTokens; i4++) {
                int i5 = 0;
                try {
                    i5 = atoi(stringTokenizer.nextToken());
                } catch (NumberFormatException e2) {
                    System.err.println(String.format("invalid index: %d at line %d", Integer.valueOf(i5), Integer.valueOf(i2)));
                    System.exit(1);
                }
                if (i5 < 0) {
                    System.err.println(String.format("invalid index: %d at line %d", Integer.valueOf(i5), Integer.valueOf(i2)));
                    System.exit(1);
                }
                if (i5 <= i3) {
                    System.err.println(String.format("indices must be sorted in ascending order at line %d", Integer.valueOf(i2)));
                    System.exit(1);
                }
                i3 = i5;
                String nextToken2 = stringTokenizer.nextToken();
                try {
                    featureArr[i4] = new FeatureNode(i5, atof(nextToken2));
                } catch (NumberFormatException e3) {
                    System.err.println(String.format("invalid value: %f at line ", nextToken2, Integer.valueOf(i2)));
                    System.exit(1);
                }
            }
            if (countTokens > 0) {
                i = Math.max(i, featureArr[countTokens - 1].getIndex());
            }
            arrayList3.add(featureArr);
        }
        return constructProblem(arrayList2, arrayList3, i, d);
    }

    /* JADX WARN: Type inference failed for: r1v8, types: [de.bwaldvogel.liblinear.Feature[], de.bwaldvogel.liblinear.Feature[][]] */
    private static Problem constructProblem(List<Integer> list, List<Feature[]> list2, int i, double d) {
        Problem problem = new Problem();
        problem.bias = d;
        problem.l = list.size();
        problem.n = i;
        if (d >= 0.0d) {
            problem.n++;
        }
        problem.x = new Feature[problem.l];
        for (int i2 = 0; i2 < problem.l; i2++) {
            problem.x[i2] = list2.get(i2);
            if (d >= 0.0d) {
                if (!$assertionsDisabled && problem.x[i2][problem.x[i2].length - 1] != null) {
                    throw new AssertionError();
                }
                problem.x[i2][problem.x[i2].length - 1] = new FeatureNode(i + 1, d);
            }
        }
        problem.y = new int[problem.l];
        for (int i3 = 0; i3 < problem.l; i3++) {
            problem.y[i3] = list.get(i3).intValue();
        }
        return problem;
    }

    static double atof(String str) {
        if (str == null || str.length() < 1) {
            throw new IllegalArgumentException("Can't convert empty string to integer");
        }
        double parseDouble = Double.parseDouble(str);
        if (Double.isNaN(parseDouble) || Double.isInfinite(parseDouble)) {
            throw new IllegalArgumentException("NaN or Infinity in input: " + str);
        }
        return parseDouble;
    }

    static int atoi(String str) throws NumberFormatException {
        if (str == null || str.length() < 1) {
            throw new IllegalArgumentException("Can't convert empty string to integer");
        }
        if (str.charAt(0) == '+') {
            str = str.substring(1);
        }
        return Integer.parseInt(str);
    }

    public static RealMatrix features2Matrix(Feature[][] featureArr) {
        return features2MatrixWithoutBias(featureArr);
    }

    public static RealMatrix features2Matrix(Feature[][] featureArr, double d) {
        int i = 0;
        for (int i2 = 0; i2 < featureArr.length; i2++) {
            i = Math.max(i, featureArr[i2][(d >= 0.0d ? featureArr[i2].length - 1 : featureArr[i2].length) - 1].getIndex());
        }
        if (d >= 0.0d) {
            i++;
        }
        OpenMapRealMatrix openMapRealMatrix = new OpenMapRealMatrix(i, featureArr.length);
        for (int i3 = 0; i3 < featureArr.length; i3++) {
            for (int i4 = 0; i4 < featureArr[i3].length; i4++) {
                openMapRealMatrix.setEntry(featureArr[i3][i4].getIndex() - 1, i3, featureArr[i3][i4].getValue());
            }
        }
        return openMapRealMatrix;
    }

    public static RealMatrix features2MatrixWithoutBias(Feature[][] featureArr) {
        return features2MatrixWithoutBias(featureArr, 1.0d);
    }

    public static RealMatrix features2MatrixWithoutBias(Feature[][] featureArr, double d) {
        RealMatrix features2Matrix = features2Matrix(featureArr, d);
        if (d >= 0.0d) {
            features2Matrix = features2Matrix.getSubMatrix(0, features2Matrix.getRowDimension() - 2, 0, features2Matrix.getColumnDimension() - 1);
        }
        return features2Matrix;
    }

    @Override // jml.classification.Classifier
    public RealMatrix predictLabelScoreMatrix(RealMatrix realMatrix) {
        return this.bias >= 0.0d ? realMatrix.transpose().multiply(this.W.getSubMatrix(0, this.W.getRowDimension() - 2, 0, this.W.getColumnDimension() - 1)).add(Matlab.repmat(this.W.getRowMatrix(this.W.getRowDimension() - 1).scalarMultiply(this.bias), new int[]{realMatrix.getColumnDimension(), 1})) : realMatrix.transpose().multiply(this.W);
    }

    public int[] predict(Feature[][] featureArr) {
        return predict(features2MatrixWithoutBias(featureArr, this.bias));
    }

    public void loadSVMModel(String str) {
        System.out.println("Loading SVM model...");
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(str));
            this.model = (de.bwaldvogel.liblinear.Model) objectInputStream.readObject();
            this.IDLabelMap = this.model.getLabels();
            objectInputStream.close();
            System.out.println("SVM model loaded.");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            System.exit(1);
        } catch (IOException e2) {
            e2.printStackTrace();
        } catch (ClassNotFoundException e3) {
            e3.printStackTrace();
        }
        reconstructProjectionMatrix();
    }

    private void reconstructProjectionMatrix() {
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(this.model.getFeatureWeights());
        this.nClass = this.model.getNrClass();
        this.W = Matlab.reshape(array2DRowRealMatrix.transpose(), new int[]{this.nClass, this.model.getNrFeature() + 1}).transpose();
    }

    public void saveSVMModel(String str) {
        File parentFile = new File(str).getParentFile();
        if (parentFile != null && !parentFile.exists()) {
            parentFile.mkdirs();
        }
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
            objectOutputStream.writeObject(this.model);
            objectOutputStream.close();
            System.out.println("SVM model saved.");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            System.exit(1);
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    @Override // jml.classification.Classifier
    public void loadModel(String str) {
        loadSVMModel(str);
    }

    @Override // jml.classification.Classifier
    public void saveModel(String str) {
        saveSVMModel(str);
    }
}
