package org.openmarkov.inference.likelihoodWeighting;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Stack;
import org.openmarkov.core.exception.IncompatibleEvidenceException;
import org.openmarkov.core.exception.InvalidStateException;
import org.openmarkov.core.exception.NormalizeNullVectorException;
import org.openmarkov.core.exception.NotEvaluableNetworkException;
import org.openmarkov.core.exception.UnexpectedInferenceException;
import org.openmarkov.core.inference.InferenceAlgorithm;
import org.openmarkov.core.inference.annotation.InferenceAnnotation;
import org.openmarkov.core.model.graph.Graph;
import org.openmarkov.core.model.graph.Link;
import org.openmarkov.core.model.graph.Node;
import org.openmarkov.core.model.network.EvidenceCase;
import org.openmarkov.core.model.network.Finding;
import org.openmarkov.core.model.network.NodeType;
import org.openmarkov.core.model.network.ProbNet;
import org.openmarkov.core.model.network.ProbNode;
import org.openmarkov.core.model.network.Variable;
import org.openmarkov.core.model.network.potential.Potential;
import org.openmarkov.core.model.network.potential.PotentialRole;
import org.openmarkov.core.model.network.potential.TablePotential;
import org.openmarkov.core.model.network.potential.operation.DiscretePotentialOperations;
import org.openmarkov.core.model.network.type.BayesianNetworkType;
import org.openmarkov.core.model.network.type.TuningNetworkType;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;

@InferenceAnnotation(name = "LikelihoodWeighting")
/* loaded from: input_file:org/openmarkov/inference/likelihoodWeighting/LikelihoodWeighting.class */
public class LikelihoodWeighting extends InferenceAlgorithm {
    private static final int DEFAULT_SAMPLE_SIZE = 10000;
    private int sampleSize;
    private double accumulatedWeight;
    private int positiveSampleCount;

    public LikelihoodWeighting(ProbNet probNet) throws NotEvaluableNetworkException {
        super(probNet);
        this.sampleSize = 10000;
        this.accumulatedWeight = 0.0d;
        this.positiveSampleCount = 0;
    }

    @Override // org.openmarkov.core.inference.InferenceAlgorithm
    public boolean isEvaluable(ProbNet probNet) {
        return probNet.getNetworkType().equals(BayesianNetworkType.getUniqueInstance()) || probNet.getNetworkType().equals(TuningNetworkType.getUniqueInstance());
    }

    public static void checkEvaluability(ProbNet probNet) throws NotEvaluableNetworkException {
        if (!probNet.getNetworkType().equals(BayesianNetworkType.getUniqueInstance()) && !probNet.getNetworkType().equals(TuningNetworkType.getUniqueInstance())) {
            throw new NotEvaluableNetworkException("");
        }
    }

