package org.openmarkov.learning.algorithm.pc.independencetester;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import javax.swing.event.UndoableEditEvent;
import org.openmarkov.core.exception.CanNotDoEditException;
import org.openmarkov.core.exception.ConstraintViolationException;
import org.openmarkov.core.exception.NonProjectablePotentialException;
import org.openmarkov.core.exception.NormalizeNullVectorException;
import org.openmarkov.core.exception.ProbNodeNotFoundException;
import org.openmarkov.core.exception.WrongCriterionException;
import org.openmarkov.core.model.graph.Node;
import org.openmarkov.core.model.network.ProbNet;
import org.openmarkov.core.model.network.ProbNode;
import org.openmarkov.core.model.network.Variable;
import org.openmarkov.core.model.network.potential.PotentialRole;
import org.openmarkov.core.model.network.potential.TablePotential;

/* loaded from: input_file:org/openmarkov/learning/algorithm/pc/independencetester/CrossEntropyIndependenceTester.class */
public class CrossEntropyIndependenceTester implements IndependenceTester {
    @Override // org.openmarkov.learning.algorithm.pc.independencetester.IndependenceTester
    public double test(ProbNet probNet, int[][] iArr, Node node, Node node2, List<Node> list) throws ProbNodeNotFoundException {
        double d = -1.0d;
        try {
            d = testValue(probNet, iArr, node, node2, list);
        } catch (Exception e) {
            System.err.println("Error while computing the independence test.");
            e.printStackTrace();
        }
        return d;
    }

    private double testValue(ProbNet probNet, int[][] iArr, Node node, Node node2, List<Node> list) throws ProbNodeNotFoundException, NormalizeNullVectorException {
        long j = 1;
        ArrayList<ProbNode> arrayList = new ArrayList<>();
        ArrayList<ProbNode> arrayList2 = new ArrayList<>();
        arrayList.add(probNet.getProbNode(node2));
        for (Node node3 : list) {
            arrayList.add(probNet.getProbNode(node3));
            arrayList2.add(probNet.getProbNode(node3));
            j *= probNet.getProbNode(node3).getVariable().getNumStates();
        }
        long numStates = j * probNet.getProbNode(node).getVariable().getNumStates() * probNet.getProbNode(node2).getVariable().getNumStates();
        double length = 2.0d * iArr.length * crossEntropy(probNet, iArr, node, node2, arrayList, arrayList2);
        if (Math.abs(length) < 1.0E-10d) {
            length = 0.0d;
        }
        long numStates2 = list.size() != 0 ? j * (probNet.getProbNode(node).getVariable().getNumStates() - 1) * (probNet.getProbNode(node2).getVariable().getNumStates() - 1) : (probNet.getProbNode(node).getVariable().getNumStates() - 1) * (probNet.getProbNode(node2).getVariable().getNumStates() - 1);
        if (numStates < numStates2) {
            numStates2 = numStates;
        }
        if (numStates2 <= 0) {
            numStates2 = 1;
        }
        return StatisticalUtilities.chiSquare(length, numStates2);
    }

    private double crossEntropy(ProbNet probNet, int[][] iArr, Node node, Node node2, ArrayList<ProbNode> arrayList, ArrayList<ProbNode> arrayList2) throws ProbNodeNotFoundException, NormalizeNullVectorException {
        ProbNode probNode = probNet.getProbNode(node);
        return conditionedEntropy(probNet, iArr, probNode, arrayList2) - conditionedEntropy(probNet, iArr, probNode, arrayList);
    }

    private double conditionedEntropy(ProbNet probNet, int[][] iArr, ProbNode probNode, ArrayList<ProbNode> arrayList) throws NormalizeNullVectorException, ProbNodeNotFoundException {
        int numStates = probNode.getVariable().getNumStates();
        int i = 1;
        double d = 0.0d;
        int i2 = 0;
        ArrayList<ProbNode> arrayList2 = new ArrayList<>();
        arrayList2.add(probNode);
        Iterator<ProbNode> it = arrayList.iterator();
        while (it.hasNext()) {
            ProbNode next = it.next();
            arrayList2.add(next);
            i *= next.getVariable().getNumStates();
        }
        double[] absoluteNormalization = absoluteNormalization(absoluteFrequencies(probNet, iArr, arrayList2), iArr);
        for (int i3 = 0; i3 < i; i3++) {
            double d2 = 0.0d;
            for (int i4 = 0; i4 < numStates; i4++) {
                d2 += absoluteNormalization[i2 + i4];
            }
            for (int i5 = 0; i5 < numStates; i5++) {
                double d3 = absoluteNormalization[i2];
                if (d3 > 0.0d) {
                    d += d3 * Math.log(d3 / d2);
                }
                i2++;
            }
        }
        return -d;
    }

    private TablePotential absoluteFrequencies(ProbNet probNet, int[][] iArr, ArrayList<ProbNode> arrayList) throws ProbNodeNotFoundException {
        int i = 1;
        int i2 = 0;
        ArrayList<Variable> arrayList2 = new ArrayList<>();
        int[] iArr2 = new int[arrayList.size()];
        Iterator<ProbNode> it = arrayList.iterator();
        while (it.hasNext()) {
            ProbNode next = it.next();
            arrayList2.add(next.getVariable());
            iArr2[i2] = probNet.getProbNodes().indexOf(probNet.getProbNode(next.getVariable()));
            i *= next.getVariable().getNumStates();
            i2++;
        }
        return absoluteFreqPotential(probNet, iArr, arrayList, i, arrayList2, iArr2);
    }

    private TablePotential absoluteFreqPotential(ProbNet probNet, int[][] iArr, ArrayList<ProbNode> arrayList, int i, ArrayList<Variable> arrayList2, int[] iArr2) throws ProbNodeNotFoundException {
        TablePotential tablePotential = new TablePotential(arrayList2, PotentialRole.CONDITIONAL_PROBABILITY);
        double[] values = tablePotential.getValues();
        int numStates = arrayList.get(0).getVariable().getNumStates();
        int indexOf = probNet.getProbNodes().indexOf(arrayList.get(0));
        arrayList.remove(0);
        arrayList2.remove(0);
        for (int i2 = 0; i2 < i; i2++) {
            values[i2] = 0.0d;
        }
        for (int i3 = 0; i3 < iArr.length; i3++) {
            double d = 0.0d;
            int i4 = 1;
            while (probNet.getProbNodes(arrayList2).iterator().hasNext()) {
                d = (d * r0.next().getVariable().getNumStates()) + iArr[i3][iArr2[i4]];
                i4++;
            }
            int i5 = (numStates * ((int) d)) + iArr[i3][indexOf];
            values[i5] = values[i5] + 1.0d;
        }
        return tablePotential;
    }

    private double[] absoluteNormalization(TablePotential tablePotential, int[][] iArr) {
        double[] dArr = new double[tablePotential.getTableSize()];
        for (int i = 0; i < tablePotential.getTableSize(); i++) {
            dArr[i] = tablePotential.getValues()[i] / iArr.length;
        }
        return dArr;
    }

    @Override // org.openmarkov.core.action.PNUndoableEditListener
    public void undoableEditWillHappen(UndoableEditEvent undoableEditEvent) throws ConstraintViolationException, CanNotDoEditException, NonProjectablePotentialException, WrongCriterionException {
    }

    @Override // org.openmarkov.core.action.PNUndoableEditListener
    public void undoEditHappened(UndoableEditEvent undoableEditEvent) {
    }

    public void undoableEditHappened(UndoableEditEvent undoableEditEvent) {
    }
}
