Skip to content

Commit 557213d

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Fix nullptr error in FloatEncoding (#4753)
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent 6fef4ca commit 557213d

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/float/encoding.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
# pylint: disable=redefined-builtin
3838
""" Float encoding definition """
3939

40-
from typing import Union, List, Dict
40+
from typing import Union, List, Dict, Optional
4141
import torch
4242
from torch._C._nn import _parse_to as parse_to_args
4343

@@ -52,7 +52,7 @@ class FloatEncoding(EncodingBase):
5252
"""
5353
Encoding object for float quantization
5454
"""
55-
def __init__(self, mantissa_bits: int, exponent_bits: int, maxval: torch.Tensor):
55+
def __init__(self, mantissa_bits: int, exponent_bits: int, maxval: Optional[torch.Tensor]):
5656
self._mantissa_bits = mantissa_bits
5757
self._exponent_bits = exponent_bits
5858
self._maxval = maxval
@@ -109,19 +109,27 @@ def to(self, *args, **kwargs):
109109
Changes dtype of data in quantizer encoding or device where the data is.
110110
Behaves similar to torch.Tensor.to
111111
"""
112+
if self._maxval is None:
113+
return self
114+
115+
current_dtype = self._maxval.dtype
116+
current_device = self._maxval.device
117+
112118
to_args = parse_to_args(*args, **kwargs)
113119
device, dtype, _, _ = to_args
114-
dtype = dtype if dtype else self._maxval.dtype
115-
device = device if device else self._maxval.device
116120

117-
if dtype is self._maxval.dtype and device is self._maxval.device:
121+
dtype = dtype or current_dtype
122+
device = device or current_device
123+
124+
if dtype == current_dtype and device == current_device:
118125
return self
119126

120-
if not dtype.is_floating_point:
127+
if dtype and not dtype.is_floating_point:
121128
raise RuntimeError(f"Cannot change encoding data dtype to {dtype}, "
122129
"only floating point data types are supported")
123130

124131
maxval = self._maxval.to(dtype=dtype, device=device)
132+
125133
return type(self)(self._mantissa_bits, self._exponent_bits, maxval)
126134

127135
def quantize(self, input: torch.Tensor) -> torch.Tensor:

TrainingExtensions/torch/test/python/v2/quantization/float/test_float_quantizer.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
import numpy as np
4444
import warnings
4545
from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer
46-
from aimet_torch.v2.quantization.float import FloatQuantizeDequantize
46+
from aimet_torch.v2.quantization.float import FloatQuantizeDequantize, FloatEncoding
4747
from aimet_torch.v2.quantization.float.quantizer import _ieee_float_max_representable_value
4848
from aimet_torch.fp_quantization import fake_cast_to_ieee_float
4949

@@ -199,3 +199,39 @@ def test_onnx_export():
199199
qdq = FloatQuantizeDequantize(dtype=torch.float16)
200200
with tempfile.TemporaryFile() as f:
201201
torch.onnx.export(qdq, torch.randn(10, 10), f)
202+
203+
204+
def test_float_encoding_to():
205+
"""
206+
Given: FloatEncoding with maxval=None
207+
When: Call .to()
208+
Then: Should return identical object
209+
"""
210+
encoding = FloatEncoding(exponent_bits=5, mantissa_bits=10, maxval=None)
211+
new_encoding = encoding.to(device="cpu", dtype=torch.float16)
212+
assert new_encoding is encoding
213+
214+
"""
215+
Given: FloatEncoding with maxval=None
216+
"""
217+
encoding = FloatEncoding(exponent_bits=5,
218+
mantissa_bits=10,
219+
maxval=torch.tensor(124.))
220+
"""
221+
When: Call .to() with same dtype and device
222+
Then: Should return identical object
223+
"""
224+
new_encoding = encoding.to(device="cpu", dtype=torch.float32)
225+
assert new_encoding is encoding
226+
227+
"""
228+
When: Call .to() with new dtype and device
229+
Then: 1. New encoding object should be in proper dtype and device
230+
2. Old encoding object should not be affected
231+
"""
232+
new_encoding = encoding.to(device="cpu", dtype=torch.float16)
233+
assert new_encoding.maxval.device == torch.device("cpu")
234+
assert new_encoding.maxval.dtype == torch.float16
235+
236+
assert encoding.maxval.device == torch.device("cpu")
237+
assert encoding.maxval.dtype == torch.float32

0 commit comments

Comments
 (0)