Skip to content

Commit 7236fb2

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Allow [b]float16 quantizers in sim.onnx.export
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent b972a4b commit 7236fb2

File tree

4 files changed

+110
-29
lines changed

4 files changed

+110
-29
lines changed

TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def _create_int32_bias_quantizer(self, input, _): # pylint: disable=redefined-bu
713713

714714
if len(input) == 1:
715715
input, = input
716-
if self.input_quantizers[0]:
716+
if isinstance(self.input_quantizers[0], AffineQuantizerBase):
717717
input_scale = self.input_quantizers[0].get_scale()
718718
elif isinstance(input, QuantizedTensorBase) and isinstance(input.encoding, AffineEncoding):
719719
input_scale = input.encoding.scale

TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/float/encoding.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,4 +147,18 @@ def to_qnn_encoding_dict(self, encoding_version=None) -> Union[List, Dict]:
147147
if encoding_version == '1.0.0':
148148
return {'dtype': 'FLOAT', 'bw': self.bitwidth, 'enc_type': EncodingType.PER_TENSOR.name}
149149

150+
if encoding_version == "2.0.0.beta":
151+
if self.exponent_bits == 5 and self.mantissa_bits == 10:
152+
# float16
153+
return {}
154+
155+
if self.exponent_bits == 8 and self.mantissa_bits == 7:
156+
# bfloat16
157+
return {}
158+
159+
raise NotImplementedError(
160+
"Floating point encoding export only supports [b]float16; "
161+
f"got exponent_bits={self.exponent_bits}, mantissa_bits={self.mantissa_bits}"
162+
)
163+
150164
raise AssertionError(f'Export encoding version {encoding_version} not supported.')

TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
save_checkpoint,
6262
load_checkpoint,
6363
check_accumulator_overflow,
64-
_QuantizedModuleProtocol,
6564
)
6665
from aimet_torch.v2 import nn as aimet_nn
6766
from aimet_torch.v2.nn import BaseQuantizationMixin, QuantizationMixin, UnknownModuleError
@@ -70,6 +69,7 @@
7069
from aimet_torch.v2.quantization import DequantizedTensor
7170
from aimet_torch.v2.quantization.base import QuantizerBase
7271
from aimet_torch.v2.quantization.affine import AffineQuantizerBase
72+
from aimet_torch.v2.quantization.float import FloatQuantizeDequantize
7373
from aimet_torch.v2.quantization.encoding_analyzer import PercentileEncodingAnalyzer
7474
from aimet_torch.v2.utils import patch_attr
7575
from aimet_torch import utils
@@ -616,9 +616,11 @@ def _concretize_int32_bias_quantizers(self, args):
616616
# In this case, we honor the custom bias quantizer defined by the user
617617
continue
618618

619-
# pylint: disable=protected-access
620-
handle = qmodule.register_forward_hook(type(qmodule)._create_int32_bias_quantizer)
621-
handles.append(handle)
619+
if "weight" in qmodule.param_quantizers and \
620+
isinstance(qmodule.param_quantizers["weight"], AffineQuantizerBase):
621+
# pylint: disable=protected-access
622+
handle = qmodule.register_forward_hook(type(qmodule)._create_int32_bias_quantizer)
623+
handles.append(handle)
622624
try:
623625
self.model(*args)
624626
finally:
@@ -682,6 +684,41 @@ def _temporarily_unfold_param_quantizers(sim: QuantizationSimModel):
682684
qmodule._fold_param_quantizers()
683685

684686

687+
@contextlib.contextmanager
688+
def _remove_fp16_quantizers(sim: QuantizationSimModel):
689+
"""
690+
Temporarily remove [b]float16 quantizers for sim.onnx.export,
691+
as sim.onnx.export does NOT support exporting [b]float16 quantizers.
692+
"""
693+
original_containers = {}
694+
695+
try:
696+
for qmodule in sim.qmodules():
697+
for name, qtzr in qmodule.param_quantizers.items():
698+
if isinstance(qtzr, FloatQuantizeDequantize) and \
699+
(qtzr.is_float16() or qtzr.is_bfloat16()):
700+
original_containers[(qmodule.param_quantizers, name)] = qtzr
701+
qmodule.param_quantizers[name] = None
702+
703+
for i, qtzr in enumerate(qmodule.input_quantizers):
704+
if isinstance(qtzr, FloatQuantizeDequantize) and \
705+
(qtzr.is_float16() or qtzr.is_bfloat16()):
706+
original_containers[(qmodule.input_quantizers, i)] = qtzr
707+
qmodule.input_quantizers[i] = None
708+
709+
for i, qtzr in enumerate(qmodule.output_quantizers):
710+
if isinstance(qtzr, FloatQuantizeDequantize) and \
711+
(qtzr.is_float16() or qtzr.is_bfloat16()):
712+
original_containers[(qmodule.output_quantizers, i)] = qtzr
713+
qmodule.output_quantizers[i] = None
714+
715+
yield
716+
717+
finally:
718+
for (container, key), qtzr in original_containers.items():
719+
container[key] = qtzr
720+
721+
685722
class _QuantizationSimOnnxExport:
686723
"""
687724
Helper class for exporting quantized models to ONNX format.
@@ -708,14 +745,13 @@ def export(self,
708745
:param f: file object or path where to store exported ONNX mode
709746
"""
710747
# pylint: disable=too-many-locals, too-many-branches, protected-access
711-
if self._has_non_affine_quantizer(self.sim.model):
712-
raise RuntimeError("Export using onnx only export only supports affine quantizers. "
713-
"Other quantizer types are not supported.")
748+
self._check_unsupported_quantizers(self.sim.model)
714749

