Skip to content

Commit 5962a36

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Pull up compute_encodings and get_scale/offset/min/max to AffineQuantizerBase
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent aaec17f commit 5962a36

File tree

1 file changed

+67
-122
lines changed
  • TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine

1 file changed

+67
-122
lines changed

TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py

Lines changed: 67 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
# pylint: disable=redefined-builtin
3838
""" Affine quantizers """
3939

40-
import abc
4140
from itertools import chain, repeat
4241
from typing import Optional, List, Dict, Tuple, overload
4342
import contextlib
@@ -160,67 +159,92 @@ def __init__(self, shape, *args, **kwargs):
160159
raise RuntimeError(f'Encoding analyzer of shape {self.encoding_analyzer.observer.shape} '
161160
f'is incompatible with quantizer of shape {self.shape}.')
162161

163-
@abc.abstractmethod
164-
def get_min(self, dtype=None) -> torch.Tensor:
162+
self.register_quantization_parameter('min', nn.Parameter(-torch.ones(self.shape)))
163+
self.register_quantization_parameter('max', nn.Parameter(torch.ones(self.shape)))
164+
165+
def get_min(self, dtype=None) -> Optional[torch.Tensor]:
165166
"""
166167
Compute quantization min to be used for forward pass.
167-
Return None f the quantizer is not initialized yet.
168168
169-
Args:
170-
dtype (torch.dtype): dtype of the computed min
171-
172-
Returns:
173-
Quantization min
169+
NOTE: self.min may not be equal to self.get_min().
170+
self.get_min() returns slightly recalibrated version of self.min.
174171
172+
:param dtype: dtype of the computed min. Use of self.min.dtype by default.
173+
:return: Quantization min
175174
"""
175+
if not self.is_initialized():
176+
return None
177+
return self.get_scale(dtype) * (self.get_offset(dtype) + self.qmin)
176178

177-
@abc.abstractmethod
178-
def get_max(self, dtype=None) -> torch.Tensor:
179+
def get_max(self, dtype=None) -> Optional[torch.Tensor]:
179180
"""
180181
Compute quantization max to be used for forward pass.
181-
Return None f the quantizer is not initialized yet.
182-
183-
Args:
184-
dtype (torch.dtype): dtype of the computed max
185182
186-
Returns:
187-
Quantization max
183+
NOTE: self.max may not be equal to self.get_max()
184+
self.get_max() returns slightly recalibrated version of self.max.
188185
186+
:param dtype: dtype of the computed max. Use of self.min.dtype by default.
187+
:return: Quantization max
189188
"""
189+
if not self.is_initialized():
190+
return None
191+
return self.get_scale(dtype) * (self.get_offset(dtype) + self.qmax)
190192

191-
@abc.abstractmethod
192-
def get_scale(self, dtype=None) -> torch.Tensor:
193+
194+
def get_scale(self, dtype=None) -> Optional[torch.Tensor]:
193195
"""
194196
Compute quantization scale to be used for forward pass.
195-
Return None f the quantizer is not initialized yet.
197+
Return None if the quantizer is not initialized yet.
196198
197199
Args:
198200
dtype (torch.dtype): dtype of the computed scale
199201
200202
Returns:
201203
Quantization scale
202-
203204
"""
205+
if not self.is_initialized():
206+
return None
207+
208+
dtype = dtype or torch.float32
209+
num_steps = self.qmax - self.qmin
210+
211+
scale = (self.max.to(dtype) - self.min.to(dtype)) / num_steps
212+
return scale.to(dtype)
204213

205-
@abc.abstractmethod
206-
def get_offset(self, dtype=None) -> torch.Tensor:
214+
def get_offset(self, dtype=None) -> Optional[torch.Tensor]:
207215
"""
208216
Compute quantization offset to be used for forward pass.
209-
Return None f the quantizer is not initialized yet.
217+
Return None if the quantizer is not initialized yet.
210218
211219
Args:
212220
dtype (torch.dtype): dtype of the computed offset
213221
214222
Returns:
215223
Quantization offset
216-
217224
"""
225+
if not self.is_initialized():
226+
return None
227+
228+
dtype = dtype or torch.float32
229+
230+
if self.symmetric:
231+
offset = torch.full_like(self.min,
232+
fill_value=-round((self.qmin + self.qmax) / 2),
233+
requires_grad=False,
234+
dtype=dtype)
235+
else:
236+
offset = ste_round(self.min.to(dtype) / self.get_scale(dtype)) - self.qmin
237+
238+
return offset.to(dtype)
218239

219-
@abc.abstractmethod
240+
@torch.no_grad()
220241
def set_range(self, min: torch.Tensor, max: torch.Tensor):
221242
"""
222243
Set quantization parameters to the given min-max range
223244
"""
245+
with SafeGatheredParameters(self.parameters(recurse=False), modifier_rank=0):
246+
self.min.copy_(min)
247+
self.max.copy_(max)
224248

225249
def get_encodings(self) -> Optional[AffineEncoding]:
226250
"""
@@ -333,21 +357,6 @@ def signed(self) -> bool: # pylint: disable=missing-function-docstring
333357
def signed(self, signed: bool):
334358
self._set_signed(signed)
335359

336-
337-
class MinMaxQuantizer(AffineQuantizerBase): # pylint: disable=abstract-method
338-
"""
339-
Affine quantizer with min-max as trainable parameters
340-
"""
341-
342-
min: torch.nn.Parameter
343-
max: torch.nn.Parameter
344-
345-
def __init__(self, *args, **kwargs):
346-
super().__init__(*args, **kwargs)
347-
348-
self.register_quantization_parameter('min', nn.Parameter(-torch.ones(self.shape)))
349-
self.register_quantization_parameter('max', nn.Parameter(torch.ones(self.shape)))
350-
351360
@contextlib.contextmanager
352361
def compute_encodings(self):
353362
"""
@@ -360,24 +369,28 @@ def compute_encodings(self):
360369
return
361370