    @Override // org.openmarkov.core.inference.InferenceAlgorithm
    public HashMap<Variable, TablePotential> getProbsAndUtilities(List<Variable> list) throws IncompatibleEvidenceException {
        EvidenceCase postResolutionEvidence = getPostResolutionEvidence();
        postResolutionEvidence.fuse(getPreResolutionEvidence(), true);
        HashMap<Variable, TablePotential> hashMap = new HashMap<>();
        if (this.probNet.getNetworkType().equals(TuningNetworkType.getUniqueInstance())) {
            setEvidenceInDecisionNodes(this.probNet, postResolutionEvidence);
        }
        List<Potential> buildLinkRestrictionList = buildLinkRestrictionList(this.probNet);
        List<Variable> sortTopologically = sortTopologically(this.probNet, getVariablesToSample(list, postResolutionEvidence));
        ArrayList arrayList = new ArrayList();
        Iterator<Variable> it = sortTopologically.iterator();
        while (it.hasNext()) {
            arrayList.addAll(this.probNet.getProbNode(it.next()).getPotentials());
        }
        if (this.probNet.additionalProperties.containsKey("unobservedUtilityPenalty")) {
            int parseInt = Integer.parseInt(this.probNet.additionalProperties.get("unobservedUtilityPenalty").toString());
            for (int i = 0; i < arrayList.size(); i++) {
                Potential potential = (Potential) arrayList.get(i);
                if (potential.getPotentialRole() == PotentialRole.UTILITY && !postResolutionEvidence.existsEvidence(potential.getVariables()) && (potential instanceof TablePotential)) {
                    TablePotential tablePotential = (TablePotential) potential.copy();
                    double[] dArr = tablePotential.values;
                    for (int i2 = 0; i2 < dArr.length; i2++) {
                        int i3 = i2;
                        dArr[i3] = dArr[i3] / parseInt;
                    }
                    arrayList.set(i, tablePotential);
                }
            }
        }
        List<Variable> variables = postResolutionEvidence.getVariables();
        ArrayList<Potential> arrayList2 = new ArrayList();
        for (Variable variable : variables) {
            try {
                arrayList2.addAll(this.probNet.getProbNode(variable).getPotentials());
            } catch (NullPointerException e) {
                throw new NullPointerException("Variable " + variable.getName() + " has no Potential");
            }
        }
        ArrayList arrayList3 = new ArrayList();
        for (Potential potential2 : arrayList2) {
            int[] iArr = new int[potential2.getVariables().size() - 1];
            for (int i4 = 1; i4 < potential2.getVariables().size(); i4++) {
                iArr[i4 - 1] = sortTopologically.indexOf(potential2.getVariable(i4));
            }
            arrayList3.add(iArr);
        }
        HashMap<Variable, Integer> hashMap2 = new HashMap<>();
        for (Variable variable2 : postResolutionEvidence.getVariables()) {
            hashMap2.put(variable2, Integer.valueOf(postResolutionEvidence.getFinding(variable2).getStateIndex()));
        }
        ArrayList arrayList4 = new ArrayList(list.size());
        for (Variable variable3 : list) {
            double[] dArr2 = NodeType.UTILITY != this.probNet.getProbNode(variable3).getNodeType() ? new double[variable3.getNumStates()] : new double[1];
            for (int i5 = 0; i5 < dArr2.length; i5++) {
                dArr2[i5] = 0.0d;
            }
            arrayList4.add(dArr2);
        }
        Random random = new Random();
        this.accumulatedWeight = 0.0d;
        this.positiveSampleCount = 0;
        HashMap<Variable, Double> hashMap3 = new HashMap<>();
        for (int i6 = 0; i6 < this.sampleSize; i6++) {
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                Potential potential3 = (Potential) it2.next();
                if (potential3.getPotentialRole() != PotentialRole.UTILITY) {
                    hashMap2.put(potential3.getVariable(0), potential3.sample(random, hashMap2));
                } else {
                    hashMap3.put(potential3.getUtilityVariable(), Double.valueOf(potential3.getUtility(hashMap2, hashMap3)));
                }
            }
            double d = 1.0d;
            for (Potential potential4 : arrayList2) {
                if (potential4.getVariables().size() > 1 || potential4.getProbability(hashMap2) == 0.0d) {
                    d *= potential4.getProbability(hashMap2);
                }
            }
            Iterator<Potential> it3 = buildLinkRestrictionList.iterator();
            while (it3.hasNext()) {
                d *= it3.next().getProbability(hashMap2);
            }
            for (int i7 = 0; i7 < list.size(); i7++) {
                if (((double[]) arrayList4.get(i7)).length > 1) {
                    double[] dArr3 = (double[]) arrayList4.get(i7);
                    int intValue = hashMap2.get(list.get(i7)).intValue();
                    dArr3[intValue] = dArr3[intValue] + d;
                } else {
                    double[] dArr4 = (double[]) arrayList4.get(i7);
                    dArr4[0] = dArr4[0] + (hashMap3.get(list.get(i7)).doubleValue() * d);
                }
            }
            this.accumulatedWeight += d;
            if (d > 0.0d) {
                this.positiveSampleCount++;
            }
        }
        for (int i8 = 0; i8 < list.size(); i8++) {
            try {
                if (((double[]) arrayList4.get(i8)).length > 1) {
                    TablePotential tablePotential2 = new TablePotential(PotentialRole.JOINT_PROBABILITY, list.get(i8));
                    tablePotential2.values = (double[]) arrayList4.get(i8);
                    hashMap.put(list.get(i8), DiscretePotentialOperations.normalize(tablePotential2));
                } else {
                    TablePotential tablePotential3 = new TablePotential(new ArrayList(), PotentialRole.UTILITY, list.get(i8));
                    tablePotential3.values[0] = ((double[]) arrayList4.get(i8))[0] / this.accumulatedWeight;
                    hashMap.put(list.get(i8), tablePotential3);
                }
            } catch (NormalizeNullVectorException e2) {
                throw new IncompatibleEvidenceException(e2.getMessage());
            }
        }
        return hashMap;
    }

    @Override // org.openmarkov.core.inference.InferenceAlgorithm
    public HashMap<Variable, TablePotential> getProbsAndUtilities() throws IncompatibleEvidenceException {
        return getProbsAndUtilities(this.probNet.getVariables());
    }

    public static List<Variable> sortTopologically(ProbNet probNet, List<Variable> list) {
        Graph copy = probNet.getGraph().copy();
        ArrayList arrayList = new ArrayList(list.size());
        Stack stack = new Stack();
        ArrayList arrayList2 = new ArrayList();
        for (Node node : copy.getNodes()) {
            if (node.getParents().size() == 0) {
                stack.push(node);
            }
        }
        while (!stack.isEmpty()) {
            Node node2 = (Node) stack.pop();
            arrayList2.add(node2);
            for (Node node3 : node2.getChildren()) {
                copy.removeLink(node2, node3, true);
                if (node3.getParents().isEmpty()) {
                    stack.push(node3);
                }
            }
        }
        Iterator it = arrayList2.iterator();
        while (it.hasNext()) {
            Node node4 = (Node) it.next();
            if (list.contains(((ProbNode) node4.getObject()).getVariable())) {
                arrayList.add(((ProbNode) node4.getObject()).getVariable());
            }
        }
        return arrayList;
    }

    @Override // org.openmarkov.core.inference.InferenceAlgorithm
    public TablePotential getGlobalUtility() throws IncompatibleEvidenceException {
        throw new NotImplementedException();
    }

    @Override // org.openmarkov.core.inference.InferenceAlgorithm
    public TablePotential getJointProbability(List<Variable> list) throws IncompatibleEvidenceException {
        int[] calculateOffsets = TablePotential.calculateOffsets(TablePotential.calculateDimensions(list));
        EvidenceCase postResolutionEvidence = getPostResolutionEvidence();
        postResolutionEvidence.fuse(getPreResolutionEvidence(), true);
        if (this.probNet.getNetworkType().equals(TuningNetworkType.getUniqueInstance())) {
            setEvidenceInDecisionNodes(this.probNet, postResolutionEvidence);
        }
        List<Potential> buildLinkRestrictionList = buildLinkRestrictionList(this.probNet);
        List<Variable> sortTopologically = sortTopologically(this.probNet, getVariablesToSample(this.probNet.getVariables(NodeType.CHANCE), postResolutionEvidence));
        ArrayList<Potential> arrayList = new ArrayList();
        Iterator<Variable> it = sortTopologically.iterator();
        while (it.hasNext()) {
            arrayList.addAll(this.probNet.getProbNode(it.next()).getPotentials());
        }
        List<Variable> variables = postResolutionEvidence.getVariables();
        ArrayList<Potential> arrayList2 = new ArrayList();
        for (Variable variable : variables) {
            try {
                arrayList2.addAll(this.probNet.getProbNode(variable).getPotentials());
            } catch (NullPointerException e) {
                throw new NullPointerException("Variable " + variable.getName() + " has no Potential");
            }
        }
        ArrayList arrayList3 = new ArrayList();
        for (Potential potential : arrayList2) {
            int[] iArr = new int[potential.getVariables().size() - 1];
            for (int i = 1; i < potential.getVariables().size(); i++) {
                iArr[i - 1] = sortTopologically.indexOf(potential.getVariable(i));
            }
            arrayList3.add(iArr);
        }
        HashMap<Variable, Integer> hashMap = new HashMap<>();
        for (Variable variable2 : postResolutionEvidence.getVariables()) {
            hashMap.put(variable2, Integer.valueOf(postResolutionEvidence.getFinding(variable2).getStateIndex()));
        }
        TablePotential tablePotential = new TablePotential(list, PotentialRole.JOINT_PROBABILITY);
        this.accumulatedWeight = 0.0d;
        this.positiveSampleCount = 0;
        Random random = new Random();
        for (int i2 = 0; i2 < this.sampleSize; i2++) {
            for (Potential potential2 : arrayList) {
                hashMap.put(potential2.getVariable(0), potential2.sample(random, hashMap));
            }
            double d = 1.0d;
            for (Potential potential3 : arrayList2) {
                if (potential3.getVariables().size() > 1 || potential3.getProbability(hashMap) == 0.0d) {
                    d *= potential3.getProbability(hashMap);
                }
            }
            Iterator<Potential> it2 = buildLinkRestrictionList.iterator();
            while (it2.hasNext()) {
                d *= it2.next().getProbability(hashMap);
            }
            double[] dArr = tablePotential.values;
            int sampleIndex = getSampleIndex(list, calculateOffsets, hashMap);
            dArr[sampleIndex] = dArr[sampleIndex] + d;
            this.accumulatedWeight += d;
            if (d > 0.0d) {
                this.positiveSampleCount++;
            }
        }
        try {
            DiscretePotentialOperations.normalize(tablePotential);
            return tablePotential;
        } catch (NormalizeNullVectorException e2) {
            throw new IncompatibleEvidenceException(e2.getMessage());
        }
    }

    public Map<Variable, TablePotential> getFamilyJointProbabilities() throws IncompatibleEvidenceException {
        HashMap hashMap = new HashMap();
        EvidenceCase postResolutionEvidence = getPostResolutionEvidence();
        postResolutionEvidence.fuse(getPreResolutionEvidence(), true);
        if (this.probNet.getNetworkType().equals(TuningNetworkType.getUniqueInstance())) {
            setEvidenceInDecisionNodes(this.probNet, postResolutionEvidence);
        }
        List<Potential> buildLinkRestrictionList = buildLinkRestrictionList(this.probNet);
        List<Variable> sortTopologically = sortTopologically(this.probNet, getVariablesToSample(this.probNet.getVariables(NodeType.CHANCE), postResolutionEvidence));
        ArrayList<Potential> arrayList = new ArrayList();
        Iterator<Variable> it = sortTopologically.iterator();
        while (it.hasNext()) {
            arrayList.addAll(this.probNet.getProbNode(it.next()).getPotentials());
        }
        List<Variable> variables = postResolutionEvidence.getVariables();
        ArrayList<Potential> arrayList2 = new ArrayList();
        for (Variable variable : variables) {
            try {
                arrayList2.addAll(this.probNet.getProbNode(variable).getPotentials());
            } catch (NullPointerException e) {
                throw new NullPointerException("Variable " + variable.getName() + " has no Potential");
            }
        }
        ArrayList arrayList3 = new ArrayList();
        for (Potential potential : arrayList2) {
            int[] iArr = new int[potential.getVariables().size() - 1];
            for (int i = 1; i < potential.getVariables().size(); i++) {
                iArr[i - 1] = sortTopologically.indexOf(potential.getVariable(i));
            }
            arrayList3.add(iArr);
        }
        HashMap<Variable, Integer> hashMap2 = new HashMap<>();
        for (Variable variable2 : postResolutionEvidence.getVariables()) {
            hashMap2.put(variable2, Integer.valueOf(postResolutionEvidence.getFinding(variable2).getStateIndex()));
        }
        HashMap hashMap3 = new HashMap();
        for (Potential potential2 : arrayList) {
            TablePotential tablePotential = new TablePotential(potential2.getVariables(), PotentialRole.JOINT_PROBABILITY);
            for (int i2 = 0; i2 < tablePotential.values.length; i2++) {
                tablePotential.values[i2] = 0.0d;
            }
            hashMap.put(potential2.getVariable(0), tablePotential);
            hashMap3.put(potential2, TablePotential.calculateOffsets(TablePotential.calculateDimensions(potential2.getVariables())));
        }
        this.accumulatedWeight = 0.0d;
        this.positiveSampleCount = 0;
        Random random = new Random();
        for (int i3 = 0; i3 < this.sampleSize; i3++) {
            for (Potential potential3 : arrayList) {
                hashMap2.put(potential3.getVariable(0), potential3.sample(random, hashMap2));
            }
            double d = 1.0d;
            for (Potential potential4 : arrayList2) {
                if (potential4.getVariables().size() > 1 || potential4.getProbability(hashMap2) == 0.0d) {
                    d *= potential4.getProbability(hashMap2);
                }
            }
            Iterator<Potential> it2 = buildLinkRestrictionList.iterator();
            while (it2.hasNext()) {
                d *= it2.next().getProbability(hashMap2);
            }
            for (Potential potential5 : arrayList) {
                List<Variable> variables2 = potential5.getVariables();
                int[] iArr2 = (int[]) hashMap3.get(potential5);
                double[] dArr = ((TablePotential) hashMap.get(potential5.getVariable(0))).values;
                int sampleIndex = getSampleIndex(variables2, iArr2, hashMap2);
                dArr[sampleIndex] = dArr[sampleIndex] + d;
            }
            this.accumulatedWeight += d;
            if (d > 0.0d) {
                this.positiveSampleCount++;
            }
        }
        HashMap hashMap4 = new HashMap();
        for (Finding finding : postResolutionEvidence.getFindings()) {
            Variable variable3 = finding.getVariable();
            TablePotential tablePotential2 = new TablePotential(PotentialRole.JOINT_PROBABILITY, variable3);
            int i4 = 0;
            while (i4 < variable3.getNumStates()) {
                tablePotential2.values[i4] = finding.getStateIndex() == i4 ? 1.0d : 0.0d;
                i4++;
            }
            hashMap4.put(variable3, tablePotential2);
        }
        for (Potential potential6 : arrayList2) {
            ArrayList arrayList4 = new ArrayList();
            arrayList4.add((TablePotential) potential6);
            for (Variable variable4 : potential6.getVariables()) {
                if (hashMap4.containsKey(variable4)) {
                    arrayList4.add((TablePotential) hashMap4.get(variable4));
                }
            }
            hashMap.put(potential6.getVariable(0), DiscretePotentialOperations.multiply(arrayList4));
        }
        try {
            for (TablePotential tablePotential3 : hashMap.values()) {
                tablePotential3.setPotentialRole(PotentialRole.JOINT_PROBABILITY);
                DiscretePotentialOperations.normalize(tablePotential3);
            }
            return hashMap;
        } catch (NormalizeNullVectorException e2) {
            throw new IncompatibleEvidenceException(e2.getMessage());
        }
    }

    private int getSampleIndex(List<Variable> list, int[] iArr, HashMap<Variable, Integer> hashMap) {
        int i = 0;
        for (int i2 = 0; i2 < list.size(); i2++) {
            Variable variable = list.get(i2);
            int indexOf = list.indexOf(variable);
            if (indexOf != -1) {
                i += iArr[indexOf] * hashMap.get(variable).intValue();
            }
        }
        return i;
    }

    public int getSampleSize() {
        return this.sampleSize;
    }

    public void setSampleSize(int i) {
        this.sampleSize = i;
    }

    private void setEvidenceInDecisionNodes(ProbNet probNet, EvidenceCase evidenceCase) throws IncompatibleEvidenceException {
        for (Variable variable : probNet.getVariables(NodeType.DECISION)) {
            if (!evidenceCase.contains(variable)) {
                try {
                    evidenceCase.addFinding(new Finding(variable, 1));
                } catch (InvalidStateException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    private List<Potential> buildLinkRestrictionList(ProbNet probNet) {
        ArrayList arrayList = new ArrayList();
        for (Link link : probNet.getGraph().getLinks()) {
            if (link.hasRestrictions()) {
                arrayList.add(link.getRestrictionsPotential());
            }
        }
        return arrayList;
    }

    private List<Variable> getVariablesToSample(List<Variable> list, EvidenceCase evidenceCase) {
        ArrayList arrayList = new ArrayList();
        for (Variable variable : list) {
            if (!evidenceCase.contains(variable)) {
                arrayList.add(variable);
            }
        }
        return arrayList;
    }

    public double getAccumulatedWeight() {
        return this.accumulatedWeight;
    }

    public double getPositiveSampleRatio() {
        return this.positiveSampleCount / this.sampleSize;
    }

    @Override // org.openmarkov.core.inference.InferenceAlgorithm
    public Potential getOptimizedPolicy(Variable variable) throws IncompatibleEvidenceException, UnexpectedInferenceException {
        throw new NotImplementedException();
    }

    @Override // org.openmarkov.core.inference.InferenceAlgorithm
    public Potential getExpectedUtilities(Variable variable) throws IncompatibleEvidenceException, UnexpectedInferenceException {
        throw new NotImplementedException();
    }
}
