package com.ibm.wala.cast.python.ml.analysis;

import com.ibm.wala.cast.loader.AstMethod;
import com.ibm.wala.cast.lsp.AnalysisError;
import com.ibm.wala.cast.python.ml.types.TensorType;
import com.ibm.wala.cast.tree.CAstSourcePositionMap;
import com.ibm.wala.cast.util.SourceBuffer;
import com.ibm.wala.dataflow.graph.AbstractMeetOperator;
import com.ibm.wala.dataflow.graph.DataflowSolver;
import com.ibm.wala.dataflow.graph.IKilldallFramework;
import com.ibm.wala.dataflow.graph.ITransferFunctionProvider;
import com.ibm.wala.fixpoint.UnaryOperator;
import com.ibm.wala.ipa.callgraph.propagation.LocalPointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
import com.ibm.wala.ssa.DefUse;
import com.ibm.wala.ssa.SSAInstruction;
import com.ibm.wala.util.collections.Pair;
import com.ibm.wala.util.graph.Graph;
import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.eclipse.lsp4j.DiagnosticSeverity;

/* loaded from: input_file:com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis.class */
public class TensorTypeAnalysis extends DataflowSolver<PointsToSetVariable, TensorVariable> {
    private final Map<PointsToSetVariable, TensorType> init;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis$1, reason: invalid class name */
    /* loaded from: input_file:com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis$1.class */
    public static class AnonymousClass1 implements IKilldallFramework<PointsToSetVariable, TensorVariable> {
        final /* synthetic */ Graph val$G;
        final /* synthetic */ Map val$errorLog;
        final /* synthetic */ Map val$reshapeNodes;
        final /* synthetic */ Set val$conv2ds;
        final /* synthetic */ Set val$conv3ds;
        final /* synthetic */ Map val$set_shapes;

        AnonymousClass1(Graph graph, Map map, Map map2, Set set, Set set2, Map map3) {
            this.val$G = graph;
            this.val$errorLog = map;
            this.val$reshapeNodes = map2;
            this.val$conv2ds = set;
            this.val$conv3ds = set2;
            this.val$set_shapes = map3;
        }

        @Override // com.ibm.wala.dataflow.graph.IKilldallFramework
        public Graph<PointsToSetVariable> getFlowGraph() {
            return this.val$G;
        }