362371
original_forward = self.forward
372+
shape = self.shape
373+
374+
try:
375+
dtype, device = next((p.dtype, p.device) for p in self.parameters())
376+
except StopIteration as e:
377+
raise RuntimeError from e
363378

364379
@functools.wraps(original_forward)
365380
def forward_wrapper(input):
366381
input = input.as_subclass(torch.Tensor)
367-
expanded_input = torch_builtins.reshape_tensor_for_blocks(input, self.shape, self.block_size)
382+
expanded_input = torch_builtins.reshape_tensor_for_blocks(input, shape, self.block_size)
368383
batch_statistics = self.encoding_analyzer.update_stats(expanded_input)
369384
num_steps = self.qmax - self.qmin
370385
dynamic_min, dynamic_max =\
371386
self.encoding_analyzer.compute_encodings_from_stats(batch_statistics,
372387
num_steps,
373388
self.symmetric)
374389
if self.block_size is not None:
375-
dynamic_min = dynamic_min.view(self.min.shape)
376-
dynamic_max = dynamic_max.view(self.max.shape)
377-
dynamic_min = dynamic_min.to(dtype=self.min.dtype,
378-
device=self.min.device).expand_as(self.min)
379-
dynamic_max = dynamic_max.to(dtype=self.max.dtype,
380-
device=self.max.device).expand_as(self.max)
390+
dynamic_min = dynamic_min.view(shape)
391+
dynamic_max = dynamic_max.view(shape)
392+
dynamic_min = dynamic_min.to(dtype=dtype, device=device).expand(shape)
393+
dynamic_max = dynamic_max.to(dtype=dtype, device=device).expand(shape)
381394

382395
with patch_attr(self, 'min', dynamic_min),\
383396
patch_attr(self, 'max', dynamic_max):
@@ -395,8 +408,8 @@ def forward_wrapper(input):
395408
num_steps = self.qmax - self.qmin
396409
enc_min, enc_max = self.encoding_analyzer.compute_encodings(num_steps, self.symmetric)
397410
if self.block_size is not None:
398-
enc_min = enc_min.view(self.min.shape)
399-
enc_max = enc_max.view(self.max.shape)
411+
enc_min = enc_min.view(shape)
412+
enc_max = enc_max.view(shape)
400413
_flag_extreme_min_max(enc_min, enc_max)
401414

402415
except StatisticsNotFoundError:
@@ -407,79 +420,11 @@ def forward_wrapper(input):
407420

408421
self.set_range(enc_min, enc_max)
409422

410-
def get_min(self, dtype=None) -> Optional[torch.Tensor]:
411-
"""
412-
Compute quantization min to be used for forward pass.
413-
414-
NOTE: self.min may not be equal to self.get_min().
415-
self.get_min() returns slightly recalibrated version of self.min.
416-
417-
:param dtype: dtype of the computed min. Use of self.min.dtype by default.
418-
:return: Quantization min
419-
"""
420-
if not self.is_initialized():
421-
return None
422-
return self.get_scale(dtype) * (self.get_offset(dtype) + self.qmin)
423-
424-
def get_max(self, dtype=None) -> Optional[torch.Tensor]:
425-
"""
426-
Compute quantization max to be used for forward pass.
427-
428-
NOTE: self.max may not be equal to self.get_max()
429-
self.get_max() returns slightly recalibrated version of self.max.
430-
431-
:param dtype: dtype of the computed max. Use of self.min.dtype by default.
432-
:return: Quantization max
433-
"""
434-
if not self.is_initialized():
435-
return None
436-
return self.get_scale(dtype) * (self.get_offset(dtype) + self.qmax)
437-
438-
def get_scale(self, dtype=None) -> Optional[torch.Tensor]:
439-
"""
440-
Compute quantization scale to be used for forward pass.
441-
442-
:param dtype: dtype of the computed scale. Use of self.min.dtype by default.
443-
:return: Quantization scale
444-
"""
445-
if not self.is_initialized():
446-
return None
447-
448-
dtype = dtype or torch.float32
449-
num_steps = self.qmax - self.qmin
450-
451-
scale = (self.max.to(dtype) - self.min.to(dtype)) / num_steps
452-
return scale.to(dtype)
453-
454-
def get_offset(self, dtype=None) -> Optional[torch.Tensor]:
455-
"""
456-
Compute quantization offset to be used for forward pass.
457423

458-
:param dtype: dtype of the computed offset. Use of self.min.dtype by default.
459-
:return: Quantization offset
460-
"""
461-
if not self.is_initialized():
462-
return None
463-
464-
dtype = dtype or torch.float32
465-
466-
if self.symmetric:
467-
offset = torch.full_like(self.min,
468-
fill_value=-round((self.qmin + self.qmax) / 2),
469-
requires_grad=False,
470-
dtype=dtype)
471-
else:
472-
offset = ste_round(self.min.to(dtype) / self.get_scale(dtype)) - self.qmin
473-
474-
return offset.to(dtype)
475-
476-
def set_range(self, min: torch.Tensor, max: torch.Tensor):
477-
"""
478-
Set quantization parameters to the given min-max range
479-
"""
480-
with torch.no_grad(), SafeGatheredParameters(self.parameters(recurse=False), modifier_rank=0):
481-
self.min.copy_(min)
482-
self.max.copy_(max)
424+
class MinMaxQuantizer(AffineQuantizerBase): # pylint: disable=abstract-method
425+
"""
426+
Affine quantizer with min-max as trainable parameters
427+
"""
483428

484429

485430
class Quantize(MinMaxQuantizer):

0 commit comments

Comments
 (0)