Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with performing shape inference using symbolic_shape_infer.py with Phi-3 ONNX Models #21194

Open
shamith2 opened this issue Jun 27, 2024 · 4 comments
Labels
model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. platform:windows issues related to the Windows platform

Comments

@shamith2
Copy link

shamith2 commented Jun 27, 2024

Describe the issue

I am running into this error when I run SymbolicShapeInference.infer_shapes() function on Phi 3 mini onnx model optimized for CPU. From my understanding, looks like infer_shapes() is not able to infer shape for MatMulNBits op (https://github.com/microsoft/onnxruntime/blob/rel-1.18.0/docs/ContribOperators.md#com.microsoft.MatMulNBits)

This issue might not be limited to Phi-3. I suspect this has to do with the operator domain. Does the infer_shapes() automatically work on operators from com.microsoft domain?

I made the input and output dims of the onnx model static before performing shape inference using update_inputs_outputs_dims() from https://github.com/onnx/onnx/blob/main/onnx/tools/update_model_dims.py

Error:

DEBUG:onnxruntime.tools.symbolic_shape_infer:Stopping at incomplete shape inference at MatMulNBits: /model/layers.0/attn/qkv_proj/MatMul_Q4
DEBUG:onnxruntime.tools.symbolic_shape_infer:node inputs:
DEBUG:onnxruntime.tools.symbolic_shape_infer:name: "/model/layers.0/input_layernorm/output_0"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 3072
      }
    }
  }
}

DEBUG:onnxruntime.tools.symbolic_shape_infer:name: "model.layers.0.attn.qkv_proj.MatMul.weight_Q4"
type {
  tensor_type {
    elem_type: 2
    shape {
      dim {
        dim_value: 9216
      }
      dim {
        dim_value: 96
      }
      dim {
        dim_value: 16
      }
    }
  }
}

DEBUG:onnxruntime.tools.symbolic_shape_infer:name: "model.layers.0.attn.qkv_proj.MatMul.weight_scales"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 884736
      }
    }
  }
}

DEBUG:onnxruntime.tools.symbolic_shape_infer:node outputs:
DEBUG:onnxruntime.tools.symbolic_shape_infer:name: "/model/layers.0/attn/qkv_proj/MatMul/output_0"
type {
}

