Skip to content

Commit b6771f8

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Allow torch.nn.Module as input of aimet_torch.onnx.export
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent 60dab14 commit b6771f8

File tree

4 files changed

+259
-233
lines changed

4 files changed

+259
-233
lines changed

TrainingExtensions/torch/src/python/aimet_torch/onnx.py

Lines changed: 243 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,255 @@
3535
# @@-COPYRIGHT-END-@@
3636
# =============================================================================
3737
""" Defines onnx export API """
38-
from .v2.quantsim.quantsim import QuantizationSimModel, _QuantizationSimOnnxExport
38+
import contextlib
39+
import io
40+
import os
41+
import tempfile
42+
from typing import Any, Mapping, Tuple, Union
3943

44+
import onnx
45+
import torch
4046

41-
def export(sim: QuantizationSimModel, *args, **kwargs):
47+
from aimet_common.onnx._utils import _add_onnx_qdq_nodes
48+
49+
from .nn import QuantizationMixin
50+
from .quantization import DequantizedTensor
51+
from .quantization.base import EncodingBase
52+
from .quantization.affine import AffineQuantizerBase
53+
from .quantization.float import FloatQuantizeDequantize
54+
from .quantsim import QuantizationSimModel
55+
from .v2.experimental import onnx as _onnx
56+
57+
58+
def export(model: Union[torch.nn.Module, QuantizationSimModel],
59+
args: Union[Tuple[Any, ...], torch.Tensor],
60+
f: Union[str, io.BytesIO],
61+
*posargs, **kwargs):
4262
"""
43-
Export :class:`QuantizationSimModel` object to onnx model with
63+
Export QuantizationSimModel to onnx model with
4464
QuantizeLinear/DequantizeLinear embedded in the graph.
4565
46-
This function takes set of same arguments as torch.onnx.export(),
47-
except that the first argument is a QuantizationSimModel object, not a nn.Module.
66+
This function takes set of same arguments as torch.onnx.export()
67+
"""
68+
if isinstance(model, QuantizationSimModel):
69+
model = model.model
70+
71+
if not isinstance(model, torch.nn.Module):
72+
raise RuntimeError(
73+
f"aimet_torch.export only supports torch.nn.Module or QuantizationSimModel; got {type(model)}"
74+
)
75+
76+
onnx_model, tensor_to_encoding_map = _to_onnx(model, args, *posargs, **kwargs)
77+
onnx_qdq_model = _to_onnx_qdq(onnx_model, tensor_to_encoding_map)
78+
onnx.save(onnx_qdq_model, f)
79+
80+
81+
def _to_onnx(model: torch.nn.Module,
82+
args: Union[Tuple[Any, ...], torch.Tensor],
83+
*posargs, **kwargs):
84+
# pylint: disable=protected-access
85+
_check_unsupported_quantizers(model)
86+
87+
with tempfile.TemporaryDirectory() as tmp_dir:
88+
with _temporarily_unfold_param_quantizers(model), \
89+
_concretize_int32_bias_quantizers(model, args), \
90+
QuantizationSimModel._apply_qdq_to_model_parameters(model), \
91+
_remove_fp16_quantizers(model):
92+
tmp_onnx_path = os.path.join(tmp_dir, "quantized_model.onnx")
93+
_onnx.export(model, args, tmp_onnx_path, *posargs, **kwargs)
94+
onnx_model = onnx.load(tmp_onnx_path)
95+
96+
param_names = {
97+
f"{layer_name}.{param_name}"
98+
for layer_name, layer in model.named_modules()
99+
if isinstance(layer, QuantizationMixin)
100+
for param_name, quantizer in layer.param_quantizers.items()
101+
if quantizer
102+
}
103+
104+
tensor_to_encoding_map: Mapping[str, Tuple[EncodingBase, bool]]
105+
tensor_to_encoding_map = {
106+
name: (encoding, name in param_names)
107+
for name, encoding in _onnx.remove_quantization_nodes_from_onnx_graph(onnx_model).items()
108+
}
109+
return onnx_model, tensor_to_encoding_map
110+
111+
112+
@contextlib.contextmanager
113+
def _concretize_int32_bias_quantizers(model, args):
114+
if not isinstance(args, (tuple, list)):
115+
args = (args,)
116+
117+
handles = []
118+
orig_bias_quantizers = {
119+
qmodule: qmodule.param_quantizers["bias"]
120+
for qmodule in model.modules()
121+
if isinstance(qmodule, QuantizationMixin)
122+
and "bias" in qmodule.param_quantizers
123+
and qmodule.bias is not None
124+
}
125+
126+
try:
127+
for qmodule, qtzr in orig_bias_quantizers.items():
128+
if qtzr is not None:
129+
# Bias quantizer already exists.
130+
# This means the user created bias quantizer by him/herself
131+
# In this case, we honor the custom bias quantizer defined by the user
132+
continue
133+
134+
if "weight" in qmodule.param_quantizers and \
135+
isinstance(qmodule.param_quantizers["weight"], AffineQuantizerBase):
136+
# pylint: disable=protected-access
137+
handle = qmodule.register_forward_hook(type(qmodule)._create_int32_bias_quantizer)
138+
handles.append(handle)
139+
try:
140+
model(*args)
141+
finally:
142+
for handle in handles:
143+
handle.remove()
144+
yield
145+
finally:
146+
for qmodule, qtzr in orig_bias_quantizers.items():
147+
qmodule.param_quantizers["bias"] = qtzr
148+
149+
150+
@contextlib.contextmanager
151+
def _temporarily_unfold_param_quantizers(model: torch.nn.Module):
152+
# pylint: disable=protected-access
153+
"""
154+
Temporarily re-instantiate param quantizers for ease of export
155+
"""
156+
modules_with_folded_parameters = [
157+
qmodule for qmodule in model.modules()
158+
if isinstance(qmodule, QuantizationMixin) and
159+
any(isinstance(param, DequantizedTensor) for param in qmodule.parameters())
160+
]
161+
162+
try:
163+
for qmodule in modules_with_folded_parameters:
164+
qmodule._unfold_param_quantizers()
165+
yield
166+
finally:
167+
for qmodule in modules_with_folded_parameters:
168+
qmodule._fold_param_quantizers()
169+
170+
171+
@contextlib.contextmanager
172+
def _remove_fp16_quantizers(model: torch.nn.Module):
48173
"""
49-
if not isinstance(sim, QuantizationSimModel):
50-
raise RuntimeError(f"Expected {QuantizationSimModel} object; got {type(sim)}")
174+
Temporarily remove [b]float16 quantizers for sim.onnx.export,
175+
as sim.onnx.export does NOT support exporting [b]float16 quantizers.
176+
"""
177+
original_containers = {}
51178

