Skip to content

Commit aecf48b

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Fix int4x2 packing bug in onnx QDQ export
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent cac788c commit aecf48b

File tree

8 files changed

+1146
-72
lines changed

8 files changed

+1146
-72
lines changed

TrainingExtensions/common/src/python/aimet_common/onnx/_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def _add_onnx_qdq_nodes(model: ModelProto,
138138
f"{input_name}_zero_point",
139139
],
140140
output=output_name,
141-
output_dtype=output_dtype,
141+
dtype=output_dtype,
142142
axis=axis,
143143
block_size=block_size,
144144
)
@@ -156,7 +156,7 @@ def _add_onnx_qdq_nodes(model: ModelProto,
156156
f"{input_name}_zero_point",
157157
],
158158
output=f"{input_name}_int",
159-
output_dtype=output_dtype,
159+
dtype=output_dtype,
160160
axis=axis,
161161
block_size=block_size,
162162
),
@@ -168,7 +168,7 @@ def _add_onnx_qdq_nodes(model: ModelProto,
168168
f"{input_name}_zero_point",
169169
],
170170
output=output_name,
171-
output_dtype=output_dtype,
171+
dtype=output_dtype,
172172
axis=axis,
173173
block_size=block_size,
174174
),

TrainingExtensions/common/src/python/aimet_common/onnx/opset10.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
# =============================================================================
3737
# pylint: disable=no-member
3838
from abc import ABC, abstractmethod
39+
import sys
3940
from typing import Iterable, Mapping, Optional
4041
import numpy as np
4142
from onnx import helper, numpy_helper, TensorProto
@@ -47,7 +48,7 @@ class _QdqNodeFactory(ABC):
4748
@classmethod
4849
@abstractmethod
4950
def make_node(cls, name: str, inputs: Iterable[str], output: str,
50-
output_dtype: str, axis: Optional[int] = None,
51+
dtype: str, axis: Optional[int] = None,
5152
block_size: Optional[int] = None):
5253
...
5354

@@ -65,21 +66,51 @@ def _check_dtype(cls, dtype: str):
6566
def make_zero_point(cls, zero_point: np.ndarray, dtype: str, name: str):
6667
cls._check_dtype(dtype)
6768

69+
if (dtype == "int32" or dtype.startswith("float")) and not np.all(zero_point == 0):
70+
raise RuntimeError(
71+
"DequantizeLinear with type int32 or float8 should have "
72+
"no zero point or all zero points should be 0"
73+
)
74+
6875
if dtype not in ("int4", "uint4"):
6976
zero_point = zero_point.astype(dtype)
7077
return numpy_helper.from_array(zero_point, name=name)
7178

79+
target_shape = zero_point.shape
80+
7281
# Numpy doesn't support int4/uint4.
7382
# Do bitshift operations to pack int4 array into int8 array
74-
zero_point = zero_point.astype("int8" if dtype == "int4" else "uint8")
75-
MSB = zero_point.flatten()[::2] << 4
76-
LSB = zero_point.flatten()[1::2]
77-
tensor = numpy_helper.from_array(MSB | LSB, name=name)
83+
zero_point = zero_point.astype("int8" if dtype == "int4" else "uint8").flatten()
84+
if zero_point.size % 2 == 1:
85+
# Add 0 padding to enable int4x2 packing
86+
zero_point = np.concatenate((zero_point, np.array([0], dtype=zero_point.dtype)))
87+
88+
if sys.byteorder == "little":
89+
# Little endian:
90+
#
91+
# zp[n+1] zp[n]
92+
# <-----> | <----->
93+
# bit: 7 6 5 4 3 2 1 0
94+
# (MSB) (LSB)
95+
MSB = zero_point[1::2] << 4
96+
LSB = zero_point[::2] & 0x0F
97+
else:
98+
# Big endian:
99+
#
100+
# zp[n] zp[n+1]
101+
# <-----> | <----->
102+
# bit: 7 6 5 4 3 2 1 0
103+
# (MSB) (LSB)
104+
MSB = zero_point[::2] << 4
105+
LSB = zero_point[1::2] & 0x0F
106+
107+
zero_point_int4x2 = MSB | LSB
108+
tensor = numpy_helper.from_array(zero_point_int4x2, name=name)
78109

79110
# Restore data_type to INT4/UINT4
80111
tensor.data_type = TensorProto.INT4 if dtype == "int4" else TensorProto.UINT4
81112
tensor.ClearField("dims")
82-
tensor.dims.extend(zero_point.shape)
113+
tensor.dims.extend(target_shape)
83114

84115
return tensor
85116

@@ -93,7 +124,7 @@ class QuantizeLinear(_QdqNodeFactory):
93124