Traceback (most recent call last):
  File "C:\Users\Administrator\Documents\onnxInsights\scripts\onnxProfile\onnx_profiling.py", line 61, in <module>
    inferred_onnx_model_path = onnx_t.shapeInfer(
  File "C:\Users\Administrator\Documents\onnxInsights\src\onnxInsights\onnxHelpers\onnxTransformer.py", line 245, in shapeInfer
    inferred_model = SymbolicShapeInference.infer_shapes(
  File "C:\Users\Administrator\miniconda3\envs\onnx_test\lib\site-packages\onnxruntime\tools\symbolic_shape_infer.py", line 2912, in infer_shapes
    raise Exception("Incomplete symbolic shape inference")
Exception: Incomplete symbolic shape inference

To reproduce

Download Model:

huggingface-cli download microsoft/Phi-3-mini-128k-instruct-onnx cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/phi3-mini-128k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx
huggingface-cli download microsoft/Phi-3-mini-128k-instruct-onnx cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/phi3-mini-128k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data

Packages:

python -> 3.10.14
onnx -> 1.16.1
onnxruntime -> 1.18.0

Code:

from onnx.tools.update_model_dims import update_inputs_outputs_dims
import onnxruntime
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference

dummy_session = onnxruntime.InferenceSession(
    onnx_model_path,
    providers=["CPUExecutionProvider"]
)
    
model_inputs = dummy_session.get_inputs()
model_outputs = dummy_session.get_outputs()

# static input and output dims
static_input_dims = {
    'input_ids': [1, 1],
    'attention_mask': [1, 2048]
}

for i in range(32*2):
    static_input_dims[model_inputs[i+2].name] = [1, 32, 2047, 96]

static_output_dims = {
    'logits': [1, 1, 32064]
}

for i in range(32*2):
    static_output_dims[model_outputs[i+1].name] = [1, 32, 2048, 96]

# make the input and output dims static in the onnx model
onnx_model = onnx.load(onnx_model_path)

static_dim_model = update_inputs_outputs_dims(onnx_model, static_input_dims, static_output_dims)

# perform shape inference
inferred_model = SymbolicShapeInference.infer_shapes(
            static_dim_model,
            int_max=2**31 - 1,
            auto_merge=False,
            guess_output_rank=False,
            verbose=0
        )

Urgency

No response

Platform

Windows

OS Version

11

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.18.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@github-actions github-actions bot added model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. platform:windows issues related to the Windows platform labels Jun 27, 2024
@shamith2 shamith2 changed the title Issue with performing shape inference using symbolic_shape_infer.py with Phi 3 Family ONNX Models Issue with performing shape inference using symbolic_shape_infer.py with Phi-3 ONNX Models Jun 27, 2024
@sophies927 sophies927 removed the platform:windows issues related to the Windows platform label Jun 27, 2024
@github-actions github-actions bot added the platform:windows issues related to the Windows platform label Jun 27, 2024
@tianleiwu
Copy link
Contributor

@kunal-vaishnavi, could you take a look at symbolic shape inference works on phi-3 models.

@kunal-vaishnavi
Copy link
Contributor

The uploaded Phi-3 ONNX models already have been symbolic shape inferenced with dynamic axes.

The symbolic shape inference for most quantization operators is defined in each operator's spec.

ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(MatMulNBits_ver1_doc)
.Attr("K", "size of each input feature", AttributeProto::INT)
.Attr("N", "size of each output feature", AttributeProto::INT)
.Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT)
.Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
.Attr("accuracy_level",
"The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) "
"(default unset). It is used to control how input A is quantized or downcast internally while "
"doing computation, for example: 0 means input A will not be quantized or downcast while doing "
"computation. 4 means input A can be quantized with the same block_size to int8 internally from "
"type T1.",
AttributeProto::INT, static_cast<int64_t>(0))
.Input(0, "A", "The input tensor, not quantized", "T1")
.Input(1, "B", "1 or 2 dimensional data blob", "T2")
.Input(2, "scales", "quantization scale", "T1")
.Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional)
.Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional)
.Input(5, "bias", "Bias to add to result. It should have shape [N].", "T1", OpSchema::Optional)
.Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1")
.TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.")
.TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.")
.TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.")
.TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
int64_t in_features = getAttribute(ctx, "K", -1);
int64_t out_features = getAttribute(ctx, "N", -1);
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true);
// validate bias shape
if (ctx.hasInput(5)) {
if (!hasInputShape(ctx, 5)) {
fail_shape_inference("bias shape must be known");
}
const auto& bias_shape = getInputShape(ctx, 5);
if (bias_shape.dim_size() != 1 ||
!bias_shape.dim(0).has_dim_value() ||
bias_shape.dim(0).dim_value() != out_features) {
fail_shape_inference("bias shape must be [N] where N = ", out_features);
}
}
});

Here is the list of supported operators whose shapes can be symbolically inferred in the SymbolicShapeInference.infer_shapes tool.

