Skip to content

Commit 96d61db

Browse files
michaelgtuttlequic-mtuttle
authored andcommitted
Debug aimet_onnx quantizer group creation
Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com> Co-authored-by: Michael Tuttle <quic_mtuttle@quicinc.com>
1 parent 88d29c5 commit 96d61db

File tree

3 files changed

+187
-237
lines changed

3 files changed

+187
-237
lines changed

TrainingExtensions/onnx/src/python/aimet_onnx/amp/quantizer_groups.py

Lines changed: 80 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,14 @@
4242
from dataclasses import dataclass, field
4343

4444
from aimet_common.connected_graph.operation import Op
45-
from aimet_common.connected_graph.connectedgraph_utils import get_all_input_ops, get_all_output_ops
4645

4746
from aimet_common.amp.utils import CANDIDATE_WITH_DTYPE
4847

49-
from aimet_common.connected_graph.connectedgraph import get_ordered_ops
5048
from aimet_common.amp.quantizer_groups import QuantizerGroupBase, get_supported_candidates_for_quantizers, \
51-
compute_baseline_candidate_options, find_valid_ops
49+
compute_baseline_candidate_options
5250
from aimet_common.utils import AimetLogger
5351

54-
from aimet_onnx.meta.connectedgraph import ConnectedGraph
52+
from aimet_onnx.meta.connectedgraph import ConnectedGraph, Product
5553
from aimet_onnx.quantsim import QuantizationSimModel
5654
from aimet_onnx.qc_quantize_op import QcQuantizeOp
5755

@@ -163,208 +161,107 @@ def get_param_quantizers(self, name_to_quantizer_dict):
163161
ops_not_to_traverse = ['Shape']
164162

165163

