Skip to content

Commit 9d0226c

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Don't export disabled quantizers to onnx QDQ
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent 93ad0e2 commit 9d0226c

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,11 @@ def _to_onnx_qdq(self) -> onnx.ModelProto:
12551255

12561256
for aimet_node in aimet_qc_quantize_nodes:
12571257
qtzr = self.qc_quantize_op_dict[aimet_node.input[0]]
1258-
encodings = qtzr._export_2_0_0_encodings() # pylint: disable=protected-access
1258+
1259+
if qtzr.enabled:
1260+
encodings = qtzr._export_2_0_0_encodings() # pylint: disable=protected-access
1261+
else:
1262+
encodings = None
12591263

12601264
if encodings:
12611265
# Affine quantizer

TrainingExtensions/onnx/test/python/test_quantsim.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2361,8 +2361,8 @@ def test_bias_export(model_factory, input_shape, block_size, lpbq, tmp_path):
23612361
@pytest.mark.parametrize("export_int32_bias_encodings", [False, True])
23622362
@pytest.mark.parametrize(
23632363
"model_factory, input_shape, tolerance", [
2364-
(lambda: single_residual_model(opset_version=13), (1, 3, 32, 32), 2),
2365-
(lambda: transposed_conv_model(opset_version=13), (10, 10, 4, 4), 2),
2364+
(lambda: single_residual_model(opset_version=13), (1, 3, 32, 32), 1),
2365+
(lambda: transposed_conv_model(opset_version=13), (10, 10, 4, 4), 1),
23662366
(batchnorm_model, (10, 10, 8, 8), 1),
23672367
(batchnorm_model_constants, (10, 10, 8, 8), 1),
23682368
(lambda: instance_norm_model(opset_version=13), (2, 10, 24, 24), 3),
@@ -2383,7 +2383,6 @@ def test_onnx_qdq(model_factory,
23832383

23842384
"""
23852385
When: Create a pure onnx model with sim._to_onnx_qdq()
2386-
Then: Output of the pure onnx model should be equal to that of sim.session
23872386
"""
23882387
sim.compute_encodings(lambda sess, _: sess.run(None, {"input": input}), None)
23892388

@@ -2407,6 +2406,12 @@ def test_onnx_qdq(model_factory,
24072406

24082407
onnx_qdq_model = sim._to_onnx_qdq()
24092408

2409+
"""
2410+
Then: Onnx QDQ model should contain as many DequantizeLinear as the number of of ENABLED QcQuantizers
2411+
"""
2412+
assert len([node for node in onnx_qdq_model.graph.node if node.op_type == "DequantizeLinear"]) \
2413+
== len([qtzr for qtzr in sim.qc_quantize_op_dict.values() if qtzr.enabled])
2414+
24102415
# NOTE: Should disable all ORT graph optimization to circumvent known bugs
24112416
# in CPUExecutionProvider operator fusing.
24122417
# ORT CPUExecutionProvider produces corrupted output after fusing pattern A to B:
@@ -2423,6 +2428,9 @@ def test_onnx_qdq(model_factory,
24232428
sess_options = ort.SessionOptions()
24242429
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
24252430

2431+
"""
2432+
Then: Output of the pure onnx model should be equal to that of sim.session
2433+
"""
24262434
sess = ort.InferenceSession(onnx_qdq_model.SerializeToString(), sess_options=sess_options)
24272435
out_onnx_qdq, = sess.run(None, {"input": input})
24282436
assert np.allclose(out_sim, out_onnx_qdq, atol=atol)

0 commit comments

Comments
 (0)