self.dispatcher_ = {
"Add": self._infer_symbolic_compute_ops,
"ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor,
"AveragePool": self._infer_Pool,
"BatchNormalization": self._infer_BatchNormalization,
"Cast": self._infer_Cast,
"CategoryMapper": self._infer_CategoryMapper,
"Compress": self._infer_Compress,
"Concat": self._infer_Concat,
"ConcatFromSequence": self._infer_ConcatFromSequence,
"Constant": self._infer_Constant,
"ConstantOfShape": self._infer_ConstantOfShape,
"Conv": self._infer_Conv,
"CumSum": self._pass_on_shape_and_type,
"Div": self._infer_symbolic_compute_ops,
"Einsum": self._infer_Einsum,
"Expand": self._infer_Expand,
"Equal": self._infer_symbolic_compute_ops,
"Floor": self._infer_symbolic_compute_ops,
"Gather": self._infer_Gather,
"GatherElements": self._infer_GatherElements,
"GatherND": self._infer_GatherND,
"Identity": self._pass_on_shape_and_type,
"AllReduce": self._pass_on_shape_and_type,
"If": self._infer_If,
"Loop": self._infer_Loop,
"MatMul": self._infer_MatMul,
"MatMulInteger16": self._infer_MatMulInteger,
"MaxPool": self._infer_Pool,
"Max": self._infer_symbolic_compute_ops,
"MemcpyFromHost": self._pass_on_shape_and_type,
"MemcpyToHost": self._pass_on_shape_and_type,
"Min": self._infer_symbolic_compute_ops,
"MoE": self._pass_on_shape_and_type,
"Mul": self._infer_symbolic_compute_ops,
"NonMaxSuppression": self._infer_NonMaxSuppression,
"NonZero": self._infer_NonZero,
"OneHot": self._infer_OneHot,
"Pad": self._infer_Pad,
"Range": self._infer_Range,
"Reciprocal": self._pass_on_shape_and_type,
"ReduceSum": self._infer_ReduceSum,
"ReduceProd": self._infer_ReduceProd,
"Reshape": self._infer_Reshape,
"Resize": self._infer_Resize,
"Round": self._pass_on_shape_and_type,
"Scan": self._infer_Scan,
"ScatterElements": self._infer_ScatterElements,
"SequenceAt": self._infer_SequenceAt,
"SequenceInsert": self._infer_SequenceInsert,
"Shape": self._infer_Shape,
"Size": self._infer_Size,
"Slice": self._infer_Slice,
"SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss,
"SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss,
"NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss,
"Split": self._infer_Split,
"SplitToSequence": self._infer_SplitToSequence,
"Squeeze": self._infer_Squeeze,
"Sub": self._infer_symbolic_compute_ops,
"Tile": self._infer_Tile,
"TopK": self._infer_TopK,
"Transpose": self._infer_Transpose,
"Unsqueeze": self._infer_Unsqueeze,
"Where": self._infer_symbolic_compute_ops,
"ZipMap": self._infer_ZipMap,
"Neg": self._infer_symbolic_compute_ops,
# contrib ops:
"Attention": self._infer_Attention,
"BiasAdd": self._infer_BiasAdd,
"BiasGelu": self._infer_BiasGelu,
"BiasSplitGelu": self._infer_BiasSplitGelu,
"DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention,
"DequantizeLinear": self._infer_DequantizeLinear,
"EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
"FastGelu": self._infer_FastGelu,
"GatedRelativePositionBias": self._infer_GatedRelativePositionBias,
"Gelu": self._infer_Gelu,
"GemmFastGelu": self._infer_GemmFastGelu,
"GemmFloat8": self._infer_GemmFloat8,
"GroupNorm": self._infer_GroupNorm,
"GroupQueryAttention": self._infer_GroupQueryAttention,
"SparseAttention": self._infer_SparseAttention,
"SkipGroupNorm": self._infer_SkipGroupNorm,
"LayerNormalization": self._infer_LayerNormalization,
"LongformerAttention": self._infer_LongformerAttention,
"MultiHeadAttention": self._infer_MultiHeadAttention,
"NhwcConv": self._infer_NhwcConv,
"PackedAttention": self._infer_PackedAttention,
"PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention,
"PagedAttention": self._infer_PagedAttention,
"PythonOp": self._infer_PythonOp,
"QuantizeLinear": self._infer_QuantizeLinear,
"QuickGelu": self._infer_FastGelu,
"RelativePositionBias": self._infer_RelativePositionBias,
"RemovePadding": self._infer_RemovePadding,
"RestorePadding": self._infer_RestorePadding,
"RotaryEmbedding": self._infer_RotaryEmbedding,
"SimplifiedLayerNormalization": self._infer_LayerNormalization,
"SkipLayerNormalization": self._infer_SkipLayerNormalization,
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
}
self.aten_op_dispatcher_ = {
"embedding": self._infer_Gather,
"bitwise_or": self._infer_aten_bitwise_or,
"diagonal": self._infer_aten_diagonal,
"max_pool2d_with_indices": self._infer_aten_pool2d,
"max": self._infer_aten_minmax,
"min": self._infer_aten_minmax,
"multinomial": self._infer_aten_multinomial,
"unfold": self._infer_aten_unfold,
"argmax": self._infer_aten_argmax,
"avg_pool2d": self._infer_aten_pool2d,
"_adaptive_avg_pool2d": self._infer_aten_pool2d,
"numpy_T": self._infer_Transpose,
"native_group_norm": self._infer_aten_group_norm,
"upsample_nearest1d": self._infer_aten_upsample,
"upsample_nearest2d": self._infer_aten_upsample,
"upsample_nearest3d": self._infer_aten_upsample,
"upsample_bicubic2d": self._infer_aten_upsample,
}

