Skip to content

Commit a26c9ae

Browse files
michaelgtuttlequic-mtuttle
authored andcommitted
Fix failures and incorrect configurations in aimet_onnx AMP
Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com> Co-authored-by: Michael Tuttle <quic_mtuttle@quicinc.com>
1 parent d10e41c commit a26c9ae

File tree

5 files changed

+103
-4
lines changed

5 files changed

+103
-4
lines changed

TrainingExtensions/common/src/python/aimet_common/amp/mixed_precision_algo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@ def _count_and_get_quantizers_flipped(self):
275275
(candidate[CandAttr.parameter][CandParam.bitwdith],
276276
candidate[CandAttr.parameter][CandParam.data_type])))
277277

278-
percentage_act_quantizers_flipped = (count_act_quantizers_flipped * 100) / total_act_quantizers
279-
percentage_param_quantizers_flipped = (count_param_quantizers_flipped * 100) / total_param_quantizers
278+
percentage_act_quantizers_flipped = (count_act_quantizers_flipped * 100) / total_act_quantizers if total_act_quantizers else 0
279+
percentage_param_quantizers_flipped = (count_param_quantizers_flipped * 100) / total_param_quantizers if total_param_quantizers else 0
280280
percentage_quantizers_flipped = ((count_param_quantizers_flipped + count_act_quantizers_flipped) * 100) / \
281281
(total_param_quantizers + total_act_quantizers)
282282
total_quantizers = total_param_quantizers + total_act_quantizers

TrainingExtensions/onnx/src/python/aimet_onnx/amp/mixed_precision_algo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,9 @@ def _optimize_mp_profile_and_evaluate_model(self):
539539
"""
540540
Uses OpGraph if available to optimize the mixed precision profile in the sim object
541541
"""
542+
# Apply exception rules logic to enforce a valid quantizer configuration
543+
self._sim._apply_exception_rules() # pylint: disable = protected-access
544+
542545
# Recompute quantizer encodings
543546
self._sim.compute_encodings(self.algo_params.forward_pass_callback,
544547
self.algo_params.forward_pass_callback_args)

TrainingExtensions/onnx/src/python/aimet_onnx/amp/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
# Import AIMET specific modules
4343
from aimet_common.amp.utils import CANDIDATE_WITH_DTYPE, get_effective_bitwidth
4444
from aimet_common.cost_calculator import CostCalculator
45-
from aimet_onnx.meta.connectedgraph import ConnectedGraph
45+
from aimet_onnx.meta.connectedgraph import ConnectedGraph, WEIGHT_INDEX
4646
from aimet_onnx.amp.quantizer_groups import QuantizerGroup
4747
from aimet_onnx import utils
4848
from aimet_onnx.quantsim import QuantizationSimModel
@@ -107,7 +107,8 @@ def _get_weight_shape(op):
107107
if len(layer.output_shape) == 2:
108108
# Append 1, 1 to Linear layer's shape
109109
layer.output_shape = list(layer.output_shape) + [1, 1]
110-
layer.weight_shape = _get_weight_shape(ops[node.name])
110+
# If _get_weight_shape returns None, weight index is an activation
111+
layer.weight_shape = _get_weight_shape(ops[node.name]) or activation_shapes[node.input[WEIGHT_INDEX]]
111112
op_database[node.name] = layer
112113

113114
return op_database

TrainingExtensions/onnx/test/python/models/models_for_tests.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2704,6 +2704,35 @@ def conv_with_weight_identity_input():
27042704
onnx.checker.check_model(model, True)
27052705
return model
27062706

2707+
def dynamic_conv_model():
2708+
model = helper.make_model(
2709+
graph=helper.make_graph(
2710+
name="DynamicConvModel",
2711+
inputs=[helper.make_tensor_value_info('x', TensorProto.FLOAT, shape=[10, 10, 32, 32]),
2712+
helper.make_tensor_value_info('y', TensorProto.FLOAT, shape=[10, 10, 1, 1])],
2713+
outputs=[helper.make_tensor_value_info('model_output', TensorProto.FLOAT, shape=[10, 10, 32, 32])],
2714+
initializer=[
2715+
numpy_helper.from_array(np.random.randn(10, 10, 1, 1).astype('float32'), name='add.input'),
2716+
],
2717+
nodes=[
2718+
helper.make_node(
2719+
"Add",
2720+
inputs=["y", "add.input"],
2721+
outputs=["dynamic_conv.weight"],
2722+
name="add"
2723+
),
2724+
helper.make_node(
2725+
"Conv",
2726+
inputs=["x", "dynamic_conv.weight"],
2727+
outputs=["model_output"],
2728+
name="conv"
2729+
)
2730+
]
2731+
)
2732+
)
2733+
onnx.checker.check_model(model, True)
2734+
return model
2735+
27072736