94125
@classmethod
95126
def make_node(cls, name: str, inputs: Iterable[str], output: str,
96-
output_dtype: str, axis: Optional[int] = None,
127+
dtype: str, axis: Optional[int] = None,
97128
block_size: Optional[int] = None):
98129
if axis is not None:
99130
raise RuntimeError(
@@ -105,7 +136,7 @@ def make_node(cls, name: str, inputs: Iterable[str], output: str,
105136
f"Blockwise quantization is not supported in opset {cls.OPSET}"
106137
)
107138

108-
cls._check_dtype(output_dtype)
139+
cls._check_dtype(dtype)
109140

110141
return helper.make_node("QuantizeLinear",
111142
name=name,
@@ -123,7 +154,7 @@ class DequantizeLinear(_QdqNodeFactory):
123154

124155
@classmethod
125156
def make_node(cls, name: str, inputs: Iterable[str], output: str,
126-
output_dtype: str, axis: Optional[int] = None,
157+
dtype: str, axis: Optional[int] = None,
127158
block_size: Optional[int] = None):
128159
if axis is not None:
129160
raise RuntimeError(
@@ -135,7 +166,7 @@ def make_node(cls, name: str, inputs: Iterable[str], output: str,
135166
f"Blockwise quantization is not supported in opset {cls.OPSET}"
136167
)
137168

138-
cls._check_dtype(output_dtype)
169+
cls._check_dtype(dtype)
139170

140171
return helper.make_node("DequantizeLinear",
141172
name=name,

TrainingExtensions/common/src/python/aimet_common/onnx/opset13.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ class QuantizeLinear(opset10.QuantizeLinear):
4545

4646
@classmethod
4747
def make_node(cls, name: str, inputs: Iterable[str], output: str,
48-
output_dtype: str, axis: Optional[int] = None,
48+
dtype: str, axis: Optional[int] = None,
4949
block_size: Optional[int] = None):
5050
if block_size is not None:
5151
raise RuntimeError(
5252
f"Blockwise quantization is not supported in opset {cls.OPSET}"
5353
)
5454

55-
cls._check_dtype(output_dtype)
55+
cls._check_dtype(dtype)
5656

5757
return helper.make_node("QuantizeLinear",
5858
name=name,
@@ -66,14 +66,14 @@ class DequantizeLinear(opset10.DequantizeLinear):
6666

6767
@classmethod
6868
def make_node(cls, name: str, inputs: Iterable[str], output: str,
69-
output_dtype: str, axis: Optional[int] = None,
69+
dtype: str, axis: Optional[int] = None,
7070
block_size: Optional[int] = None):
7171
if block_size is not None:
7272
raise RuntimeError(
7373
f"Blockwise quantization is not supported in opset {cls.OPSET}"
7474
)
7575

76-
cls._check_dtype(output_dtype)
76+
cls._check_dtype(dtype)
7777

7878
return helper.make_node("DequantizeLinear",
7979
name=name,

TrainingExtensions/common/src/python/aimet_common/onnx/opset21.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,24 @@ class QuantizeLinear(opset13.QuantizeLinear):
5252

5353
@classmethod
5454
def make_node(cls, name: str, inputs: Iterable[str], output: str,
55-
output_dtype: str, axis: Optional[int] = None,
55+
dtype: str, axis: Optional[int] = None,
5656
block_size: Optional[int] = None):
57-
cls._check_dtype(output_dtype)
57+
cls._check_dtype(dtype)
58+
59+
if axis is None and block_size is not None:
60+
raise RuntimeError(
61+
"axis must be specified if block_size is not None; "
62+
f"got axis={axis}, block_size={block_size}"
63+
)
5864

5965
return helper.make_node("QuantizeLinear",
6066
name=name,
6167
inputs=list(inputs),
6268
outputs=[output],
63-
output_dtype=cls.SUPPORTED_DTYPES[output_dtype],
69+
# NOTE: Don't pass output_dtype explicitly; ORT has a bug
70+
# where per-tensor int8 QuantizeLinear
71+
# fails with output_dtype explicitly specified as INT8
72+
# output_dtype=cls.SUPPORTED_DTYPES[dtype],
6473
axis=axis,
6574
block_size=block_size)
6675

@@ -77,9 +86,15 @@ class DequantizeLinear(opset13.DequantizeLinear):
7786

7887
@classmethod
7988
def make_node(cls, name: str, inputs: Iterable[str], output: str,
80-
output_dtype: str, axis: Optional[int] = None,
89+
dtype: str, axis: Optional[int] = None,
8190
block_size: Optional[int] = None):
82-
cls._check_dtype(output_dtype)
91+
cls._check_dtype(dtype)
92+
93+
if axis is None and block_size is not None:
94+
raise RuntimeError(
95+
"axis must be specified if block_size is not None; "
96+
f"got axis={axis}, block_size={block_size}"
97+
)
8398

8499
return helper.make_node("DequantizeLinear",
85100
name=name,

0 commit comments

Comments
 (0)