package com.ibm.wala.cast.python.ml.client;

import com.ibm.wala.cast.lsp.AnalysisError;
import com.ibm.wala.cast.python.client.PythonAnalysisEngine;
import com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis;
import com.ibm.wala.cast.python.ml.types.TensorType;
import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.cast.types.AstMethodReference;
import com.ibm.wala.classLoader.CallSiteReference;
import com.ibm.wala.ipa.callgraph.AnalysisOptions;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.propagation.LocalPointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
import com.ibm.wala.ipa.cha.IClassHierarchy;
import com.ibm.wala.ssa.SSAAbstractInvokeInstruction;
import com.ibm.wala.ssa.SSAInstruction;
import com.ibm.wala.ssa.SSAInvokeInstruction;
import com.ibm.wala.types.MethodReference;
import com.ibm.wala.types.TypeName;
import com.ibm.wala.types.TypeReference;
import com.ibm.wala.util.CancelException;
import com.ibm.wala.util.NullProgressMonitor;
import com.ibm.wala.util.collections.HashMapFactory;
import com.ibm.wala.util.collections.HashSetFactory;
import com.ibm.wala.util.graph.Graph;
import com.ibm.wala.util.graph.impl.SlowSparseNumberedGraph;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.class */
public class PythonTensorAnalysisEngine extends PythonAnalysisEngine<TensorTypeAnalysis> {
    private static final MethodReference conv2d;
    private static final MethodReference conv3d;
    private static final MethodReference reshape;
    private static final MethodReference placeholder;
    private static final MethodReference set_shape;
    private final Map<PointerKey, AnalysisError> errorLog = HashMapFactory.make();
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    @FunctionalInterface
    /* loaded from: input_file:com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine$SourceCallHandler.class */
    public interface SourceCallHandler {
        void handleCall(CGNode cGNode, SSAAbstractInvokeInstruction sSAAbstractInvokeInstruction);
    }

    private static Set<PointsToSetVariable> getDataflowSources(Graph<PointsToSetVariable> graph) {
        HashSet make = HashSetFactory.make();
        for (PointsToSetVariable pointsToSetVariable : graph) {
            PointerKey pointerKey = pointsToSetVariable.getPointerKey();
            if (pointerKey instanceof LocalPointerKey) {
                LocalPointerKey localPointerKey = (LocalPointerKey) pointerKey;
                int valueNumber = localPointerKey.getValueNumber();
                SSAInstruction def = localPointerKey.getNode().getDU().getDef(valueNumber);
                if (def instanceof SSAInvokeInstruction) {
                    SSAInvokeInstruction sSAInvokeInstruction = (SSAInvokeInstruction) def;
                    if (sSAInvokeInstruction.getCallSite().getDeclaredTarget().getName().toString().equals("read_data") && sSAInvokeInstruction.getException() != valueNumber) {
                        make.add(pointsToSetVariable);
                    }
                }
            }
        }
        return make;
    }

    private void getSourceCalls(MethodReference methodReference, PropagationCallGraphBuilder propagationCallGraphBuilder, SourceCallHandler sourceCallHandler) {
        Iterator<CGNode> it = propagationCallGraphBuilder.getCallGraph().iterator();
        while (it.hasNext()) {
            CGNode next = it.next();
            if (next.getMethod().getReference().equals(methodReference)) {
                Iterator<CGNode> predNodes = propagationCallGraphBuilder.getCallGraph().getPredNodes(next);
                while (predNodes.hasNext()) {
                    CGNode next2 = predNodes.next();
                    Iterator<CallSiteReference> possibleSites = propagationCallGraphBuilder.getCallGraph().getPossibleSites(next2, next);
                    while (possibleSites.hasNext()) {
                        for (SSAAbstractInvokeInstruction sSAAbstractInvokeInstruction : next2.getIR().getCalls(possibleSites.next())) {
                            sourceCallHandler.handleCall(next2, sSAAbstractInvokeInstruction);
                        }
                    }
                }
            }
        }
    }

    private Map<PointsToSetVariable, TensorType> getShapeSourceCalls(MethodReference methodReference, PropagationCallGraphBuilder propagationCallGraphBuilder, int i) {
        HashMap make = HashMapFactory.make();
        getSourceCalls(methodReference, propagationCallGraphBuilder, (cGNode, sSAAbstractInvokeInstruction) -> {
            if (sSAAbstractInvokeInstruction.getNumberOfUses() > i) {
                make.put(propagationCallGraphBuilder.getPropagationSystem().findOrCreatePointsToSet(propagationCallGraphBuilder.getPointerAnalysis().getHeapModel().getPointerKeyForLocal(cGNode, sSAAbstractInvokeInstruction.getDef())), TensorType.shapeArg(cGNode, sSAAbstractInvokeInstruction.getUse(i)));
            }
        });
        return make;
    }

    private Set<PointsToSetVariable> getKeysDefinedByCall(MethodReference methodReference, PropagationCallGraphBuilder propagationCallGraphBuilder) {
        HashSet make = HashSetFactory.make();
        getSourceCalls(methodReference, propagationCallGraphBuilder, (cGNode, sSAAbstractInvokeInstruction) -> {
            make.add(propagationCallGraphBuilder.getPropagationSystem().findOrCreatePointsToSet(propagationCallGraphBuilder.getPointerAnalysis().getHeapModel().getPointerKeyForLocal(cGNode, sSAAbstractInvokeInstruction.getDef())));
        });
        return make;
    }

