package org.openmarkov.io.database.weka;

import antlr.RecognitionException;
import antlr.TokenStreamException;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.IOUtils;
import org.apache.log4j.spi.LocationInfo;
import org.openmarkov.core.io.database.CaseDatabase;
import org.openmarkov.core.io.database.CaseDatabaseReader;
import org.openmarkov.core.io.database.CaseDatabaseWriter;
import org.openmarkov.core.io.database.plugin.CaseDatabaseFormat;
import org.openmarkov.core.model.network.ProbNet;
import org.openmarkov.core.model.network.State;
import org.openmarkov.core.model.network.Variable;

@CaseDatabaseFormat(extension = "arff", name = "WekaDB")
/* loaded from: input_file:org/openmarkov/io/database/weka/ArffDataBaseIO.class */
public class ArffDataBaseIO implements CaseDatabaseReader, CaseDatabaseWriter {
    private HashMap<String, Object> ioNet;

    @Override // org.openmarkov.core.io.database.CaseDatabaseReader
    public CaseDatabase load(String str) throws IOException {
        FileInputStream fileInputStream = null;
        HashMap<String, String> hashMap = new HashMap<>();
        try {
            try {
                try {
                    fileInputStream = new FileInputStream(str);
                    ArffParser arffParser = new ArffParser(new ArffLexer(fileInputStream));
                    this.ioNet = arffParser.relation();
                    ProbNet probNet = (ProbNet) this.ioNet.get("ProbNet");
                    for (Map.Entry<String, Object> entry : this.ioNet.entrySet()) {
                        hashMap.put(entry.getKey(), entry.getValue().toString());
                    }
                    probNet.additionalProperties = hashMap;
                    CaseDatabase caseDatabase = new CaseDatabase(probNet.getVariables(), arffParser.getCases());
                    if (fileInputStream != null) {
                        fileInputStream.close();
                    }
                    return caseDatabase;
                } catch (RecognitionException e) {
                    e.printStackTrace();
                    throw new IOException("RecognitionException in " + str + ".");
                }
            } catch (TokenStreamException e2) {
                e2.printStackTrace();
                throw new IOException("TokenStreamException in " + str + ".");
            } catch (FileNotFoundException e3) {
                e3.printStackTrace();
                throw new IOException("File " + str + " not found.");
            }
        } catch (Throwable th) {
            if (fileInputStream != null) {
                fileInputStream.close();
            }
            throw th;
        }
    }

    @Override // org.openmarkov.core.io.database.CaseDatabaseWriter
    public void save(String str, CaseDatabase caseDatabase) throws IOException {
        OutputStreamWriter outputStreamWriter = new OutputStreamWriter(new FileOutputStream(str));
        outputStreamWriter.write("\n@RELATION \"" + str + "\"\n");
        for (Variable variable : caseDatabase.getVariables()) {
            String name = variable.getName();
            if (name.contains(" ")) {
                outputStreamWriter.write("\n@ATTRIBUTE \"" + name + "\" ");
            } else {
                outputStreamWriter.write("\n@ATTRIBUTE " + name + " ");
            }
            State[] states = variable.getStates();
            boolean z = true;
            for (int i = 0; i < states.length; i++) {
                try {
                    if (!states[i].getName().equals(LocationInfo.NA)) {
                        Integer.parseInt(states[i].getName());
                    }
                } catch (NumberFormatException e) {
                    z = false;
                }
            }
            if (z) {
                outputStreamWriter.write("numeric {");
            } else {
                outputStreamWriter.write("{");
            }
            for (int i2 = 0; i2 < states.length; i2++) {
                if (!states[i2].getName().equals(LocationInfo.NA)) {
                    if (states[i2].getName().contains(" ")) {
                        outputStreamWriter.write("\"" + states[i2].getName() + "\"");
                    } else {
                        outputStreamWriter.write(states[i2].getName());
                    }
                    if (i2 != states.length - 1 && !states[i2 + 1].getName().equals(LocationInfo.NA)) {
                        outputStreamWriter.write(",");
                    }
                } else if (i2 != 0 && i2 != states.length - 1) {
                    outputStreamWriter.write(",");
                }
            }
            outputStreamWriter.write("}\n");
        }
        outputStreamWriter.write("\n@DATA\n");
        List<Variable> variables = caseDatabase.getVariables();
        int[][] cases = caseDatabase.getCases();
        for (int i3 = 0; i3 < cases.length; i3++) {
            for (int i4 = 0; i4 < cases[i3].length; i4++) {
                State[] states2 = variables.get(i4).getStates();
                if (states2[cases[i3][i4]].getName().contains(" ")) {
                    outputStreamWriter.write("\"" + states2[cases[i3][i4]].getName() + "\"");
                } else {
                    outputStreamWriter.write(states2[cases[i3][i4]].getName());
                }
                if (i4 != cases[i3].length - 1) {
                    outputStreamWriter.write(",");
                }
            }
            outputStreamWriter.write(IOUtils.LINE_SEPARATOR_UNIX);
        }
        outputStreamWriter.close();
    }
}
