Skip to content

Commit ab7b19d

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Add fold_param_quantizers as a public feature
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent a444995 commit ab7b19d

File tree

7 files changed

+212
-6
lines changed

7 files changed

+212
-6
lines changed

TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,19 @@
4646
from torch import nn
4747

4848
from aimet_torch.utils import is_vector_encoding
49-
from aimet_torch.v2.quantization.affine.encoding import VectorEncoding, AffineEncoding
49+
from aimet_torch.v2.quantization.affine.encoding import (
50+
AffineEncoding,
51+
GroupedBlockEncoding,
52+
VectorEncoding,
53+
)
5054
from aimet_torch.v2.quantization.affine import (
5155
AffineQuantizerBase,
5256
GroupedBlockQuantizeDequantize,
5357
QuantizeDequantize,
5458
)
59+
from aimet_torch.v2.quantization.float import FloatEncoding, FloatQuantizeDequantize
5560

56-
from aimet_torch.v2.quantization.tensor import QuantizedTensorBase
61+
from aimet_torch.v2.quantization.tensor import QuantizedTensorBase, DequantizedTensor
5762
from aimet_torch.v2.quantization.base import QuantizerBase
5863
from aimet_torch.v2.utils import (
5964
patch_attr,
@@ -750,10 +755,39 @@ def _create_int32_bias_quantizer(self, input, _): # pylint: disable=redefined-bu
750755
def _derive_bias_scale(self, input_scale: Optional[torch.Tensor], weight_scale: Optional[torch.Tensor]):
751756
raise NotImplementedError
752757

753-
def _fold_param_quantizers(self):
758+
def fold_param_quantizers(self):
754759
"""
755-
Fold param quantizers into parameters to speed up inference.
760+
Fold parameter quantizers into their associated parameters to accelerate inference.
761+
762+
Example:
763+
764+
>>> qlinear = QuantizedLinear(10, 10)
765+
>>> qlinear.param_quantizers["weight"] = QuantizeDequantize((), -128, 127, symmetric=True)
766+
>>> type(qlinear.weight)
767+
<class 'torch.nn.parameter.Parameter'>
768+
>>> qlinear
769+
QuantizedLinear(
770+
in_features=10, out_features=10, bias=True
771+
(param_quantizers): ModuleDict(
772+
(weight): QuantizeDequantize(shape=(), qmin=-128, qmax=127, symmetric=True)
773+
(bias): None
774+
)
775+
)
776+
>>> qlinear.fold_param_quantizers()
777+
>>> type(qlinear.weight)
778+
<class 'aimet_torch.v2.quantization.tensor.DequantizedTensor'>
779+
>>> qlinear
780+
QuantizedLinear(
781+
in_features=10, out_features=10, bias=True
782+
(param_quantizers): ModuleDict(
783+
(weight): None
784+
(bias): None
785+
)
786+
)
756787
"""
788+
return self._fold_param_quantizers()
789+
790+
def _fold_param_quantizers(self):
757791
self._compute_param_encodings(overwrite=False)
758792

759793
for param_name, param_qtzr in self.param_quantizers.items():
@@ -765,6 +799,33 @@ def _fold_param_quantizers(self):
765799
setattr(self, param_name, torch.nn.Parameter(qdq_param, requires_grad=param.requires_grad))
766800
self.param_quantizers[param_name] = None
767801

802+
def _unfold_param_quantizers(self):
803+
"""
804+
Re-instantiate param quantizers for ease of export
805+
"""
806+
for param_name, qdq_param in self.named_parameters():
807+
if not isinstance(qdq_param, DequantizedTensor):
808+
continue
809+
810+
if qdq_param.encoding is None:
811+
continue
812+
813+
if isinstance(qdq_param.encoding, GroupedBlockEncoding):
814+
param_qtzr = GroupedBlockQuantizeDequantize.from_encodings(qdq_param.encoding)
815+
elif isinstance(qdq_param.encoding, AffineEncoding):
816+
param_qtzr = QuantizeDequantize.from_encodings(qdq_param.encoding)
817+
elif isinstance(qdq_param.encoding, FloatEncoding):
818+
param_qtzr = FloatQuantizeDequantize.from_encodings(qdq_param.encoding)
819+
else:
820+
raise ValueError
821+
822+
if not param_qtzr:
823+
continue
824+
825+
param = qdq_param.as_subclass(torch.Tensor)
826+
setattr(self, param_name, torch.nn.Parameter(param, requires_grad=param.requires_grad))
827+
self.param_quantizers[param_name] = param_qtzr
828+
768829

769830
def _remove_quantizers(quantizers, keys):
770831
orig_quantizers = {key: quantizers[key] for key in keys}

TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,21 @@ def get_encodings(self) -> Optional[AffineEncoding]:
232232
self.qmin, self.qmax, self._symmetric, self.block_size)
233233
return None
234234

