Skip to content

Commit 24e2d4a

Browse files
Lee, Kyunggeunquic-kyunggeu
authored andcommitted
Export spot-on weights in aimet-onnx QDQ export
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com> Co-authored-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
1 parent aecf48b commit 24e2d4a

File tree

2 files changed

+98
-4
lines changed

2 files changed

+98
-4
lines changed

TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,8 @@ def _to_onnx_qdq(self) -> onnx.ModelProto:
12531253
model_copy = onnx.ModelProto()
12541254
model_copy.CopyFrom(self.model.model)
12551255

1256+
self._overwrite_parameters(model_copy, self._get_qdq_parameters())
1257+
12561258
aimet_qc_quantize_nodes = [
12571259
node for node in model_copy.graph.node
12581260
if node.op_type == "QcQuantizeOp"
@@ -1289,6 +1291,102 @@ def _to_onnx_qdq(self) -> onnx.ModelProto:
12891291
model_copy = onnx.version_converter.convert_version(model_copy, desired_onnx_opset_version)
12901292
return model_copy
12911293

1294+
def _get_qdq_parameters(self):
1295+
param_names = {
1296+
product.name
1297+
for op in self.connected_graph.get_all_ops().values()
1298+
for product, _ in op.parameters.values()
1299+
if self.qc_quantize_op_dict[product.name].bitwidth <= 16
1300+
}
1301+
qdq_params = {
1302+
f"{product.name}_qdq": product
1303+
for op in self.connected_graph.get_all_ops().values()
1304+
for product, _ in op.parameters.values()
1305+
if self.qc_quantize_op_dict[product.name].bitwidth <= 16
1306+
}
1307+
1308+
partial_model = onnx.helper.make_model(
1309+
graph=onnx.helper.make_graph(
1310+
name="partial",
1311+
inputs=[],
1312+
outputs=[
1313+
onnx.helper.make_tensor_value_info(qdq_param_name,
1314+
onnx.TensorProto.FLOAT,
1315+
shape=p.shape)
1316+
for qdq_param_name, p in qdq_params.items()
1317+
],
1318+
initializer=[
1319+
init for init in self.model.model.graph.initializer
1320+
if init.name in param_names
1321+
],
1322+
nodes=[
1323+
node for node in self.model.model.graph.node
1324+
if any(inp in param_names for inp in node.input) or
1325+
(node.op_type == "Constant" and
1326+
any(out in param_names for out in node.output))
1327+
],
1328+
)
1329+
)
1330+
1331+
sess = self.build_session(partial_model, ["CPUExecutionProvider"])
1332+
out = sess.run(list(qdq_params.keys()), {})
1333+
return {
1334+
qdq_param_name: qdq_param
1335+
for qdq_param_name, qdq_param
1336+
in zip(qdq_params.keys(), out)
1337+
}
1338+
1339+
@staticmethod
1340+
def _overwrite_parameters(model: onnx.ModelProto, parameters: Dict[str, np.ndarray]):
1341+
initializers = [
1342+
(init, parameters.pop(f"{init.name}_qdq"))
1343+
for init in model.graph.initializer
1344+
if f"{init.name}_qdq" in parameters
1345+
]
1346+
constants = [
1347+
(node, parameters.pop(f"{node.output[0]}_qdq"))
1348+
for node in model.graph.node
1349+
if node.op_type == "Constant" and f"{node.output[0]}_qdq" in parameters
1350+
]
1351+
1352+
found = set(init.name for init, _ in initializers) | \
1353+
set(const.output[0] for const, _ in constants)
1354+
1355+
not_found = parameters.keys() - found
1356+
1357+
if not_found:
1358+
raise RuntimeError(
1359+
f"Couldn't find parameters: {list(not_found)}"
1360+
)
1361+
1362+
for const, _ in constants:
1363+
if any(attr.name in ("value_string", "value_strings")
1364+
for attr in const.attribute):
1365+
raise RuntimeError(f"String constant {const.name} can't be quantized")
1366+
1367+
for init, qdq_param in initializers:
1368+
init.raw_data = qdq_param.tobytes()
1369+
1370+
for const, qdq_param in constants:
1371+
for attr in const.attribute:
1372+
if attr.name == "value":
1373+
attr.t.raw_data = qdq_param.tobytes()
1374+
break
1375+
if attr.name == "value_float":
1376+
attr.float = float(qdq_param)
1377+
break
1378+
if attr.name == "value_floats":
1379+
attr.ClearField("floats")
1380+
attr.floats.extend(qdq_param.astype(np.float32).tolist())
1381+
break
1382+
if attr.name == "value_int":
1383+
attr.int = int(qdq_param)
1384+
break
1385+
if attr.name == "value_ints":
1386+
attr.ClearField("ints")
1387+
attr.floats.extend(qdq_param.astype(np.int64).tolist())
1388+
break
1389+
12921390

12931391
# pylint: disable=too-many-locals, too-many-branches
12941392
def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_path: str, strict=True) -> \

TrainingExtensions/onnx/test/python/test_quantsim.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,12 +2410,8 @@ def test_onnx_qdq(model_factory,
24102410

24112411
if export_int32_bias_encodings:
24122412
sim._concretize_int32_bias_quantizers()
2413-
24142413
# FIXME: Need extra tolerance due to numerical instability of AIMET int32 bias qdq.
24152414
tolerance += 1
2416-
if any(node.op_type == "InstanceNormalization" for node in sim.model.graph().node):
2417-
# FIXME: InstanceNormalization is especially more unstable with int32 bias qdq.
2418-
tolerance += 2
24192415

24202416
out_sim, = sim.session.run(None, {"input": input})
24212417

0 commit comments

Comments
 (0)