@@ -72,10 +72,47 @@ def export(model: Union[torch.nn.Module, QuantizationSimModel],
7272 export_int32_bias : bool = True ,
7373 ** kwargs ):
7474 """
75- Export QuantizationSimModel to onnx model with
76- QuantizeLinear/DequantizeLinear embedded in the graph.
77-
78- This function takes set of same arguments as torch.onnx.export()
75+ Export :class:`QuantizationSimModel` to onnx model with
76+ onnx `QuantizeLinear`_ and `DequantizeLinear`_ embedded in the graph.
77+
78+ This function takes set of same arguments as `torch.onnx.export()`_
79+
80+ Args:
81+ model: The model to be exported
82+ args: Same as `torch.onnx.export()`
83+ f: Same as `torch.onnx.export()`
84+ export_int32_bias (bool, optional):
85+ If true, generate and export int32 bias encoding on the fly (default: `True`)
86+ **kwargs: Same as `torch.onnx.export()`
87+
88+
89+ .. note::
90+ Unlike `torch.onnx.export()`, this function allows up to opset 21.
91+ to support 4/16-bit quantization only available in opset 21.
92+ However, exporting to opset 21 is a beta feature and not fully stable yet.
93+ For robustness, opset 20 or lower is recommended whenever possible.
94+
95+ .. note::
96+ Dynamo-based export (`dynamo=True`) is not supported yet
97+
98+ .. _torch.onnx.export(): https://docs.pytorch.org/docs/stable/onnx_torchscript.html#torch.onnx.export
99+ .. _QuantizeLinear: https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html
100+ .. _DequantizeLinear: https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html
101+
102+ Examples:
103+
104+ >>> aimet_torch.onnx.export(sim.model, x, f="model.onnx",
105+ ... input_names=["input"], output_names=["output"],
106+ ... opset_version=21, export_int32_bias=True)
107+ ...
108+ >>> import onnxruntime as ort
109+ >>> options = ort.SessionOptions()
110+ >>> options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
111+ >>> sess = ort.InferenceSession("model.onnx", sess_options=options)
112+ >>> onnx_output, = sess.run(None, {"input": x.detach().numpy()})
113+ >>> torch.nn.functional.cosine_similarity(torch.from_numpy(onnx_output), sim.model(x))
114+ tensor([1.0000, 0.9999, 1.0000, ..., 1.0000, 1.0000, 1.0000],
115+ grad_fn=<AliasBackward0>)
79116 """
80117 if isinstance (model , QuantizationSimModel ):
81118 model = model .model
0 commit comments