715750
with tempfile.TemporaryDirectory() as tmp_dir:
716751
with _temporarily_unfold_param_quantizers(self.sim), \
717752
self.sim._concretize_int32_bias_quantizers(args), \
718-
self.sim._apply_qdq_to_model_parameters(self.sim.model):
753+
self.sim._apply_qdq_to_model_parameters(self.sim.model), \
754+
_remove_fp16_quantizers(self.sim):
719755
tmp_onnx_path = os.path.join(tmp_dir, "quantized_model.onnx")
720756
export(self.sim.model, args, tmp_onnx_path, *posargs, **kwargs)
721757
onnx_model = onnx.load(tmp_onnx_path)
@@ -745,30 +781,31 @@ def export(self,
745781
"encodings": [
746782
{"name": name, **qnn_encoding}
747783
for name, qnn_encoding in qnn_encodings.items()
784+
if qnn_encoding
748785
]
749786
})
750787
else:
751788
if quantsim.encoding_version >= "1.0.0":
752789
param_encodings = [
753790
{"name": name, **qnn_encoding}
754791
for name, qnn_encoding in qnn_encodings.items()
755-
if name in param_names
792+
if qnn_encoding and name in param_names
756793
]
757794
activation_encodings = [
758795
{"name": name, **qnn_encoding}
759796
for name, qnn_encoding in qnn_encodings.items()
760-
if name not in param_names
797+
if qnn_encoding and name not in param_names
761798
]
762799
else:
763800
param_encodings = {
764801
name: qnn_encoding
765802
for name, qnn_encoding in qnn_encodings.items()
766-
if name in param_names
803+
if qnn_encoding and name in param_names
767804
}
768805
activation_encodings = {
769806
name: qnn_encoding
770807
for name, qnn_encoding in qnn_encodings.items()
771-
if name not in param_names
808+
if qnn_encoding and name not in param_names
772809
}
773810

