package peggy.pb;

import eqsat.FlowValue;
import eqsat.meminfer.engine.peg.CPEGTerm;
import eqsat.meminfer.engine.peg.CPEGValue;
import eqsat.meminfer.peggy.engine.CPeggyAxiomEngine;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import llvm.bitcode.HashList;
import peggy.analysis.StackMap;
import peggy.represent.PEGInfo;
import peggy.represent.StickyPredicate;
import peggy.revert.AbstractReversionHeuristic;
import util.graph.CRecursiveExpressionGraph;

/* loaded from: input_file:peggy/pb/AverageReversionHeuristic.class */
public abstract class AverageReversionHeuristic<L, P, R> extends AbstractReversionHeuristic<L, P, R, Integer> {
    private static boolean DEBUG = false;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:peggy/pb/AverageReversionHeuristic$ChildContinuation.class */
    public class ChildContinuation extends AverageReversionHeuristic<L, P, R>.Continuation {
        protected final AverageReversionHeuristic<L, P, R>.Continuation after;
        protected final CPEGTerm<L, P> parent;
        protected final int childIndex;

        public ChildContinuation(AverageReversionHeuristic<L, P, R>.Continuation continuation, CPEGTerm<L, P> cPEGTerm, int i) {
            super(AverageReversionHeuristic.this, null);
            this.after = continuation;
            this.childIndex = i;
            this.parent = cPEGTerm;
        }