        @Override // com.ibm.wala.dataflow.graph.IKilldallFramework
        public ITransferFunctionProvider<PointsToSetVariable, TensorVariable> getTransferFunctionProvider() {
            return new ITransferFunctionProvider<PointsToSetVariable, TensorVariable>() { // from class: com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis.1.1
                private final UnaryOperator<TensorVariable> nodeOp = new UnaryOperator<TensorVariable>() { // from class: com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis.1.1.1
                    @Override // com.ibm.wala.fixpoint.UnaryOperator
                    public byte evaluate(TensorVariable tensorVariable, TensorVariable tensorVariable2) {
                        if (tensorVariable2 == null || tensorVariable2.state == null) {
                            return (byte) 0;
                        }
                        if (tensorVariable != null && tensorVariable.state != null) {
                            return tensorVariable.state.addAll(tensorVariable2.state) ? (byte) 1 : (byte) 0;
                        }
                        tensorVariable.copyState(tensorVariable2);
                        return (byte) 1;
                    }

                    @Override // com.ibm.wala.fixpoint.AbstractOperator
                    public int hashCode() {
                        return 817504253;
                    }

                    @Override // com.ibm.wala.fixpoint.AbstractOperator
                    public boolean equals(Object obj) {
                        return obj == this;
                    }

                    @Override // com.ibm.wala.fixpoint.AbstractOperator
                    public String toString() {
                        return "propagate node tensor types";
                    }
                };

                /* JADX INFO: Access modifiers changed from: package-private */
                /* renamed from: com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis$1$1$ConvOp */
                /* loaded from: input_file:com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis$1$1$ConvOp.class */
                public final class ConvOp extends UnaryOperator<TensorVariable> {
                    private final PointsToSetVariable v;
                    private final int dimensions;

                    public ConvOp(int i, PointsToSetVariable pointsToSetVariable) {
                        this.v = pointsToSetVariable;
                        this.dimensions = i;
                    }

                    @Override // com.ibm.wala.fixpoint.UnaryOperator
                    public byte evaluate(TensorVariable tensorVariable, TensorVariable tensorVariable2) {
                        boolean z = false;
                        if (tensorVariable2 != null && tensorVariable2.state != null) {
                            for (TensorType tensorType : tensorVariable2.state) {
                                int i = 0;
                                Iterator<TensorType.Dimension<?>> it = tensorType.iterator();
                                while (it.hasNext()) {
                                    it.next();
                                    i++;
                                }
                                if (i == this.dimensions + 2) {
                                    z |= tensorVariable.state.add(tensorType);
                                } else {
                                    AnonymousClass1.this.val$errorLog.put(this.v.getPointerKey(), new ConvError(tensorType, this.dimensions, getTargetPos(this.v.getPointerKey()), getTargetDef(this.v.getPointerKey())));
                                }
                            }
                        }
                        return z ? (byte) 3 : (byte) 0;
                    }

                    @Override // com.ibm.wala.fixpoint.AbstractOperator
                    public int hashCode() {
                        return this.v.hashCode();
                    }

                    @Override // com.ibm.wala.fixpoint.AbstractOperator
                    public boolean equals(Object obj) {
                        return (obj instanceof ConvOp) && ((ConvOp) obj).v.equals(this.v);
                    }

                    @Override // com.ibm.wala.fixpoint.AbstractOperator
                    public String toString() {
                        return "conv at " + this.v;
                    }
                }

                /* JADX INFO: Access modifiers changed from: package-private */
                /* renamed from: com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis$1$1$ReshapeOp */
                /* loaded from: input_file:com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis$1$1$ReshapeOp.class */
                public final class ReshapeOp extends UnaryOperator<TensorVariable> {
                    private final TensorType reshapeTo;
                    private final PointsToSetVariable v;
                    static final /* synthetic */ boolean $assertionsDisabled;

                    public ReshapeOp(TensorType tensorType, PointsToSetVariable pointsToSetVariable) {
                        this.v = pointsToSetVariable;
                        this.reshapeTo = tensorType;
                    }

                    @Override // com.ibm.wala.fixpoint.UnaryOperator
                    public byte evaluate(TensorVariable tensorVariable, TensorVariable tensorVariable2) {
                        boolean z = false;
                        int symbolicDims = this.reshapeTo.symbolicDims();
                        int concreteSize = this.reshapeTo.concreteSize();
                        if (tensorVariable2 != null && tensorVariable2.state != null) {
                            for (TensorType tensorType : tensorVariable2.state) {
                                if (tensorType.symbolicDims() == symbolicDims && tensorType.concreteSize() == concreteSize) {
                                    z |= tensorVariable.state.add(this.reshapeTo);
                                } else {
                                    CAstSourcePositionMap.Position targetPos = getTargetPos(this.v.getPointerKey());
                                    if (!$assertionsDisabled && targetPos == null) {
                                        throw new AssertionError();
                                    }
                                    AnonymousClass1.this.val$errorLog.put(this.v.getPointerKey(), new ReshapeError(tensorType, this.reshapeTo, targetPos, getTargetDef(this.v.getPointerKey())));
                                }
                            }
                        }
                        return z ? (byte) 3 : (byte) 0;
                    }

                    @Override // com.ibm.wala.fixpoint.AbstractOperator
                    public int hashCode() {
                        return this.reshapeTo.hashCode();
                    }

                    @Override // com.ibm.wala.fixpoint.AbstractOperator
                    public boolean equals(Object obj) {
                        return this == obj || ((obj instanceof ReshapeOp) && this.reshapeTo.equals(((ReshapeOp) obj).reshapeTo));
                    }

                    @Override // com.ibm.wala.fixpoint.AbstractOperator
                    public String toString() {
                        return "reshape to " + this.reshapeTo;
                    }

                    static {
                        $assertionsDisabled = !TensorTypeAnalysis.class.desiredAssertionStatus();
                    }
                }

                /* JADX INFO: Access modifiers changed from: package-private */
                /* renamed from: com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis$1$1$SetShapeOp */
                /* loaded from: input_file:com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis$1$1$SetShapeOp.class */
                public final class SetShapeOp extends UnaryOperator<TensorVariable> {
                    private final TensorType setShapeTo;

                    public SetShapeOp(TensorType tensorType) {
                        this.setShapeTo = tensorType;
                    }

                    @Override // com.ibm.wala.fixpoint.UnaryOperator
                    public byte evaluate(TensorVariable tensorVariable, TensorVariable tensorVariable2) {
                        return tensorVariable.state.add(this.setShapeTo) ? (byte) 3 : (byte) 0;
                    }

                    @Override // com.ibm.wala.fixpoint.AbstractOperator
                    public int hashCode() {
                        return this.setShapeTo.hashCode();
                    }

                    @Override // com.ibm.wala.fixpoint.AbstractOperator
                    public boolean equals(Object obj) {
                        return this == obj || ((obj instanceof ReshapeOp) && this.setShapeTo.equals(((ReshapeOp) obj).reshapeTo));
                    }

                    @Override // com.ibm.wala.fixpoint.AbstractOperator
                    public String toString() {
                        return "set shape to " + this.setShapeTo;
                    }
                }

                /* JADX INFO: Access modifiers changed from: private */
                public CAstSourcePositionMap.Position getTargetPos(PointerKey pointerKey) {
                    if (!(pointerKey instanceof LocalPointerKey)) {
                        return null;
                    }
                    LocalPointerKey localPointerKey = (LocalPointerKey) pointerKey;
                    SSAInstruction def = localPointerKey.getNode().getDU().getDef(localPointerKey.getValueNumber());
                    if (localPointerKey.getNode().getMethod() instanceof AstMethod) {
                        return ((AstMethod) localPointerKey.getNode().getMethod()).debugInfo().getOperandPosition(def.iindex, 1);
                    }
                    return null;
                }

                /* JADX INFO: Access modifiers changed from: private */
                public CAstSourcePositionMap.Position getTargetDef(PointerKey pointerKey) {
                    SSAInstruction def;
                    if (!(pointerKey instanceof LocalPointerKey)) {
                        return null;
                    }
                    LocalPointerKey localPointerKey = (LocalPointerKey) pointerKey;
                    DefUse du = localPointerKey.getNode().getDU();
                    SSAInstruction def2 = du.getDef(localPointerKey.getValueNumber());
                    if (!(localPointerKey.getNode().getMethod() instanceof AstMethod) || (def = du.getDef(def2.getUse(1))) == null) {
                        return null;
                    }
                    return ((AstMethod) localPointerKey.getNode().getMethod()).debugInfo().getInstructionPosition(def.iindex);
                }

                @Override // com.ibm.wala.dataflow.graph.ITransferFunctionProvider
                public UnaryOperator<TensorVariable> getNodeTransferFunction(PointsToSetVariable pointsToSetVariable) {
                    return AnonymousClass1.this.val$reshapeNodes.containsKey(pointsToSetVariable) ? new ReshapeOp((TensorType) AnonymousClass1.this.val$reshapeNodes.get(pointsToSetVariable), pointsToSetVariable) : AnonymousClass1.this.val$conv2ds.contains(pointsToSetVariable) ? new ConvOp(2, pointsToSetVariable) : AnonymousClass1.this.val$conv3ds.contains(pointsToSetVariable) ? new ConvOp(3, pointsToSetVariable) : this.nodeOp;
                }

                @Override // com.ibm.wala.dataflow.graph.ITransferFunctionProvider
                public boolean hasNodeTransferFunctions() {
                    return true;
                }

                @Override // com.ibm.wala.dataflow.graph.ITransferFunctionProvider
                public UnaryOperator<TensorVariable> getEdgeTransferFunction(PointsToSetVariable pointsToSetVariable, PointsToSetVariable pointsToSetVariable2) {
                    return AnonymousClass1.this.val$set_shapes.containsKey(pointsToSetVariable2) ? new SetShapeOp((TensorType) AnonymousClass1.this.val$set_shapes.get(pointsToSetVariable2)) : this.nodeOp;
                }

                @Override // com.ibm.wala.dataflow.graph.ITransferFunctionProvider
                public boolean hasEdgeTransferFunctions() {
                    return true;
                }

                @Override // com.ibm.wala.dataflow.graph.ITransferFunctionProvider
                public AbstractMeetOperator<TensorVariable> getMeetOperator() {
                    return new AbstractMeetOperator<TensorVariable>() { // from class: com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis.1.1.2
                        @Override // com.ibm.wala.fixpoint.AbstractOperator
                        public byte evaluate(TensorVariable tensorVariable, TensorVariable[] tensorVariableArr) {
                            boolean z = false;
                            for (TensorVariable tensorVariable2 : tensorVariableArr) {
                                z |= tensorVariable.state.addAll(tensorVariable2.state);
                            }
                            return z ? (byte) 1 : (byte) 0;
                        }

                        @Override // com.ibm.wala.fixpoint.AbstractOperator
                        public int hashCode() {
                            return 413158523;
                        }

                        @Override // com.ibm.wala.fixpoint.AbstractOperator
                        public boolean equals(Object obj) {
                            return this == obj;
                        }

                        @Override // com.ibm.wala.fixpoint.AbstractOperator
                        public String toString() {
                            return "Tensor types set union";
                        }
                    };
                }
            };
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis$ConvError.class */
    public static class ConvError implements AnalysisError {
        CAstSourcePositionMap.Position definer;
        TensorType from;
        int dims;
        CAstSourcePositionMap.Position pos;

        ConvError(TensorType tensorType, int i, CAstSourcePositionMap.Position position, CAstSourcePositionMap.Position position2) {
            this.definer = position2;
            this.from = tensorType;
            this.dims = i;
            this.pos = position;
        }

        @Override // com.ibm.wala.cast.lsp.AnalysisError
        public Iterable<Pair<CAstSourcePositionMap.Position, String>> related() {
            return Collections.singleton(Pair.make(this.definer, "definition"));
        }

        public String toString() {
            return toString(false);
        }

        @Override // com.ibm.wala.cast.lsp.AnalysisError
        public CAstSourcePositionMap.Position position() {
            return this.pos;
        }

        private String checkReshape() {
            boolean z = true;
            int i = 0;
            String str = "";
            Iterator<TensorType.Dimension<?>> it = this.from.iterator();
            while (it.hasNext()) {
                TensorType.Dimension<?> next = it.next();
                if (next instanceof TensorType.CompoundDim) {
                    Iterator<TensorType.Dimension<?>> it2 = ((TensorType.CompoundDim) next).value().iterator();
                    while (it2.hasNext()) {
                        str = str + (!z ? ", " : "") + it2.next().value();
                        z = false;
                    }
                    i += ((TensorType.CompoundDim) next).value().size();
                } else {
                    str = str + (!z ? ", " : "") + (next instanceof TensorType.SymbolicDim ? "-1" : next.value());
                    z = false;
                    i++;
                }
            }
            if (i == this.dims + 1) {
                i++;
                str = str + ", 1";
            }
            if (i != this.dims + 2) {
                return null;
            }
            try {
                return "tf.reshape(" + new SourceBuffer(this.pos).toString() + ", [" + str + "])";
            } catch (IOException e) {
                e.printStackTrace();
                return null;
            }
        }

        @Override // com.ibm.wala.cast.lsp.AnalysisError
        public String toString(boolean z) {
            String str = "Bad type to convolve " + this.from.toCString(z) + ", needs " + (this.dims + 2) + " dimensions";
            String checkReshape = checkReshape();
            if (checkReshape != null) {
                str = str + " (possible fix: " + checkReshape + ")";
            }
            return str;
        }

        @Override // com.ibm.wala.cast.lsp.AnalysisError
        public DiagnosticSeverity severity() {
            return DiagnosticSeverity.Error;
        }

        @Override // com.ibm.wala.cast.lsp.AnalysisError
        public String source() {
            return "Ariadne";
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis$ReshapeError.class */
    public static class ReshapeError implements AnalysisError {
        TensorType from;
        TensorType to;
        CAstSourcePositionMap.Position pos;
        CAstSourcePositionMap.Position definer;

        ReshapeError(TensorType tensorType, TensorType tensorType2, CAstSourcePositionMap.Position position, CAstSourcePositionMap.Position position2) {
            this.definer = position2;
            this.from = tensorType;
            this.to = tensorType2;
            this.pos = position;
        }

        @Override // com.ibm.wala.cast.lsp.AnalysisError
        public Iterable<Pair<CAstSourcePositionMap.Position, String>> related() {
            return Collections.singleton(Pair.make(this.definer, "definition"));
        }

        @Override // com.ibm.wala.cast.lsp.AnalysisError
        public CAstSourcePositionMap.Position position() {
            return this.pos;
        }

        public String toString() {
            return toString(false);
        }

        @Override // com.ibm.wala.cast.lsp.AnalysisError
        public String toString(boolean z) {
            return "Cannot reshape " + this.from.toCString(z) + " to " + this.to.toCString(z);
        }

        @Override // com.ibm.wala.cast.lsp.AnalysisError
        public DiagnosticSeverity severity() {
            return DiagnosticSeverity.Warning;
        }

        @Override // com.ibm.wala.cast.lsp.AnalysisError
        public String source() {
            return "Ariadne";
        }
    }

    private static IKilldallFramework<PointsToSetVariable, TensorVariable> createProblem(Graph<PointsToSetVariable> graph, Map<PointsToSetVariable, TensorType> map, Map<PointsToSetVariable, TensorType> map2, Set<PointsToSetVariable> set, Set<PointsToSetVariable> set2, Map<PointerKey, AnalysisError> map3) {
        return new AnonymousClass1(graph, map3, map, set, set2, map2);
    }

    public TensorTypeAnalysis(Graph<PointsToSetVariable> graph, Map<PointsToSetVariable, TensorType> map, Map<PointsToSetVariable, TensorType> map2, Map<PointsToSetVariable, TensorType> map3, Set<PointsToSetVariable> set, Set<PointsToSetVariable> set2, Map<PointerKey, AnalysisError> map4) {
        super(createProblem(graph, map2, map3, set, set2, map4));
        this.init = map;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.ibm.wala.dataflow.graph.DataflowSolver
    public TensorVariable makeNodeVariable(PointsToSetVariable pointsToSetVariable, boolean z) {
        return new TensorVariable();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.ibm.wala.dataflow.graph.DataflowSolver
    public TensorVariable makeEdgeVariable(PointsToSetVariable pointsToSetVariable, PointsToSetVariable pointsToSetVariable2) {
        return new TensorVariable();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.ibm.wala.fixedpoint.impl.AbstractFixedPointSolver
    public TensorVariable[] makeStmtRHS(int i) {
        return new TensorVariable[i];
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.ibm.wala.dataflow.graph.DataflowSolver, com.ibm.wala.fixedpoint.impl.AbstractFixedPointSolver
    public void initializeVariables() {
        super.initializeVariables();
        for (PointsToSetVariable pointsToSetVariable : this.init.keySet()) {
            getOut(pointsToSetVariable).state.add(this.init.get(pointsToSetVariable));
        }
    }

    @Override // com.ibm.wala.fixedpoint.impl.AbstractFixedPointSolver
    public String toString() {
        StringBuffer stringBuffer = new StringBuffer("answer:\n");
        for (PointsToSetVariable pointsToSetVariable : getProblem().getFlowGraph()) {
            if (getOut(pointsToSetVariable) != null && ((TensorVariable) getOut(pointsToSetVariable)).state != null && !((TensorVariable) getOut(pointsToSetVariable)).state.isEmpty()) {
                stringBuffer.append(pointsToSetVariable.getPointerKey()).append(getOut(pointsToSetVariable)).append("\n");
            }
        }
        return stringBuffer.toString();
    }
}
