Skip to content

Commit 661a107

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Add aimet_torch.onnx.export API reference
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent fd6c9fb commit 661a107

File tree

2 files changed

+44
-5
lines changed
  • Docs/apiref/torch
  • TrainingExtensions/torch/src/python/aimet_torch

2 files changed

+44
-5
lines changed

Docs/apiref/torch/index.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ aimet_torch API
99

1010
Migrate to aimet_torch 2 <migration_guide>
1111
aimet_torch.quantsim <quantsim>
12-
aimet_torch.adaround <adaround>
1312
aimet_torch.nn <nn>
1413
aimet_torch.quantization <quantization>
14+
aimet_torch.onnx <onnx>
15+
aimet_torch.adaround <adaround>
1516
aimet_torch.seq_mse <seq_mse>
1617
aimet_torch.adascale <adascale>
1718
aimet_torch.quantsim.config_utils <lpbq>
@@ -42,6 +43,7 @@ aimet_torch
4243
- :ref:`aimet_torch.quantsim <apiref-torch-quantsim>`
4344
- :ref:`aimet_torch.nn <apiref-torch-nn>`
4445
- :ref:`aimet_torch.quantization <apiref-torch-quantization>`
46+
- :ref:`aimet_torch.onnx (beta) <apiref-torch-onnx>`
4547
- :ref:`aimet_torch.adaround <apiref-torch-adaround>`
4648
- :ref:`aimet_torch.seq_mse <apiref-torch-seq-mse>`
4749
- :ref:`aimet_torch.adascale <apiref-torch-adascale>`

TrainingExtensions/torch/src/python/aimet_torch/onnx.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,47 @@ def export(model: Union[torch.nn.Module, QuantizationSimModel],
7272
export_int32_bias: bool = True,
7373
**kwargs):
7474
"""
75-
Export QuantizationSimModel to onnx model with
76-
QuantizeLinear/DequantizeLinear embedded in the graph.
77-
78-
This function takes set of same arguments as torch.onnx.export()
75+
Export :class:`QuantizationSimModel` to onnx model with
76+
onnx `QuantizeLinear`_ and `DequantizeLinear`_ embedded in the graph.
77+
78+
This function takes set of same arguments as `torch.onnx.export()`_
79+
80+
Args:
81+
model: The model to be exported
82+
args: Same as `torch.onnx.export()`
83+
f: Same as `torch.onnx.export()`
84+
export_int32_bias (bool, optional):
85+
If true, generate and export int32 bias encoding on the fly (default: `True`)
86+
**kwargs: Same as `torch.onnx.export()`
87+
88+
89+
.. note::
90+
Unlike `torch.onnx.export()`, this function allows up to opset 21.
91+
to support 4/16-bit quantization only available in opset 21.
92+
However, exporting to opset 21 is a beta feature and not fully stable yet.
93+
For robustness, opset 20 or lower is recommended whenever possible.
94+
95+
.. note::
96+
Dynamo-based export (`dynamo=True`) is not supported yet
97+
98+
.. _torch.onnx.export(): https://docs.pytorch.org/docs/stable/onnx_torchscript.html#torch.onnx.export
99+
.. _QuantizeLinear: https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html
100+
.. _DequantizeLinear: https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html
101+
102+
Examples:
103+
104+
>>> aimet_torch.onnx.export(sim.model, x, f="model.onnx",
105+
... input_names=["input"], output_names=["output"],
106+
... opset_version=21, export_int32_bias=True)
107+
...
108+
>>> import onnxruntime as ort
109+
>>> options = ort.SessionOptions()
110+
>>> options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
111+
>>> sess = ort.InferenceSession("model.onnx", sess_options=options)
112+
>>> onnx_output, = sess.run(None, {"input": x.detach().numpy()})
113+
>>> torch.nn.functional.cosine_similarity(torch.from_numpy(onnx_output), sim.model(x))
114+
tensor([1.0000, 0.9999, 1.0000, ..., 1.0000, 1.0000, 1.0000],
115+
grad_fn=<AliasBackward0>)
79116
"""
80117
if isinstance(model, QuantizationSimModel):
81118
model = model.model

0 commit comments

Comments
 (0)