package org.openmarkov.core.inference;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.junit.Assert;
import org.junit.Test;
import org.openmarkov.core.exception.NonProjectablePotentialException;
import org.openmarkov.core.exception.ProbNodeNotFoundException;
import org.openmarkov.core.exception.WrongCriterionException;
import org.openmarkov.core.model.network.NetsFactory;
import org.openmarkov.core.model.network.ProbNet;
import org.openmarkov.core.model.network.Variable;
import org.openmarkov.core.model.network.potential.Potential;
import org.openmarkov.core.model.network.potential.PotentialRole;
import org.openmarkov.core.model.network.potential.TablePotential;
import org.openmarkov.core.model.network.potential.TablePotentialTest;
import org.openmarkov.core.model.network.potential.operation.DiscretePotentialOperations;

/* loaded from: input_file:org/openmarkov/core/inference/FactoryExpandedSMMTest.class */
public class FactoryExpandedSMMTest {
    protected double maxError = 1.0E-6d;
    private static /* synthetic */ int[] $SWITCH_TABLE$org$openmarkov$core$model$network$potential$PotentialRole;

    @Test
    public void testExpansionSMMWithoutStateVariable() {
        for (int i = 1; i <= 2; i++) {
            ProbNet constructExpandedNetwork = FactoryExpandedSMM.constructExpandedNetwork(i, NetsFactory.createSMMWithoutStateVariable(0.9d, 1.0d, 40000.0d, 0.0d), 0.01d * 100.0d, 0.01d * 100.0d, true);
            TablePotential sum = DiscretePotentialOperations.sum(extractUtilityPotentialsProjecToTablesAndCheckVariables(constructExpandedNetwork));
            double d = 1.0d / (1.0d + 0.01d);
            double sumTermsGeometricProgression = sumTermsGeometricProgression(0.9d, d, i);
            double sumTermsGeometricProgression2 = sumTermsGeometricProgression(1.0d, d, i);
            ArrayList arrayList = new ArrayList();
            try {
                arrayList.add(constructExpandedNetwork.getVariable("Treatment"));
            } catch (ProbNodeNotFoundException e) {
                e.printStackTrace();
            }
            arrayList.add(constructExpandedNetwork.decisionCriteria);
            TablePotential tablePotential = new TablePotential(arrayList, PotentialRole.UTILITY);
            tablePotential.setValues(new double[]{40000.0d, 0.0d, sumTermsGeometricProgression, sumTermsGeometricProgression2});
            TablePotentialTest.checkEqualPotentials(sum, tablePotential, this.maxError);
        }
    }

    @Test
    public void testExpansionSMMWithStateVariable() {
        for (int i = 1; i <= 5; i++) {
            ProbNet constructExpandedNetwork = FactoryExpandedSMM.constructExpandedNetwork(i, NetsFactory.createSMMWithStateVariable(0.9d, 1.0d, 40000.0d, 0.0d, 0.7d, 0.5d), 0.01d * 100.0d, 0.01d * 100.0d, true);
            List<TablePotential> extractUtilityPotentialsProjecToTablesAndCheckVariables = extractUtilityPotentialsProjecToTablesAndCheckVariables(constructExpandedNetwork);
            double d = 1.0d / (1.0d + 0.01d);
            for (TablePotential tablePotential : extractUtilityPotentialsProjecToTablesAndCheckVariables) {
                if (hasTemporalVariableRoleAndNotZeroSlice(tablePotential, PotentialRole.UTILITY)) {
                    checkUtilityPotentialQoLSMMWithState(constructExpandedNetwork, tablePotential, 0.9d, 1.0d, d, tablePotential.getUtilityVariable().getTimeSlice());
                }
            }
        }
    }

