Skip to content

Commit 407b8f0

Browse files
committed
Refactor QAT patch to skip quantizer params
1 parent a19d986 commit 407b8f0

File tree

2 files changed

+30
-113
lines changed

2 files changed

+30
-113
lines changed

verl/utils/modelopt/megatron_qat_patch.py

Lines changed: 29 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import gc
2424
import logging
2525
import re
26-
from typing import Iterable, Optional
26+
from typing import Optional
2727

2828
import torch
2929

@@ -327,126 +327,43 @@ def revert_local_name_to_global_patch():
327327
logger.info("Reverted QAT patch: _megatron_local_name_to_global.")
328328

329329

330-
def apply_build_conversion_tasks_patch():
331-
"""Patch ``build_conversion_tasks`` to filter out ``None`` entries."""
332-
import itertools
330+
def apply_skip_quantizer_params_patch():
331+
"""Extend ``_is_adapter_param_name`` to also skip ModelOpt quantizer parameters.
333332
334-
import megatron.bridge.models.conversion.model_bridge as bridge_module
335-
from megatron.bridge.models.conversion.model_bridge import (
336-
MegatronModelBridge,
337-
WeightConversionTask,
338-
)
339-
from megatron.bridge.models.conversion.utils import (
340-
get_module_and_param_from_name,
341-
persistent_buffers,
342-
)
343-
from megatron.bridge.utils.common_utils import print_rank_0
344-
from megatron.core import parallel_state
345-
from megatron.core.utils import unwrap_model
333+
After ``mtq.quantize()``, quantizer sub-modules (``weight_quantizer``,
334+
``input_quantizer``) are registered in the model tree. Their internal
335+
parameters (e.g. ``_amax``) have no HF counterpart and must not enter
336+
the Bridge's conversion pipeline.
337+
"""
338+
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
346339

347-
if getattr(MegatronModelBridge, "_build_tasks_patched", False):
340+
if getattr(MegatronModelBridge, "_quantizer_filter_patched", False):
348341
return
349-
MegatronModelBridge._build_tasks_patched = True
350-
MegatronModelBridge._original_build_conversion_tasks = MegatronModelBridge.build_conversion_tasks
351-
352-
def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model):
353-
if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")):
354-
raise ValueError("hf_pretrained.state.source is required for weight ordering")
355-
356-
hf_keys: Iterable[str] = hf_pretrained.state.source.get_all_keys()
357-
358-
mapping_registry = self.mapping_registry()
359-
unwrapped_model = unwrap_model(megatron_model)[0]
360-
model_config = unwrapped_model.config
361-
embeddings_are_tied = self._share_embeddings_and_output_weights(model_config, unwrapped_model)
362-
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
363-
sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks(megatron_model)
364-
365-
if embeddings_are_tied:
366-
sorted_global_param_names_all_pp_ranks = [
367-
name for name in sorted_global_param_names_all_pp_ranks if "output_layer" not in name
368-
]
342+
MegatronModelBridge._quantizer_filter_patched = True
343+
MegatronModelBridge._original_is_adapter_param_name = MegatronModelBridge._is_adapter_param_name
369344

370-
global_names_index_dict = {name: idx for idx, name in enumerate(sorted_global_param_names_all_pp_ranks)}
345+
_orig = MegatronModelBridge._is_adapter_param_name
371346

372-
tasks = [None] * len(sorted_global_param_names_all_pp_ranks)
373-
for vp_stage, model in enumerate(megatron_model):
374-
for local_name, _ in itertools.chain(model.named_parameters(), persistent_buffers(model)):
375-
if "_extra_state" in local_name or self._is_adapter_param_name(local_name):
376-
continue
347+
def _patched_is_adapter_param_name(self, param_name: str) -> bool:
348+
if _orig(self, param_name):
349+
return True
350+
return "_quantizer" in param_name
377351

