6161 save_checkpoint ,
6262 load_checkpoint ,
6363 check_accumulator_overflow ,
64- _QuantizedModuleProtocol ,
6564)
6665from aimet_torch .v2 import nn as aimet_nn
6766from aimet_torch .v2 .nn import BaseQuantizationMixin , QuantizationMixin , UnknownModuleError
7069from aimet_torch .v2 .quantization import DequantizedTensor
7170from aimet_torch .v2 .quantization .base import QuantizerBase
7271from aimet_torch .v2 .quantization .affine import AffineQuantizerBase
72+ from aimet_torch .v2 .quantization .float import FloatQuantizeDequantize
7373from aimet_torch .v2 .quantization .encoding_analyzer import PercentileEncodingAnalyzer
7474from aimet_torch .v2 .utils import patch_attr
7575from aimet_torch import utils
@@ -616,9 +616,11 @@ def _concretize_int32_bias_quantizers(self, args):
616616 # In this case, we honor the custom bias quantizer defined by the user
617617 continue
618618
619- # pylint: disable=protected-access
620- handle = qmodule .register_forward_hook (type (qmodule )._create_int32_bias_quantizer )
621- handles .append (handle )
619+ if "weight" in qmodule .param_quantizers and \
620+ isinstance (qmodule .param_quantizers ["weight" ], AffineQuantizerBase ):
621+ # pylint: disable=protected-access
622+ handle = qmodule .register_forward_hook (type (qmodule )._create_int32_bias_quantizer )
623+ handles .append (handle )
622624 try :
623625 self .model (* args )
624626 finally :
@@ -682,6 +684,41 @@ def _temporarily_unfold_param_quantizers(sim: QuantizationSimModel):
682684 qmodule ._fold_param_quantizers ()
683685
684686
687+ @contextlib .contextmanager
688+ def _remove_fp16_quantizers (sim : QuantizationSimModel ):
689+ """
690+ Temporarily remove [b]float16 quantizers for sim.onnx.export,
691+ as sim.onnx.export does NOT support exporting [b]float16 quantizers.
692+ """
693+ original_containers = {}
694+
695+ try :
696+ for qmodule in sim .qmodules ():
697+ for name , qtzr in qmodule .param_quantizers .items ():
698+ if isinstance (qtzr , FloatQuantizeDequantize ) and \
699+ (qtzr .is_float16 () or qtzr .is_bfloat16 ()):
700+ original_containers [(qmodule .param_quantizers , name )] = qtzr
701+ qmodule .param_quantizers [name ] = None
702+
703+ for i , qtzr in enumerate (qmodule .input_quantizers ):
704+ if isinstance (qtzr , FloatQuantizeDequantize ) and \
705+ (qtzr .is_float16 () or qtzr .is_bfloat16 ()):
706+ original_containers [(qmodule .input_quantizers , i )] = qtzr
707+ qmodule .input_quantizers [i ] = None
708+
709+ for i , qtzr in enumerate (qmodule .output_quantizers ):
710+ if isinstance (qtzr , FloatQuantizeDequantize ) and \
711+ (qtzr .is_float16 () or qtzr .is_bfloat16 ()):
712+ original_containers [(qmodule .output_quantizers , i )] = qtzr
713+ qmodule .output_quantizers [i ] = None
714+
715+ yield
716+
717+ finally :
718+ for (container , key ), qtzr in original_containers .items ():
719+ container [key ] = qtzr
720+
721+
685722class _QuantizationSimOnnxExport :
686723 """
687724 Helper class for exporting quantized models to ONNX format.
@@ -708,14 +745,13 @@ def export(self,
708745 :param f: file object or path where to store exported ONNX mode
709746 """
710747 # pylint: disable=too-many-locals, too-many-branches, protected-access
711- if self ._has_non_affine_quantizer (self .sim .model ):
712- raise RuntimeError ("Export using onnx only export only supports affine quantizers. "
713- "Other quantizer types are not supported." )
748+ self ._check_unsupported_quantizers (self .sim .model )
714749
715750 with tempfile .TemporaryDirectory () as tmp_dir :
716751 with _temporarily_unfold_param_quantizers (self .sim ), \
717752 self .sim ._concretize_int32_bias_quantizers (args ), \
718- self .sim ._apply_qdq_to_model_parameters (self .sim .model ):
753+ self .sim ._apply_qdq_to_model_parameters (self .sim .model ), \
754+ _remove_fp16_quantizers (self .sim ):
719755 tmp_onnx_path = os .path .join (tmp_dir , "quantized_model.onnx" )
720756 export (self .sim .model , args , tmp_onnx_path , * posargs , ** kwargs )
721757 onnx_model = onnx .load (tmp_onnx_path )
@@ -745,30 +781,31 @@ def export(self,
745781 "encodings" : [
746782 {"name" : name , ** qnn_encoding }
747783 for name , qnn_encoding in qnn_encodings .items ()
784+ if qnn_encoding
748785 ]
749786 })
750787 else :
751788 if quantsim .encoding_version >= "1.0.0" :
752789 param_encodings = [
753790 {"name" : name , ** qnn_encoding }
754791 for name , qnn_encoding in qnn_encodings .items ()
755- if name in param_names
792+ if qnn_encoding and name in param_names
756793 ]
757794 activation_encodings = [
758795 {"name" : name , ** qnn_encoding }
759796 for name , qnn_encoding in qnn_encodings .items ()
760- if name not in param_names
797+ if qnn_encoding and name not in param_names
761798 ]
762799 else :
763800 param_encodings = {
764801 name : qnn_encoding
765802 for name , qnn_encoding in qnn_encodings .items ()
766- if name in param_names
803+ if qnn_encoding and name in param_names
767804 }
768805 activation_encodings = {
769806 name : qnn_encoding
770807 for name , qnn_encoding in qnn_encodings .items ()
771- if name not in param_names
808+ if qnn_encoding and name not in param_names
772809 }
773810
774811 encodings_dict .update ({
@@ -787,15 +824,15 @@ def export(self,
787824 json .dump (encodings_dict , encoding_file , indent = 2 )
788825
789826 @staticmethod
790- def _has_non_affine_quantizer (module : torch .nn .Module ):
791- for submodule in module .modules ():
792- if isinstance (submodule , _QuantizedModuleProtocol ):
793- for quantizer in itertools . chain ( submodule . input_quantizers ,
794- submodule . output_quantizers ,
795- submodule . param_quantizers . values ()):
796- if quantizer and not isinstance ( quantizer , AffineQuantizerBase ):
797- return True
798- return False
827+ def _check_unsupported_quantizers (module : torch .nn .Module ):
828+ for qtzr in module .modules ():
829+ if isinstance (qtzr , FloatQuantizeDequantize ):
830+ if not qtzr . is_float16 () and not qtzr . is_bfloat16 ():
831+ msg = " " . join ([
832+ "sim.onnx.export doesn't support exporting floating point encodings" ,
833+ f"except [b]float16. Got { qtzr . bitwidth } -bit float encoding" ,
834+ ])
835+ raise RuntimeError ( msg )
799836
800837
801838@deprecated ("""
0 commit comments