@shamith2
Copy link
Author

shamith2 commented Jul 1, 2024

@kunal-vaishnavi, thanks for the response. I have a few questions and comments from my side:

  1. The attached Phi-3 ONNX model is not shape inferred for all the operators. Couple of operators might have symbolic shape inferenced with dynamic axes. The vast majority of the operators are not shape inferenced. For example, here is one of the subgraphs of the model that is not shape inferenced, visualized in Netron:

graph

From my understanding, this is how Netron visualizes shape inferenced operators after running the model through the SymbolicShapeInference.infer_shapes tool, which I was not able to do for the phi-3 model (this subgraph is from a different onnx model):

graph_inf

  1. I do not see MatMulNBits operator in the list of supported operators you shared for the SymbolicShapeInference.infer_shapes tool, which might be a reason why SymbolicShapeInference.infer_shapes tool is giving out the error

  2. Were you able to successfully shape infer the phi-3 model for all operators? I am not able to do it with release version of onnxruntime 1.18.0. Which version of onnxruntime are you using?

@kunal-vaishnavi
Copy link
Contributor

The attached Phi-3 ONNX model is not shape inferred for all the operators. Couple of operators might have symbolic shape inferenced with dynamic axes. The vast majority of the operators are not shape inferenced. For example, here is one of the subgraphs of the model that is not shape inferenced, visualized in Netron. From my understanding, this is how Netron visualizes shape inferenced operators after running the model through the SymbolicShapeInference.infer_shapes tool, which I was not able to do for the phi-3 model (this subgraph is from a different onnx model).

You can find the shape inference by clicking on the operator and pressing the '+' icon next to the right of each input name and output name. Here is an example.

image

I do not see MatMulNBits operator in the list of supported operators you shared for the SymbolicShapeInference.infer_shapes tool, which might be a reason why SymbolicShapeInference.infer_shapes tool is giving out the error

Yes, your error occurs because symbolic shape inference for MatMulNBits isn't implemented in SymbolicShapeInference.infer_shapes. We can add MatMulNBits to fix this.

Were you able to successfully shape infer the phi-3 model for all operators? I am not able to do it with release version of onnxruntime 1.18.0. Which version of onnxruntime are you using?

The uploaded Phi-3 ONNX models are created via ONNX Runtime GenAI's model builder. The shape inferences for their operators are created here in the model builder using onnx.helper.make_tensor_value_info and added to the ModelProto here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. platform:windows issues related to the Windows platform
Projects
None yet
Development

No branches or pull requests

4 participants