    public void checkUtilityPotentialQoLSMMWithState(ProbNet probNet, TablePotential tablePotential, double d, double d2, double d3, int i) {
        ArrayList arrayList = new ArrayList();
        try {
            arrayList.add(probNet.getVariable("Treatment"));
            arrayList.add(probNet.decisionCriteria);
            arrayList.add(probNet.getVariable(nameStateVariable(tablePotential.getUtilityVariable())));
        } catch (ProbNodeNotFoundException e) {
            e.printStackTrace();
        }
        double[] dArr = {0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, termGeometricProgression(d, d3, i), termGeometricProgression(d2, d3, i)};
        TablePotential tablePotential2 = new TablePotential(arrayList, PotentialRole.UTILITY);
        tablePotential2.setValues(dArr);
        TablePotentialTest.checkEqualPotentials(tablePotential, tablePotential2, this.maxError);
    }

    private String nameStateVariable(Variable variable) {
        Variable variable2 = new Variable("State");
        variable2.setBaseName(variable2.getName());
        variable2.setTimeSlice(variable.getTimeSlice());
        return variable2.getName();
    }

    private boolean hasTemporalVariableRoleAndNotZeroSlice(TablePotential tablePotential, PotentialRole potentialRole) {
        boolean z = false;
        Variable variable = null;
        if (tablePotential.getPotentialRole() == potentialRole) {
            switch ($SWITCH_TABLE$org$openmarkov$core$model$network$potential$PotentialRole()[potentialRole.ordinal()]) {
                case 1:
                    variable = tablePotential.getVariables().get(0);
                    break;
                case 5:
                    variable = tablePotential.getUtilityVariable();
                    break;
            }
            z = variable != null && variable.isTemporal() && variable.getTimeSlice() > 0;
        }
        return z;
    }

    private List<TablePotential> extractUtilityPotentialsProjecToTablesAndCheckVariables(ProbNet probNet) {
        List<Potential> potentialsRole = probNet.getPotentialsRole(PotentialRole.UTILITY);
        InferenceOptions inferenceOptions = new InferenceOptions(probNet, null);
        ArrayList arrayList = new ArrayList();
        for (Potential potential : potentialsRole) {
            Assert.assertNotNull(potential.getUtilityVariable());
            try {
                List<TablePotential> tableProject = potential.tableProject(null, inferenceOptions);
                Iterator<TablePotential> it = tableProject.iterator();
                while (it.hasNext()) {
                    Assert.assertNotNull(it.next().getUtilityVariable());
                }
                arrayList.addAll(tableProject);
            } catch (NonProjectablePotentialException | WrongCriterionException e) {
                e.printStackTrace();
            }
        }
        return arrayList;
    }

    public static double sumTermsGeometricProgression(double d, double d2, int i) {
        return (d - (d * Math.pow(d2, i))) / (1.0d - d2);
    }

    private double termGeometricProgression(double d, double d2, int i) {
        return d * Math.pow(d2, i);
    }

    static /* synthetic */ int[] $SWITCH_TABLE$org$openmarkov$core$model$network$potential$PotentialRole() {
        int[] iArr = $SWITCH_TABLE$org$openmarkov$core$model$network$potential$PotentialRole;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[PotentialRole.valuesCustom().length];
        try {
            iArr2[PotentialRole.CONDITIONAL_PROBABILITY.ordinal()] = 1;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[PotentialRole.DECISION.ordinal()] = 2;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[PotentialRole.JOINT_PROBABILITY.ordinal()] = 3;
        } catch (NoSuchFieldError unused3) {
        }
        try {
            iArr2[PotentialRole.LINK_RESTRICTION.ordinal()] = 6;
        } catch (NoSuchFieldError unused4) {
        }
        try {
            iArr2[PotentialRole.POLICY.ordinal()] = 4;
        } catch (NoSuchFieldError unused5) {
        }
        try {
            iArr2[PotentialRole.UNSPECIFIED.ordinal()] = 7;
        } catch (NoSuchFieldError unused6) {
        }
        try {
            iArr2[PotentialRole.UTILITY.ordinal()] = 5;
        } catch (NoSuchFieldError unused7) {
        }
        $SWITCH_TABLE$org$openmarkov$core$model$network$potential$PotentialRole = iArr2;
        return iArr2;
    }
}