235+
@classmethod
236+
def from_encodings(cls, encodings: AffineEncoding) -> "AffineQuantizerBase":
237+
if not isinstance(encodings, AffineEncoding):
238+
raise TypeError(f"Expected {AffineEncoding}; got {type(encodings)}")
239+
240+
qtzr = cls(shape=encodings.scale.shape,
241+
qmin=encodings.qmin,
242+
qmax=encodings.qmax,
243+
symmetric=encodings.symmetry,
244+
block_size=encodings.block_size)
245+
246+
qtzr.set_range(encodings.min, encodings.max)
247+
248+
return qtzr
249+
235250
@torch.no_grad()
236251
def get_legacy_encodings(self) -> Optional[List[Dict]]:
237252
"""
@@ -863,3 +878,19 @@ def get_encodings(self) -> Optional[GroupedBlockEncoding]:
863878
decompressed_bw=self.decompressed_bw,
864879
per_channel_scale=per_channel_scale)
865880
return None
881+
882+
@classmethod
883+
def from_encodings(cls, encodings: GroupedBlockEncoding) -> "GroupedBlockQuantizeDequantize":
884+
if not isinstance(encodings, GroupedBlockEncoding):
885+
raise TypeError(f"Expected {GroupedBlockEncoding}; got {type(encodings)}")
886+
887+
qtzr = cls(shape=encodings.scale.shape,
888+
bitwidth=encodings.bitwidth,
889+
symmetric=encodings.symmetry,
890+
decompressed_bw=encodings.decompressed_bw,
891+
block_size=encodings.block_size,
892+
block_grouping=encodings.block_grouping)
893+
894+
qtzr.set_range(encodings.min, encodings.max)
895+
896+
return qtzr

TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/base/quantizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,13 @@ def set_encodings(self, encodings: EncodingBase):
124124
"""
125125
raise NotImplementedError
126126

127+
@classmethod
128+
@abc.abstractmethod
129+
def from_encodings(cls, encodings: EncodingBase) -> "QuantizerBase":
130+
"""
131+
Create quantizer object from encoding object
132+
"""
133+
127134
def register_quantization_parameter(self, name: str, param: nn.Parameter):
128135
"""
129136
Register quantization parameter.

TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/float/quantizer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,19 @@ def get_encodings(self) -> Optional[FloatEncoding]:
246246
return FloatEncoding(self.mantissa_bits, self.exponent_bits, self.maxval)
247247
return None
248248

249+
@classmethod
250+
def from_encodings(cls, encodings: FloatEncoding) -> "FloatQuantizeDequantize":
251+
if not isinstance(encodings, FloatEncoding):
252+
raise TypeError(f"Expected {FloatEncoding}; got {type(encodings)}")
253+
254+
qtzr = cls(exponent_bits=encodings.exponent_bits,
255+
mantissa_bits=encodings.mantissa_bits)
256+
257+
if encodings.maxval is not None:
258+
qtzr.maxval.copy_(encodings.maxval)
259+
260+
return qtzr
261+
249262
@contextlib.contextmanager
250263
def compute_encodings(self):
251264
"""

TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from aimet_torch.v2.nn import BaseQuantizationMixin, QuantizationMixin, UnknownModuleError
6868
from aimet_torch.v2.nn.fake_quant import _legacy_impl
6969
from aimet_torch.v2._builder import _V2LazyQuantizeWrapper
70+
from aimet_torch.v2.quantization import DequantizedTensor
7071
from aimet_torch.v2.quantization.base import QuantizerBase
7172
from aimet_torch.v2.quantization.affine import AffineQuantizerBase
7273
from aimet_torch.v2.quantization.encoding_analyzer import PercentileEncodingAnalyzer
@@ -628,6 +629,58 @@ def _concretize_int32_bias_quantizers(self, args):
628629
for qmodule, qtzr in orig_bias_quantizers.items():
629630
qmodule.param_quantizers["bias"] = qtzr
630631

632+
def fold_param_quantizers(self):
633+
"""
634+
Fold parameter quantizers into their associated parameters to accelerate inference.
635+
636+
Example:
637+
638+
>>> sim = QuantizationSimModel(...)
639+
>>> type(sim.model[0].weight)
640+
<class 'torch.nn.parameter.Parameter'>
641+
>>> sim.model[0]
642+
QuantizedLinear(
643+
in_features=10, out_features=10, bias=True
644+
(param_quantizers): ModuleDict(
645+
(weight): QuantizeDequantize(shape=(), qmin=-128, qmax=127, symmetric=True)
646+
(bias): None
647+
)
648+
)
649+
>>> sim.fold_param_quantizers()
650+
>>> type(sim.model[0].weight)
651+
<class 'aimet_torch.v2.quantization.tensor.DequantizedTensor'>
652+
>>> sim.model[0]
653+
QuantizedLinear(
654+
in_features=10, out_features=10, bias=True
655+
(param_quantizers): ModuleDict(
656+
(weight): None
657+
(bias): None
658+
)
659+
)
660+
"""
661+
for qmodule in self.qmodules():
662+
qmodule.fold_param_quantizers()
663+
664+
665+
@contextlib.contextmanager
666+
def _temporarily_unfold_param_quantizers(sim: QuantizationSimModel):
667+
# pylint: disable=protected-access
668+
"""
669+
Temporarily re-instantiate param quantizers for ease of export
670+
"""
671+
modules_with_folded_parameters = [
672+
qmodule for qmodule in sim.qmodules()
673+
if any(isinstance(param, DequantizedTensor) for param in qmodule.parameters())
674+
]
675+
676+
try:
677+
for qmodule in modules_with_folded_parameters:
678+
qmodule._unfold_param_quantizers()
679+
yield
680+
finally:
681+
for qmodule in modules_with_folded_parameters:
682+
qmodule._fold_param_quantizers()
683+
631684

