3636# =============================================================================
3737""" Default quantization backend for quantizing weights and activations """
3838import functools
39- from typing import Callable , Optional , List
39+ from packaging import version
40+ from typing import Callable , Optional , List , Tuple
4041import torch
4142from aimet_torch .v2 .utils import _is_expandable , _ContextManager
4243import aimet_torch .v2 .experimental .onnx ._export as _onnx
43- from packaging import version
4444
4545
46- if version .parse (torch .__version__ ) >= version .parse ("2.0.0" ):
46+ _torch_version : Tuple [int , int , int ] = (version .parse (torch .__version__ ).major ,
47+ version .parse (torch .__version__ ).minor ,
48+ version .parse (torch .__version__ ).micro )
49+
50+ if _torch_version >= (2 , 0 , 0 ):
4751 _compile = torch .compile
4852else :
4953 _compile = lambda fn : fn
@@ -155,6 +159,8 @@ def quantize(tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor,
155159
156160
157161
162+ _ALLOW_FAST_FORWARD = True # temporary flag for debugging
163+
158164@_onnx .register_symbolic (_onnx .quantize_dequantize_symbolic )
159165def quantize_dequantize (tensor : torch .Tensor , scale : torch .Tensor , offset : torch .Tensor ,
160166 qmin : int , qmax : int , block_size : Optional [List ] = None ) -> torch .Tensor :
@@ -170,6 +176,26 @@ def quantize_dequantize(tensor: torch.Tensor, scale: torch.Tensor, offset: torch
170176 """
171177 _validate_arguments (tensor , scale , qmin , qmax , block_size )
172178
179+ _fast_forward = _ALLOW_FAST_FORWARD
180+
181+ # torch.fake_quantize doesn't support blockwise quantization
182+ _fast_forward &= block_size is None
183+
184+ # torch.fake_quantize doesn't support JIT tracing
185+ _fast_forward &= not torch .jit .is_tracing ()
186+
187+ # torch.fake_quantize doesn't compute gradients for scale/offset
188+ _fast_forward &= (not scale .requires_grad and not offset .requires_grad ) or (not torch .is_grad_enabled ())
189+
190+ # if user explicitly designated specific rounding function, honor it strictly
191+ _fast_forward &= (_round_fn == torch .round and _round_fn_inplace == torch .round_ )
192+
193+ if _fast_forward :
194+ ret = _torch_fake_quantize (tensor , scale , offset , qmin , qmax )
195+
196+ if ret is not None :
197+ return ret
198+
173199 output_dtype = internal_dtype = tensor .dtype
174200
175201 if not _is_numerically_stable (internal_dtype , qmin , qmax ):
@@ -190,6 +216,60 @@ def quantize_dequantize(tensor: torch.Tensor, scale: torch.Tensor, offset: torch
190216 offset .to (internal_dtype ),
191217 qmin , qmax ).to (output_dtype ).view (orig_tensor_shape )
192218
219+
220+ def _torch_fake_quantize (tensor : torch .Tensor ,
221+ scale : torch .Tensor ,
222+ offset : torch .Tensor ,
223+ qmin : int ,
224+ qmax : int ) -> Optional [torch .Tensor ]:
225+ scale_internal_dtype = torch .float32
226+ tensor_internal_dtype = tensor .dtype
227+
228+ if _torch_version < (2 , 6 , 0 ) and tensor_internal_dtype == torch .bfloat16 :
229+ # torch.fake_quantize only supports bfloat16 in >=2.6.0
230+ tensor_internal_dtype = torch .float32
231+
232+ is_per_tensor = scale .numel () == offset .numel () == 1
233+
234+ if is_per_tensor :
235+ return torch .fake_quantize_per_tensor_affine (tensor .to (tensor_internal_dtype ),
236+ scale .view (()).to (scale_internal_dtype ),
237+ - offset .to (torch .int32 ).view (()),
238+ qmin , qmax ).to (tensor .dtype )
239+
240+ scale = scale .view (* (1 for _ in range (tensor .dim () - scale .dim ())),
241+ * scale .shape )
242+ offset = offset .view (* (1 for _ in range (tensor .dim () - offset .dim ())),
243+ * offset .shape )
244+
245+ is_per_channel = scale .shape == offset .shape and all (
246+ scale_dim in (1 , tensor_dim )
247+ for scale_dim , tensor_dim
248+ in zip (scale .shape , tensor .shape )
249+ )
250+
251+ if is_per_channel :
252+ axes = [
253+ axis for axis , scale_dim in enumerate (scale .shape ) if scale_dim != 1
254+ ]
255+ assert axes
256+
257+ if len (axes ) == 1 :
258+ axis , = axes
259+ try :
260+ return torch .fake_quantize_per_channel_affine (tensor .to (tensor_internal_dtype ),
261+ scale .flatten ().to (scale_internal_dtype ),
262+ - offset .to (torch .int32 ).flatten (),
263+ axis , qmin , qmax ).to (tensor .dtype )
264+ except RuntimeError :
265+ # NOTE: torch.fake_quantize_per_channel_affine throws runtime error
266+ # if zero_point is not in [qmin, qmax]. In practice, this error will
267+ # almost never occur because per-channel quantization always uses zero_point=0
268+ return None
269+
270+ return None
271+
272+
193273@_onnx .register_symbolic (_onnx .dequantize_symbolic )
194274def dequantize (tensor : torch .Tensor , scale : torch .Tensor , offset : torch .Tensor , block_size : Optional [List ] = None ) \
195275 -> torch .Tensor :
0 commit comments