378-
local_name = self._unwrap_name(local_name)
379-
global_name = bridge_module._megatron_local_name_to_global(
380-
megatron_model, model_config, local_name, vp_stage
381-
)
382-
if global_name not in global_names_index_dict:
383-
print_rank_0(f"WARNING: {global_name} not in global_names_index_dict")
384-
continue
385-
global_name_idx = global_names_index_dict[global_name]
386-
mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name))
387-
388-
if not mapping:
389-
logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}")
390-
continue
391-
392-
if not mapping.allow_hf_name_mismatch:
393-
if isinstance(mapping.hf_param, str):
394-
if mapping.hf_param not in hf_keys:
395-
logger.warning(f"WARNING: Can't find {mapping.hf_param} in hf_keys")
396-
continue
397-
else:
398-
missing_params = [hf_param for hf_param in mapping.hf_param.values() if hf_param not in hf_keys]
399-
if missing_params:
400-
logger.warning(
401-
f"WARNING: Can't find the following HF parameters in hf_keys: {missing_params}"
402-
)
403-
continue
404-
405-
local_module, local_weights = get_module_and_param_from_name(megatron_model, local_name, vp_stage)
406-
if local_module is not None and not hasattr(local_module, "config"):
407-
local_module.config = model_config
408-
409-
tasks[global_name_idx] = WeightConversionTask(
410-
pp_rank=pp_rank,
411-
vp_stage=vp_stage,
412-
param_name=local_name,
413-
global_param_name=global_name,
414-
megatron_module=local_module,
415-
param_weight=local_weights,
416-
mapping=mapping,
417-
)
418-
419-
for idx, global_name in enumerate(sorted_global_param_names_all_pp_ranks):
420-
if tasks[idx] is None:
421-
mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name))
422-
if mapping is None:
423-
continue
424-
tasks[idx] = WeightConversionTask(
425-
pp_rank=pp_rank,
426-
vp_stage=None,
427-
param_name=global_name,
428-
global_param_name=global_name,
429-
megatron_module=None,
430-
param_weight=None,
431-
mapping=mapping,
432-
)
433-
434-
tasks = [task for task in tasks if task is not None]
435-
return tasks
436-
437-
MegatronModelBridge.build_conversion_tasks = _patched_build_conversion_tasks
438-
logger.info("Applied QAT patch: MegatronModelBridge.build_conversion_tasks now filters out None entries.")
352+
MegatronModelBridge._is_adapter_param_name = _patched_is_adapter_param_name
353+
logger.info(
354+
"Applied QAT patch: _is_adapter_param_name now also skips ModelOpt quantizer parameters (*_quantizer*)."
355+
)
439356

440357

441-
def revert_build_conversion_tasks_patch():
442-
"""Revert :func:`apply_build_conversion_tasks_patch`."""
358+
def revert_skip_quantizer_params_patch():
359+
"""Revert :func:`apply_skip_quantizer_params_patch`."""
443360
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
444361

445-
if not getattr(MegatronModelBridge, "_build_tasks_patched", False):
362+
if not getattr(MegatronModelBridge, "_quantizer_filter_patched", False):
446363
return
447-
MegatronModelBridge.build_conversion_tasks = MegatronModelBridge._original_build_conversion_tasks
448-
MegatronModelBridge._build_tasks_patched = False
449-
logger.info("Reverted QAT patch: MegatronModelBridge.build_conversion_tasks.")
364+
MegatronModelBridge._is_adapter_param_name = MegatronModelBridge._original_is_adapter_param_name
365+
MegatronModelBridge._quantizer_filter_patched = False
366+
logger.info("Reverted QAT patch: _is_adapter_param_name (quantizer filter).")
450367

451368

452369
def apply_detect_parallelism_type_patch():
@@ -493,7 +410,7 @@ def apply_qat_patch():
493410
apply_ep_gather_patch()
494411
apply_extract_sort_key_patch()
495412
apply_local_name_to_global_patch()
496-
apply_build_conversion_tasks_patch()
413+
apply_skip_quantizer_params_patch()
497414
apply_detect_parallelism_type_patch()
498415

499416

@@ -503,5 +420,5 @@ def revert_qat_patch():
503420
revert_ep_gather_patch()
504421
revert_extract_sort_key_patch()
505422
revert_local_name_to_global_patch()
506-
revert_build_conversion_tasks_patch()
423+
revert_skip_quantizer_params_patch()
507424
revert_detect_parallelism_type_patch()

0 commit comments

Comments
 (0)