774811
encodings_dict.update({
@@ -787,15 +824,15 @@ def export(self,
787824
json.dump(encodings_dict, encoding_file, indent=2)
788825

789826
@staticmethod
790-
def _has_non_affine_quantizer(module: torch.nn.Module):
791-
for submodule in module.modules():
792-
if isinstance(submodule, _QuantizedModuleProtocol):
793-
for quantizer in itertools.chain(submodule.input_quantizers,
794-
submodule.output_quantizers,
795-
submodule.param_quantizers.values()):
796-
if quantizer and not isinstance(quantizer, AffineQuantizerBase):
797-
return True
798-
return False
827+
def _check_unsupported_quantizers(module: torch.nn.Module):
828+
for qtzr in module.modules():
829+
if isinstance(qtzr, FloatQuantizeDequantize):
830+
if not qtzr.is_float16() and not qtzr.is_bfloat16():
831+
msg = " ".join([
832+
"sim.onnx.export doesn't support exporting floating point encodings",
833+
f"except [b]float16. Got {qtzr.bitwidth}-bit float encoding",
834+
])
835+
raise RuntimeError(msg)
799836

800837

801838
@deprecated("""

TrainingExtensions/torch/test/python/v2/experimental/test_onnx.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -274,17 +274,26 @@ def test_export_torchvision_models(model_factory, input_shape):
274274

275275
@torch.no_grad()
276276
@pytest.mark.parametrize("encoding_version", ["0.6.1", "1.0.0", "2.0.0.beta"])
277-
@pytest.mark.parametrize("lpbq", (False, True))
278-
@pytest.mark.parametrize("fold_param_quantizers", (False, True))
279-
def test_quantsim_export_resnet18(encoding_version, lpbq: bool, fold_param_quantizers: bool):
277+
@pytest.mark.parametrize("lpbq", [False, True])
278+
@pytest.mark.parametrize("fold_param_quantizers", [False, True])
279+
@pytest.mark.parametrize(
280+
"weight_dtype, activation_dtype", [
281+
(torch.int8, torch.uint8),
282+
(torch.int8, torch.float16),
283+
(torch.float16, torch.float16),
284+
])
285+
def test_quantsim_export_resnet18(encoding_version, lpbq: bool, fold_param_quantizers: bool,
286+
weight_dtype: torch.dtype, activation_dtype: torch.dtype):
280287
"""
281288
When: Export quantized torchvision model using quantsim.export
282289
"""
283290
x = torch.randn(1, 3, 224, 224)
284291
model = resnet18().eval()
285292
model = prepare_model(model)
286293
fold_all_batch_norms(model, None, x)
287-
sim = QuantizationSimModel(model, x, config_file=get_path_for_per_channel_config())
294+
sim = QuantizationSimModel(model, x,
295+
default_param_bw=weight_dtype.itemsize * 8,
296+
default_output_bw=activation_dtype.itemsize * 8)
288297

289298
if lpbq:
290299
set_grouped_blockwise_quantization_for_weights(sim,
@@ -294,6 +303,26 @@ def test_quantsim_export_resnet18(encoding_version, lpbq: bool, fold_param_quant
294303
decompressed_bw=8,
295304
block_size=64)
296305

306+
if weight_dtype.is_floating_point:
307+
for qmodule in sim.qmodules():
308+
for name, qtzr in qmodule.param_quantizers.items():
309+
if not qtzr:
310+
continue
311+
qmodule.param_quantizers[name] = Q.float.FloatQuantizeDequantize(dtype=weight_dtype)
312+
313+
if activation_dtype.is_floating_point:
314+
for qmodule in sim.qmodules():
315+
for i, qtzr in enumerate(qmodule.input_quantizers):
316+
if not qtzr:
317+
continue
318+
qmodule.input_quantizers[i] = Q.float.FloatQuantizeDequantize(dtype=activation_dtype)
319+
320+
for qmodule in sim.qmodules():
321+
for i, qtzr in enumerate(qmodule.output_quantizers):
322+
if not qtzr:
323+
continue
324+
qmodule.output_quantizers[i] = Q.float.FloatQuantizeDequantize(dtype=activation_dtype)
325+
297326
sim.compute_encodings(lambda model: model(x))
298327

299328
# Compute original pytorch model output with qdq weights
@@ -302,19 +331,20 @@ def test_quantsim_export_resnet18(encoding_version, lpbq: bool, fold_param_quant
302331
f"{module_name}.{param_name}": qtzr.get_encodings().to_qnn_encoding_dict(encoding_version)
303332
for module_name, qmodule in sim.named_qmodules()
304333
for param_name, qtzr in qmodule.param_quantizers.items()
334+
if isinstance(qtzr, Q.affine.AffineQuantizerBase)
305335
}
306336
expected_activation_encodings = {}
307337
expected_activation_encodings.update({
308338
f"{module_name}.input_quantizers.{i}": qtzr.get_encodings().to_qnn_encoding_dict(encoding_version)
309339
for module_name, qmodule in sim.named_qmodules()
310340
for i, qtzr in enumerate(qmodule.input_quantizers)
311-
if qtzr is not None
341+
if isinstance(qtzr, Q.affine.AffineQuantizerBase)
312342
})
313343
expected_activation_encodings.update({
314344
f"{module_name}.output_quantizers.{i}": qtzr.get_encodings().to_qnn_encoding_dict(encoding_version)
315345
for module_name, qmodule in sim.named_qmodules()
316346
for i, qtzr in enumerate(qmodule.output_quantizers)
317-
if qtzr is not None
347+
if isinstance(qtzr, Q.affine.AffineQuantizerBase)
318348
})
319349

320350
with remove_activation_quantizers(sim.model):
@@ -407,6 +437,6 @@ def test_quantsim_export_resnet18(encoding_version, lpbq: bool, fold_param_quant
407437
the original pytorch model with qdq weights
408438
"""
409439
sess = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
410-
out, = sess.run(None, {onnx_model.graph.input[0].name: x.numpy()})
440+
out, = sess.run(None, {"input": x.numpy()})
411441

412442
assert torch.allclose(torch.from_numpy(out), expected_out, atol=1e-5)

0 commit comments

Comments
 (0)