4141import tempfile
4242from pathlib import Path
4343import os
44- from typing import Any , Callable , Dict , List , Optional , overload , Tuple , TypeVar , Union
44+ from typing import Any , Callable , Dict , List , Optional , overload , Tuple , TypeVar , Union , Sequence
4545import itertools
4646import json
4747import warnings
@@ -160,39 +160,92 @@ class QuantizationSimModel:
160160 :param model: ONNX model
161161 :param dummy_input: Dummy input to the model. If None, will attempt to auto-generate a dummy input
162162 :param quant_scheme: Quantization scheme (e.g. QuantScheme.post_training_tf)
163- :param rounding_mode: Rounding mode (e.g. nearest)
163+ :param rounding_mode: Deprecated
164164 :param default_param_bw: Quantization bitwidth for parameter
165165 :param default_activation_bw: Quantization bitwidth for activation
166- :param use_symmetric_encodings: True if symmetric encoding is used. False otherwise.
167- :param use_cuda: True if using CUDA to run quantization op. False otherwise.
166+ :param use_symmetric_encodings: Deprecated, symmetry is controlled by the config_file
167+ :param use_cuda: Deprecated, use `providers` instead
168168 :param config_file: File path or alias of the configuration file.
169169 Alias can be one of {{ { ', ' .join (_config_file_aliases .keys ())} }} (Default: `"default"`)
170170 :param default_data_type: Default data type to use for quantizing all layer inputs, outputs and parameters.
171171 Possible options are QuantizationDataType.int and QuantizationDataType.float.
172172 Note that the mode default_data_type=QuantizationDataType.float is only supported with
173173 default_output_bw=16 and default_param_bw=16
174174 :param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries
175+ :param providers: Onnxruntime execution providers to use when building InferenceSession.
176+ If `None`, falls back to `onnxruntime.get_available_providers()`
175177 :param path: Directory to save the artifacts.
176178 """
177179
178180 def __init__ (self ,
179181 model : Union [ModelProto , ONNXModel ],
180- dummy_input : Dict [str , np .ndarray ] = None ,
182+ dummy_input : Optional [ Dict [str , np .ndarray ] ] = None ,
181183 quant_scheme : QuantScheme = QuantScheme .min_max ,
182- rounding_mode : str = 'nearest' ,
184+ rounding_mode : str = None , # Deprecated
183185 default_param_bw : int = 8 ,
184186 default_activation_bw : int = 8 ,
185- use_symmetric_encodings : bool = False , use_cuda : bool = True ,
186- device : int = 0 , config_file : str = None ,
187+ use_symmetric_encodings : bool = None , # Deprecated
188+ use_cuda : bool = None , # Deprecated
189+ device : int = None , # Deprecated
190+ config_file : Optional [str ] = None ,
187191 default_data_type : QuantizationDataType = QuantizationDataType .int ,
188- user_onnx_libs : List [str ] = None , path : str = None ):
192+ user_onnx_libs : List [str ] = None ,
193+ providers : Optional [Sequence [str | Tuple [str , Dict [Any , Any ]]]] = None ,
194+ path : Optional [str ] = None ):
195+ # pylint: disable = too-many-branches, too-many-statements
196+ if rounding_mode is not None :
197+ if rounding_mode == 'nearest' :
198+ warnings .warn (_red ("Passing rounding_mode='nearest' is no longer needed " \
199+ "and will be deprecated soon in the later versions." ),
200+ DeprecationWarning , stacklevel = 2 )
201+ else :
202+ raise TypeError ("'rounding_mode' parameter is no longer supported." )
203+
204+ if use_symmetric_encodings is not None :
205+ warnings .warn (_red ("Passing `use_symmetric_encodings` is not needed and will be deprecated in later versions." ),
206+ DeprecationWarning , stacklevel = 2 )
207+
208+ if device is not None :
209+ warnings .warn (_red ("Passing `device` will be deprecated in later versions. " \
210+ "Please use the `providers` argument instead to specify cuda device." ),
211+ DeprecationWarning , stacklevel = 2 )
212+ if providers is not None :
213+ raise RuntimeError ("Cannot provide `device` and `providers` at the same time." )
214+
215+ if use_cuda is not None :
216+ warnings .warn (_red ("Passing `use_cuda` will be deprecated in later versions. " \
217+ "Please use the `providers` argument instead." ),
218+ DeprecationWarning , stacklevel = 2 )
219+ if providers is not None :
220+ raise RuntimeError ("Cannot provide `use_cuda` and `providers` at the same time." )
221+
222+ # Legacy behavior of use_cuda
223+ if "CUDAExecutionProvider" not in ort .get_available_providers ():
224+ use_cuda = False
225+
226+ device = device or 0
227+ if use_cuda :
228+ providers = [('CUDAExecutionProvider' , {'device_id' : device }), 'CPUExecutionProvider' ]
229+ else :
230+ providers = ['CPUExecutionProvider' ]
231+
232+ if not providers :
233+ providers = ort .get_available_providers ()
234+
189235 if isinstance (quant_scheme , str ):
190236 quant_scheme = QuantScheme .from_str (quant_scheme )
191237
192238 if isinstance (model , ModelProto ):
193239 model = ONNXModel (model )
194240
241+ op_domain = "aimet.customop.cpu"
242+ for provider in providers :
243+ if provider == "CUDAExecutionProvider" or provider [0 ] == "CUDAExecutionProvider" :
244+ op_domain = "aimet.customop.cuda"
245+
195246 self .model = model
247+ self ._op_domain = op_domain
248+ self .providers = providers
196249
197250 if not dummy_input :
198251 dummy_input = make_dummy_input (self .model .model )
@@ -204,16 +257,6 @@ def __init__(self,
204257 self ._default_param_bw = default_param_bw
205258 self ._default_activation_bw = default_activation_bw
206259 self ._default_quantization_data_type = default_data_type
207- self ._use_symmetric_encodings = use_symmetric_encodings
208- self ._use_cuda = use_cuda
209- if 'CUDAExecutionProvider' not in ort .get_available_providers ():
210- self ._use_cuda = False
211- if self ._use_cuda :
212- self ._op_domain = "aimet.customop.cuda"
213- self .providers = [('CUDAExecutionProvider' , {'device_id' : device , 'cudnn_conv_algo_search' : 'DEFAULT' }), 'CPUExecutionProvider' ]
214- else :
215- self ._op_domain = "aimet.customop.cpu"
216- self .providers = ['CPUExecutionProvider' ]
217260 self ._user_onnx_libs = user_onnx_libs
218261 self .param_names = []
219262 self .input_quantizers_name = []
@@ -465,7 +508,7 @@ def _insert_param_quantization_nodes(self):
465508 rounding_mode = self ._rounding_mode ,
466509 op_mode = OpMode .oneShotQuantizeDequantize ,
467510 bitwidth = self ._default_param_bw ,
468- use_symmetric_encodings = self . _use_symmetric_encodings ,
511+ use_symmetric_encodings = False ,
469512 tensor_quantizer_params = tensor_quantizer_params )
470513
471514 def _create_quant_info_object_for_param (self , param_name : str ):
@@ -533,7 +576,7 @@ def _insert_activation_quantization_nodes(self):
533576 rounding_mode = self ._rounding_mode ,
534577 op_mode = OpMode .updateStats ,
535578 bitwidth = self ._default_activation_bw ,
536- use_symmetric_encodings = self . _use_symmetric_encodings )
579+ use_symmetric_encodings = False )
537580
538581 @staticmethod
539582 def build_session (model : onnx .ModelProto , providers : List , user_onnx_libs : List [str ] = None , path : str = None ):
0 commit comments