package edu.gtts.sautrela.wfsa.models;

import edu.gtts.sautrela.ManPage;
import edu.gtts.sautrela.util.GetOpt;
import edu.gtts.sautrela.util.RandomFactory;
import edu.gtts.sautrela.util.SAXHandler;
import edu.gtts.sautrela.util.SyntaxError;
import edu.gtts.sautrela.util.XML;
import edu.gtts.sautrela.util.XMLBuilder;
import edu.gtts.sautrela.wfsa.Alphabet;
import edu.gtts.sautrela.wfsa.NdWFSA;
import edu.gtts.sautrela.wfsa.Probability;
import edu.gtts.sautrela.wfsa.Util;
import edu.gtts.sautrela.wfsa.WFSA;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.parsers.SAXParserFactory;
import org.xml.sax.Attributes;
import org.xml.sax.InputSource;
import org.xml.sax.SAXException;

/* loaded from: input_file:edu/gtts/sautrela/wfsa/models/DHMM.class */
public class DHMM implements NdWFSA<State, Symbol, DefaultTransition<State, Symbol>> {
    private String name;
    private int hmmSize;
    private int cdbkSize;
    private DefaultAlphabet<Symbol> alphabet;
    private State iniState;
    private State[] hmmStates;
    private double[] hmmPIni;
    private final Random R;
    private double[] iniTrainCount;
    public static final WFSA.Factory myFactory = new WFSA.Factory() { // from class: edu.gtts.sautrela.wfsa.models.DHMM.1
        @Override // edu.gtts.sautrela.wfsa.WFSA.Factory
        public WFSA getInstance(InputSource inputSource) throws ParserConfigurationException, SAXException, IOException {
            return new DHMM(inputSource);
        }

        @Override // edu.gtts.sautrela.wfsa.WFSA.Factory
        public WFSA getInstance(InputSource inputSource, Alphabet alphabet) throws ParserConfigurationException, SAXException, IOException {
            return new DHMM(inputSource, alphabet);
        }
    };

    /* loaded from: input_file:edu/gtts/sautrela/wfsa/models/DHMM$State.class */
    public class State extends DefaultState {
        private final int index;
        private double[] hmmPTrans;
        private double[] hmmPEmis;
        private DefaultTransition<State, Symbol>[][] symbolTrans;
        private DefaultTransition<State, Symbol>[] allTrans = null;
        private double[] transTrainCount = null;
        private double[] emisTrainCount = null;

        /* JADX WARN: Type inference failed for: r1v21, types: [edu.gtts.sautrela.wfsa.models.DefaultTransition<edu.gtts.sautrela.wfsa.models.DHMM$State, edu.gtts.sautrela.wfsa.models.DHMM$Symbol>[][], edu.gtts.sautrela.wfsa.models.DefaultTransition[]] */
        public State(String str, int i) {
            this.hmmPTrans = null;
            this.hmmPEmis = null;
            this.symbolTrans = (DefaultTransition[][]) null;
            setName(str);
            this.index = i;
            this.hmmPTrans = new double[DHMM.this.hmmSize + 1];
            this.hmmPEmis = new double[DHMM.this.cdbkSize];
            Arrays.fill(this.hmmPTrans, Double.NEGATIVE_INFINITY);
            Arrays.fill(this.hmmPEmis, Double.NEGATIVE_INFINITY);
            this.symbolTrans = new DefaultTransition[DHMM.this.cdbkSize];
        }