166-
def find_op_groups(connected_graph: ConnectedGraph) -> Dict:
164+
def find_quantizer_group(sim: QuantizationSimModel):
167165
"""
168-
Finds parent child groups based on following rules.
169-
1) If there is a direct connection between two ops, op1 and op2, then op1 is parent of op2 and they form a group
170-
2) If the input to an op (op1) is shared with another op (op2), the op producing the input (op0) is the parent,
171-
and op1 and op2 are the children
166+
Create quantizer groups following these rules:
172167
173-
:param connected_graph: Connected graph
174-
:return: Dict of parent (key) and children (value) groups
175-
"""
176-
# Get ordered ops in Connected graph
177-
ordered_ops = get_ordered_ops(connected_graph.starting_ops)
178-
valid_ops = find_valid_ops(connected_graph, ops_not_to_traverse)
168+
1) All quantized tensors exist in exactly 1 quantizer group
169+
2) A parameter's quantizer group contains all other tensors which feed into all ops that the parameter feeds into
170+
3) Any quantizer group must not be decomposable into multiple quantizer groups that still follow rules 1 and 2
179171
180-
parent_child_op_groups = defaultdict(list)
181-
map_for_skipped_ops = {}
172+
Note that two activations feeding into the same binary op would not fall into the same quantizer group using the above
173+
definition, while an activation and parameter feeding into a binary op would be in the same group
174+
"""
175+
quantized_tensors = {name for name, quantizer in sim.qc_quantize_op_dict.items() if quantizer.enabled}
176+
visited_tensors = set()
177+
quantizer_groups = []
182178

183-
for op in ordered_ops:
184-
if op.dotted_name not in valid_ops or op.type in op_types_to_ignore:
179+
for tensor_name in quantized_tensors:
180+
# Avoid re-creating duplicate quantizer groups
181+
if tensor_name in visited_tensors:
185182
continue
186-
_find_parent_child_op_groups(op, parent_child_op_groups, map_for_skipped_ops)
187183

188-
return parent_child_op_groups
184+
# Get all tensors belonging to the same group
185+
# TODO: Derive op_types_to_ignore from config file
186+
related_tensors = _get_related_quantizers(tensor_name,
187+
quantized_tensors,
188+
sim.connected_graph,
189+
op_types_to_ignore)
189190

191+
visited_tensors |= related_tensors
190192

191-
def _find_parent_child_op_groups(op: Op, parent_child_op_groups: Dict, map_for_skipped_ops: Dict):
192-
"""
193-
Finds op groups along the parent to child flow
194-
:param op: Op
195-
:param parent_child_op_groups: parent child op groups dict
196-
:param map_for_skipped_ops: map to find first skipped parents of skipped ops
197-
"""
198-
if op.outputs:
199-
consumers = op.output_ops
200-
for consumer in consumers:
201-
dotted_name = op.dotted_name
202-
if consumer.type in ops_not_to_traverse:
203-
continue
204-
if op.dotted_name in map_for_skipped_ops:
205-
dotted_name = map_for_skipped_ops[op.dotted_name]
206-
207-
if consumer.type in op_types_to_ignore:
208-
map_for_skipped_ops[consumer.dotted_name] = dotted_name
209-
_find_parent_child_op_groups(consumer, parent_child_op_groups, map_for_skipped_ops)
210-
else:
211-
if consumer.dotted_name not in parent_child_op_groups[dotted_name]:
212-
parent_child_op_groups[dotted_name].append(consumer.dotted_name)
213-
if not consumers and op.dotted_name in map_for_skipped_ops and \
214-
map_for_skipped_ops[op.dotted_name] not in parent_child_op_groups:
215-
parent_child_op_groups[map_for_skipped_ops[op.dotted_name]] = []
216-
else:
217-
dotted_name = op.dotted_name
218-
parent_child_op_groups[dotted_name].append(None)
219-
220-
221-
def find_quantizer_group(sim: QuantizationSimModel) -> Tuple[Dict, List[QuantizerGroup]]:
222-
"""
223-
Finds quantizer groups in a quantization sim
224-
:param sim: Quantization sim
225-
:return: Dictionary of quantized op name to sim.quantizer_config object, List of quantizer groups
226-
"""
227-
# Get connected graph from quantsim
228-
connected_graph = sim.connected_graph
193+
# Use ConnectedGraph to determine which tensors are parameters vs. activations
194+
parameters = {tensor for tensor in related_tensors if sim.connected_graph.get_all_products()[tensor].is_parm}
195+
activations = related_tensors - parameters
229196

230-
if connected_graph is None:
231-
raise AssertionError('Aborting Auto Mixed Precision, connected graph needs to exist for Auto Mixed precision')
197+
quantizer_group = QuantizerGroup(tuple(parameters), tuple(activations))
198+
logger.debug('Quantizer Group added: %s', quantizer_group)
199+
quantizer_groups.append(quantizer_group)
232200

233-
# Find parent to children mapping for connected graph ops
234-
parent_child_op_groups = find_op_groups(connected_graph)
201+
return sim.qc_quantize_op_dict, quantizer_groups
235202

236-
# Find mapping of quantized op name to quantizer info
237-
op_name_to_quantizer_dict = _get_op_name_to_act_quantizer_name_dicts(sim)
238-
op_to_param_dict = _get_op_to_param_name_dict(sim)
239203

240-
quantizer_groups = []
204+
def _get_related_quantizers(tensor: str, quantized_tensors: set[str], connected_graph: ConnectedGraph, pass_through_op_types: List[str]):
205+
"""
206+
Get all tensors for which the valid configurations depend on the configuration of `tensor`.
241207
242-
_add_input_quantizer_group(op_to_param_dict, sim, quantizer_groups)
243-
244-
for parent, children in parent_child_op_groups.items():
245-
activation_quantizers = []
246-
parameter_quantizers = []
247-
if parent in op_name_to_quantizer_dict:
248-
activation_quantizers.append(op_name_to_quantizer_dict[parent])
249-
for child in children:
250-
if child and child in op_to_param_dict:
251-
parameter_quantizers.append(op_to_param_dict[child])
252-
child_cg_op = connected_graph.get_op_from_module_name(child)
253-
for inp_prod in child_cg_op.inputs:
254-
if inp_prod.is_const and inp_prod.name in sim.qc_quantize_op_dict and \
255-
sim.qc_quantize_op_dict[inp_prod.name].enabled:
256-
activation_quantizers.append(inp_prod.name)
257-
if activation_quantizers or parameter_quantizers:
258-
_add_quantizer_group(quantizer_groups, tuple(activation_quantizers), tuple(parameter_quantizers))
259-
260-
_add_output_quantizer_group(op_name_to_quantizer_dict, sim, quantizer_groups)
208+
Dependant tensors are all inputs to all ops that consume `tensor`, and all tensors which depend on these tensors.
209+
"""
210+
tensor_queue = [tensor]
211+
related_quantized_tensors = set()
212+
visited_ops = set()
261213

262-
return sim.qc_quantize_op_dict, quantizer_groups
214+
while tensor_queue:
215+
name = tensor_queue.pop(0)
216+
product: Product = connected_graph.get_all_products()[name]
263217

218+
# Find all ops that consume this tensor
219+
consumers = _get_tensor_consumers(product, pass_through_op_types)
264220

265-
def _add_quantizer_group(quantizer_groups: List[QuantizerGroup], activation_quantizers: Tuple,
266-
parameter_quantizers: Tuple):
267-
"""
268-
Adds quantizer group to the quantizer groups list
269-
:param quantizer_groups: List of Quantizer groups
270-
:param activation_quantizers: Tuple of activation quantizers
271-
:param parameter_quantizers: Tuple of parameter quantizers
272-
"""
273-
quantizer_group = QuantizerGroup(parameter_quantizers=parameter_quantizers,
274-
activation_quantizers=activation_quantizers)
275-
if quantizer_group not in quantizer_groups:
276-
quantizer_groups.append(quantizer_group)
277-
logger.info('Quantizer Group added: %s', quantizer_group)
221+
# Ignore already-visited ops
222+
consumers -= visited_ops
223+
visited_ops |= consumers
278224

225+
# For any consumer which has a quantized parameter, add all inputs to the quantizer group
226+
input_tensors = {name}
227+
for op in consumers:
228+
if any(name in quantized_tensors for name in op.parameters.keys()):
229+
input_tensors |= set(t.name for t in _get_op_input_tensors(op, pass_through_op_types))
279230

280-
def _add_input_quantizer_group(op_to_param_dict: Dict, sim: QuantizationSimModel, quantizer_groups: List):
281-
"""
282-
Adds input's (of the model) quantizer group
283-
:param op_to_param_dict: Key: op_name Value: Weight name associated
284-
:param sim: Quantization Sim
285-
:param quantizer_groups: Quantizer Groups List
286-
"""
287-
conn_graph_ops = get_all_input_ops(sim.connected_graph)
288-
act_and_param_quants = []
289-
for input_op in conn_graph_ops:
290-
parent_child_op_groups = {input_op.dotted_name: [input_op.dotted_name]}
291-
if input_op.type in op_types_to_ignore:
292-
parent_child_op_groups, map_for_skipped_ops = {input_op.dotted_name: []}, {}
293-
_find_parent_child_op_groups(input_op, parent_child_op_groups, map_for_skipped_ops)
294-
parameter_quantizers = []
295-
activation_quantizers = []
296-
for child_name in parent_child_op_groups[input_op.dotted_name]:
297-
if child_name in op_to_param_dict:
298-
parameter_quantizers.append(op_to_param_dict[child_name])
299-
for input_product in input_op.inputs:
300-
activation_quantizer = input_product.tensor_dict[input_op]
301-
if isinstance(activation_quantizer, str) and \
302-
activation_quantizer in sim.activation_names and \
303-
sim.qc_quantize_op_dict[activation_quantizer].enabled:
304-
activation_quantizers.append(input_product.tensor_dict[input_op])
305-
if activation_quantizers or parameter_quantizers:
306-
for quant in act_and_param_quants:
307-
if set(quant[0]).intersection(set(activation_quantizers)) or \
308-
set(quant[1]).intersection(set(parameter_quantizers)):
309-
quant[0].extend(activation_quantizers)
310-
quant[1].extend(parameter_quantizers)
311-
break
312-
else:
313-
act_and_param_quants.append([activation_quantizers, parameter_quantizers])
314-
315-
for quant in act_and_param_quants:
316-
_add_quantizer_group(quantizer_groups, tuple(set(quant[0])), tuple(set(quant[1])))
317-
318-
319-
def _add_output_quantizer_group(op_name_to_quantizer_dict: Dict, sim: QuantizationSimModel, quantizer_groups: List):
320-
"""
321-
Adds output's (of the model) quantizer group
322-
:param op_name_to_quantizer_dict: Key: op_name Value: quantizer associated with op name
323-
:param sim: Quantization Sim
324-
:param quantizer_groups: Quantizer Groups List
325-
"""
326-
conn_graph_ops = get_all_output_ops(sim.connected_graph)
327-
for output_op in conn_graph_ops:
328-
activation_quantizers = []
329-
if output_op.dotted_name in op_name_to_quantizer_dict:
330-
activation_quantizers.append(op_name_to_quantizer_dict[output_op.dotted_name])
331-
if activation_quantizers:
332-
_add_quantizer_group(quantizer_groups, tuple(activation_quantizers), ())
231+
# Only look at quantized tensors which we haven't visited
232+
input_tensors = (input_tensors & quantized_tensors) - related_quantized_tensors
233+
related_quantized_tensors |= input_tensors
333234

235+
# Add newly found tensors to tensor_queue
236+
for item in input_tensors:
237+
if item not in tensor_queue:
238+
tensor_queue.append(item)
334239

335-
def _get_op_to_param_name_dict(sim: QuantizationSimModel) -> Dict:
336-
"""
337-
Creates the dict where param name (weight) is mapped to op's name
338-
:param sim: Quantization Sim
339-
"""
340-
op_to_param_dict = {}
341-
conn_graph_ops = sim.connected_graph.get_all_ops()
342-
for op in conn_graph_ops.values():
343-
for param_name in op.parameters:
344-
_, param_type = op.parameters[param_name]
345-
if param_type == 'weight' and sim.qc_quantize_op_dict[param_name].enabled:
346-
op_to_param_dict[op.dotted_name] = param_name
240+
return related_quantized_tensors
347241

348-
return op_to_param_dict
349242

243+
def _get_op_input_tensors(op: Op, pass_through_op_types: List[str]) -> List[Product]:
244+
""" Get all input tensors to `op`, traversing through ops of type `pass_through_op_types` """
245+
inputs = []
246+
for inp in op.inputs:
247+
# Pass through ops which don't have output quantizers if necessary
248+
while inp.producer and inp.producer.type in pass_through_op_types:
249+
inp = inp.producer.inputs[0]
250+
inputs.append(inp)
350251

351-
def _get_op_name_to_act_quantizer_name_dicts(sim: QuantizationSimModel) -> Dict:
352-
"""
353-
Creates the dict where param quantizers if enabled are mapped to their param_names and activation
354-
quantizer if enabled is mapped to it's inputs name
355-
:param sim: Quantization Sim
356-
:return op_name_to_activation_quantizer_name_dict
357-
"""
358-
op_name_to_activation_quantizer_name_dict = {}
359-
for node in sim.model.model.graph.node:
360-
if 'QcQuantizeOp' in node.name:
361-
continue
362-
for output_product in node.output:
363-
if output_product in sim.activation_names:
364-
activation_quantizer_op = sim.qc_quantize_op_dict[output_product]
365-
if activation_quantizer_op.enabled:
366-
op_name_to_activation_quantizer_name_dict[node.name] = output_product
367-
return op_name_to_activation_quantizer_name_dict
252+
return inputs
253+
254+
255+
def _get_tensor_consumers(product: Product, pass_through_op_types: List[str]):
256+
""" Get all consumers of `product`, traversing through ops of type `pass_through_op_types` """
257+
consumers = set()
258+
for consumer in product.consumers:
259+
if consumer.type in pass_through_op_types:
260+
consumers |= {op for output in consumer.outputs for op in _get_tensor_consumers(output, pass_through_op_types)}
261+
else:
262+
consumers.add(consumer)
263+
264+
return consumers
368265

369266

370267
def find_supported_candidates(quantizer_groups: List[QuantizerGroup],

TrainingExtensions/onnx/test/python/test_mixed_precision.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def eval_func(model, args):
124124
# quantizer group 1
125125
input_quantizer = sim.qc_quantize_op_dict['input']
126126

127-
conv0_param_quantizer = list(sim.qc_quantize_op_dict.values())[0]
127+
conv0_param_quantizer = sim.qc_quantize_op_dict[sim.connected_graph.ordered_ops[0].inputs[1].name]
128128
if input_quantizer.enabled and conv0_param_quantizer.enabled:
129129
quantizer_1 = (
130130
(input_quantizer.bitwidth, QuantizationDataType.int),
@@ -268,7 +268,8 @@ def test_phase1(self, sim, candidates, forward_pass_callback, eval_callback_phas
268268
results_dir, True, forward_pass_callback)
269269
algo.set_baseline()
270270

271-
candidate = algo.quantizer_groups[0].get_candidate(algo._module_name_dict)
271+
input_group = [qg for qg in algo.quantizer_groups if "input" in qg.activation_quantizers][0]
272+
candidate = input_group.get_candidate(algo._module_name_dict)
272273
# Check if quantizer group is set to maximum bitwidth
273274
assert algo.baseline_candidate == candidate
274275

@@ -400,13 +401,14 @@ def _run_phase2(self, algo, allowed_accuracy_drop, search_algo):
400401
algo.min_candidate = W16A8
401402
fp32_acc = 1.0
402403

404+
input_group = [qg for qg in algo.quantizer_groups if "input" in qg.activation_quantizers][0]
405+
fc_group = [qg for qg in algo.quantizer_groups if "fc.weight" in qg.parameter_quantizers][0]
403406
accuracy_list = [
404-
(algo.quantizer_groups[0], W8A16, phase1_eval_score_lookup_table[(W8A16, "fp32")], 100),
405-
(algo.quantizer_groups[0], W16A8, phase1_eval_score_lookup_table[(W16A8, "fp32")], 90),
406-
(algo.quantizer_groups[8], W8A16, phase1_eval_score_lookup_table[("fp32", W8A16)], 80),
407-
(algo.quantizer_groups[8], W16A8, phase1_eval_score_lookup_table[("fp32", W16A8)], 70),
407+
(input_group, W8A16, phase1_eval_score_lookup_table[(W8A16, "fp32")], 100),
408+
(input_group, W16A8, phase1_eval_score_lookup_table[(W16A8, "fp32")], 90),
409+
(fc_group, W8A16, phase1_eval_score_lookup_table[("fp32", W8A16)], 80),
410+
(fc_group, W16A8, phase1_eval_score_lookup_table[("fp32", W16A8)], 70),
408411
]
409-
410412
return algo._create_pareto_front_list(allowed_accuracy_drop, accuracy_list, fp32_acc,
411413
algo.baseline_candidate, algo.min_candidate, search_algo, phase2_reverse = False)
412414

0 commit comments

Comments
 (0)