package com.ibm.wala.cast.python.ml.types;

import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.ibm.wala.cast.loader.AstMethod;
import com.ibm.wala.cast.python.ssa.PythonPropertyWrite;
import com.ibm.wala.cast.python.util.PythonUtil;
import com.ibm.wala.cast.tree.CAstSourcePositionMap;
import com.ibm.wala.cast.util.SourceBuffer;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ssa.DefUse;
import com.ibm.wala.ssa.SSAInstruction;
import com.ibm.wala.ssa.SSAPutInstruction;
import com.ibm.wala.ssa.SymbolTable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.python.core.PyObject;
import org.python.icu.impl.locale.BaseLocale;

/* loaded from: input_file:com/ibm/wala/cast/python/ml/types/TensorType.class */
public class TensorType implements Iterable<Dimension<?>> {
    private final String cellType;
    private final List<Dimension<?>> dims;

    /* loaded from: input_file:com/ibm/wala/cast/python/ml/types/TensorType$CompoundDim.class */
    public static class CompoundDim extends Dimension<List<Dimension<?>>> {
        CompoundDim(List<Dimension<?>> list) {
            super(list);
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        DimensionType type() {
            return DimensionType.Compound;
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        int concreteSize() {
            int i = -1;
            Iterator<Dimension<?>> it = value().iterator();
            while (it.hasNext()) {
                int concreteSize = it.next().concreteSize();
                if (concreteSize >= 0) {
                    i = i >= 0 ? i * concreteSize : concreteSize;
                }
            }
            return i;
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        int symbolicDims() {
            int i = 0;
            Iterator<Dimension<?>> it = value().iterator();
            while (it.hasNext()) {
                i += it.next().symbolicDims();
            }
            return i;
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        String toMDString() {
            return (String) value().stream().map((v0) -> {
                return v0.toMDString();
            }).collect(Collectors.joining(" \\* "));
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        String toCString(boolean z) {
            return (String) value().stream().map(dimension -> {
                return dimension.toCString(z);
            }).collect(Collectors.joining(z ? " \\* " : " * "));
        }
    }

    /* loaded from: input_file:com/ibm/wala/cast/python/ml/types/TensorType$Dimension.class */
    public static abstract class Dimension<T> {
        private final T v;

        protected Dimension(T t) {
            this.v = t;
        }

        abstract DimensionType type();

        abstract int symbolicDims();

        abstract int concreteSize();

        abstract String toMDString();

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract String toCString(boolean z);

        JsonElement toJsonSchema(JsonElement jsonElement) {
            JsonObject jsonObject = new JsonObject();
            jsonObject.addProperty("type", "array");
            if (jsonElement != null) {
                jsonObject.add("items", jsonElement);
            }
            int concreteSize = concreteSize();
            if (concreteSize >= 0) {
                jsonObject.addProperty("minItems", Integer.valueOf(concreteSize));
                jsonObject.addProperty("maxItems", Integer.valueOf(concreteSize));
                jsonObject.addProperty("description", "Array of dimension " + toCString(false));
            }
            return jsonObject;
        }

        public T value() {
            return this.v;
        }

        public String toString() {
            return "D:" + type() + "," + value();
        }

        public int hashCode() {
            return (31 * 1) + (this.v == null ? 0 : this.v.hashCode());
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            Dimension dimension = (Dimension) obj;
            return this.v == null ? dimension.v == null : this.v.equals(dimension.v);
        }
    }

    /* loaded from: input_file:com/ibm/wala/cast/python/ml/types/TensorType$DimensionType.class */
    enum DimensionType {
        Constant,
        Symbolic,
        Compound
    }

    /* loaded from: input_file:com/ibm/wala/cast/python/ml/types/TensorType$Format.class */
    public enum Format {
        CString,
        MCString,
        MDString,
        JsonSchema
    }

    /* loaded from: input_file:com/ibm/wala/cast/python/ml/types/TensorType$NumericDim.class */
    static class NumericDim extends Dimension<Integer> {
        NumericDim(Integer num) {
            super(num);
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        DimensionType type() {
            return DimensionType.Constant;
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        int concreteSize() {
            return value().intValue();
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        int symbolicDims() {
            return 0;
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        String toMDString() {
            return value().toString();
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        String toCString(boolean z) {
            return value().toString();
        }
    }

    /* loaded from: input_file:com/ibm/wala/cast/python/ml/types/TensorType$SymbolicDim.class */
    public static class SymbolicDim extends Dimension<String> {
        SymbolicDim(String str) {
            super(str);
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        DimensionType type() {
            return DimensionType.Symbolic;
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        int concreteSize() {
            return -1;
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        int symbolicDims() {
            return 1;
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        String toMDString() {
            return "*" + value() + "*";
        }

        @Override // com.ibm.wala.cast.python.ml.types.TensorType.Dimension
        String toCString(boolean z) {
            return z ? "*" + value() + "*" : value();
        }
    }

    public TensorType(String str, List<Dimension<?>> list) {
        this.cellType = str;
        this.dims = list;
    }

    String toFormattedString(Format format) {
        switch (format) {
            case CString:
                return toCString(false);
            case MCString:
                return toCString(true);
            case MDString:
                return toMDString();
            case JsonSchema:
                return new Gson().toJson(toJsonSchema());
            default:
                throw new IllegalArgumentException("unknown format type: " + format);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v14, types: [com.google.gson.JsonElement] */
    public JsonElement toJsonSchema() {
        JsonObject jsonObject = null;
        if (this.cellType != null) {
            jsonObject = new JsonObject();
            jsonObject.addProperty("description", "Elements of type " + this.cellType);
        }
        JsonObject jsonObject2 = jsonObject;
        Iterator<Dimension<?>> it = this.dims.iterator();
        while (it.hasNext()) {
            jsonObject2 = it.next().toJsonSchema(jsonObject2);
        }
        return jsonObject2;
    }

    public String toMDString() {
        return "[ " + ((String) this.dims.stream().map((v0) -> {
            return v0.toMDString();
        }).collect(Collectors.joining(" ; "))) + " **of** _" + this.cellType + "_ ]";
    }

    public String toCString(boolean z) {
        return (z ? BaseLocale.SEP + this.cellType + BaseLocale.SEP : this.cellType) + ((String) this.dims.stream().map(dimension -> {
            return dimension.toCString(z);
        }).map(str -> {
            return "[" + str + "]";
        }).collect(Collectors.joining()));
    }

    public String toString() {
        return "{" + this.dims.toString() + " of " + this.cellType + "}";
    }

    public int hashCode() {
        return (31 * ((31 * 1) + (this.cellType == null ? 0 : this.cellType.hashCode()))) + (this.dims == null ? 0 : this.dims.hashCode());
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        TensorType tensorType = (TensorType) obj;
        if (this.cellType == null) {
            if (tensorType.cellType != null) {
                return false;
            }
        } else if (!this.cellType.equals(tensorType.cellType)) {
            return false;
        }
        return this.dims == null ? tensorType.dims == null : this.dims.equals(tensorType.dims);
    }

    public static TensorType mnistInput() {
        return new TensorType("pixel", Arrays.asList(new SymbolicDim("n"), new CompoundDim(Arrays.asList(new NumericDim(28), new NumericDim(28)))));
    }

    public static TensorType shapeArg(CGNode cGNode, int i) {
        int val;
        int ref;
        System.err.println(cGNode.getIR());
        ArrayList arrayList = new ArrayList();
        DefUse du = cGNode.getDU();
        SymbolTable symbolTable = cGNode.getIR().getSymbolTable();
        Iterator<SSAInstruction> uses = du.getUses(i);
        while (uses.hasNext()) {
            SSAInstruction next = uses.next();
            if (next instanceof SSAPutInstruction) {
                val = ((SSAPutInstruction) next).getVal();
                ref = ((SSAPutInstruction) next).getRef();
            } else if (next instanceof PythonPropertyWrite) {
                val = ((PythonPropertyWrite) next).getValue();
                ref = ((PythonPropertyWrite) next).getObjectRef();
            }
            if (ref == i) {
                if (symbolTable.isNumberConstant(val)) {
                    int intValue = ((Number) symbolTable.getConstantValue(val)).intValue();
                    System.err.println("value: " + intValue);
                    arrayList.add(intValue >= 0 ? new NumericDim(Integer.valueOf(intValue)) : new SymbolicDim("?"));
                } else {
                    if (du.getDef(val) != null && (cGNode.getMethod() instanceof AstMethod)) {
                        CAstSourcePositionMap.Position instructionPosition = ((AstMethod) cGNode.getMethod()).debugInfo().getInstructionPosition(du.getDef(val).iindex);
                        System.err.println(instructionPosition);
                        try {
                            String sourceBuffer = new SourceBuffer(instructionPosition).toString();
                            System.err.println(sourceBuffer);
                            PyObject eval = PythonUtil.getInterp().eval(sourceBuffer);
                            System.err.println(eval);
                            if (eval.isInteger()) {
                                arrayList.add(new NumericDim(Integer.valueOf(eval.asInt())));
                            }
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                    }
                    arrayList.add(new SymbolicDim("?"));
                }
            }
        }
        return new TensorType("pixel", arrayList);
    }

    @Override // java.lang.Iterable
    public Iterator<Dimension<?>> iterator() {
        return this.dims.iterator();
    }

    public int symbolicDims() {
        int i = 0;
        Iterator<Dimension<?>> it = iterator();
        while (it.hasNext()) {
            i += it.next().symbolicDims();
        }
        return i;
    }

    public int concreteSize() {
        int i = -1;
        Iterator<Dimension<?>> it = iterator();
        while (it.hasNext()) {
            int concreteSize = it.next().concreteSize();
            if (concreteSize >= 0) {
                i = i >= 0 ? i * concreteSize : concreteSize;
            }
        }
        return i;
    }
}
