@@ -1253,6 +1253,8 @@ def _to_onnx_qdq(self) -> onnx.ModelProto:
12531253 model_copy = onnx .ModelProto ()
12541254 model_copy .CopyFrom (self .model .model )
12551255
1256+ self ._overwrite_parameters (model_copy , self ._get_qdq_parameters ())
1257+
12561258 aimet_qc_quantize_nodes = [
12571259 node for node in model_copy .graph .node
12581260 if node .op_type == "QcQuantizeOp"
@@ -1289,6 +1291,102 @@ def _to_onnx_qdq(self) -> onnx.ModelProto:
12891291 model_copy = onnx .version_converter .convert_version (model_copy , desired_onnx_opset_version )
12901292 return model_copy
12911293
1294+ def _get_qdq_parameters (self ):
1295+ param_names = {
1296+ product .name
1297+ for op in self .connected_graph .get_all_ops ().values ()
1298+ for product , _ in op .parameters .values ()
1299+ if self .qc_quantize_op_dict [product .name ].bitwidth <= 16
1300+ }
1301+ qdq_params = {
1302+ f"{ product .name } _qdq" : product
1303+ for op in self .connected_graph .get_all_ops ().values ()
1304+ for product , _ in op .parameters .values ()
1305+ if self .qc_quantize_op_dict [product .name ].bitwidth <= 16
1306+ }
1307+
1308+ partial_model = onnx .helper .make_model (
1309+ graph = onnx .helper .make_graph (
1310+ name = "partial" ,
1311+ inputs = [],
1312+ outputs = [
1313+ onnx .helper .make_tensor_value_info (qdq_param_name ,
1314+ onnx .TensorProto .FLOAT ,
1315+ shape = p .shape )
1316+ for qdq_param_name , p in qdq_params .items ()
1317+ ],
1318+ initializer = [
1319+ init for init in self .model .model .graph .initializer
1320+ if init .name in param_names
1321+ ],
1322+ nodes = [
1323+ node for node in self .model .model .graph .node
1324+ if any (inp in param_names for inp in node .input ) or
1325+ (node .op_type == "Constant" and
1326+ any (out in param_names for out in node .output ))
1327+ ],
1328+ )
1329+ )
1330+
1331+ sess = self .build_session (partial_model , ["CPUExecutionProvider" ])
1332+ out = sess .run (list (qdq_params .keys ()), {})
1333+ return {
1334+ qdq_param_name : qdq_param
1335+ for qdq_param_name , qdq_param
1336+ in zip (qdq_params .keys (), out )
1337+ }
1338+
1339+ @staticmethod
1340+ def _overwrite_parameters (model : onnx .ModelProto , parameters : Dict [str , np .ndarray ]):
1341+ initializers = [
1342+ (init , parameters .pop (f"{ init .name } _qdq" ))
1343+ for init in model .graph .initializer
1344+ if f"{ init .name } _qdq" in parameters
1345+ ]
1346+ constants = [
1347+ (node , parameters .pop (f"{ node .output [0 ]} _qdq" ))
1348+ for node in model .graph .node
1349+ if node .op_type == "Constant" and f"{ node .output [0 ]} _qdq" in parameters
1350+ ]
1351+
1352+ found = set (init .name for init , _ in initializers ) | \
1353+ set (const .output [0 ] for const , _ in constants )
1354+
1355+ not_found = parameters .keys () - found
1356+
1357+ if not_found :
1358+ raise RuntimeError (
1359+ f"Couldn't find parameters: { list (not_found )} "
1360+ )
1361+
1362+ for const , _ in constants :
1363+ if any (attr .name in ("value_string" , "value_strings" )
1364+ for attr in const .attribute ):
1365+ raise RuntimeError (f"String constant { const .name } can't be quantized" )
1366+
1367+ for init , qdq_param in initializers :
1368+ init .raw_data = qdq_param .tobytes ()
1369+
1370+ for const , qdq_param in constants :
1371+ for attr in const .attribute :
1372+ if attr .name == "value" :
1373+ attr .t .raw_data = qdq_param .tobytes ()
1374+ break
1375+ if attr .name == "value_float" :
1376+ attr .float = float (qdq_param )
1377+ break
1378+ if attr .name == "value_floats" :
1379+ attr .ClearField ("floats" )
1380+ attr .floats .extend (qdq_param .astype (np .float32 ).tolist ())
1381+ break
1382+ if attr .name == "value_int" :
1383+ attr .int = int (qdq_param )
1384+ break
1385+ if attr .name == "value_ints" :
1386+ attr .ClearField ("ints" )
1387+ attr .floats .extend (qdq_param .astype (np .int64 ).tolist ())
1388+ break
1389+
12921390
12931391# pylint: disable=too-many-locals, too-many-branches
12941392def load_encodings_to_sim (quant_sim_model : QuantizationSimModel , onnx_encoding_path : str , strict = True ) -> \
0 commit comments