632685
class _QuantizationSimOnnxExport:
633686
"""
@@ -660,7 +713,8 @@ def export(self,
660713
"Other quantizer types are not supported.")
661714

662715
with tempfile.TemporaryDirectory() as tmp_dir:
663-
with self.sim._concretize_int32_bias_quantizers(args), \
716+
with _temporarily_unfold_param_quantizers(self.sim), \
717+
self.sim._concretize_int32_bias_quantizers(args), \
664718
self.sim._apply_qdq_to_model_parameters(self.sim.model):
665719
tmp_onnx_path = os.path.join(tmp_dir, "quantized_model.onnx")
666720
export(self.sim.model, args, tmp_onnx_path, *posargs, **kwargs)

TrainingExtensions/torch/test/python/v2/experimental/test_onnx.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ def test_export_torchvision_models(model_factory, input_shape):
275275
@torch.no_grad()
276276
@pytest.mark.parametrize("encoding_version", ["0.6.1", "1.0.0"])
277277
@pytest.mark.parametrize("lpbq", (False, True))
278-
def test_quantsim_export_resnet18(encoding_version, lpbq: bool):
278+
@pytest.mark.parametrize("fold_param_quantizers", (False, True))
279+
def test_quantsim_export_resnet18(encoding_version, lpbq: bool, fold_param_quantizers: bool):
279280
"""
280281
When: Export quantized torchvision model using quantsim.export
281282
"""
@@ -319,6 +320,9 @@ def test_quantsim_export_resnet18(encoding_version, lpbq: bool):
319320
with remove_activation_quantizers(sim.model):
320321
expected_out = sim.model(x)
321322

323+
if fold_param_quantizers:
324+
sim.fold_param_quantizers()
325+
322326
with tempfile.TemporaryDirectory() as dirname:
323327
onnx_path = os.path.join(dirname, "torchvision_model.onnx")
324328
encodings_path = os.path.join(dirname, "torchvision_model.encodings")

TrainingExtensions/torch/test/python/v2/quantsim/test_quantsim.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from aimet_common.defs import QuantizationDataType, QuantScheme
4747
from aimet_torch import onnx_utils
4848
from aimet_torch.v2.quantsim import QuantizationSimModel, load_encodings_to_sim
49+
from aimet_torch.v2.quantization import DequantizedTensor
4950
from aimet_torch.v2.quantization.encoding_analyzer import PercentileEncodingAnalyzer
5051
from aimet_torch.v2.quantization.base import QuantizerBase
5152
from aimet_torch.v2.quantization.affine import AffineQuantizerBase, GroupedBlockQuantizeDequantize, QuantizeDequantize
@@ -1242,6 +1243,41 @@ def test_compute_encodings_optional_arg(self):
12421243
assert torch.equal(qtzr_a.get_min(), qtzr_b.get_min())
12431244
assert torch.equal(qtzr_a.get_max(), qtzr_b.get_max())
12441245

1246+
@pytest.mark.parametrize("data_type", [QuantizationDataType.int, QuantizationDataType.float])
1247+
def test_fold_param_quantizers(self, tmpdir, data_type):
1248+
model = torch.nn.Sequential(
1249+
torch.nn.Linear(10, 10),
1250+
)
1251+
x = torch.randn(10, 10)
1252+
sim = QuantizationSimModel(model, x,
1253+
default_param_bw=16,
1254+
default_output_bw=16,
1255+
default_data_type=data_type)
1256+
sim.compute_encodings(lambda model: model(x))
1257+
1258+
sim.export(tmpdir, "before_fold", x)
1259+
1260+
"""
1261+
When: Call fold_param_quantizers()
1262+
Then: 1. All param quantizers should be folded to the parameter
1263+
2. Export artifact of sim.export() should not be affected
1264+
"""
1265+
sim.fold_param_quantizers()
1266+
assert sim.model[0].param_quantizers["weight"] is None
1267+
assert isinstance(sim.model[0].weight, DequantizedTensor)
1268+
1269+
sim.export(tmpdir, "after_fold", x)
1270+
1271+
with open(os.path.join(tmpdir, "before_fold.encodings")) as f:
1272+
encodings_before_fold = json.load(f)
1273+
with open(os.path.join(tmpdir, "after_fold.encodings")) as f:
1274+
encodings_after_fold = json.load(f)
1275+
1276+
assert encodings_before_fold == encodings_after_fold
1277+
1278+
# trivial sanity check
1279+
assert [enc["name"] for enc in encodings_before_fold["param_encodings"]] == ["0.weight"]
1280+
12451281

12461282
class TestQuantsimUtilities:
12471283

0 commit comments

Comments
 (0)