    @Override // com.ibm.wala.cast.python.client.PythonAnalysisEngine, com.ibm.wala.client.AbstractAnalysisEngine
    public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder propagationCallGraphBuilder) throws CancelException {
        SlowSparseNumberedGraph duplicate = SlowSparseNumberedGraph.duplicate(propagationCallGraphBuilder.getPropagationSystem().getFlowGraphIncludingImplicitConstraints());
        Set<PointsToSetVariable> dataflowSources = getDataflowSources(duplicate);
        TensorType mnistInput = TensorType.mnistInput();
        HashMap make = HashMapFactory.make();
        Iterator<PointsToSetVariable> it = dataflowSources.iterator();
        while (it.hasNext()) {
            make.put(it.next(), mnistInput);
        }
        Map<PointsToSetVariable, TensorType> handleShapeSourceOp = handleShapeSourceOp(propagationCallGraphBuilder, duplicate, placeholder, 2);
        System.err.println(handleShapeSourceOp);
        for (Map.Entry<PointsToSetVariable, TensorType> entry : handleShapeSourceOp.entrySet()) {
            make.put(entry.getKey(), entry.getValue());
        }
        HashMap make2 = HashMapFactory.make();
        for (Map.Entry<PointsToSetVariable, TensorType> entry2 : getShapeSourceCalls(set_shape, propagationCallGraphBuilder, 1).entrySet()) {
            CGNode node = ((LocalPointerKey) entry2.getKey().getPointerKey()).getNode();
            make2.put(propagationCallGraphBuilder.getPropagationSystem().findOrCreatePointsToSet(propagationCallGraphBuilder.getPointerAnalysis().getHeapModel().getPointerKeyForLocal(node, node.getDU().getDef(node.getDU().getDef(((LocalPointerKey) entry2.getKey().getPointerKey()).getValueNumber()).getUse(0)).getUse(0))), entry2.getValue());
        }
        HashMap make3 = HashMapFactory.make();
        make3.putAll(handleShapeSourceOp(propagationCallGraphBuilder, duplicate, reshape, 2));
        TensorTypeAnalysis tensorTypeAnalysis = new TensorTypeAnalysis(duplicate, make, make3, make2, getKeysDefinedByCall(conv2d, propagationCallGraphBuilder), getKeysDefinedByCall(conv3d, propagationCallGraphBuilder), this.errorLog);
        tensorTypeAnalysis.solve(new NullProgressMonitor());
        return tensorTypeAnalysis;
    }

    private Map<PointsToSetVariable, TensorType> handleShapeSourceOp(PropagationCallGraphBuilder propagationCallGraphBuilder, Graph<PointsToSetVariable> graph, MethodReference methodReference, int i) {
        Map<PointsToSetVariable, TensorType> shapeSourceCalls = getShapeSourceCalls(methodReference, propagationCallGraphBuilder, i);
        for (PointsToSetVariable pointsToSetVariable : shapeSourceCalls.keySet()) {
            if (!$assertionsDisabled && !(pointsToSetVariable.getPointerKey() instanceof LocalPointerKey)) {
                throw new AssertionError();
            }
            int valueNumber = ((LocalPointerKey) pointsToSetVariable.getPointerKey()).getValueNumber();
            CGNode node = ((LocalPointerKey) pointsToSetVariable.getPointerKey()).getNode();
            graph.addEdge(propagationCallGraphBuilder.getPropagationSystem().findOrCreatePointsToSet(propagationCallGraphBuilder.getPointerAnalysis().getHeapModel().getPointerKeyForLocal(node, node.getDU().getDef(valueNumber).getUse(1))), pointsToSetVariable);
        }
        return shapeSourceCalls;
    }

    public Map<PointerKey, AnalysisError> getErrors() {
        return this.errorLog;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.ibm.wala.cast.python.client.PythonAnalysisEngine
    public void addBypassLogic(IClassHierarchy iClassHierarchy, AnalysisOptions analysisOptions) {
        super.addBypassLogic(iClassHierarchy, analysisOptions);
        addSummaryBypassLogic(analysisOptions, "tensorflow.xml");
    }

    static {
        $assertionsDisabled = !PythonTensorAnalysisEngine.class.desiredAssertionStatus();
        conv2d = MethodReference.findOrCreate(TypeReference.findOrCreate(PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/conv2d")), AstMethodReference.fnSelector);
        conv3d = MethodReference.findOrCreate(TypeReference.findOrCreate(PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/conv3d")), AstMethodReference.fnSelector);
        reshape = MethodReference.findOrCreate(TypeReference.findOrCreate(PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/reshape")), AstMethodReference.fnSelector);
        placeholder = MethodReference.findOrCreate(TypeReference.findOrCreate(PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/placeholder")), AstMethodReference.fnSelector);
        set_shape = MethodReference.findOrCreate(TypeReference.findOrCreate(PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/set_shape")), AstMethodReference.fnSelector);
    }
}
