Skip to content

Commit a444995

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Delegate quantization kernel to torch.fake_quantize
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent 8acfe33 commit a444995

File tree

2 files changed

+167
-3
lines changed

2 files changed

+167
-3
lines changed

TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/backends/torch_builtins.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,18 @@
3636
# =============================================================================
3737
""" Default quantization backend for quantizing weights and activations """
3838
import functools
39-
from typing import Callable, Optional, List
39+
from packaging import version
40+
from typing import Callable, Optional, List, Tuple
4041
import torch
4142
from aimet_torch.v2.utils import _is_expandable, _ContextManager
4243
import aimet_torch.v2.experimental.onnx._export as _onnx
43-
from packaging import version
4444

4545

46-
if version.parse(torch.__version__) >= version.parse("2.0.0"):
46+
_torch_version: Tuple[int, int, int] = (version.parse(torch.__version__).major,
47+
version.parse(torch.__version__).minor,
48+
version.parse(torch.__version__).micro)
49+
50+
if _torch_version >= (2, 0, 0):
4751
_compile = torch.compile
4852
else:
4953
_compile = lambda fn: fn
@@ -155,6 +159,8 @@ def quantize(tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor,
155159

156160

157161

162+
_ALLOW_FAST_FORWARD = True # temporary flag for debugging
163+
158164
@_onnx.register_symbolic(_onnx.quantize_dequantize_symbolic)
159165
def quantize_dequantize(tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor,
160166
qmin: int, qmax: int, block_size: Optional[List] = None) -> torch.Tensor:
@@ -170,6 +176,26 @@ def quantize_dequantize(tensor: torch.Tensor, scale: torch.Tensor, offset: torch
170176
"""
171177
_validate_arguments(tensor, scale, qmin, qmax, block_size)
172178

179+
_fast_forward = _ALLOW_FAST_FORWARD
180+
181+
# torch.fake_quantize doesn't support blockwise quantization
182+
_fast_forward &= block_size is None
183+
184+
# torch.fake_quantize doesn't support JIT tracing
185+
_fast_forward &= not torch.jit.is_tracing()
186+
187+
# torch.fake_quantize doesn't compute gradients for scale/offset
188+
_fast_forward &= (not scale.requires_grad and not offset.requires_grad) or (not torch.is_grad_enabled())
189+
190+
# if user explicitly designated specific rounding function, honor it strictly
191+
_fast_forward &= (_round_fn == torch.round and _round_fn_inplace == torch.round_)
192+
193+
if _fast_forward:
194+
ret = _torch_fake_quantize(tensor, scale, offset, qmin, qmax)
195+
196+
if ret is not None:
197+
return ret
198+
173199
output_dtype = internal_dtype = tensor.dtype
174200

175201
if not _is_numerically_stable(internal_dtype, qmin, qmax):
@@ -190,6 +216,60 @@ def quantize_dequantize(tensor: torch.Tensor, scale: torch.Tensor, offset: torch
190216
offset.to(internal_dtype),
191217
qmin, qmax).to(output_dtype).view(orig_tensor_shape)
192218

219+
220+
def _torch_fake_quantize(tensor: torch.Tensor,
221+
scale: torch.Tensor,
222+
offset: torch.Tensor,
223+
qmin: int,
224+
qmax: int) -> Optional[torch.Tensor]:
225+
scale_internal_dtype = torch.float32
226+
tensor_internal_dtype = tensor.dtype
227+
228+
if _torch_version < (2, 6, 0) and tensor_internal_dtype == torch.bfloat16:
229+
# torch.fake_quantize only supports bfloat16 in >=2.6.0
230+
tensor_internal_dtype = torch.float32
231+
232+
is_per_tensor = scale.numel() == offset.numel() == 1
233+
234+
if is_per_tensor:
235+
return torch.fake_quantize_per_tensor_affine(tensor.to(tensor_internal_dtype),
236+
scale.view(()).to(scale_internal_dtype),
237+
-offset.to(torch.int32).view(()),
238+
qmin, qmax).to(tensor.dtype)
239+
240+
scale = scale.view(*(1 for _ in range(tensor.dim() - scale.dim())),
241+
*scale.shape)
242+
offset = offset.view(*(1 for _ in range(tensor.dim() - offset.dim())),
243+
*offset.shape)
244+
245+
is_per_channel = scale.shape == offset.shape and all(
246+
scale_dim in (1, tensor_dim)
247+
for scale_dim, tensor_dim
248+
in zip(scale.shape, tensor.shape)
249+
)
250+
251+
if is_per_channel:
252+
axes = [
253+
axis for axis, scale_dim in enumerate(scale.shape) if scale_dim != 1
254+
]
255+
assert axes
256+
257+
if len(axes) == 1:
258+
axis, = axes
259+
try:
260+
return torch.fake_quantize_per_channel_affine(tensor.to(tensor_internal_dtype),
261+
scale.flatten().to(scale_internal_dtype),
262+
-offset.to(torch.int32).flatten(),
263+
axis, qmin, qmax).to(tensor.dtype)
264+
except RuntimeError:
265+
# NOTE: torch.fake_quantize_per_channel_affine throws runtime error
266+
# if zero_point is not in [qmin, qmax]. In practice, this error will
267+
# almost never occur because per-channel quantization always uses zero_point=0
268+
return None
269+
270+
return None
271+
272+
193273
@_onnx.register_symbolic(_onnx.dequantize_symbolic)
194274
def dequantize(tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor, block_size: Optional[List] = None) \
195275
-> torch.Tensor:

TrainingExtensions/torch/test/python/v2/quantization/affine/backends/test_backend.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,3 +737,87 @@ def test_invalid_block_size(self, backend_module):
737737
block_size=[1, 3])
738738
backend_module._validate_arguments(torch.randn(1, 4), torch.randn(1, 2), torch.randn(1, 2),
739739
block_size=[1, 2])
740+
741+
742+
743+
@pytest.mark.parametrize(
744+
"qmin, qmax, offset", [
745+
(-8, 7, 0),
746+
(0, 15, 0),
747+
(0, 15, 8),
748+
(-128, 127, 0),
749+
(0, 255, 0),
750+
(0, 255, 128),
751+
(-2**15, 2**15-1, 0),
752+
(0, 2**16-1, 0),
753+
(0, 2**16-1, 2**15),
754+
])
755+
@pytest.mark.parametrize("device", [
756+
"cpu",
757+
*(("cuda",) if torch.cuda.is_available() else ())
758+
])
759+
@pytest.mark.parametrize("dtype", [
760+
torch.float32,
761+
torch.float16,
762+
torch.bfloat16,
763+
])
764+
def test_cross_validate_torch_fake_quantize(qmin, qmax, offset, dtype, device):
765+
"""
766+
Given same inputs, the following three functions should always produce the same output
767+
* quantize_dequantize
768+
* QuantDequantFunc.apply
769+
* _torch_fake_quantize
770+
"""
771+
scale = torch.tensor([0.1], dtype=torch.float32, device=device)
772+
offset = torch.tensor([offset], dtype=torch.float32, device=device)
773+
tensor = scale * torch.tensor([
774+
qmin - .5, qmin, qmin + .5, qmax - .5, qmax, qmax + .5
775+
], device=device)
776+
tensor = tensor.to(dtype)
777+
778+
expected = tensor.to(torch.float32)\
779+
.div(scale)\
780+
.round()\
781+
.sub(offset)\
782+
.clamp(qmin, qmax)\
783+
.add(offset)\
784+
.mul(scale)\
785+
.to(dtype)
786+
787+
# Allow off-by-one error for float16 and bfloat16
788+
atol = scale.item() if dtype in (torch.float16, torch.bfloat16) else 1e-8
789+
790+
out1 = torch_builtins.quantize_dequantize(tensor, scale, offset, qmin, qmax)
791+
out2 = torch_builtins.QuantDequantFunc.apply(tensor, scale, offset, qmin, qmax).to(dtype)
792+
out3 = torch_builtins._torch_fake_quantize(tensor, scale, offset, qmin, qmax)
793+
794+
assert torch.allclose(out1, expected, atol=atol)
795+
assert torch.allclose(out2, expected, atol=atol)
796+
if out3 is not None:
797+
assert torch.allclose(out3, expected, atol=atol)
798+
799+
scale = torch.stack([
800+
scale,
801+
scale,
802+
])
803+
offset = torch.stack([
804+
offset,
805+
offset,
806+
])
807+
tensor = torch.stack([
808+
tensor,
809+
tensor
810+
])
811+
expected = torch.stack([
812+
expected,
813+
expected,
814+
])
815+
816+
out1 = torch_builtins.quantize_dequantize(tensor, scale, offset, qmin, qmax)
817+
out2 = torch_builtins.QuantDequantFunc.apply(tensor, scale, offset, qmin, qmax).to(dtype)
818+
out3 = torch_builtins._torch_fake_quantize(tensor, scale, offset, qmin, qmax)
819+
820+
assert torch.allclose(out1, expected, atol=atol)
821+
assert torch.allclose(out2, expected, atol=atol)
822+
if out3 is not None:
823+
assert torch.allclose(out3, expected, atol=atol)

0 commit comments

Comments
 (0)