2323import gc
2424import logging
2525import re
26- from typing import Iterable , Optional
26+ from typing import Optional
2727
2828import torch
2929
@@ -327,126 +327,43 @@ def revert_local_name_to_global_patch():
327327 logger .info ("Reverted QAT patch: _megatron_local_name_to_global." )
328328
329329
330- def apply_build_conversion_tasks_patch ():
331- """Patch ``build_conversion_tasks`` to filter out ``None`` entries."""
332- import itertools
330+ def apply_skip_quantizer_params_patch ():
331+ """Extend ``_is_adapter_param_name`` to also skip ModelOpt quantizer parameters.
333332
334- import megatron .bridge .models .conversion .model_bridge as bridge_module
335- from megatron .bridge .models .conversion .model_bridge import (
336- MegatronModelBridge ,
337- WeightConversionTask ,
338- )
339- from megatron .bridge .models .conversion .utils import (
340- get_module_and_param_from_name ,
341- persistent_buffers ,
342- )
343- from megatron .bridge .utils .common_utils import print_rank_0
344- from megatron .core import parallel_state
345- from megatron .core .utils import unwrap_model
333+ After ``mtq.quantize()``, quantizer sub-modules (``weight_quantizer``,
334+ ``input_quantizer``) are registered in the model tree. Their internal
335+ parameters (e.g. ``_amax``) have no HF counterpart and must not enter
336+ the Bridge's conversion pipeline.
337+ """
338+ from megatron .bridge .models .conversion .model_bridge import MegatronModelBridge
346339
347- if getattr (MegatronModelBridge , "_build_tasks_patched " , False ):
340+ if getattr (MegatronModelBridge , "_quantizer_filter_patched " , False ):
348341 return
349- MegatronModelBridge ._build_tasks_patched = True
350- MegatronModelBridge ._original_build_conversion_tasks = MegatronModelBridge .build_conversion_tasks
351-
352- def _patched_build_conversion_tasks (self , hf_pretrained , megatron_model ):
353- if not (hasattr (hf_pretrained , "state" ) and hasattr (hf_pretrained .state , "source" )):
354- raise ValueError ("hf_pretrained.state.source is required for weight ordering" )
355-
356- hf_keys : Iterable [str ] = hf_pretrained .state .source .get_all_keys ()
357-
358- mapping_registry = self .mapping_registry ()
359- unwrapped_model = unwrap_model (megatron_model )[0 ]
360- model_config = unwrapped_model .config
361- embeddings_are_tied = self ._share_embeddings_and_output_weights (model_config , unwrapped_model )
362- pp_rank = parallel_state .get_pipeline_model_parallel_rank ()
363- sorted_global_param_names_all_pp_ranks = self ._megatron_global_param_names_all_pp_ranks (megatron_model )
364-
365- if embeddings_are_tied :
366- sorted_global_param_names_all_pp_ranks = [
367- name for name in sorted_global_param_names_all_pp_ranks if "output_layer" not in name
368- ]
342+ MegatronModelBridge ._quantizer_filter_patched = True
343+ MegatronModelBridge ._original_is_adapter_param_name = MegatronModelBridge ._is_adapter_param_name
369344
370- global_names_index_dict = { name : idx for idx , name in enumerate ( sorted_global_param_names_all_pp_ranks )}
345+ _orig = MegatronModelBridge . _is_adapter_param_name
371346
372- tasks = [None ] * len (sorted_global_param_names_all_pp_ranks )
373- for vp_stage , model in enumerate (megatron_model ):
374- for local_name , _ in itertools .chain (model .named_parameters (), persistent_buffers (model )):
375- if "_extra_state" in local_name or self ._is_adapter_param_name (local_name ):
376- continue
347+ def _patched_is_adapter_param_name (self , param_name : str ) -> bool :
348+ if _orig (self , param_name ):
349+ return True
350+ return "_quantizer" in param_name
377351
378- local_name = self ._unwrap_name (local_name )
379- global_name = bridge_module ._megatron_local_name_to_global (
380- megatron_model , model_config , local_name , vp_stage
381- )
382- if global_name not in global_names_index_dict :
383- print_rank_0 (f"WARNING: { global_name } not in global_names_index_dict" )
384- continue
385- global_name_idx = global_names_index_dict [global_name ]
386- mapping = mapping_registry .megatron_to_hf_lookup (self ._get_lora_unwrapped_name (global_name ))
387-
388- if not mapping :
389- logger .warning (f"WARNING: No mapping found for megatron_param: { global_name } " )
390- continue
391-
392- if not mapping .allow_hf_name_mismatch :
393- if isinstance (mapping .hf_param , str ):
394- if mapping .hf_param not in hf_keys :
395- logger .warning (f"WARNING: Can't find { mapping .hf_param } in hf_keys" )
396- continue
397- else :
398- missing_params = [hf_param for hf_param in mapping .hf_param .values () if hf_param not in hf_keys ]
399- if missing_params :
400- logger .warning (
401- f"WARNING: Can't find the following HF parameters in hf_keys: { missing_params } "
402- )
403- continue
404-
405- local_module , local_weights = get_module_and_param_from_name (megatron_model , local_name , vp_stage )
406- if local_module is not None and not hasattr (local_module , "config" ):
407- local_module .config = model_config
408-
409- tasks [global_name_idx ] = WeightConversionTask (
410- pp_rank = pp_rank ,
411- vp_stage = vp_stage ,
412- param_name = local_name ,
413- global_param_name = global_name ,
414- megatron_module = local_module ,
415- param_weight = local_weights ,
416- mapping = mapping ,
417- )
418-
419- for idx , global_name in enumerate (sorted_global_param_names_all_pp_ranks ):
420- if tasks [idx ] is None :
421- mapping = mapping_registry .megatron_to_hf_lookup (self ._get_lora_unwrapped_name (global_name ))
422- if mapping is None :
423- continue
424- tasks [idx ] = WeightConversionTask (
425- pp_rank = pp_rank ,
426- vp_stage = None ,
427- param_name = global_name ,
428- global_param_name = global_name ,
429- megatron_module = None ,
430- param_weight = None ,
431- mapping = mapping ,
432- )
433-
434- tasks = [task for task in tasks if task is not None ]
435- return tasks
436-
437- MegatronModelBridge .build_conversion_tasks = _patched_build_conversion_tasks
438- logger .info ("Applied QAT patch: MegatronModelBridge.build_conversion_tasks now filters out None entries." )
352+ MegatronModelBridge ._is_adapter_param_name = _patched_is_adapter_param_name
353+ logger .info (
354+ "Applied QAT patch: _is_adapter_param_name now also skips ModelOpt quantizer parameters (*_quantizer*)."
355+ )
439356
440357
441- def revert_build_conversion_tasks_patch ():
442- """Revert :func:`apply_build_conversion_tasks_patch `."""
358+ def revert_skip_quantizer_params_patch ():
359+ """Revert :func:`apply_skip_quantizer_params_patch `."""
443360 from megatron .bridge .models .conversion .model_bridge import MegatronModelBridge
444361
445- if not getattr (MegatronModelBridge , "_build_tasks_patched " , False ):
362+ if not getattr (MegatronModelBridge , "_quantizer_filter_patched " , False ):
446363 return
447- MegatronModelBridge .build_conversion_tasks = MegatronModelBridge ._original_build_conversion_tasks
448- MegatronModelBridge ._build_tasks_patched = False
449- logger .info ("Reverted QAT patch: MegatronModelBridge.build_conversion_tasks ." )
364+ MegatronModelBridge ._is_adapter_param_name = MegatronModelBridge ._original_is_adapter_param_name
365+ MegatronModelBridge ._quantizer_filter_patched = False
366+ logger .info ("Reverted QAT patch: _is_adapter_param_name (quantizer filter) ." )
450367
451368
452369def apply_detect_parallelism_type_patch ():
@@ -493,7 +410,7 @@ def apply_qat_patch():
493410 apply_ep_gather_patch ()
494411 apply_extract_sort_key_patch ()
495412 apply_local_name_to_global_patch ()
496- apply_build_conversion_tasks_patch ()
413+ apply_skip_quantizer_params_patch ()
497414 apply_detect_parallelism_type_patch ()
498415
499416
@@ -503,5 +420,5 @@ def revert_qat_patch():
503420 revert_ep_gather_patch ()
504421 revert_extract_sort_key_patch ()
505422 revert_local_name_to_global_patch ()
506- revert_build_conversion_tasks_patch ()
423+ revert_skip_quantizer_params_patch ()
507424 revert_detect_parallelism_type_patch ()
0 commit comments