        public int compareTo(State state) {
            return this.index - state.index;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void initTrainCounts() {
            this.allTrans = null;
            Arrays.fill(this.symbolTrans, (Object) null);
            this.transTrainCount = new double[DHMM.this.hmmSize + 1];
            this.emisTrainCount = new double[DHMM.this.cdbkSize];
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void dumpTrainCounts() {
            Util.doGopalakrishnan(this.hmmPTrans, this.transTrainCount);
            Util.doGopalakrishnan(this.hmmPEmis, this.emisTrainCount);
        }
    }

    /* loaded from: input_file:edu/gtts/sautrela/wfsa/models/DHMM$Symbol.class */
    public class Symbol extends DefaultSymbol {
        private int index;

        public Symbol(String str, int i) {
            super(str);
            this.index = i;
        }

        public int compareTo(Symbol symbol) {
            return this.index - symbol.index;
        }
    }

    /* loaded from: input_file:edu/gtts/sautrela/wfsa/models/DHMM$xmlHandler.class */
    private class xmlHandler extends SAXHandler {
        public xmlHandler() {
            super(null);
        }

        private double parseProb(Attributes attributes) throws SAXHandler.ParseException {
            double log2log;
            String str = "linProb";
            String value = attributes.getValue(str);
            try {
                if (value != null) {
                    log2log = Probability.lin2log(value);
                } else {
                    str = "logProb";
                    String value2 = attributes.getValue("logProb");
                    value = value2;
                    log2log = Probability.log2log(value2);
                }
                return log2log;
            } catch (NullPointerException e) {
                throw new SAXHandler.ParseException("Attributes not found:  \"linProb\" or \"logProb\"");
            } catch (NumberFormatException e2) {
                throw new SAXHandler.ParseException("Wrong double format: \"" + value + "\" for attribute \"" + str + "\"");
            }
        }

        @Override // org.xml.sax.helpers.DefaultHandler, org.xml.sax.ContentHandler
        public void startElement(String str, String str2, String str3, Attributes attributes) throws SAXException {
            if (str3.compareTo("Emission") == 0) {
                DHMM.this.setHMMEmis(parse("state", attributes), parse("symbol", attributes), parseProb(attributes));
                return;
            }
            if (str3.compareTo("Transition") == 0) {
                DHMM.this.setHMMTrans(parse("from", attributes), parse("to", attributes), parseProb(attributes));
                return;
            }
            if (str3.compareTo("IniState") == 0) {
                String value = attributes.getValue("linProb");
                DHMM.this.setHMMIniState(attributes.getValue("name"), value != null ? Probability.lin2log(value) : Probability.log2log(attributes.getValue("logProb")));
            } else if (str3.compareTo("FinState") == 0) {
                String value2 = attributes.getValue("linProb");
                DHMM.this.setHMMFinState(attributes.getValue("name"), value2 != null ? Probability.lin2log(value2) : Probability.log2log(attributes.getValue("logProb")));
            } else {
                if (str3.compareTo("WFSA") != 0) {
                    throw new SAXHandler.ParseException("Unknown element name: \"" + str3 + "\"");
                }
                DHMM.this.name = attributes.getValue("name");
                DHMM.this._setStaticSizes(parseInt("size", attributes), parseInt("cdbkSize", attributes));
            }
        }
    }

    private DHMM() {
        this(new DefaultAlphabet());
    }

    private DHMM(Alphabet<Symbol> alphabet) {
        this.name = null;
        this.alphabet = null;
        this.iniState = null;
        this.hmmStates = null;
        this.hmmPIni = null;
        this.R = RandomFactory.newRandom();
        if (!(alphabet instanceof DefaultAlphabet)) {
            throw new UnsupportedOperationException("Unsupported Alphabet type: " + alphabet.getClass());
        }
        this.alphabet = (DefaultAlphabet) alphabet;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void _setStaticSizes(int i, int i2) {
        this.hmmSize = i;
        this.cdbkSize = i2;
        this.hmmStates = new State[i];
        this.hmmPIni = new double[i];
        Arrays.fill(this.hmmPIni, Double.NEGATIVE_INFINITY);
        this.iniState = new State("_ndWFSA_INISTATE_", i);
    }

    public void setName(String str) {
        this.name = str;
    }

    public String toString() {
        return this.name + " (HMM," + this.hmmSize + " states," + this.alphabet.size() + " symbols)";
    }

    private State addStateIfNew(String str) {
        State state = null;
        int i = 0;
        while (true) {
            if (i >= this.hmmStates.length || this.hmmStates[i] == null) {
                break;
            }
            if (this.hmmStates[i].getName().compareTo(str) == 0) {
                state = this.hmmStates[i];
                break;
            }
            i++;
        }
        if (state == null) {
            State state2 = new State(str, i);
            this.hmmStates[i] = state2;
            state = state2;
        }
        return state;
    }

    private Symbol addSymbolIfNew(String str) {
        Symbol valueOf = this.alphabet.valueOf(str);
        if (valueOf == null) {
            if (this.alphabet.size() >= this.cdbkSize) {
                throw new UnsupportedOperationException("Cannot add symbol \"" + str + "\", declared cdbkSize is " + this.cdbkSize);
            }
            DefaultAlphabet<Symbol> defaultAlphabet = this.alphabet;
            Symbol symbol = new Symbol(str, this.alphabet.size());
            valueOf = symbol;
            defaultAlphabet.add((DefaultAlphabet<Symbol>) symbol);
        }
        return valueOf;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void setHMMIniState(String str, double d) {
        this.hmmPIni[addStateIfNew(str).index] = d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void setHMMFinState(String str, double d) {
        addStateIfNew(str).hmmPTrans[this.hmmSize] = d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void setHMMTrans(String str, String str2, double d) {
        addStateIfNew(str).hmmPTrans[addStateIfNew(str2).index] = d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void setHMMEmis(String str, String str2, double d) {
        addStateIfNew(str).hmmPEmis[addSymbolIfNew(str2).index] = d;
    }

    @Override // edu.gtts.sautrela.wfsa.Named
    public String getName() {
        return this.name;
    }

    public Symbol getSymbolByName(String str) {
        return this.alphabet.valueOf(str);
    }

    @Override // edu.gtts.sautrela.wfsa.WFSA
    public Alphabet<Symbol> getAlphabet() {
        return this.alphabet;
    }

    @Override // edu.gtts.sautrela.wfsa.WFSA
    public double getFinProb(State state) {
        return state.hmmPTrans[this.hmmSize];
    }

    @Override // edu.gtts.sautrela.wfsa.WFSA
    public State getIniState() {
        return this.iniState;
    }

    @Override // edu.gtts.sautrela.wfsa.WFSA
    public DefaultTransition<State, Symbol>[] getTrans(State state) {
        if (state.allTrans == null) {
            ArrayList arrayList = new ArrayList(this.hmmSize * this.cdbkSize);
            Iterator<Y> it = this.alphabet.iterator();
            while (it.hasNext()) {
                for (DefaultTransition<State, Symbol> defaultTransition : getTrans(state, (Symbol) it.next())) {
                    arrayList.add(defaultTransition);
                }
            }
            Arrays.sort(state.allTrans = (DefaultTransition[]) arrayList.toArray(new DefaultTransition[arrayList.size()]));
        }
        return state.allTrans;
    }

    @Override // edu.gtts.sautrela.wfsa.WFSA
    public DefaultTransition<State, Symbol> getRandomTrans(State state) {
        int randomProbIndex = state == this.iniState ? Util.getRandomProbIndex(this.hmmPIni, this.R) : Util.getRandomProbIndex(state.hmmPTrans, this.R);
        if (randomProbIndex == this.hmmSize) {
            return null;
        }
        State state2 = this.hmmStates[randomProbIndex];
        int randomProbIndex2 = Util.getRandomProbIndex(state2.hmmPEmis, this.R);
        Iterator<Y> it = this.alphabet.iterator();
        while (it.hasNext()) {
            Symbol symbol = (Symbol) it.next();
            if (symbol.index == randomProbIndex2) {
                return new DefaultTransition<>(state, symbol, state2, state.hmmPTrans[randomProbIndex] + state2.hmmPEmis[randomProbIndex2]);
            }
        }
        throw new RuntimeException("No symbol found for index " + randomProbIndex2);
    }

    public Iterable<DefaultTransition<State, Symbol>> getTrans2(State state) throws UnsupportedOperationException {
        throw new UnsupportedOperationException();
    }

    public Iterable<DefaultTransition<State, Symbol>> getTrans2(State state, Symbol symbol) throws UnsupportedOperationException {
        throw new UnsupportedOperationException();
    }

    @Override // edu.gtts.sautrela.wfsa.NdWFSA
    public DefaultTransition<State, Symbol>[] getTrans(State state, Symbol symbol) {
        if (state.symbolTrans[symbol.index] == null) {
            ArrayList arrayList = new ArrayList(this.hmmSize);
            for (State state2 : this.hmmStates) {
                double d = state == this.iniState ? this.hmmPIni[state2.index] : state.hmmPTrans[state2.index];
                double d2 = state2.hmmPEmis[symbol.index];
                if (d != Double.NEGATIVE_INFINITY && d2 != Double.NEGATIVE_INFINITY) {
                    arrayList.add(new DefaultTransition(state, symbol, state2, d + d2));
                }
            }
            DefaultTransition[][] defaultTransitionArr = state.symbolTrans;
            int i = symbol.index;
            DefaultTransition[] defaultTransitionArr2 = (DefaultTransition[]) arrayList.toArray(new DefaultTransition[arrayList.size()]);
            defaultTransitionArr[i] = defaultTransitionArr2;
            Arrays.sort(defaultTransitionArr2);
        }
        return state.symbolTrans[symbol.index];
    }

    @Override // edu.gtts.sautrela.wfsa.WFSA
    public void priorExpectation(double d) {
        if (d != Probability.oneLogProb) {
            throw new UnsupportedOperationException("Cannot manage noncero init train count");
        }
        this.iniTrainCount = new double[this.hmmSize];
        for (State state : this.hmmStates) {
            state.initTrainCounts();
        }
        this.iniState.initTrainCounts();
    }

    @Override // edu.gtts.sautrela.wfsa.WFSA
    public void addFinalExpectation(State state, double d) {
        double[] dArr = state.transTrainCount;
        int i = this.hmmSize;
        dArr[i] = dArr[i] + d;
    }

    @Override // edu.gtts.sautrela.wfsa.WFSA
    public void addExpectation(DefaultTransition<State, Symbol> defaultTransition, double d) {
        if (defaultTransition.source == this.iniState) {
            double[] dArr = this.iniTrainCount;
            int i = defaultTransition.destination.index;
            dArr[i] = dArr[i] + d;
        } else {
            double[] dArr2 = defaultTransition.source.transTrainCount;
            int i2 = defaultTransition.destination.index;
            dArr2[i2] = dArr2[i2] + d;
        }
        double[] dArr3 = defaultTransition.destination.emisTrainCount;
        int i3 = defaultTransition.symbol.index;
        dArr3[i3] = dArr3[i3] + d;
    }

    @Override // edu.gtts.sautrela.wfsa.WFSA
    public void dumpSuffStats() {
        Util.doGopalakrishnan(this.hmmPIni, this.iniTrainCount);
        for (State state : this.hmmStates) {
            state.dumpTrainCounts();
        }
    }

    public DHMM(InputSource inputSource) throws ParserConfigurationException, SAXException, IOException {
        this();
        SAXParserFactory.newInstance().newSAXParser().parse(inputSource, new xmlHandler());
    }

    public DHMM(InputSource inputSource, Alphabet<Symbol> alphabet) throws ParserConfigurationException, SAXException, IOException {
        this(alphabet);
        SAXParserFactory.newInstance().newSAXParser().parse(inputSource, new xmlHandler());
    }

    @Override // edu.gtts.sautrela.wfsa.WFSA
    public String toXML() {
        XMLBuilder xMLBuilder = new XMLBuilder();
        xMLBuilder.openElement("WFSA", "name", getName(), "size", Integer.valueOf(this.hmmSize), "cdbkSize", Integer.valueOf(this.alphabet.size()), "className", getClass().getName());
        for (State state : this.hmmStates) {
            double d = this.hmmPIni[state.index];
            if (d != Double.NEGATIVE_INFINITY) {
                xMLBuilder.appendElement("IniState", "name", state.getName(), "linProb", Double.valueOf(Probability.log2lin(d)));
            }
        }
        for (State state2 : this.hmmStates) {
            double d2 = state2.hmmPTrans[this.hmmSize];
            if (d2 != Double.NEGATIVE_INFINITY) {
                xMLBuilder.appendElement("FinState", "name", state2.getName(), "linProb", Double.valueOf(Probability.log2lin(d2)));
            }
        }
        for (State state3 : this.hmmStates) {
            for (State state4 : this.hmmStates) {
                double d3 = state3.hmmPTrans[state4.index];
                if (d3 != Double.NEGATIVE_INFINITY) {
                    xMLBuilder.appendElement("Transition", "from", state3.getName(), "to", state4.getName(), "linProb", Double.valueOf(Probability.log2lin(d3)));
                }
            }
        }
        for (State state5 : this.hmmStates) {
            Iterator<Y> it = this.alphabet.iterator();
            while (it.hasNext()) {
                Symbol symbol = (Symbol) it.next();
                double d4 = state5.hmmPEmis[symbol.index];
                if (d4 != Double.NEGATIVE_INFINITY) {
                    xMLBuilder.appendElement("Emission", "state", state5.getName(), "symbol", symbol.getName(), "linProb", Double.valueOf(Probability.log2lin(d4)));
                }
            }
        }
        xMLBuilder.closeElement();
        return xMLBuilder.toString();
    }

    public static String getManPage() {
        return new ManPage(DHMM.class, "create a discrete Hidden Markov Model", "[-E] [-n name] [-r from,to] [-s symbol[,symbol ...]] size", "-E", "Use an ergodic topology. Default topology is linear with loops.", "-n name", "The name of the DHMM", "-r from,to", "Integer range of emission symbols", "-s symbol[,symbol ...]", "List of emission symbols", "size", "Number of states").toString();
    }

    public static void main(String[] strArr) throws ParserConfigurationException, SAXException, IOException {
        Random newRandom = RandomFactory.newRandom();
        String str = "noName";
        String str2 = null;
        String str3 = null;
        boolean z = false;
        HashSet hashSet = new HashSet();
        GetOpt getOpt = new GetOpt(strArr, "En:r:s:");
        while (true) {
            try {
                int opt = getOpt.getOpt();
                if (opt == -1) {
                    if (str2 != null) {
                        String[] split = str2.split(",");
                        if (split.length != 2) {
                            throw new SyntaxError("Wrong syntax for option \"-r\" : " + str2);
                        }
                        int parseInt = Integer.parseInt(split[0]);
                        int parseInt2 = Integer.parseInt(split[1]);
                        for (int i = parseInt; i <= parseInt2; i++) {
                            String num = Integer.toString(i);
                            if (!hashSet.add(num)) {
                                System.err.println("Warning: duplicated symbol \"" + num + "\".");
                            }
                        }
                    }
                    if (str3 != null) {
                        for (String str4 : str3.split(",")) {
                            if (!hashSet.add(str4)) {
                                System.err.println("Warning: duplicated symbol \"" + str4 + "\".");
                            }
                        }
                    }
                    int parseInt3 = Integer.parseInt(getOpt.getArgs(1)[0]);
                    if (hashSet.isEmpty()) {
                        throw new SyntaxError("There is no emission symbol");
                    }
                    ArrayList<String> arrayList = new ArrayList();
                    for (int i2 = 0; i2 < parseInt3; i2++) {
                        arrayList.add(Integer.toString(i2));
                    }
                    DHMM dhmm = new DHMM();
                    dhmm.name = str;
                    dhmm._setStaticSizes(parseInt3, hashSet.size());
                    if (z) {
                        double[] randomEQLogProbs = Probability.randomEQLogProbs(parseInt3, newRandom);
                        int i3 = 0;
                        for (String str5 : arrayList) {
                            double[] randomEQLogProbs2 = Probability.randomEQLogProbs(parseInt3 + 1, newRandom);
                            int i4 = 0;
                            int i5 = i3;
                            i3++;
                            dhmm.setHMMIniState(str5, randomEQLogProbs[i5]);
                            Iterator it = arrayList.iterator();
                            while (it.hasNext()) {
                                int i6 = i4;
                                i4++;
                                dhmm.setHMMTrans(str5, (String) it.next(), randomEQLogProbs2[i6]);
                            }
                            dhmm.setHMMFinState(str5, randomEQLogProbs2[i4]);
                        }
                    } else {
                        double lin2log = Probability.lin2log(0.5d);
                        String[] strArr2 = (String[]) arrayList.toArray(new String[0]);
                        dhmm.setHMMIniState(strArr2[0], Probability.oneLogProb);
                        for (int i7 = 0; i7 < strArr2.length - 1; i7++) {
                            dhmm.setHMMTrans(strArr2[i7], strArr2[i7], lin2log);
                            dhmm.setHMMTrans(strArr2[i7], strArr2[i7 + 1], lin2log);
                        }
                        dhmm.setHMMTrans(strArr2[strArr2.length - 1], strArr2[strArr2.length - 1], lin2log);
                        dhmm.setHMMFinState(strArr2[strArr2.length - 1], Probability.oneLogProb);
                    }
                    double lin2log2 = Probability.lin2log(1.0d / hashSet.size());
                    for (String str6 : arrayList) {
                        Iterator it2 = hashSet.iterator();
                        while (it2.hasNext()) {
                            dhmm.setHMMEmis(str6, (String) it2.next(), lin2log2);
                        }
                    }
                    XML.write(dhmm.toXML());
                    return;
                }
                switch (opt) {
                    case 69:
                        z = true;
                        break;
                    case 110:
                        str = getOpt.getOptArg();
                        break;
                    case 114:
                        str2 = getOpt.getOptArg();
                        break;
                    case 115:
                        str3 = getOpt.getOptArg();
                        break;
                }
            } catch (NumberFormatException e) {
                throw new SyntaxError("Wrong number format \"" + ((String) null) + "\"");
            }
        }
    }
}