        @Override // peggy.pb.AverageReversionHeuristic.Continuation
        public boolean continuate(AverageReversionHeuristic<L, P, R>.Info info) {
            int size = info.path.size();
            info.path.addLast(new TermIndex(this.parent, this.childIndex));
            if (!AverageReversionHeuristic.this.chooseTerm2((CPEGValue) this.parent.getChild(this.childIndex).getValue(), info, this.after)) {
                return false;
            }
            while (info.path.size() > size) {
                info.path.removeLast();
            }
            return true;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:peggy/pb/AverageReversionHeuristic$Continuation.class */
    public abstract class Continuation {
        private Continuation() {
        }

        public abstract boolean continuate(AverageReversionHeuristic<L, P, R>.Info info);

        /* synthetic */ Continuation(AverageReversionHeuristic averageReversionHeuristic, Continuation continuation) {
            this();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:peggy/pb/AverageReversionHeuristic$Info.class */
    public class Info {
        CostModel<CPEGTerm<L, P>, Integer> costmodel;
        HashList<CPEGTerm<L, P>> term2index;
        Map<CPEGTerm<L, P>, BitSet> term2termset;
        BitSet currentChosen;
        StackMap<CPEGValue<L, P>, CPEGTerm<L, P>> result;
        LinkedList<AverageReversionHeuristic<L, P, R>.TermIndex> path;

        Info() {
        }
    }

    /* loaded from: input_file:peggy/pb/AverageReversionHeuristic$RootChildContinuation.class */
    private class RootChildContinuation extends AverageReversionHeuristic<L, P, R>.Continuation {
        protected final CPEGValue<L, P> rootValue;
        protected final AverageReversionHeuristic<L, P, R>.Continuation after;

        public RootChildContinuation(CPEGValue<L, P> cPEGValue, AverageReversionHeuristic<L, P, R>.Continuation continuation) {
            super(AverageReversionHeuristic.this, null);
            this.after = continuation;
            this.rootValue = cPEGValue;
        }

        @Override // peggy.pb.AverageReversionHeuristic.Continuation
        public boolean continuate(AverageReversionHeuristic<L, P, R>.Info info) {
            return AverageReversionHeuristic.this.chooseTerm2(this.rootValue, info, this.after);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:peggy/pb/AverageReversionHeuristic$TermChoice.class */
    public class TermChoice implements Comparable<AverageReversionHeuristic<L, P, R>.TermChoice> {
        CPEGTerm<L, P> term;
        int cost;
        BitSet bits;

        TermChoice() {
        }

        @Override // java.lang.Comparable
        public int compareTo(AverageReversionHeuristic<L, P, R>.TermChoice termChoice) {
            return this.cost - termChoice.cost;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:peggy/pb/AverageReversionHeuristic$TermIndex.class */
    public class TermIndex {
        CPEGTerm<L, P> term;
        int index;

        TermIndex(CPEGTerm<L, P> cPEGTerm, int i) {
            this.term = cPEGTerm;
            this.index = i;
        }
    }

    private static void debug(String str) {
        if (DEBUG) {
            System.err.println("AverageReversionHeuristic: " + str);
        }
    }

    protected abstract boolean isRevertible(FlowValue<P, L> flowValue);

    protected abstract StickyPredicate<FlowValue<P, L>> getStickyPredicate();

    /* JADX WARN: Multi-variable type inference failed */
    @Override // peggy.revert.ReversionHeuristic
    public Map<? extends CPEGValue<L, P>, ? extends CPEGTerm<L, P>> chooseReversionNodes(CPeggyAxiomEngine<L, P> cPeggyAxiomEngine, PEGInfo<L, P, R> pEGInfo, Map<? extends CRecursiveExpressionGraph.Vertex<FlowValue<P, L>>, ? extends CPEGTerm<L, P>> map) {
        HashList<CPEGTerm<L, P>> hashList = new HashList<>();
        Iterator it = cPeggyAxiomEngine.getEGraph().getValueManager().getValues().iterator();
        while (it.hasNext()) {
            Iterator it2 = ((CPEGValue) it.next()).getTerms().iterator();
            while (it2.hasNext()) {
                hashList.add((CPEGTerm) it2.next());
            }
        }
        HashMap hashMap = new HashMap();
        buildTermSets(cPeggyAxiomEngine, hashList, hashMap);
        HashMap hashMap2 = new HashMap();
        Iterator it3 = cPeggyAxiomEngine.getEGraph().getValueManager().getValues().iterator();
        while (it3.hasNext()) {
            for (CPEGTerm<L, P> cPEGTerm : ((CPEGValue) it3.next()).getTerms()) {
                BitSet bitSet = new BitSet();
                for (int i = 0; i < cPEGTerm.getArity(); i++) {
                    bitSet.or(hashMap.get(cPEGTerm.getChild(i).getValue()));
                }
                bitSet.set(hashList.getIndex(cPEGTerm));
                hashMap2.put(cPEGTerm, bitSet);
            }
        }
        AverageReversionHeuristic<L, P, R>.Info info = new Info();
        info.costmodel = (CostModel<CPEGTerm<L, P>, Integer>) getCostModel();
        info.currentChosen = new BitSet();
        info.term2termset = hashMap2;
        info.term2index = hashList;
        info.result = new StackMap<>();
        info.path = new LinkedList<>();
        AverageReversionHeuristic<L, P, R>.Continuation continuation = new AverageReversionHeuristic<L, P, R>.Continuation() { // from class: peggy.pb.AverageReversionHeuristic.1
            @Override // peggy.pb.AverageReversionHeuristic.Continuation
            public boolean continuate(AverageReversionHeuristic<L, P, R>.Info info2) {
                return true;
            }
        };
        Iterator it4 = pEGInfo.getReturns().iterator();
        while (it4.hasNext()) {
            continuation = new RootChildContinuation((CPEGValue) map.get(pEGInfo.getReturnVertex(it4.next())).getValue(), continuation);
        }
        if (!continuation.continuate(info)) {
            throw new RuntimeException("Cannot find root values");
        }
        HashMap hashMap3 = new HashMap();
        for (CPEGValue<L, P> cPEGValue : info.result.keySet()) {
            hashMap3.put(cPEGValue, info.result.get(cPEGValue));
        }
        return hashMap3;
    }

    private Boolean checkBaseCases(CPEGValue<L, P> cPEGValue, AverageReversionHeuristic<L, P, R>.Info info) {
        if (!info.result.containsKey(cPEGValue)) {
            return null;
        }
        CPEGTerm<L, P> cPEGTerm = info.result.get(cPEGValue);
        int size = info.path.size() - 1;
        Iterator<AverageReversionHeuristic<L, P, R>.TermIndex> descendingIterator = info.path.descendingIterator();
        while (descendingIterator.hasNext() && !descendingIterator.next().term.equals(cPEGTerm)) {
            size--;
        }
        if (size < 0) {
            return true;
        }
        AverageReversionHeuristic<L, P, R>.TermIndex termIndex = info.path.get(size);
        return ((FlowValue) termIndex.term.getOp()).isTheta() && termIndex.index == 1;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public boolean chooseTerm2(CPEGValue<L, P> cPEGValue, AverageReversionHeuristic<L, P, R>.Info info, AverageReversionHeuristic<L, P, R>.Continuation continuation) {
        Boolean checkBaseCases = checkBaseCases(cPEGValue, info);
        if (checkBaseCases != null) {
            if (checkBaseCases.booleanValue()) {
                return continuation.continuate(info);
            }
            return false;
        }
        int height = info.result.getHeight();
        int size = info.path.size();
        BitSet bitSet = info.currentChosen;
        Iterator<AverageReversionHeuristic<L, P, R>.TermChoice> termChoices = getTermChoices(info, cPEGValue);
        while (termChoices.hasNext()) {
            AverageReversionHeuristic<L, P, R>.TermChoice next = termChoices.next();
            info.result.push(cPEGValue, next.term);
            info.currentChosen = new BitSet();
            info.currentChosen.or(bitSet);
            info.currentChosen.set(info.term2index.getIndex(next.term));
            AverageReversionHeuristic<L, P, R>.Continuation continuation2 = continuation;
            for (int arity = next.term.getArity() - 1; arity >= 0; arity--) {
                continuation2 = new ChildContinuation(continuation2, next.term, arity);
            }
            if (continuation2.continuate(info)) {
                return true;
            }
            info.result.popToHeight(height);
            while (info.path.size() > size) {
                info.path.removeLast();
            }
        }
        return false;
    }

    private Iterator<AverageReversionHeuristic<L, P, R>.TermChoice> getTermChoices(AverageReversionHeuristic<L, P, R>.Info info, CPEGValue<L, P> cPEGValue) {
        ArrayList arrayList = new ArrayList(cPEGValue.getTerms().size());
        StickyPredicate<FlowValue<P, L>> stickyPredicate = getStickyPredicate();
        if (info.path.size() > 0) {
            stickyPredicate.isSticky((FlowValue) info.path.getLast().term.getOp(), info.path.getLast().index);
        }
        for (CPEGTerm<L, P> cPEGTerm : cPEGValue.getTerms()) {
            if (isRevertible((FlowValue) cPEGTerm.getOp())) {
                BitSet bitSet = new BitSet();
                bitSet.or(info.term2termset.get(cPEGTerm));
                bitSet.or(info.currentChosen);
                int i = 0;
                for (int i2 = 0; i2 < bitSet.length(); i2++) {
                    if (bitSet.get(i2)) {
                        i += info.costmodel.cost(info.term2index.getValue(i2)).intValue();
                    }
                }
                TermChoice termChoice = new TermChoice();
                termChoice.term = cPEGTerm;
                termChoice.cost = i;
                termChoice.bits = bitSet;
                arrayList.add(termChoice);
            }
        }
        Collections.sort(arrayList);
        return arrayList.iterator();
    }

    private void buildTermSets(CPeggyAxiomEngine<L, P> cPeggyAxiomEngine, HashList<CPEGTerm<L, P>> hashList, Map<CPEGValue<L, P>, BitSet> map) {
        HashMap hashMap = new HashMap();
        for (CPEGValue cPEGValue : cPeggyAxiomEngine.getEGraph().getValueManager().getValues()) {
            BitSet bitSet = new BitSet();
            Iterator it = cPEGValue.getTerms().iterator();
            while (it.hasNext()) {
                bitSet.set(hashList.getIndex((CPEGTerm) it.next()));
            }
            hashMap.put(cPEGValue, bitSet);
        }
        for (CPEGValue<L, P> cPEGValue2 : cPeggyAxiomEngine.getEGraph().getValueManager().getValues()) {
            HashSet hashSet = new HashSet();
            LinkedList linkedList = new LinkedList();
            linkedList.addLast(cPEGValue2);
            while (linkedList.size() > 0) {
                CPEGValue cPEGValue3 = (CPEGValue) linkedList.removeFirst();
                if (!hashSet.contains(cPEGValue3)) {
                    hashSet.add(cPEGValue3);
                    for (CPEGTerm cPEGTerm : cPEGValue3.getTerms()) {
                        for (int i = 0; i < cPEGTerm.getArity(); i++) {
                            linkedList.addLast((CPEGValue) cPEGTerm.getChild(i).getValue());
                        }
                    }
                }
            }
            BitSet bitSet2 = (BitSet) hashMap.get(cPEGValue2);
            Iterator it2 = hashSet.iterator();
            while (it2.hasNext()) {
                bitSet2.or((BitSet) hashMap.get((CPEGValue) it2.next()));
            }
            map.put(cPEGValue2, bitSet2);
        }
    }
}