27082737
def squeezenet1_0(tmpdir):
27092738
import torchvision

TrainingExtensions/onnx/test/python/test_mixed_precision.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,15 @@
4949
from aimet_onnx.quantsim import QuantizationSimModel
5050
from aimet_onnx.amp.mixed_precision_algo import GreedyMixedPrecisionAlgo, _compute_sqnr, EvalCallbackFactory
5151
from aimet_onnx.defs import DataLoader
52+
from aimet_onnx.utils import make_dummy_input
53+
from aimet_onnx.mixed_precision import choose_mixed_precision
5254

5355
from aimet_common.defs import QuantizationDataType, CallbackFunc
5456
from aimet_common.amp.mixed_precision_algo import interpolation_search, brute_force_search, binary_search
5557
from aimet_common.amp.utils import calculate_starting_bit_ops, AMPSearchAlgo
5658

5759
from .models.test_models import single_residual_model
60+
from .models import models_for_tests
5861

5962
INPUT_SHAPE = (1, 3, 32, 32)
6063

@@ -585,6 +588,69 @@ def test_respect_frozen_encodings(self, sim, forward_pass_callback, eval_callbac
585588
assert quantizer.bitwidth == 4
586589

587590

591+
@pytest.mark.parametrize("model", (
592+
single_residual_model().model,
593+
models_for_tests.dynamic_matmul_model(10),
594+
models_for_tests.matmul_with_constant_first_input(),
595+
models_for_tests.weight_matmul_model(),
596+
models_for_tests.dynamic_conv_model(),
597+
models_for_tests.mobilenetv2().model,
598+
models_for_tests.depthwise_transposed_conv_model().model,
599+
models_for_tests.model_with_split_matmul(),
600+
models_for_tests.hierarchical_model().model,
601+
))
602+
def test_choose_mixed_precision(self, model, tmpdir):
603+
np.random.seed(0)
604+
605+
sim = QuantizationSimModel(model, default_activation_bw=8, default_param_bw=8, config_file="htp_v73")
606+
enabled_quantizers = {q for q in sim.qc_quantize_op_dict.values() if q.enabled}
607+
total_bits = 16 * len(enabled_quantizers)
608+
609+
forward_callback = CallbackFunc(lambda sess, _: sess.run(None, make_dummy_input(model)), None)
610+
611+
def phase_2_callback(sess, _):
612+
bits = sum(q.bitwidth if q.enabled else 16 for q in enabled_quantizers)
613+
return bits / total_bits
614+
615+
# Define dummy eval callbacks
616+
eval_callback_phase1 = CallbackFunc(lambda sess, _: np.random.rand())
617+
eval_callback_phase2 = CallbackFunc(phase_2_callback, None)
618+
619+
candidates = [((16, QuantizationDataType.float), (16, QuantizationDataType.float)),
620+
((16, QuantizationDataType.int), (8, QuantizationDataType.int)),
621+
((8, QuantizationDataType.int), (8, QuantizationDataType.int))]
622+
623+
# Apply mixed precision
624+
choose_mixed_precision(sim, candidates, eval_callback_phase1, eval_callback_phase2, 0.4, tmpdir, True,
625+
forward_callback)
626+
627+
# Assert that no param quantizers are in int16 (not a valid candidate)
628+
for name in sim.param_names:
629+
quantizer = sim.qc_quantize_op_dict[name]
630+
assert not (quantizer.bitwidth == 16 and quantizer.data_type == QuantizationDataType.int)
631+
632+
# Assert that the final result meets the accuracy metric
633+
assert sum(q.bitwidth for q in enabled_quantizers) <= total_bits
634+
assert sum(q.bitwidth for q in enabled_quantizers) >= total_bits * 0.6
635+
636+
# Assert that the final mixed-precision profile obeys config file's exception rules
637+
for op in sim.connected_graph.ordered_ops:
638+
if not op.type in ("MatMul", "Gemm"):
639+
continue
640+
641+
q1, q2 = sim._get_closest_enabled_quantizer(op.inputs[0]), sim._get_closest_enabled_quantizer(op.inputs[1])
642+
if not q1 or not q2:
643+
continue
644+
645+
# Config requires symmetric second input for 16-bit matmul
646+
if q2.bitwidth == 16 and not q2.data_type == QuantizationDataType.float:
647+
assert q2.use_symmetric_encodings
648+
649+
# 8 x 16 MatMul is not a valid combination
650+
if q1.bitwidth == 8:
651+
assert q2.bitwidth == 8
652+
653+
588654
class TestAMPv2:
589655
def test_compute_sqnr(self):
590656
""" Verify _compute_sqnr() method """

0 commit comments

Comments
 (0)