52179
try:
53-
embed_qdq = kwargs.pop("embed_qdq")
54-
except KeyError:
55-
embed_qdq = True
180+
for qmodule in model.modules():
181+
if not isinstance(qmodule, QuantizationMixin):
182+
continue
183+
184+
for name, qtzr in qmodule.param_quantizers.items():
185+
if isinstance(qtzr, FloatQuantizeDequantize) and \
186+
(qtzr.is_float16() or qtzr.is_bfloat16()):
187+
original_containers[(qmodule.param_quantizers, name)] = qtzr
188+
qmodule.param_quantizers[name] = None
189+
190+
for i, qtzr in enumerate(qmodule.input_quantizers):
191+
if isinstance(qtzr, FloatQuantizeDequantize) and \
192+
(qtzr.is_float16() or qtzr.is_bfloat16()):
193+
original_containers[(qmodule.input_quantizers, i)] = qtzr
194+
qmodule.input_quantizers[i] = None
195+
196+
for i, qtzr in enumerate(qmodule.output_quantizers):
197+
if isinstance(qtzr, FloatQuantizeDequantize) and \
198+
(qtzr.is_float16() or qtzr.is_bfloat16()):
199+
original_containers[(qmodule.output_quantizers, i)] = qtzr
200+
qmodule.output_quantizers[i] = None
201+
202+
yield
203+
204+
finally:
205+
for (container, key), qtzr in original_containers.items():
206+
container[key] = qtzr
207+
208+
209+
def _to_onnx_qdq(onnx_model: onnx.ModelProto,
210+
tensor_to_encoding_map: Mapping[str, Tuple[EncodingBase, bool]]) -> onnx.ModelProto:
211+
qnn_encodings = {
212+
name: encoding.to_qnn_encoding_dict("2.0.0.beta")
213+
for name, (encoding, _) in tensor_to_encoding_map.items()
214+
}
215+
qnn_encodings = {
216+
name: encoding for name, encoding in qnn_encodings.items() if encoding
217+
}
218+
219+
qdq_tensor_names = {
220+
fp_tensor_name: f"{fp_tensor_name}_qdq"
221+
for fp_tensor_name in qnn_encodings
222+
}
223+
224+
onnx_opset_version = next(opset.version for opset in onnx_model.opset_import if opset.domain == "")
225+
226+
# Add onnx QDQ nodes in batch
227+
_add_onnx_qdq_nodes(onnx_model,
228+
input_names=qnn_encodings.keys(),
229+
output_names=qdq_tensor_names.values(),
230+
node_name_prefixes=qnn_encodings.keys(),
231+
encodings=qnn_encodings.values(),
232+
onnx_opset=onnx_opset_version)
233+
234+
# Restore model output names from "{output}_qdq" to "{output}"
235+
_restore_model_output_names(onnx_model, qdq_tensor_names)
236+
237+
return onnx_model
238+
239+
240+
def _check_unsupported_quantizers(module: torch.nn.Module):
241+
for qtzr in module.modules():
242+
if isinstance(qtzr, FloatQuantizeDequantize):
243+
if not qtzr.is_float16() and not qtzr.is_bfloat16():
244+
msg = " ".join([
245+
"sim.onnx.export doesn't support exporting floating point encodings",
246+
f"except [b]float16. Got {qtzr.bitwidth}-bit float encoding",
247+
])
248+
raise RuntimeError(msg)
249+
250+
251+
def _rename_inputs(onnx_model: onnx.ModelProto, new_names: Mapping[str, str]):
252+
for node in onnx_model.graph.node:
253+
for i, old_name in enumerate(node.input):
254+
new_name = new_names.get(old_name, None)
255+
if new_name is not None:
256+
node.input[i] = new_name
257+
258+
259+
def _rename_outputs(onnx_model: onnx.ModelProto, new_names: Mapping[str, str]):
260+
for node in onnx_model.graph.node:
261+
for i, old_name in enumerate(node.output):
262+
new_name = new_names.get(old_name, None)
263+
if new_name is not None:
264+
node.output[i] = new_name
265+
266+
267+
def _restore_model_output_names(onnx_model: onnx.ModelProto, new_names: Mapping[str, str]):
268+
"""
269+
Rename model outputs. Assuming "output" is the model output,
270+
271+
before:
272+
Softmax ----> output -------> QDQ -------> output_qdq
273+
274+
after:
275+
Softmax ----> output__ -----> QDQ -------> output
276+
"""
277+
_new_names = {
278+
output.name: f"{output.name}__"
279+
for output in onnx_model.graph.output
280+
if output.name in new_names
281+
}
282+
_rename_inputs(onnx_model, _new_names)
56283

57-
_QuantizationSimOnnxExport(sim).export(*args,
58-
embed_qdq=embed_qdq,
59-
**kwargs)
284+
_new_names.update({
285+
new_names[output.name]: output.name
286+
for output in onnx_model.graph.output
287+
if output.name in new_names
288+
})
289+
_rename_outputs(onnx_model, _new_names)

TrainingExtensions/torch/src/python/aimet_torch/v2/experimental/onnx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@
3535
# @@-COPYRIGHT-END-@@
3636
# =============================================================================
3737
""" Utility APIs for onnx export """
38+
from ._export import export, remove_quantization_nodes_from_onnx_graph

0 commit comments

Comments
 (0)