|
42 | 42 | from dataclasses import dataclass, field |
43 | 43 |
|
44 | 44 | from aimet_common.connected_graph.operation import Op |
45 | | -from aimet_common.connected_graph.connectedgraph_utils import get_all_input_ops, get_all_output_ops |
46 | 45 |
|
47 | 46 | from aimet_common.amp.utils import CANDIDATE_WITH_DTYPE |
48 | 47 |
|
49 | | -from aimet_common.connected_graph.connectedgraph import get_ordered_ops |
50 | 48 | from aimet_common.amp.quantizer_groups import QuantizerGroupBase, get_supported_candidates_for_quantizers, \ |
51 | | - compute_baseline_candidate_options, find_valid_ops |
| 49 | + compute_baseline_candidate_options |
52 | 50 | from aimet_common.utils import AimetLogger |
53 | 51 |
|
54 | | -from aimet_onnx.meta.connectedgraph import ConnectedGraph |
| 52 | +from aimet_onnx.meta.connectedgraph import ConnectedGraph, Product |
55 | 53 | from aimet_onnx.quantsim import QuantizationSimModel |
56 | 54 | from aimet_onnx.qc_quantize_op import QcQuantizeOp |
57 | 55 |
|
@@ -163,208 +161,107 @@ def get_param_quantizers(self, name_to_quantizer_dict): |
163 | 161 | ops_not_to_traverse = ['Shape'] |
164 | 162 |
|
165 | 163 |
|
166 | | -def find_op_groups(connected_graph: ConnectedGraph) -> Dict: |
| 164 | +def find_quantizer_group(sim: QuantizationSimModel): |
167 | 165 | """ |
168 | | - Finds parent child groups based on following rules. |
169 | | - 1) If there is a direct connection between two ops, op1 and op2, then op1 is parent of op2 and they form a group |
170 | | - 2) If the input to an op (op1) is shared with another op (op2), the op producing the input (op0) is the parent, |
171 | | - and op1 and op2 are the children |
| 166 | + Create quantizer groups following these rules: |
172 | 167 |
|
173 | | - :param connected_graph: Connected graph |
174 | | - :return: Dict of parent (key) and children (value) groups |
175 | | - """ |
176 | | - # Get ordered ops in Connected graph |
177 | | - ordered_ops = get_ordered_ops(connected_graph.starting_ops) |
178 | | - valid_ops = find_valid_ops(connected_graph, ops_not_to_traverse) |
| 168 | + 1) All quantized tensors exist in exactly 1 quantizer group |
| 169 | + 2) A parameter's quantizer group contains all other tensors which feed into all ops that the parameter feeds into |
| 170 | + 3) Any quantizer group must not be decomposable into multiple quantizer groups that still follow rules 1 and 2 |
179 | 171 |
|
180 | | - parent_child_op_groups = defaultdict(list) |
181 | | - map_for_skipped_ops = {} |
| 172 | + Note that two activations feeding into the same binary op would not fall into the same quantizer group using the above |
| 173 | + definition, while an activation and parameter feeding into a binary op would be in the same group |
| 174 | + """ |
| 175 | + quantized_tensors = {name for name, quantizer in sim.qc_quantize_op_dict.items() if quantizer.enabled} |
| 176 | + visited_tensors = set() |
| 177 | + quantizer_groups = [] |
182 | 178 |
|
183 | | - for op in ordered_ops: |
184 | | - if op.dotted_name not in valid_ops or op.type in op_types_to_ignore: |
| 179 | + for tensor_name in quantized_tensors: |
| 180 | + # Avoid re-creating duplicate quantizer groups |
| 181 | + if tensor_name in visited_tensors: |
185 | 182 | continue |
186 | | - _find_parent_child_op_groups(op, parent_child_op_groups, map_for_skipped_ops) |
187 | 183 |
|
188 | | - return parent_child_op_groups |
| 184 | + # Get all tensors belonging to the same group |
| 185 | + # TODO: Derive op_types_to_ignore from config file |
| 186 | + related_tensors = _get_related_quantizers(tensor_name, |
| 187 | + quantized_tensors, |
| 188 | + sim.connected_graph, |
| 189 | + op_types_to_ignore) |
189 | 190 |
|
| 191 | + visited_tensors |= related_tensors |
190 | 192 |
|
191 | | -def _find_parent_child_op_groups(op: Op, parent_child_op_groups: Dict, map_for_skipped_ops: Dict): |
192 | | - """ |
193 | | - Finds op groups along the parent to child flow |
194 | | - :param op: Op |
195 | | - :param parent_child_op_groups: parent child op groups dict |
196 | | - :param map_for_skipped_ops: map to find first skipped parents of skipped ops |
197 | | - """ |
198 | | - if op.outputs: |
199 | | - consumers = op.output_ops |
200 | | - for consumer in consumers: |
201 | | - dotted_name = op.dotted_name |
202 | | - if consumer.type in ops_not_to_traverse: |
203 | | - continue |
204 | | - if op.dotted_name in map_for_skipped_ops: |
205 | | - dotted_name = map_for_skipped_ops[op.dotted_name] |
206 | | - |
207 | | - if consumer.type in op_types_to_ignore: |
208 | | - map_for_skipped_ops[consumer.dotted_name] = dotted_name |
209 | | - _find_parent_child_op_groups(consumer, parent_child_op_groups, map_for_skipped_ops) |
210 | | - else: |
211 | | - if consumer.dotted_name not in parent_child_op_groups[dotted_name]: |
212 | | - parent_child_op_groups[dotted_name].append(consumer.dotted_name) |
213 | | - if not consumers and op.dotted_name in map_for_skipped_ops and \ |
214 | | - map_for_skipped_ops[op.dotted_name] not in parent_child_op_groups: |
215 | | - parent_child_op_groups[map_for_skipped_ops[op.dotted_name]] = [] |
216 | | - else: |
217 | | - dotted_name = op.dotted_name |
218 | | - parent_child_op_groups[dotted_name].append(None) |
219 | | - |
220 | | - |
221 | | -def find_quantizer_group(sim: QuantizationSimModel) -> Tuple[Dict, List[QuantizerGroup]]: |
222 | | - """ |
223 | | - Finds quantizer groups in a quantization sim |
224 | | - :param sim: Quantization sim |
225 | | - :return: Dictionary of quantized op name to sim.quantizer_config object, List of quantizer groups |
226 | | - """ |
227 | | - # Get connected graph from quantsim |
228 | | - connected_graph = sim.connected_graph |
| 193 | + # Use ConnectedGraph to determine which tensors are parameters vs. activations |
| 194 | + parameters = {tensor for tensor in related_tensors if sim.connected_graph.get_all_products()[tensor].is_parm} |
| 195 | + activations = related_tensors - parameters |
229 | 196 |
|
230 | | - if connected_graph is None: |
231 | | - raise AssertionError('Aborting Auto Mixed Precision, connected graph needs to exist for Auto Mixed precision') |
| 197 | + quantizer_group = QuantizerGroup(tuple(parameters), tuple(activations)) |
| 198 | + logger.debug('Quantizer Group added: %s', quantizer_group) |
| 199 | + quantizer_groups.append(quantizer_group) |
232 | 200 |
|
233 | | - # Find parent to children mapping for connected graph ops |
234 | | - parent_child_op_groups = find_op_groups(connected_graph) |
| 201 | + return sim.qc_quantize_op_dict, quantizer_groups |
235 | 202 |
|
236 | | - # Find mapping of quantized op name to quantizer info |
237 | | - op_name_to_quantizer_dict = _get_op_name_to_act_quantizer_name_dicts(sim) |
238 | | - op_to_param_dict = _get_op_to_param_name_dict(sim) |
239 | 203 |
|
240 | | - quantizer_groups = [] |
| 204 | +def _get_related_quantizers(tensor: str, quantized_tensors: set[str], connected_graph: ConnectedGraph, pass_through_op_types: List[str]): |
| 205 | + """ |
| 206 | + Get all tensors for which the valid configurations depend on the configuration of `tensor`. |
241 | 207 |
|
242 | | - _add_input_quantizer_group(op_to_param_dict, sim, quantizer_groups) |
243 | | - |
244 | | - for parent, children in parent_child_op_groups.items(): |
245 | | - activation_quantizers = [] |
246 | | - parameter_quantizers = [] |
247 | | - if parent in op_name_to_quantizer_dict: |
248 | | - activation_quantizers.append(op_name_to_quantizer_dict[parent]) |
249 | | - for child in children: |
250 | | - if child and child in op_to_param_dict: |
251 | | - parameter_quantizers.append(op_to_param_dict[child]) |
252 | | - child_cg_op = connected_graph.get_op_from_module_name(child) |
253 | | - for inp_prod in child_cg_op.inputs: |
254 | | - if inp_prod.is_const and inp_prod.name in sim.qc_quantize_op_dict and \ |
255 | | - sim.qc_quantize_op_dict[inp_prod.name].enabled: |
256 | | - activation_quantizers.append(inp_prod.name) |
257 | | - if activation_quantizers or parameter_quantizers: |
258 | | - _add_quantizer_group(quantizer_groups, tuple(activation_quantizers), tuple(parameter_quantizers)) |
259 | | - |
260 | | - _add_output_quantizer_group(op_name_to_quantizer_dict, sim, quantizer_groups) |
| 208 | + Dependant tensors are all inputs to all ops that consume `tensor`, and all tensors which depend on these tensors. |
| 209 | + """ |
| 210 | + tensor_queue = [tensor] |
| 211 | + related_quantized_tensors = set() |
| 212 | + visited_ops = set() |
261 | 213 |
|
262 | | - return sim.qc_quantize_op_dict, quantizer_groups |
| 214 | + while tensor_queue: |
| 215 | + name = tensor_queue.pop(0) |
| 216 | + product: Product = connected_graph.get_all_products()[name] |
263 | 217 |
|
| 218 | + # Find all ops that consume this tensor |
| 219 | + consumers = _get_tensor_consumers(product, pass_through_op_types) |
264 | 220 |
|
265 | | -def _add_quantizer_group(quantizer_groups: List[QuantizerGroup], activation_quantizers: Tuple, |
266 | | - parameter_quantizers: Tuple): |
267 | | - """ |
268 | | - Adds quantizer group to the quantizer groups list |
269 | | - :param quantizer_groups: List of Quantizer groups |
270 | | - :param activation_quantizers: Tuple of activation quantizers |
271 | | - :param parameter_quantizers: Tuple of parameter quantizers |
272 | | - """ |
273 | | - quantizer_group = QuantizerGroup(parameter_quantizers=parameter_quantizers, |
274 | | - activation_quantizers=activation_quantizers) |
275 | | - if quantizer_group not in quantizer_groups: |
276 | | - quantizer_groups.append(quantizer_group) |
277 | | - logger.info('Quantizer Group added: %s', quantizer_group) |
| 221 | + # Ignore already-visited ops |
| 222 | + consumers -= visited_ops |
| 223 | + visited_ops |= consumers |
278 | 224 |
|
| 225 | + # For any consumer which has a quantized parameter, add all inputs to the quantizer group |
| 226 | + input_tensors = {name} |
| 227 | + for op in consumers: |
| 228 | + if any(name in quantized_tensors for name in op.parameters.keys()): |
| 229 | + input_tensors |= set(t.name for t in _get_op_input_tensors(op, pass_through_op_types)) |
279 | 230 |
|
280 | | -def _add_input_quantizer_group(op_to_param_dict: Dict, sim: QuantizationSimModel, quantizer_groups: List): |
281 | | - """ |
282 | | - Adds input's (of the model) quantizer group |
283 | | - :param op_to_param_dict: Key: op_name Value: Weight name associated |
284 | | - :param sim: Quantization Sim |
285 | | - :param quantizer_groups: Quantizer Groups List |
286 | | - """ |
287 | | - conn_graph_ops = get_all_input_ops(sim.connected_graph) |
288 | | - act_and_param_quants = [] |
289 | | - for input_op in conn_graph_ops: |
290 | | - parent_child_op_groups = {input_op.dotted_name: [input_op.dotted_name]} |
291 | | - if input_op.type in op_types_to_ignore: |
292 | | - parent_child_op_groups, map_for_skipped_ops = {input_op.dotted_name: []}, {} |
293 | | - _find_parent_child_op_groups(input_op, parent_child_op_groups, map_for_skipped_ops) |
294 | | - parameter_quantizers = [] |
295 | | - activation_quantizers = [] |
296 | | - for child_name in parent_child_op_groups[input_op.dotted_name]: |
297 | | - if child_name in op_to_param_dict: |
298 | | - parameter_quantizers.append(op_to_param_dict[child_name]) |
299 | | - for input_product in input_op.inputs: |
300 | | - activation_quantizer = input_product.tensor_dict[input_op] |
301 | | - if isinstance(activation_quantizer, str) and \ |
302 | | - activation_quantizer in sim.activation_names and \ |
303 | | - sim.qc_quantize_op_dict[activation_quantizer].enabled: |
304 | | - activation_quantizers.append(input_product.tensor_dict[input_op]) |
305 | | - if activation_quantizers or parameter_quantizers: |
306 | | - for quant in act_and_param_quants: |
307 | | - if set(quant[0]).intersection(set(activation_quantizers)) or \ |
308 | | - set(quant[1]).intersection(set(parameter_quantizers)): |
309 | | - quant[0].extend(activation_quantizers) |
310 | | - quant[1].extend(parameter_quantizers) |
311 | | - break |
312 | | - else: |
313 | | - act_and_param_quants.append([activation_quantizers, parameter_quantizers]) |
314 | | - |
315 | | - for quant in act_and_param_quants: |
316 | | - _add_quantizer_group(quantizer_groups, tuple(set(quant[0])), tuple(set(quant[1]))) |
317 | | - |
318 | | - |
319 | | -def _add_output_quantizer_group(op_name_to_quantizer_dict: Dict, sim: QuantizationSimModel, quantizer_groups: List): |
320 | | - """ |
321 | | - Adds output's (of the model) quantizer group |
322 | | - :param op_name_to_quantizer_dict: Key: op_name Value: quantizer associated with op name |
323 | | - :param sim: Quantization Sim |
324 | | - :param quantizer_groups: Quantizer Groups List |
325 | | - """ |
326 | | - conn_graph_ops = get_all_output_ops(sim.connected_graph) |
327 | | - for output_op in conn_graph_ops: |
328 | | - activation_quantizers = [] |
329 | | - if output_op.dotted_name in op_name_to_quantizer_dict: |
330 | | - activation_quantizers.append(op_name_to_quantizer_dict[output_op.dotted_name]) |
331 | | - if activation_quantizers: |
332 | | - _add_quantizer_group(quantizer_groups, tuple(activation_quantizers), ()) |
| 231 | + # Only look at quantized tensors which we haven't visited |
| 232 | + input_tensors = (input_tensors & quantized_tensors) - related_quantized_tensors |
| 233 | + related_quantized_tensors |= input_tensors |
333 | 234 |
|
| 235 | + # Add newly found tensors to tensor_queue |
| 236 | + for item in input_tensors: |
| 237 | + if item not in tensor_queue: |
| 238 | + tensor_queue.append(item) |
334 | 239 |
|
335 | | -def _get_op_to_param_name_dict(sim: QuantizationSimModel) -> Dict: |
336 | | - """ |
337 | | - Creates the dict where param name (weight) is mapped to op's name |
338 | | - :param sim: Quantization Sim |
339 | | - """ |
340 | | - op_to_param_dict = {} |
341 | | - conn_graph_ops = sim.connected_graph.get_all_ops() |
342 | | - for op in conn_graph_ops.values(): |
343 | | - for param_name in op.parameters: |
344 | | - _, param_type = op.parameters[param_name] |
345 | | - if param_type == 'weight' and sim.qc_quantize_op_dict[param_name].enabled: |
346 | | - op_to_param_dict[op.dotted_name] = param_name |
| 240 | + return related_quantized_tensors |
347 | 241 |
|
348 | | - return op_to_param_dict |
349 | 242 |
|
| 243 | +def _get_op_input_tensors(op: Op, pass_through_op_types: List[str]) -> List[Product]: |
| 244 | + """ Get all input tensors to `op`, traversing through ops of type `pass_through_op_types` """ |
| 245 | + inputs = [] |
| 246 | + for inp in op.inputs: |
| 247 | + # Pass through ops which don't have output quantizers if necessary |
| 248 | + while inp.producer and inp.producer.type in pass_through_op_types: |
| 249 | + inp = inp.producer.inputs[0] |
| 250 | + inputs.append(inp) |
350 | 251 |
|
351 | | -def _get_op_name_to_act_quantizer_name_dicts(sim: QuantizationSimModel) -> Dict: |
352 | | - """ |
353 | | - Creates the dict where param quantizers if enabled are mapped to their param_names and activation |
354 | | - quantizer if enabled is mapped to it's inputs name |
355 | | - :param sim: Quantization Sim |
356 | | - :return op_name_to_activation_quantizer_name_dict |
357 | | - """ |
358 | | - op_name_to_activation_quantizer_name_dict = {} |
359 | | - for node in sim.model.model.graph.node: |
360 | | - if 'QcQuantizeOp' in node.name: |
361 | | - continue |
362 | | - for output_product in node.output: |
363 | | - if output_product in sim.activation_names: |
364 | | - activation_quantizer_op = sim.qc_quantize_op_dict[output_product] |
365 | | - if activation_quantizer_op.enabled: |
366 | | - op_name_to_activation_quantizer_name_dict[node.name] = output_product |
367 | | - return op_name_to_activation_quantizer_name_dict |
| 252 | + return inputs |
| 253 | + |
| 254 | + |
| 255 | +def _get_tensor_consumers(product: Product, pass_through_op_types: List[str]): |
| 256 | + """ Get all consumers of `product`, traversing through ops of type `pass_through_op_types` """ |
| 257 | + consumers = set() |
| 258 | + for consumer in product.consumers: |
| 259 | + if consumer.type in pass_through_op_types: |
| 260 | + consumers |= {op for output in consumer.outputs for op in _get_tensor_consumers(output, pass_through_op_types)} |
| 261 | + else: |
| 262 | + consumers.add(consumer) |
| 263 | + |
| 264 | + return consumers |
368 | 265 |
|
369 | 266 |
|
370 | 267 | def find_supported_candidates(quantizer_groups: List[QuantizerGroup], |
|
0 commit comments