package org.openmarkov.learning.core;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.openmarkov.core.action.PNEdit;
import org.openmarkov.core.exception.CanNotDoEditException;
import org.openmarkov.core.exception.ConstraintViolationException;
import org.openmarkov.core.exception.DoEditException;
import org.openmarkov.core.exception.NonProjectablePotentialException;
import org.openmarkov.core.exception.NormalizeNullVectorException;
import org.openmarkov.core.exception.ProbNodeNotFoundException;
import org.openmarkov.core.exception.WrongCriterionException;
import org.openmarkov.core.io.database.CaseDatabase;
import org.openmarkov.core.model.network.NodeType;
import org.openmarkov.core.model.network.ProbNet;
import org.openmarkov.core.model.network.ProbNode;
import org.openmarkov.core.model.network.State;
import org.openmarkov.core.model.network.Variable;
import org.openmarkov.learning.core.algorithm.LearningAlgorithm;
import org.openmarkov.learning.core.algorithm.annotation.LearningAlgorithmManager;
import org.openmarkov.learning.core.algorithm.annotation.LearningAlgorithmType;
import org.openmarkov.learning.core.constraint.ModelNetworkConstraint;
import org.openmarkov.learning.core.editionsgenerator.LearningEditMotivation;
import org.openmarkov.learning.core.editionsgenerator.LearningEditProposal;
import org.openmarkov.learning.core.exception.EmptyModelNetException;
import org.openmarkov.learning.core.exception.LatentVariablesException;
import org.openmarkov.learning.core.util.ModelNetUse;

/* loaded from: input_file:org/openmarkov/learning/core/LearningManager.class */
public class LearningManager {
    private static LearningAlgorithmManager learningAlgorithmManager = new LearningAlgorithmManager();
    private LearningAlgorithm learningAlgorithm = null;
    private ProbNet learnedNet;
    private ModelNetUse modelNetUse;
    private CaseDatabase caseDatabase;

    public LearningManager(CaseDatabase caseDatabase, String str, ProbNet probNet, ModelNetUse modelNetUse) throws NormalizeNullVectorException, EmptyModelNetException, LatentVariablesException {
        this.learnedNet = null;
        this.caseDatabase = null;
        this.caseDatabase = caseDatabase;
        if (modelNetUse == null || !modelNetUse.isUseModelNet()) {
            this.learnedNet = new ProbNet();
            Iterator<Variable> it = caseDatabase.getVariables().iterator();
            while (it.hasNext()) {
                this.learnedNet.addProbNode(it.next(), NodeType.CHANCE);
            }
        } else {
            if (probNet == null) {
                throw new EmptyModelNetException();
            }
            this.learnedNet = applyModelNet(learningAlgorithmManager.getByName(str), caseDatabase, probNet, modelNetUse);
        }
        addElviraProperties(this.learnedNet);
        this.modelNetUse = modelNetUse;
    }

    public void init(LearningAlgorithm learningAlgorithm) {
        this.learningAlgorithm = learningAlgorithm;
        learningAlgorithm.init(this.modelNetUse);
    }

    public void learn() throws NormalizeNullVectorException {
        this.learningAlgorithm.run(this.modelNetUse);
    }

    public ProbNet getLearnedNet() {
        return this.learnedNet;
    }

    public LearningAlgorithm getLearningAlgorithm() {
        return this.learningAlgorithm;
    }

    public double getScore() {
        return this.learningAlgorithm.getScore(this.learnedNet, this.caseDatabase.getCases());
    }

    public LearningEditMotivation getMotivation(PNEdit pNEdit) {
        return this.learningAlgorithm.getMotivation(this.learnedNet, this.caseDatabase.getCases(), pNEdit);
    }

    public LearningEditProposal getBestEdition(boolean z, boolean z2) {
        return this.learningAlgorithm.getBestEdition(z, z2);
    }

    public LearningEditProposal getNextEdition(boolean z, boolean z2) {
        return this.learningAlgorithm.getNextEdition(z, z2);
    }

    public void goToNextPhase() throws NormalizeNullVectorException {
        this.learningAlgorithm.goToNextPhase();
    }

    public boolean isLastPhase() {
        return this.learningAlgorithm.isLastPhase();
    }

    public void applyEdit(PNEdit pNEdit) throws ConstraintViolationException, CanNotDoEditException, NonProjectablePotentialException, WrongCriterionException, DoEditException, NormalizeNullVectorException {
        this.learnedNet.doEdit(pNEdit);
        this.learningAlgorithm.parametricLearning();
    }

    private void addElviraProperties(ProbNet probNet) {
        HashMap<String, String> hashMap = probNet.additionalProperties;
        probNet.setDefaultStates(new State[]{new State("present"), new State("absent")});
        hashMap.put("hasElviraProperties", new String("yes"));
        probNet.additionalProperties = hashMap;
    }

    private ProbNet applyModelNet(Class<? extends LearningAlgorithm> cls, CaseDatabase caseDatabase, ProbNet probNet, ModelNetUse modelNetUse) throws LatentVariablesException {
        ProbNet probNet2 = null;
        if (!((LearningAlgorithmType) cls.getAnnotation(LearningAlgorithmType.class)).supportsLatentVariables() && !caseDatabase.getVariables().containsAll(probNet.getVariables())) {
            ArrayList arrayList = new ArrayList(probNet.getVariables());
            arrayList.removeAll(caseDatabase.getVariables());
            throw new LatentVariablesException(arrayList);
        }
        if (modelNetUse.isUseNodePositions()) {
            probNet2 = new ProbNet();
            Iterator<Variable> it = caseDatabase.getVariables().iterator();
            while (it.hasNext()) {
                probNet2.addProbNode(it.next(), NodeType.CHANCE);
            }
            copyNodePositionsFromModelNet(probNet, probNet2);
        }
        if (modelNetUse.isStartFromModelNet()) {
            probNet2 = probNet.copy();
            for (Variable variable : caseDatabase.getVariables()) {
                if (!probNet2.containsVariable(variable.getName())) {
                    probNet2.addProbNode(variable, NodeType.CHANCE);
                }
            }
            try {
                probNet2.addConstraint(new ModelNetworkConstraint(modelNetUse, probNet), false);
            } catch (ConstraintViolationException e) {
            }
        }
        return probNet2;
    }

    public static Set<String> getAlgorithmNames() {
        return learningAlgorithmManager.getLearningAlgorithmNames();
    }

    public LearningAlgorithm getAlgorithmInstance(String str) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.learnedNet);
        arrayList.add(this.caseDatabase);
        return learningAlgorithmManager.getByName(str, arrayList);
    }

    public void blockEdit(PNEdit pNEdit) {
        this.learningAlgorithm.blockEdit(pNEdit);
    }

    public void unblockEdit(PNEdit pNEdit) {
        this.learningAlgorithm.unblockEdit(pNEdit);
    }

    public List<PNEdit> getBlockedEdits() {
        return this.learningAlgorithm.getBlockedEdits();
    }

    private void copyNodePositionsFromModelNet(ProbNet probNet, ProbNet probNet2) {
        if (probNet != null) {
            for (ProbNode probNode : probNet.getProbNodes()) {
                try {
                    ProbNode probNode2 = probNet2.getProbNode(probNode.getVariable().getName());
                    if (probNode2 != null) {
                        probNode2.getNode().setCoordinateX(probNode.getNode().getCoordinateX());
                        probNode2.getNode().setCoordinateY(probNode.getNode().getCoordinateY());
                    }
                } catch (ProbNodeNotFoundException e) {
                }
            }
        }
    }
}
