3737# pylint: disable=redefined-builtin
3838""" Affine quantizers """
3939
40- import abc
4140from itertools import chain , repeat
4241from typing import Optional , List , Dict , Tuple , overload
4342import contextlib
@@ -160,67 +159,92 @@ def __init__(self, shape, *args, **kwargs):
160159 raise RuntimeError (f'Encoding analyzer of shape { self .encoding_analyzer .observer .shape } '
161160 f'is incompatible with quantizer of shape { self .shape } .' )
162161
163- @abc .abstractmethod
164- def get_min (self , dtype = None ) -> torch .Tensor :
162+ self .register_quantization_parameter ('min' , nn .Parameter (- torch .ones (self .shape )))
163+ self .register_quantization_parameter ('max' , nn .Parameter (torch .ones (self .shape )))
164+
165+ def get_min (self , dtype = None ) -> Optional [torch .Tensor ]:
165166 """
166167 Compute quantization min to be used for forward pass.
167- Return None f the quantizer is not initialized yet.
168168
169- Args:
170- dtype (torch.dtype): dtype of the computed min
171-
172- Returns:
173- Quantization min
169+ NOTE: self.min may not be equal to self.get_min().
170+ self.get_min() returns slightly recalibrated version of self.min.
174171
172+ :param dtype: dtype of the computed min. Use of self.min.dtype by default.
173+ :return: Quantization min
175174 """
175+ if not self .is_initialized ():
176+ return None
177+ return self .get_scale (dtype ) * (self .get_offset (dtype ) + self .qmin )
176178
177- @abc .abstractmethod
178- def get_max (self , dtype = None ) -> torch .Tensor :
179+ def get_max (self , dtype = None ) -> Optional [torch .Tensor ]:
179180 """
180181 Compute quantization max to be used for forward pass.
181- Return None f the quantizer is not initialized yet.
182-
183- Args:
184- dtype (torch.dtype): dtype of the computed max
185182
186- Returns:
187- Quantization max
183+ NOTE: self.max may not be equal to self.get_max()
184+ self.get_max() returns slightly recalibrated version of self. max.
188185
186+ :param dtype: dtype of the computed max. Use of self.min.dtype by default.
187+ :return: Quantization max
189188 """
189+ if not self .is_initialized ():
190+ return None
191+ return self .get_scale (dtype ) * (self .get_offset (dtype ) + self .qmax )
190192
191- @ abc . abstractmethod
192- def get_scale (self , dtype = None ) -> torch .Tensor :
193+
194+ def get_scale (self , dtype = None ) -> Optional [ torch .Tensor ] :
193195 """
194196 Compute quantization scale to be used for forward pass.
195- Return None f the quantizer is not initialized yet.
197+ Return None if the quantizer is not initialized yet.
196198
197199 Args:
198200 dtype (torch.dtype): dtype of the computed scale
199201
200202 Returns:
201203 Quantization scale
202-
203204 """
205+ if not self .is_initialized ():
206+ return None
207+
208+ dtype = dtype or torch .float32
209+ num_steps = self .qmax - self .qmin
210+
211+ scale = (self .max .to (dtype ) - self .min .to (dtype )) / num_steps
212+ return scale .to (dtype )
204213
205- @abc .abstractmethod
206- def get_offset (self , dtype = None ) -> torch .Tensor :
214+ def get_offset (self , dtype = None ) -> Optional [torch .Tensor ]:
207215 """
208216 Compute quantization offset to be used for forward pass.
209- Return None f the quantizer is not initialized yet.
217+ Return None if the quantizer is not initialized yet.
210218
211219 Args:
212220 dtype (torch.dtype): dtype of the computed offset
213221
214222 Returns:
215223 Quantization offset
216-
217224 """
225+ if not self .is_initialized ():
226+ return None
227+
228+ dtype = dtype or torch .float32
229+
230+ if self .symmetric :
231+ offset = torch .full_like (self .min ,
232+ fill_value = - round ((self .qmin + self .qmax ) / 2 ),
233+ requires_grad = False ,
234+ dtype = dtype )
235+ else :
236+ offset = ste_round (self .min .to (dtype ) / self .get_scale (dtype )) - self .qmin
237+
238+ return offset .to (dtype )
218239
219- @abc . abstractmethod
240+ @torch . no_grad ()
220241 def set_range (self , min : torch .Tensor , max : torch .Tensor ):
221242 """
222243 Set quantization parameters to the given min-max range
223244 """
245+ with SafeGatheredParameters (self .parameters (recurse = False ), modifier_rank = 0 ):
246+ self .min .copy_ (min )
247+ self .max .copy_ (max )
224248
225249 def get_encodings (self ) -> Optional [AffineEncoding ]:
226250 """
@@ -333,21 +357,6 @@ def signed(self) -> bool: # pylint: disable=missing-function-docstring
333357 def signed (self , signed : bool ):
334358 self ._set_signed (signed )
335359
336-
337- class MinMaxQuantizer (AffineQuantizerBase ): # pylint: disable=abstract-method
338- """
339- Affine quantizer with min-max as trainable parameters
340- """
341-
342- min : torch .nn .Parameter
343- max : torch .nn .Parameter
344-
345- def __init__ (self , * args , ** kwargs ):
346- super ().__init__ (* args , ** kwargs )
347-
348- self .register_quantization_parameter ('min' , nn .Parameter (- torch .ones (self .shape )))
349- self .register_quantization_parameter ('max' , nn .Parameter (torch .ones (self .shape )))
350-
351360 @contextlib .contextmanager
352361 def compute_encodings (self ):
353362 """
@@ -360,24 +369,28 @@ def compute_encodings(self):
360369 return
361370
362371 original_forward = self .forward
372+ shape = self .shape
373+
374+ try :
375+ dtype , device = next ((p .dtype , p .device ) for p in self .parameters ())
376+ except StopIteration as e :
377+ raise RuntimeError from e
363378
364379 @functools .wraps (original_forward )
365380 def forward_wrapper (input ):
366381 input = input .as_subclass (torch .Tensor )
367- expanded_input = torch_builtins .reshape_tensor_for_blocks (input , self . shape , self .block_size )
382+ expanded_input = torch_builtins .reshape_tensor_for_blocks (input , shape , self .block_size )
368383 batch_statistics = self .encoding_analyzer .update_stats (expanded_input )
369384 num_steps = self .qmax - self .qmin
370385 dynamic_min , dynamic_max = \
371386 self .encoding_analyzer .compute_encodings_from_stats (batch_statistics ,
372387 num_steps ,
373388 self .symmetric )
374389 if self .block_size is not None :
375- dynamic_min = dynamic_min .view (self .min .shape )
376- dynamic_max = dynamic_max .view (self .max .shape )
377- dynamic_min = dynamic_min .to (dtype = self .min .dtype ,
378- device = self .min .device ).expand_as (self .min )
379- dynamic_max = dynamic_max .to (dtype = self .max .dtype ,
380- device = self .max .device ).expand_as (self .max )
390+ dynamic_min = dynamic_min .view (shape )
391+ dynamic_max = dynamic_max .view (shape )
392+ dynamic_min = dynamic_min .to (dtype = dtype , device = device ).expand (shape )
393+ dynamic_max = dynamic_max .to (dtype = dtype , device = device ).expand (shape )
381394
382395 with patch_attr (self , 'min' , dynamic_min ),\
383396 patch_attr (self , 'max' , dynamic_max ):
@@ -395,8 +408,8 @@ def forward_wrapper(input):
395408 num_steps = self .qmax - self .qmin
396409 enc_min , enc_max = self .encoding_analyzer .compute_encodings (num_steps , self .symmetric )
397410 if self .block_size is not None :
398- enc_min = enc_min .view (self . min . shape )
399- enc_max = enc_max .view (self . max . shape )
411+ enc_min = enc_min .view (shape )
412+ enc_max = enc_max .view (shape )
400413 _flag_extreme_min_max (enc_min , enc_max )
401414
402415 except StatisticsNotFoundError :
@@ -407,79 +420,11 @@ def forward_wrapper(input):
407420
408421 self .set_range (enc_min , enc_max )
409422
410- def get_min (self , dtype = None ) -> Optional [torch .Tensor ]:
411- """
412- Compute quantization min to be used for forward pass.
413-
414- NOTE: self.min may not be equal to self.get_min().
415- self.get_min() returns slightly recalibrated version of self.min.
416-
417- :param dtype: dtype of the computed min. Use of self.min.dtype by default.
418- :return: Quantization min
419- """
420- if not self .is_initialized ():
421- return None
422- return self .get_scale (dtype ) * (self .get_offset (dtype ) + self .qmin )
423-
424- def get_max (self , dtype = None ) -> Optional [torch .Tensor ]:
425- """
426- Compute quantization max to be used for forward pass.
427-
428- NOTE: self.max may not be equal to self.get_max()
429- self.get_max() returns slightly recalibrated version of self.max.
430-
431- :param dtype: dtype of the computed max. Use of self.min.dtype by default.
432- :return: Quantization max
433- """
434- if not self .is_initialized ():
435- return None
436- return self .get_scale (dtype ) * (self .get_offset (dtype ) + self .qmax )
437-
438- def get_scale (self , dtype = None ) -> Optional [torch .Tensor ]:
439- """
440- Compute quantization scale to be used for forward pass.
441-
442- :param dtype: dtype of the computed scale. Use of self.min.dtype by default.
443- :return: Quantization scale
444- """
445- if not self .is_initialized ():
446- return None
447-
448- dtype = dtype or torch .float32
449- num_steps = self .qmax - self .qmin
450-
451- scale = (self .max .to (dtype ) - self .min .to (dtype )) / num_steps
452- return scale .to (dtype )
453-
454- def get_offset (self , dtype = None ) -> Optional [torch .Tensor ]:
455- """
456- Compute quantization offset to be used for forward pass.
457423
458- :param dtype: dtype of the computed offset. Use of self.min.dtype by default.
459- :return: Quantization offset
460- """
461- if not self .is_initialized ():
462- return None
463-
464- dtype = dtype or torch .float32
465-
466- if self .symmetric :
467- offset = torch .full_like (self .min ,
468- fill_value = - round ((self .qmin + self .qmax ) / 2 ),
469- requires_grad = False ,
470- dtype = dtype )
471- else :
472- offset = ste_round (self .min .to (dtype ) / self .get_scale (dtype )) - self .qmin
473-
474- return offset .to (dtype )
475-
476- def set_range (self , min : torch .Tensor , max : torch .Tensor ):
477- """
478- Set quantization parameters to the given min-max range
479- """
480- with torch .no_grad (), SafeGatheredParameters (self .parameters (recurse = False ), modifier_rank = 0 ):
481- self .min .copy_ (min )
482- self .max .copy_ (max )
424+ class MinMaxQuantizer (AffineQuantizerBase ): # pylint: disable=abstract-method
425+ """
426+ Affine quantizer with min-max as trainable parameters
427+ """
483428
484429
485430class Quantize (MinMaxQuantizer ):
0 commit comments