diff --git a/recipe b/recipe index 3490a22a0a3..066d77a6333 160000 --- a/recipe +++ b/recipe @@ -1 +1 @@ -Subproject commit 3490a22a0a3adeb7e4787fe70b1060b642efbae4 +Subproject commit 066d77a6333c42f24df5e76d31bfeeda43795af4 diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index ea60c881619..05dfd008f97 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -139,6 +139,16 @@ actor_rollout_ref: mode: disabled record_file: null replay_file: null + qat: + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null load_weight: true ref: rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} diff --git a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml index b9a8b3aaf84..2aab1fc6a36 100644 --- a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml @@ -121,6 +121,16 @@ actor_rollout_ref: mode: disabled record_file: null replay_file: null + qat: + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null ref: optim: _target_: verl.workers.config.TorchtitanOptimizerConfig diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 6b97103ae9f..4787cf25eed 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -120,13 +120,6 @@ actor_rollout_ref: mode: disabled record_file: null replay_file: null - grad_clip: 1.0 - ulysses_sequence_parallel_size: 1 - entropy_from_logits_with_chunking: false - entropy_checkpointing: false - use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} - calculate_sum_pi_squared: false - sum_pi_squared_checkpointing: false qat: enable: false mode: w4a16 @@ -137,6 +130,13 @@ actor_rollout_ref: - re:.*mlp.gate$ activation_observer: static_minmax quantization_config_path: null + grad_clip: 1.0 + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} + calculate_sum_pi_squared: false + sum_pi_squared_checkpointing: false ref: rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} strategy: ${actor_rollout_ref.actor.strategy} diff --git a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml index 4528e0d667d..956c725f433 100644 --- a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml @@ -120,6 +120,16 @@ actor_rollout_ref: mode: disabled record_file: null replay_file: null + qat: + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null ref: rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} strategy: veomni diff --git a/verl/trainer/config/actor/actor.yaml b/verl/trainer/config/actor/actor.yaml index bffe8aec484..07cad10391b 100644 --- a/verl/trainer/config/actor/actor.yaml +++ b/verl/trainer/config/actor/actor.yaml @@ -259,3 +259,35 @@ router_replay: # Required when mode is 'replay' replay_file: null +# QAT (Quantization-Aware Training) configuration +# When enabled: +# - QAT is automatically applied to actor model during training +# - Fused scales (QKV/GateUp) are automatically enabled for training-inference consistency +# - Fast quantization is used when syncing weights to vLLM rollout +# Supported modes: "w4a16" (NVFP4 weight-only) +# Note: "w4a4" mode is included in the code but currently has KL divergence issues and is NOT recommended for use. +# For usage examples, see: https://github.com/verl-project/verl-recipe/blob/main/qat/README.md +qat: + + # Whether to enable QAT + enable: false + + # Quantization mode: "w4a16" (weight-only). "w4a4" is experimental and not recommended. + mode: "w4a16" + + # Quantization group size (NVFP4 requires 16) + group_size: 16 + + # Patterns to ignore (e.g., lm_head, embed_tokens) + ignore_patterns: + + - "lm_head" + - "embed_tokens" + - "re:.*mlp.gate$" + + # Activation observer for W4A4 mode: "static_minmax", "memoryless_minmax", or "minmax" + activation_observer: "static_minmax" + + # Path to vLLM quantization config JSON file + quantization_config_path: null + diff --git a/verl/trainer/config/actor/dp_actor.yaml b/verl/trainer/config/actor/dp_actor.yaml index 7fbe49c019e..fc0a16be609 100644 --- a/verl/trainer/config/actor/dp_actor.yaml +++ b/verl/trainer/config/actor/dp_actor.yaml @@ -48,35 +48,3 @@ calculate_sum_pi_squared: False # Enable gradient checkpointing for sum_pi_squared computation (saves memory) sum_pi_squared_checkpointing: False - -# QAT (Quantization-Aware Training) configuration -# When enabled: -# - QAT is automatically applied to actor model during training -# - Fused scales (QKV/GateUp) are automatically enabled for training-inference consistency -# - Fast quantization is used when syncing weights to vLLM rollout -# Supported modes: "w4a16" (NVFP4 weight-only) -# Note: "w4a4" mode is included in the code but currently has KL divergence issues and is NOT recommended for use. -# For usage examples, see: https://github.com/verl-project/verl-recipe/blob/main/qat/README.md -qat: - - # Whether to enable QAT - enable: false - - # Quantization mode: "w4a16" (weight-only). "w4a4" is experimental and not recommended. - mode: "w4a16" - - # Quantization group size (NVFP4 requires 16) - group_size: 16 - - # Patterns to ignore (e.g., lm_head, embed_tokens) - ignore_patterns: - - - "lm_head" - - "embed_tokens" - - "re:.*mlp.gate$" - - # Activation observer for W4A4 mode: "static_minmax", "memoryless_minmax", or "minmax" - activation_observer: "static_minmax" - - # Path to vLLM quantization config JSON file - quantization_config_path: null diff --git a/verl/utils/modelopt/__init__.py b/verl/utils/modelopt/__init__.py new file mode 100644 index 00000000000..30ddbd1a860 --- /dev/null +++ b/verl/utils/modelopt/__init__.py @@ -0,0 +1,42 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ModelOpt integration for NVFP4 quantization with Megatron QAT training and vLLM inference.""" + +from verl.utils.modelopt.megatron_qat_patch import ( + apply_qat_patch, + revert_qat_patch, +) +from verl.utils.modelopt.qat_weight_exporter import QATWeightExporter +from verl.utils.modelopt.quantize import ( + apply_qat, + build_quantize_config, +) +from verl.utils.modelopt.vllm_modelopt_patch import ( + apply_modelopt_nvfp4_patches, + modelopt_process_weights_after_loading, + prepare_modelopt_for_weight_reload, +) + +__all__ = [ + "build_quantize_config", + "apply_qat", + "QATWeightExporter", + "apply_modelopt_nvfp4_patches", + "prepare_modelopt_for_weight_reload", + "modelopt_process_weights_after_loading", + "apply_qat_patch", + "revert_qat_patch", +] diff --git a/verl/utils/modelopt/megatron_qat_patch.py b/verl/utils/modelopt/megatron_qat_patch.py new file mode 100644 index 00000000000..5fdf1d727c1 --- /dev/null +++ b/verl/utils/modelopt/megatron_qat_patch.py @@ -0,0 +1,424 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron-Core / megatron-bridge monkey patches for QAT workflows. + +Patches SwiGLU sharded state-dict, EP gather, extract_sort_key, local-to-global +name mapping, build_conversion_tasks, and parallelism type detection to support +SequentialMLP and quantised wrappers. +""" + +import gc +import logging +import re +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + + +def apply_swiglu_sharded_factory_patch(): + """Patch ``apply_swiglu_sharded_factory`` to support ``singleton_local_shards``.""" + import megatron.core.transformer.mlp as mlp_module + from megatron.core.dist_checkpointing import ShardedTensor + from megatron.core.dist_checkpointing.mapping import ( + ReplicaId, + ShardedTensorFactory, + ) + + if getattr(mlp_module, "_swiglu_patched", False): + return + mlp_module._swiglu_patched = True + mlp_module._original_apply_swiglu_sharded_factory = mlp_module.apply_swiglu_sharded_factory + + def patched_apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets, singleton_local_shards: bool = False): + swiglu_shard_axis = 0 + prepend_axis_num = len(sharded_offsets) + original_shape = original_sh_ten.local_shape + local_axis_size = original_shape[swiglu_shard_axis] + assert original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] % local_axis_size == 0 + rank_offset = original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] // local_axis_size + axis_frag = original_sh_ten.axis_fragmentations[swiglu_shard_axis + prepend_axis_num] + + @torch.no_grad() + def sh_ten_build_fn( + key: str, + t: torch.Tensor, + replica_id: ReplicaId, + flattened_range: Optional[slice], + ): + if singleton_local_shards: + offset_w = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag) + offset_v = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag) + w_key = f"{key}_w" + v_key = f"{key}_v" + else: + offset_w = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag * 2) + offset_v = ( + swiglu_shard_axis + prepend_axis_num, + rank_offset + axis_frag, + axis_frag * 2, + ) + w_key = key + v_key = key + + tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis) + return [ + ShardedTensor.from_rank_offsets( + w_key, + tensor_w, + *sharded_offsets, + offset_w, + replica_id=replica_id, + prepend_axis_num=prepend_axis_num, + ), + ShardedTensor.from_rank_offsets( + v_key, + tensor_v, + *sharded_offsets, + offset_v, + replica_id=replica_id, + prepend_axis_num=prepend_axis_num, + ), + ] + + def sh_ten_merge_fn(sub_state_dict): + with torch.no_grad(): + try: + return torch.cat(sub_state_dict) + except (RuntimeError, torch.cuda.OutOfMemoryError) as e: + logger.warning( + "CUDA OOM during tensor merge – falling back to CPU. (Error: %s)", + e, + ) + merged = torch.cat([t.cpu() for t in sub_state_dict]) + gc.collect() + torch.cuda.empty_cache() + return merged + + return ShardedTensorFactory( + original_sh_ten.key, + original_sh_ten.data, + sh_ten_build_fn, + sh_ten_merge_fn, + original_sh_ten.replica_id, + flattened_range=original_sh_ten.flattened_range, + ) + + mlp_module.apply_swiglu_sharded_factory = patched_apply_swiglu_sharded_factory + logger.info("Applied QAT patch: apply_swiglu_sharded_factory now supports singleton_local_shards.") + + +def revert_swiglu_sharded_factory_patch(): + """Revert :func:`apply_swiglu_sharded_factory_patch`.""" + import megatron.core.transformer.mlp as mlp_module + + if not getattr(mlp_module, "_swiglu_patched", False): + return + mlp_module.apply_swiglu_sharded_factory = mlp_module._original_apply_swiglu_sharded_factory + mlp_module._swiglu_patched = False + logger.info("Reverted QAT patch: apply_swiglu_sharded_factory.") + + +def apply_ep_gather_patch(): + """Patch ``gather_from_ep_ranks`` to support SequentialMLP and TEGroupedMLP naming.""" + from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping + + if getattr(MegatronParamMapping, "_ep_gather_patched", False): + return + MegatronParamMapping._ep_gather_patched = True + MegatronParamMapping._original_gather_from_ep_ranks = MegatronParamMapping.gather_from_ep_ranks + + def _patched_gather_from_ep_ranks( + self, + megatron_weights: Optional[torch.Tensor], + megatron_module, # Optional[MegatronModule] + hf_param_name: Optional[str], + ) -> dict[str, torch.Tensor]: + if megatron_module is None: + num_experts_per_rank = self.broadcast_obj_from_pp_rank(None, "num_experts_per_rank") + else: + model_config = self._get_config(megatron_module) + num_experts = model_config.num_moe_experts + num_experts_per_rank = num_experts // self.ep_size + num_experts_per_rank = self.broadcast_obj_from_pp_rank(num_experts_per_rank, "num_experts_per_rank") + + local_expert_number = None + + # SequentialMLP pattern: local_experts. + local_experts_match = re.search(r"local_experts\.(\d+)", self.megatron_param) + if local_experts_match: + global_expert_number = int(local_experts_match.group(1)) + local_expert_number = global_expert_number % num_experts_per_rank + else: + # TEGroupedMLP pattern: weight or bias + for key in (".weight", ".bias"): + if key in self.megatron_param: + suffix = self.megatron_param.split(key)[-1] + if suffix: # only if there is actually a number after the suffix + global_expert_number = int(suffix) + local_expert_number = global_expert_number % num_experts_per_rank + break + + if local_expert_number is None: + raise ValueError( + f"Cannot extract expert number from: {self.megatron_param}. " + f"Expected TEGroupedMLP (weight/bias) or SequentialMLP (local_experts.)." + ) + + gathered_expert_param_names = [ + re.sub( + r"experts\.(\d+)", + f"experts.{int(local_expert_number) + num_experts_per_rank * i}", + str(hf_param_name), + ) + for i in range(self.ep_size) + ] + assert str(hf_param_name) in gathered_expert_param_names, ( + f"hf_param_name {hf_param_name} not in {gathered_expert_param_names}" + ) + + gathered_weights = [torch.empty_like(megatron_weights) for _ in range(self.ep_size)] + torch.distributed.all_gather(gathered_weights, megatron_weights, group=self.ep_group) + + weights_dict: dict[str, torch.Tensor] = {} + for i, param_name in enumerate(gathered_expert_param_names): + if param_name in weights_dict: + weights_dict[param_name] = torch.cat( + [weights_dict[param_name], gathered_weights[i].unsqueeze(0)], dim=0 + ) + else: + weights_dict[param_name] = gathered_weights[i].unsqueeze(0) + for param_name in weights_dict: + weights_dict[param_name] = weights_dict[param_name].squeeze() + + return weights_dict + + MegatronParamMapping.gather_from_ep_ranks = _patched_gather_from_ep_ranks + logger.info("Applied QAT patch: MegatronParamMapping.gather_from_ep_ranks now supports SequentialMLP pattern.") + + +def revert_ep_gather_patch(): + """Revert :func:`apply_ep_gather_patch`.""" + from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping + + if not getattr(MegatronParamMapping, "_ep_gather_patched", False): + return + MegatronParamMapping.gather_from_ep_ranks = MegatronParamMapping._original_gather_from_ep_ranks + MegatronParamMapping._ep_gather_patched = False + logger.info("Reverted QAT patch: MegatronParamMapping.gather_from_ep_ranks.") + + +def apply_extract_sort_key_patch(): + """Patch ``extract_sort_key`` to support SequentialMLP naming pattern.""" + import megatron.bridge.models.conversion.model_bridge as bridge_module + import megatron.bridge.models.conversion.utils as utils_module + + if getattr(utils_module, "_sort_key_patched", False): + return + utils_module._sort_key_patched = True + bridge_module._sort_key_patched = True + utils_module._original_extract_sort_key = utils_module.extract_sort_key + bridge_module._original_extract_sort_key = bridge_module.extract_sort_key + + def _patched_extract_sort_key(param_name: str): + numbers = [] + layer_match = re.search(r"layers\.(\d+)", param_name) + if layer_match: + numbers.append(int(layer_match.group(1))) + + expert_number = None + + # TEGroupedMLP: weight, bias + expert_match = re.search(r"(?:bias|weight)(\d+)", param_name) + if expert_match: + expert_number = int(expert_match.group(1)) + + # SequentialMLP: local_experts. + if expert_number is None: + local_experts_match = re.search(r"local_experts\.(\d+)", param_name) + if local_experts_match: + expert_number = int(local_experts_match.group(1)) + + if expert_number is not None: + numbers.append(expert_number) + + while len(numbers) < 2: + numbers.append(-1) + numbers = numbers[:2] + return numbers, param_name + + utils_module.extract_sort_key = _patched_extract_sort_key + bridge_module.extract_sort_key = _patched_extract_sort_key + logger.info("Applied QAT patch: extract_sort_key now supports SequentialMLP pattern.") + + +def revert_extract_sort_key_patch(): + """Revert :func:`apply_extract_sort_key_patch`.""" + import megatron.bridge.models.conversion.model_bridge as bridge_module + import megatron.bridge.models.conversion.utils as utils_module + + if not getattr(utils_module, "_sort_key_patched", False): + return + utils_module.extract_sort_key = utils_module._original_extract_sort_key + bridge_module.extract_sort_key = bridge_module._original_extract_sort_key + utils_module._sort_key_patched = False + bridge_module._sort_key_patched = False + logger.info("Reverted QAT patch: extract_sort_key.") + + +def apply_local_name_to_global_patch(): + """Patch ``_megatron_local_name_to_global`` to support SequentialMLP + local-to-global expert number conversion under EP.""" + import megatron.bridge.models.conversion.model_bridge as bridge_module + from megatron.core import parallel_state + from megatron.core.utils import get_pg_size + + if getattr(bridge_module, "_local_name_to_global_patched", False): + return + bridge_module._local_name_to_global_patched = True + bridge_module._original_megatron_local_name_to_global = bridge_module._megatron_local_name_to_global + + _orig_fn = bridge_module._megatron_local_name_to_global + + def _patched_megatron_local_name_to_global(models, config, param_name, vp_stage=None): + param_name = _orig_fn(models, config, param_name, vp_stage) + + ep_group = parallel_state.get_expert_model_parallel_group() + if ".mlp.experts.local_experts." in param_name and get_pg_size(ep_group) > 1 and ".adapter." not in param_name: + num_experts = config.num_moe_experts + num_experts_per_rank = num_experts // ep_group.size() + local_experts_match = re.search(r"\.local_experts\.(\d+)\.", param_name) + if local_experts_match: + local_expert_number = int(local_experts_match.group(1)) + global_expert_number = num_experts_per_rank * ep_group.rank() + local_expert_number + param_name = param_name.replace( + f".local_experts.{local_expert_number}.", + f".local_experts.{global_expert_number}.", + ) + + return param_name + + bridge_module._megatron_local_name_to_global = _patched_megatron_local_name_to_global + logger.info("Applied QAT patch: _megatron_local_name_to_global now supports SequentialMLP pattern.") + + +def revert_local_name_to_global_patch(): + """Revert :func:`apply_local_name_to_global_patch`.""" + import megatron.bridge.models.conversion.model_bridge as bridge_module + + if not getattr(bridge_module, "_local_name_to_global_patched", False): + return + bridge_module._megatron_local_name_to_global = bridge_module._original_megatron_local_name_to_global + bridge_module._local_name_to_global_patched = False + logger.info("Reverted QAT patch: _megatron_local_name_to_global.") + + +def apply_skip_quantizer_params_patch(): + """Extend ``_is_adapter_param_name`` to also skip ModelOpt quantizer parameters. + + After ``mtq.quantize()``, quantizer sub-modules (``weight_quantizer``, + ``input_quantizer``) are registered in the model tree. Their internal + parameters (e.g. ``_amax``) have no HF counterpart and must not enter + the Bridge's conversion pipeline. + """ + from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge + + if getattr(MegatronModelBridge, "_quantizer_filter_patched", False): + return + MegatronModelBridge._quantizer_filter_patched = True + MegatronModelBridge._original_is_adapter_param_name = MegatronModelBridge._is_adapter_param_name + + _orig = MegatronModelBridge._is_adapter_param_name + + def _patched_is_adapter_param_name(self, param_name: str) -> bool: + if _orig(self, param_name): + return True + return "_quantizer" in param_name + + MegatronModelBridge._is_adapter_param_name = _patched_is_adapter_param_name + logger.info( + "Applied QAT patch: _is_adapter_param_name now also skips ModelOpt quantizer parameters (*_quantizer*)." + ) + + +def revert_skip_quantizer_params_patch(): + """Revert :func:`apply_skip_quantizer_params_patch`.""" + from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge + + if not getattr(MegatronModelBridge, "_quantizer_filter_patched", False): + return + MegatronModelBridge._is_adapter_param_name = MegatronModelBridge._original_is_adapter_param_name + MegatronModelBridge._quantizer_filter_patched = False + logger.info("Reverted QAT patch: _is_adapter_param_name (quantizer filter).") + + +def apply_detect_parallelism_type_patch(): + """Patch ``_detect_parallelism_type`` to recognise quantised + ``LayerNormColumnParallelLinear`` variants via substring matching.""" + from megatron.bridge.models.conversion.param_mapping import AutoMapping + + if getattr(AutoMapping, "_detect_parallelism_patched", False): + return + AutoMapping._detect_parallelism_patched = True + AutoMapping._original_detect_parallelism_type = AutoMapping._detect_parallelism_type + + def _patched_detect_parallelism_type(self, module): + module_type = type(module).__name__ + if "LayerNormColumnParallelLinear" in module_type: + if self.megatron_param and ( + self.megatron_param.endswith("layer_norm_weight") or self.megatron_param.endswith("layer_norm_bias") + ): + return "replicated" + return "column" + return AutoMapping._original_detect_parallelism_type(self, module) + + AutoMapping._detect_parallelism_type = _patched_detect_parallelism_type + logger.info( + "Applied QAT patch: AutoMapping._detect_parallelism_type " + "now supports quantised LayerNormColumnParallelLinear variants." + ) + + +def revert_detect_parallelism_type_patch(): + """Revert :func:`apply_detect_parallelism_type_patch`.""" + from megatron.bridge.models.conversion.param_mapping import AutoMapping + + if not getattr(AutoMapping, "_detect_parallelism_patched", False): + return + AutoMapping._detect_parallelism_type = AutoMapping._original_detect_parallelism_type + AutoMapping._detect_parallelism_patched = False + logger.info("Reverted QAT patch: AutoMapping._detect_parallelism_type.") + + +def apply_qat_patch(): + """Apply all QAT-related patches.""" + apply_swiglu_sharded_factory_patch() + apply_ep_gather_patch() + apply_extract_sort_key_patch() + apply_local_name_to_global_patch() + apply_skip_quantizer_params_patch() + apply_detect_parallelism_type_patch() + + +def revert_qat_patch(): + """Revert all QAT-related patches.""" + revert_swiglu_sharded_factory_patch() + revert_ep_gather_patch() + revert_extract_sort_key_patch() + revert_local_name_to_global_patch() + revert_skip_quantizer_params_patch() + revert_detect_parallelism_type_patch() diff --git a/verl/utils/modelopt/qat_weight_exporter.py b/verl/utils/modelopt/qat_weight_exporter.py new file mode 100644 index 00000000000..8b4e3787ef8 --- /dev/null +++ b/verl/utils/modelopt/qat_weight_exporter.py @@ -0,0 +1,365 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from dataclasses import dataclass +from typing import Any, Iterator, Optional + +import torch +from modelopt.torch.export.quant_utils import ( + QUANTIZATION_NONE, + QUANTIZATION_NVFP4, + get_quantization_format, + get_weight_block_size, + to_quantized_weight, +) +from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor + +logger = logging.getLogger(__name__) + +# NVFP4 two-level scaling denominator: FP4_MAX (6.0) * FP8_MAX (448.0). +_NVFP4_AMAX_DENOMINATOR = 6.0 * 448.0 + + +@dataclass +class _QuantMeta: + """Quantization metadata for a single parameter.""" + + qformat: str + block_size: int + weight_amax: Optional[torch.Tensor] + input_amax: Optional[torch.Tensor] = None + input_quantizer: Any = None + + +class QATWeightExporter: + """Export QAT-trained bf16 weights as quantized weights (e.g. NVFP4).""" + + def __init__( + self, + actor_module: list, + qat_mode: str = "w4a16", + bridge: Any = None, + ): + self.qat_mode = qat_mode + self._actor_module = actor_module + + self._registry = self._get_mapping_registry(bridge) + if self._registry is None: + raise ValueError( + "QATWeightExporter requires a bridge with a valid MappingRegistry. " + "Ensure use_mbridge=True and vanilla_mbridge=False." + ) + + self._pp_size, self._pp_rank, self._pp_group = _get_parallel_info("pp") + self._ep_size, self._ep_rank, self._ep_group = _get_parallel_info("ep") + + self._config = self._get_model_config(actor_module) + self._num_local_experts = self._count_local_experts(actor_module) + + self._metadata: dict[str, _QuantMeta] = {} + self._collect_metadata(actor_module) + + if self._pp_size > 1 and self._pp_group is not None: + self._sync_metadata(self._pp_group) + if self._ep_size > 1 and self._ep_group is not None: + self._sync_metadata(self._ep_group) + + self._log_init_summary() + + def process_weights_iterator( + self, + per_tensor_param: Iterator[tuple[str, torch.Tensor]], + ) -> Iterator[tuple[str, torch.Tensor]]: + """Wrap a weight iterator to apply quantization. + + For each ``(hf_name, bf16_weight)`` from the iterator, yields the + quantized weight plus its scaling factors when the parameter is + quantized, or the original tensor unchanged otherwise. + """ + for hf_name, weight in per_tensor_param: + meta = self._resolve_quant_metadata(hf_name) + if meta is None: + yield (hf_name, weight) + else: + yield from self._quantize_weight(hf_name, weight, meta) + + @staticmethod + def _get_mapping_registry(bridge) -> Any: + """Extract the ``MappingRegistry`` from *bridge*, or return ``None``.""" + if bridge is None: + return None + try: + return bridge._model_bridge.mapping_registry() + except Exception as exc: + logger.warning("Failed to get mapping registry from bridge: %s", exc) + return None + + @staticmethod + def _get_model_config(actor_module): + """Return the ``TransformerConfig`` from the first model chunk.""" + try: + from verl.utils.megatron_utils import unwrap_model + + model = unwrap_model(actor_module[0]) + return getattr(model, "config", None) + except Exception: + return None + + @staticmethod + def _count_local_experts(actor_module) -> int: + """Count distinct ``local_experts.`` indices across all model chunks.""" + from verl.utils.megatron_utils import unwrap_model + + indices: set[int] = set() + for module in actor_module: + model = unwrap_model(module) + for name, _ in model.named_modules(): + m = re.search(r"local_experts\.(\d+)", name) + if m: + indices.add(int(m.group(1))) + return max(indices) + 1 if indices else 0 + + def _collect_metadata(self, actor_module: list) -> None: + """Walk all QAT modules and populate ``self._metadata``.""" + from verl.utils.megatron_utils import unwrap_model + + for vpp_idx, module in enumerate(actor_module): + model = unwrap_model(module) + for name, submodule in model.named_modules(): + qformat = get_quantization_format(submodule) + if qformat == QUANTIZATION_NONE: + continue + block_size = get_weight_block_size(submodule) + if block_size == 0: + continue + + w_q = getattr(submodule, "weight_quantizer", None) + i_q = getattr(submodule, "input_quantizer", None) + w_amax = w_q._amax.clone().cpu() if w_q and getattr(w_q, "_amax", None) is not None else None + i_amax = i_q._amax.clone().cpu() if i_q and getattr(i_q, "_amax", None) is not None else None + + meta = _QuantMeta( + qformat=qformat, + block_size=block_size, + weight_amax=w_amax, + input_amax=i_amax, + input_quantizer=i_q, + ) + + for pname, _ in submodule.named_parameters(recurse=False): + full_name = f"{name}.{pname}" if name else pname + global_name = self._local_to_global_param_name(full_name, vpp_idx) + self._metadata[global_name] = meta + + def _local_to_global_param_name(self, name: str, vpp_idx: int) -> str: + """Convert a local parameter name to global (PP layers + EP experts).""" + if self._pp_size > 1 and "layers." in name and self._config is not None: + from megatron.bridge.models.conversion.model_bridge import ( + _megatron_local_name_to_global, + ) + + name = _megatron_local_name_to_global(self._actor_module, self._config, name, vpp_idx) + + # SequentialMLP ``local_experts.{idx}`` needs manual global conversion; + # TEGroupedMLP is already handled by ``_megatron_local_name_to_global``. + if self._ep_size > 1 and self._num_local_experts > 0: + m = re.search(r"local_experts\.(\d+)\.", name) + if m: + local_idx = int(m.group(1)) + global_idx = self._ep_rank * self._num_local_experts + local_idx + name = name.replace( + f"local_experts.{local_idx}.", + f"local_experts.{global_idx}.", + 1, + ) + + return name + + def _sync_metadata(self, group) -> None: + """Gather and merge metadata across the given process group.""" + world_size = torch.distributed.get_world_size(group=group) + + local_info = { + name: { + "qformat": m.qformat, + "block_size": m.block_size, + "weight_amax": m.weight_amax, + "input_amax": m.input_amax, + } + for name, m in self._metadata.items() + } + + gathered: list[dict | None] = [None] * world_size + torch.distributed.all_gather_object(gathered, local_info, group=group) + + for rank_info in gathered: + if rank_info is None: + continue + for name, info in rank_info.items(): + if name in self._metadata: + continue + self._metadata[name] = _QuantMeta( + qformat=info["qformat"], + block_size=info["block_size"], + weight_amax=info["weight_amax"], + input_amax=info["input_amax"], + input_quantizer=None, + ) + + def _resolve_quant_metadata(self, hf_name: str) -> Optional[_QuantMeta]: + """Resolve *hf_name* -> Megatron param name -> quantisation metadata. + + Returns ``None`` for parameters that are not quantised (norms, + embeddings, MoE routers, etc.). + """ + if not hf_name.endswith(".weight") or "norm" in hf_name: + return None + + for resolved in _iter_hf_to_megatron_matches(self._registry, hf_name): + meta = self._metadata.get(resolved.megatron_param) + if meta is not None: + return meta + + return None + + def _quantize_weight( + self, + name: str, + weight: torch.Tensor, + meta: _QuantMeta, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Dispatch to the format-specific quantiser.""" + if meta.qformat == QUANTIZATION_NVFP4: + yield from self._quantize_nvfp4(name, weight, meta) + else: + logger.warning("Unsupported qformat %s for %s; passing through", meta.qformat, name) + yield (name, weight) + + def _quantize_nvfp4( + self, + name: str, + weight: torch.Tensor, + meta: _QuantMeta, + ) -> Iterator[tuple[str, torch.Tensor]]: + """NVFP4 two-level quantization. + + Produces up to four tensors: + ``(name, packed_uint8_weight)`` + ``(weight_scale, per_block_fp8_scale)`` + ``(weight_scale_2, global_scale_from_amax)`` + ``(input_scale, activation_scale)`` -- only when available + """ + w_amax = meta.weight_amax.to(weight.device) + w_scale_2 = w_amax.float() / _NVFP4_AMAX_DENOMINATOR + + w_scale = NVFP4QTensor.get_weights_scaling_factor( + weight, + meta.block_size, + weights_scaling_factor_2=w_scale_2.to(weight.device), + )[0] + + quantized = to_quantized_weight(weight, w_scale, meta.qformat, w_scale_2, meta.block_size) + + yield (name, quantized) + yield (_derive_scale_name(name, "weight_scale"), w_scale) + yield (_derive_scale_name(name, "weight_scale_2"), w_scale_2) + + input_scale = _compute_input_scale(meta) + if input_scale is not None: + yield (_derive_scale_name(name, "input_scale"), input_scale) + + def _log_init_summary(self) -> None: + """Log a one-line initialisation summary.""" + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + logger.info( + "[QAT Exporter][Rank %d] mode=%s, metadata_count=%d, pp=%d/%d, ep=%d/%d", + rank, + self.qat_mode, + len(self._metadata), + self._pp_rank, + self._pp_size, + self._ep_rank, + self._ep_size, + ) + + +def _iter_hf_to_megatron_matches(registry, hf_name: str): + """Yield all resolved mappings whose HF pattern matches *hf_name*.""" + for pattern_info, mapping in registry._reverse_patterns: + if isinstance(mapping.hf_param, str): + pattern = pattern_info + if pattern is None: + if mapping.hf_param == hf_name: + yield mapping + else: + match = pattern.match(hf_name) + if match: + yield mapping.resolve(match.groups()) + else: + patterns_dict = pattern_info + for key, pattern in patterns_dict.items(): + if pattern is None: + if mapping.hf_param[key] == hf_name: + yield mapping.resolve(()) + else: + match = pattern.match(hf_name) + if match: + yield mapping.resolve(match.groups()) + + +def _get_parallel_info(kind: str) -> tuple[int, int, Any]: + """Return ``(world_size, rank, process_group)`` for *kind* in {pp, ep}.""" + try: + from megatron.core import parallel_state as mpu + + if kind == "pp": + size = mpu.get_pipeline_model_parallel_world_size() + rank = mpu.get_pipeline_model_parallel_rank() + group = mpu.get_pipeline_model_parallel_group() if size > 1 else None + elif kind == "ep": + size = mpu.get_expert_model_parallel_world_size() + rank = mpu.get_expert_model_parallel_rank() if size > 1 else 0 + group = mpu.get_expert_model_parallel_group() if size > 1 else None + else: + return 1, 0, None + return size, rank, group + except Exception: + return 1, 0, None + + +def _derive_scale_name(weight_name: str, suffix: str) -> str: + """Derive a scale parameter name from a weight parameter name. + + ``"model.layers.0.self_attn.q_proj.weight"`` + -> ``"model.layers.0.self_attn.q_proj.weight_scale"`` + """ + result = weight_name.replace(".weight", f".{suffix}") + return result if result != weight_name else f"{weight_name}_{suffix}" + + +def _compute_input_scale(meta: _QuantMeta) -> Optional[torch.Tensor]: + """Derive the activation scale from the quantizer or synced amax.""" + if meta.input_quantizer is not None: + if hasattr(NVFP4QTensor, "get_activation_scaling_factor"): + return NVFP4QTensor.get_activation_scaling_factor(meta.input_quantizer) + if hasattr(meta.input_quantizer, "_amax") and meta.input_quantizer._amax is not None: + return meta.input_quantizer._amax.float() / _NVFP4_AMAX_DENOMINATOR + + if meta.input_amax is not None: + return meta.input_amax.float() / _NVFP4_AMAX_DENOMINATOR + + return None diff --git a/verl/utils/modelopt/quantize.py b/verl/utils/modelopt/quantize.py new file mode 100644 index 00000000000..20307259249 --- /dev/null +++ b/verl/utils/modelopt/quantize.py @@ -0,0 +1,82 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ModelOpt NVFP4 quantization config and application for Megatron QAT.""" + +import logging + +import modelopt.torch.quantization as mtq +import torch.nn as nn +from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg + +logger = logging.getLogger(__name__) + + +_NVFP4_W4A16_QUANTIZER_CFG = { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": {"enable": False}, +} + + +def _ignore_patterns_to_quant_cfg(ignore_patterns: list[str]) -> dict: + """Convert user-provided ignore patterns to ModelOpt ``quant_cfg`` entries.""" + cfg = {} + for pattern in ignore_patterns: + key = pattern + if not key.startswith("*"): + key = f"*{key}" + if not key.endswith("*"): + key = f"{key}*" + cfg[key] = {"enable": False} + return cfg + + +def build_quantize_config( + qat_mode: str, + ignore_patterns: list[str] | None = None, +) -> dict: + """Build a complete ModelOpt quantization config for ``mtq.quantize``.""" + if qat_mode != "w4a16": + raise ValueError(f"Only 'w4a16' is supported, got: {qat_mode}") + + if ignore_patterns is None: + ignore_patterns = [] + + ignore_cfg = _ignore_patterns_to_quant_cfg(ignore_patterns) + + quant_cfg = { + **_NVFP4_W4A16_QUANTIZER_CFG, + **_default_disabled_quantizer_cfg, + **ignore_cfg, + } + logger.info("Built NVFP4 %s quantize config, ignore_patterns=%s", qat_mode, ignore_patterns) + + return {"quant_cfg": quant_cfg, "algorithm": "max"} + + +def apply_qat( + model: nn.Module, + qat_mode: str, + ignore_patterns: list[str] | None = None, +) -> nn.Module: + """Apply Quantization-Aware Training to a Megatron model.""" + config = build_quantize_config(qat_mode, ignore_patterns) + mtq.quantize(model, config) + return model diff --git a/verl/utils/modelopt/vllm_modelopt_patch.py b/verl/utils/modelopt/vllm_modelopt_patch.py new file mode 100644 index 00000000000..484685ac0a8 --- /dev/null +++ b/verl/utils/modelopt/vllm_modelopt_patch.py @@ -0,0 +1,664 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +vLLM ModelOpt NVFP4 Patches for Dynamic Weight Updates (Marlin Backend). + +Enables dynamic weight reloading for NVFP4 quantized models in vLLM +using the ModelOpt quantization path with the Marlin kernel backend. + +Saves parameter metadata on first load and deletes HF parameters. Before +reload, HF parameters are rebuilt from metadata, loaded, then re-converted +to Marlin format in-place via copy_ (preserving CUDA Graph tensor addresses). + +Supported schemes: +- Dense: ModelOptNvFp4LinearMethod (Marlin backend) +- MoE: ModelOptNvFp4FusedMoE (Marlin backend) +- KV: BaseKVCacheMethod (preserves scales for reload) +""" + +import logging +import os +from typing import Optional + +import torch +from torch.nn import Parameter + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +# ============================================================================ +# Utility Functions +# ============================================================================ + + +def save_param_meta(layer: torch.nn.Module, param_name: str): + """Save parameter metadata (shape, dtype, param_class, dims) for later rebuild.""" + if not hasattr(layer, "_hf_param_meta"): + layer._hf_param_meta = {} + + param = getattr(layer, param_name, None) + if param is None: + return + + meta = { + "shape": tuple(param.shape), + "dtype": param.dtype, + "device": str(param.device), + "param_class": type(param), + } + + if hasattr(param, "_input_dim"): + meta["input_dim"] = param._input_dim + if hasattr(param, "_output_dim"): + meta["output_dim"] = param._output_dim + + layer._hf_param_meta[param_name] = meta + + +def _create_param_from_meta( + module: torch.nn.Module, + param_name: str, + meta: dict, + device: Optional[torch.device] = None, +) -> Parameter: + """Create a Parameter from saved metadata. Used by rebuild and tensor swap.""" + shape = meta["shape"] + dtype = meta["dtype"] + dev = device or meta.get("device", "cuda") + param_class = meta.get("param_class", Parameter) + + weight_loaders = getattr(module, "_weight_loaders", {}) + weight_loader = weight_loaders.get(param_name) + + data = torch.empty(shape, dtype=dtype, device=dev) + + try: + if param_class is not Parameter and weight_loader is not None: + kwargs = {"data": data, "weight_loader": weight_loader} + if "input_dim" in meta: + kwargs["input_dim"] = meta["input_dim"] + if "output_dim" in meta: + kwargs["output_dim"] = meta["output_dim"] + new_param = param_class(**kwargs) + else: + new_param = Parameter(data, requires_grad=False) + if weight_loader is not None: + new_param.weight_loader = weight_loader + except Exception as e: + logger.warning(f"Failed to create param {param_name} with class {param_class}: {e}, using Parameter") + new_param = Parameter(data, requires_grad=False) + if weight_loader is not None: + new_param.weight_loader = weight_loader + + return new_param + + +def _check_first_call(layer: torch.nn.Module) -> bool: + """Check if this is the first process_weights call, and increment counter.""" + count = getattr(layer, "_process_weights_call_count", 0) + layer._process_weights_call_count = count + 1 + return count == 0 + + +def _save_weight_loaders(layer: torch.nn.Module, param_names: list[str]): + """Save weight_loader references from parameters before they are overwritten.""" + if not hasattr(layer, "_weight_loaders"): + layer._weight_loaders = {} + for pname in param_names: + param = getattr(layer, pname, None) + if param is not None and hasattr(param, "weight_loader"): + layer._weight_loaders[pname] = param.weight_loader + + +def _update_ref_or_create(layer, ref_name, new_data): + """Copy new_data into existing tensor ref (CUDA Graph safe), or create new Parameter.""" + refs = getattr(layer, "_marlin_tensor_refs", {}) + ref = refs.get(ref_name) + if ref is not None: + ref.copy_(new_data) + setattr(layer, ref_name, Parameter(ref, requires_grad=False)) + else: + logger.warning(f"_marlin_tensor_refs['{ref_name}'] not found, creating new Parameter") + t = new_data.clone() if isinstance(new_data, torch.Tensor) else torch.tensor(new_data) + setattr(layer, ref_name, Parameter(t, requires_grad=False)) + + +# ============================================================================ +# ModelOptParamMetaDict +# ============================================================================ + + +class ModelOptParamMetaDict(dict): + """ + Dict-like class for parameter management with metadata-based rebuild + and tensor swap. Supports: + - Rebuild of deleted parameters from saved metadata + - Tensor swap for parameters with shape changes (address stability for CUDA Graph) + """ + + def __init__(self, model: torch.nn.Module, device: Optional[torch.device] = None): + super().__init__() + self.device = device + + actual_model = model + if hasattr(model, "model"): + actual_model = model.model + self._model = actual_model + + self._layer_meta_cache: dict[str, dict] = {} + self._tensor_swap_layers: dict[str, dict] = {} + + self._build_mappings() + + for name, param in actual_model.named_parameters(): + self[name] = param + + def _build_mappings(self): + """Build layer metadata cache for rebuild and tensor swap.""" + for layer_name, module in self._model.named_modules(): + if not hasattr(module, "_hf_param_meta"): + continue + + self._layer_meta_cache[layer_name] = { + "module": module, + "meta": module._hf_param_meta, + } + + marlin_refs = getattr(module, "_marlin_tensor_refs", {}) + for param_name, meta in module._hf_param_meta.items(): + if param_name in marlin_refs: + key = f"{layer_name}.{param_name}" if layer_name else param_name + self._tensor_swap_layers[key] = { + "module": module, + "param_name": param_name, + "marlin_ref": marlin_refs[param_name], + "hf_meta": meta, + } + + def _try_rebuild(self, key: str) -> Optional[Parameter]: + parts = key.rsplit(".", 1) + if len(parts) != 2: + return None + layer_name, param_name = parts + if layer_name not in self._layer_meta_cache: + return None + cache_entry = self._layer_meta_cache[layer_name] + module = cache_entry["module"] + meta = cache_entry["meta"] + if param_name not in meta: + return None + if hasattr(module, param_name): + param = getattr(module, param_name) + if param is not None: + return param + new_param = _create_param_from_meta(module, param_name, meta[param_name], self.device) + module.register_parameter(param_name, new_param) + return new_param + + def prepare_for_reload(self) -> None: + """Replace kernel-format tensors with HF-shape tensors for reload.""" + for _key, swap_info in self._tensor_swap_layers.items(): + module = swap_info["module"] + param_name = swap_info["param_name"] + hf_meta = swap_info["hf_meta"] + if hasattr(module, param_name): + new_param = _create_param_from_meta(module, param_name, hf_meta, self.device) + setattr(module, param_name, new_param) + + def __getitem__(self, key: str) -> Parameter: + if key in dict.keys(self): + return super().__getitem__(key) + param = self._try_rebuild(key) + if param is not None: + self[key] = param + return param + raise KeyError(f"Parameter not found: {key}") + + def __contains__(self, key: str) -> bool: + if super().__contains__(key): + return True + parts = key.rsplit(".", 1) + if len(parts) == 2: + layer_name, param_name = parts + if layer_name in self._layer_meta_cache: + if param_name in self._layer_meta_cache[layer_name]["meta"]: + return True + return False + + def get(self, key: str, default=None): + try: + return self[key] + except KeyError: + return default + + +# ============================================================================ +# Dense Linear Patch (Marlin) +# ============================================================================ + +_DENSE_HF_PARAMS = ["weight", "weight_scale", "input_scale", "weight_scale_2"] + + +def _modelopt_dense_process_weights(self, layer: torch.nn.Module) -> None: + """ + Replacement for ModelOptNvFp4LinearMethod.process_weights_after_loading. + + First call: save metadata + weight_loaders, convert HF→Marlin format, + save _marlin_tensor_refs for CUDA Graph stability. + Subsequent: read reloaded HF data, convert, copy_ into saved refs. + """ + import vllm._custom_ops as ops + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_make_workspace_new, + marlin_permute_scales, + ) + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + nvfp4_marlin_process_global_scale, + nvfp4_marlin_process_scales, + ) + + is_first_call = _check_first_call(layer) + + if is_first_call: + for pname in _DENSE_HF_PARAMS: + save_param_meta(layer, pname) + _save_weight_loaders(layer, _DENSE_HF_PARAMS) + + weight_data = layer.weight.data + weight_scale_data = layer.weight_scale.data + weight_scale_2_data = layer.weight_scale_2.data + + assert weight_scale_data.dtype == torch.float8_e4m3fn + + device = weight_data.device + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + group_size = 16 + weight_scale_2_max = weight_scale_2_data.max().to(torch.float32) + + if is_first_call: + layer.workspace = marlin_make_workspace_new(device) + + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = weight_data.view(torch.int32).T.contiguous() + marlin_weight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4, + ) + + weight_scale = weight_scale_data.T.contiguous().to(param_dtype) + weight_scale = marlin_permute_scales( + s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=group_size, + ) + marlin_weight_scale = nvfp4_marlin_process_scales(weight_scale) + marlin_weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2_max.to(param_dtype)) + + if is_first_call: + layer.weight = Parameter(marlin_weight, requires_grad=False) + layer.weight_scale = Parameter(marlin_weight_scale, requires_grad=False) + layer.weight_scale_2 = Parameter(marlin_weight_scale_2, requires_grad=False) + layer._marlin_tensor_refs = { + "weight": layer.weight.data, + "weight_scale": layer.weight_scale.data, + "weight_scale_2": layer.weight_scale_2.data, + } + else: + _update_ref_or_create(layer, "weight", marlin_weight) + _update_ref_or_create(layer, "weight_scale", marlin_weight_scale) + _update_ref_or_create(layer, "weight_scale_2", marlin_weight_scale_2) + + for attr in ["input_scale", "alpha", "input_scale_inv"]: + if hasattr(layer, attr): + delattr(layer, attr) + + +# ============================================================================ +# MoE Helpers (Marlin) +# ============================================================================ + + +def _marlin_repack_experts(packed, perm, size_k, size_n, num_experts): + """Repack weight for each expert into Marlin format and stack.""" + import vllm._custom_ops as ops + + result = [] + for i in range(num_experts): + qweight = packed[i].view(torch.int32).T.contiguous() + result.append( + ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + ) + return torch.stack(result) + + +def _marlin_process_scales_experts(scale_hf, param_dtype, size_k, size_n, group_size, num_experts): + """Process scales for each expert into Marlin format and stack.""" + from vllm.model_executor.layers.quantization.utils.marlin_utils import marlin_permute_scales + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import nvfp4_marlin_process_scales + + result = [] + scales = scale_hf.to(param_dtype) + for i in range(num_experts): + s = marlin_permute_scales(s=scales[i].T, size_k=size_k, size_n=size_n, group_size=group_size) + result.append(nvfp4_marlin_process_scales(s)) + return torch.stack(result) + + +# ============================================================================ +# MoE Patch (Marlin) +# ============================================================================ + +_MOE_HF_PARAMS = [ + "w13_weight", + "w2_weight", + "w13_weight_scale", + "w2_weight_scale", + "w13_weight_scale_2", + "w2_weight_scale_2", + "w13_input_scale", + "w2_input_scale", +] + + +def _modelopt_moe_marlin_convert(self, layer: torch.nn.Module, is_first_call: bool) -> None: + """Convert MoE layer weights between HF and Marlin format.""" + from vllm.model_executor.layers.quantization.utils.marlin_utils import marlin_make_workspace_new + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import nvfp4_marlin_process_global_scale + + group_size = 16 + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + device = layer.w13_weight.device + param_dtype = layer.params_dtype + + if is_first_call: + layer.workspace = marlin_make_workspace_new(device, 4) + + perm = torch.empty(0, dtype=torch.int, device=device) + size_n_w13, size_k_w13 = n * 2, k + size_n_w2, size_k_w2 = k, n + + # Repack weights + w13_weight_marlin = _marlin_repack_experts(layer.w13_weight.data, perm, size_k_w13, size_n_w13, e) + w2_weight_marlin = _marlin_repack_experts(layer.w2_weight.data, perm, size_k_w2, size_n_w2, e) + + # Process scales + w13_weight_scale_marlin = _marlin_process_scales_experts( + layer.w13_weight_scale.data, + param_dtype, + size_k_w13, + size_n_w13, + group_size, + e, + ) + w2_weight_scale_marlin = _marlin_process_scales_experts( + layer.w2_weight_scale.data, + param_dtype, + size_k_w2, + size_n_w2, + group_size, + e, + ) + + # Process global scales (w13_weight_scale_2 is already (E,) after common processing) + w13_scale_2_processed = nvfp4_marlin_process_global_scale(layer.w13_weight_scale_2.data.to(param_dtype)) + w2_scale_2_processed = nvfp4_marlin_process_global_scale(layer.w2_weight_scale_2.data.to(param_dtype)) + + if is_first_call: + layer.w13_weight = Parameter(w13_weight_marlin, requires_grad=False) + layer.w2_weight = Parameter(w2_weight_marlin, requires_grad=False) + layer.w13_weight_scale = Parameter(w13_weight_scale_marlin, requires_grad=False) + layer.w2_weight_scale = Parameter(w2_weight_scale_marlin, requires_grad=False) + layer.w13_weight_scale_2 = Parameter(w13_scale_2_processed, requires_grad=False) + layer.w2_weight_scale_2 = Parameter(w2_scale_2_processed, requires_grad=False) + if not hasattr(layer, "_marlin_tensor_refs"): + layer._marlin_tensor_refs = {} + for rn in [ + "w13_weight", + "w2_weight", + "w13_weight_scale", + "w2_weight_scale", + "w13_weight_scale_2", + "w2_weight_scale_2", + ]: + layer._marlin_tensor_refs[rn] = getattr(layer, rn).data + else: + for rn, nd in [ + ("w13_weight", w13_weight_marlin), + ("w2_weight", w2_weight_marlin), + ("w13_weight_scale", w13_weight_scale_marlin), + ("w2_weight_scale", w2_weight_scale_marlin), + ("w13_weight_scale_2", w13_scale_2_processed), + ("w2_weight_scale_2", w2_scale_2_processed), + ]: + _update_ref_or_create(layer, rn, nd) + + for attr in ["w13_input_scale", "w2_input_scale"]: + if hasattr(layer, attr): + delattr(layer, attr) + + +def _modelopt_moe_process_weights(self, layer: torch.nn.Module) -> None: + """ + Replacement for ModelOptNvFp4FusedMoE.process_weights_after_loading (Marlin). + + First call: save metadata + weight_loaders, convert HF→Marlin format, + save _marlin_tensor_refs for CUDA Graph stability. + Subsequent: read reloaded HF data, convert, copy_ into saved refs. + """ + is_first_call = _check_first_call(layer) + + if is_first_call: + for pname in _MOE_HF_PARAMS: + save_param_meta(layer, pname) + _save_weight_loaders(layer, _MOE_HF_PARAMS) + + # ---- w13_weight_scale_2: reduce (E, 2) → (E,) ---- + if self.moe.is_act_and_mul and not torch.allclose(layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]): + logger.warning("w1_weight_scale_2 must match w3_weight_scale_2. Accuracy may be affected.") + + w13_weight_scale_2 = layer.w13_weight_scale_2.data + if w13_weight_scale_2.dim() == 2: + w13_weight_scale_2 = w13_weight_scale_2[:, 0] + layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) + + _modelopt_moe_marlin_convert(self, layer, is_first_call) + + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + + +# ============================================================================ +# KV Cache Patch +# ============================================================================ + + +def _modelopt_kv_process_weights(self, layer) -> None: + """ + Replacement for BaseKVCacheMethod.process_weights_after_loading. + Doesn't delete k_scale, v_scale, q_scale, prob_scale to allow + for dynamic updates during refit. + """ + from vllm.platforms import current_platform + + if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + k_scale = 1.0 + v_scale = 1.0 + else: + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError("Only support per-tensor scaling factor for fp8 KV cache") + + if layer.q_scale < 0.0: + layer._q_scale.copy_(k_scale) + layer._q_scale_float = k_scale + + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale + + if layer.q_scale > 0.0: + q_scale = layer.q_scale + if current_platform.is_fp8_fnuz(): + q_scale *= 2 + layer.calculate_kv_scales = False + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale + if current_platform.is_fp8_fnuz(): + prob_scale *= 2 + else: + prob_scale = 1.0 + + is_singleton_float = ( + lambda x: isinstance(x, float) or isinstance(x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() + ) + if not is_singleton_float(q_scale) or not is_singleton_float(prob_scale): + raise ValueError("Only support per-tensor scaling factor for fp8-quantized Q/prob") + + layer._q_scale.copy_(q_scale) + layer._q_scale_float = q_scale.item() if isinstance(q_scale, torch.Tensor) else q_scale + layer._prob_scale.copy_(prob_scale) + + +# ============================================================================ +# Patch Application & Entry Points +# ============================================================================ + +_patched = False + + +def prepare_modelopt_for_weight_reload(model, device=None): + """ + Prepare ModelOpt model for weight reloading. Call ONCE before each reload cycle. + + 1. Builds ModelOptParamMetaDict from saved metadata + 2. Swaps kernel-format tensors back to HF-shape for weight_loader compatibility + 3. Rebuilds any deleted parameters from metadata + + Args: + model: vLLM model + device: Device for created parameters + """ + inner_model = model + if hasattr(model, "model"): + inner_model = model.model + + param_meta = ModelOptParamMetaDict(inner_model, device=device) + + param_meta.prepare_for_reload() + logger.info(f"[prepare_modelopt] Tensor swap prepared for {len(param_meta._tensor_swap_layers)} layers") + + rebuilt_count = 0 + for layer_name, cache_entry in param_meta._layer_meta_cache.items(): + module = cache_entry["module"] + for param_name, pm in cache_entry["meta"].items(): + existing = getattr(module, param_name, None) + if existing is not None: + hf_shape = tuple(pm["shape"]) + hf_dtype = pm["dtype"] + if ( + tuple(existing.shape) == hf_shape + and existing.dtype == hf_dtype + and hasattr(existing, "weight_loader") + ): + continue + new_param = _create_param_from_meta(module, param_name, pm, device) + module.register_parameter(param_name, new_param) + rebuilt_count += 1 + + logger.info(f"[prepare_modelopt] Rebuilt {rebuilt_count} parameters") + inner_model._param_meta_for_restore = param_meta + return param_meta + + +def modelopt_process_weights_after_loading(model): + """Trigger weight post-processing for all quantized layers after load_weights.""" + dense_count = 0 + moe_count = 0 + + actual_model = model + if hasattr(model, "model"): + actual_model = model.model + + for module in actual_model.modules(): + if hasattr(module, "scheme"): + module.scheme.process_weights_after_loading(module) + dense_count += 1 + + quant_method = getattr(module, "quant_method", None) + if quant_method is not None and not hasattr(module, "scheme"): + if hasattr(quant_method, "process_weights_after_loading"): + if "KVCache" in quant_method.__class__.__name__: + continue + quant_method.process_weights_after_loading(module) + moe_count += 1 + + logger.debug(f"Processed {dense_count} dense layers, {moe_count} MoE layers") + return dense_count + moe_count + + +def apply_modelopt_nvfp4_patches(): + """Apply ModelOpt NVFP4 patches to support dynamic weight updates. Call before model loading.""" + global _patched + + if _patched: + logger.warning("ModelOpt NVFP4 patches already applied, skipping") + return + + logger.info("Applying ModelOpt NVFP4 patches for dynamic weight loading...") + + from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4FusedMoE, + ModelOptNvFp4LinearMethod, + ) + + ModelOptNvFp4LinearMethod.process_weights_after_loading = _modelopt_dense_process_weights + ModelOptNvFp4FusedMoE.process_weights_after_loading = _modelopt_moe_process_weights + BaseKVCacheMethod.process_weights_after_loading = _modelopt_kv_process_weights + + _patched = True + logger.info("Applied 3 ModelOpt NVFP4 patches (Dense, MoE, KV)") diff --git a/verl/workers/config/actor.py b/verl/workers/config/actor.py index 66071e8ec20..f33bb9269cc 100644 --- a/verl/workers/config/actor.py +++ b/verl/workers/config/actor.py @@ -182,6 +182,7 @@ class ActorConfig(BaseConfig): # batch_num_tokens: number of valid tokens in global batch # global_batch_size: global batch size global_batch_info: dict = field(default_factory=dict) + qat: QATConfig = field(default_factory=QATConfig) def __post_init__(self): """Validate actor configuration parameters.""" @@ -299,7 +300,6 @@ class FSDPActorConfig(ActorConfig): use_rollout_log_probs: bool = False calculate_sum_pi_squared: bool = False sum_pi_squared_checkpointing: bool = False - qat: QATConfig = field(default_factory=QATConfig) def __post_init__(self): """Validate FSDP actor configuration parameters.""" diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index f78dcde56e2..c6d7323e703 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -64,6 +64,7 @@ ) from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights +from verl.utils.modelopt import apply_qat from verl.utils.profiler import ( DistProfiler, DistProfilerExtension, @@ -219,6 +220,21 @@ def _init_hf_config_and_tf_config( provider.moe_token_dispatcher_type = "alltoall" provider.moe_router_load_balancing_type = "none" + qat_enabled = self.config.actor.get("qat", {}).get("enable", False) + if qat_enabled: + from megatron.bridge.models.gpt_provider import quantization_layer_spec + + provider.transformer_layer_spec = quantization_layer_spec + + from verl.utils.modelopt.megatron_qat_patch import apply_qat_patch + + apply_qat_patch() + + from megatron.bridge.models.conversion.param_mapping import AutoMapping + + AutoMapping.register_module_type("QuantColumnParallelLinear", "column") + AutoMapping.register_module_type("QuantRowParallelLinear", "row") + # Apply transformer config overrides for key, value in override_transformer_config.items(): setattr(provider, key, value) @@ -443,6 +459,12 @@ def _build_model_optimizer( if self.rank == 0: print_model_size(actor_module[0]) log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) + qat_config = self.config.actor.get("qat", {}) + if qat_config.get("enable", False): + qat_mode = qat_config.get("mode", "w4a16") + ignore_patterns = qat_config.get("ignore_patterns", None) + for i in range(len(actor_module)): + actor_module[i] = apply_qat(actor_module[i], qat_mode, ignore_patterns=ignore_patterns) elif self._is_ref: wrap_config = McoreModuleWrapperConfig( is_value_model=False, # ref is not value model @@ -716,6 +738,16 @@ async def rollout_mode(self): self.tf_config, self.layer_name_mapping, ) + qat_config = self.config.actor.get("qat", {}) + if qat_config.get("enable", False): + from verl.utils.modelopt import QATWeightExporter + + qat_mode = qat_config.get("mode", "w4a16") + qat_weight_exporter = QATWeightExporter(self.actor.actor_module, qat_mode, bridge=self.bridge) + # qat_weight_exporter = QATWeightExporter( + # self.actor.actor_module, qat_mode + # ) + per_tensor_param = qat_weight_exporter.process_weights_iterator(per_tensor_param) if self.config.rollout.free_cache_engine: await self.rollout.resume(tags=["weights"]) diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index 7fa3b1dd67c..fea544a2a8b 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -167,11 +167,17 @@ def __new__(cls, **kwargs): vllm_config = kwargs.get("vllm_config") quant_config = getattr(vllm_config, "quant_config", None) if vllm_config else None _is_qat_model = getattr(quant_config, "quant_format", None) == "nvfp4-pack-quantized" + _is_modelopt_qat = type(quant_config).__name__ == "ModelOptNvFp4Config" if _is_qat_model: from verl.utils.qat import apply_qat_patches apply_qat_patches() - logger.info("Applied QAT patches in vLLM worker subprocess") + logger.info("Applied QAT (compressed-tensors) patches in vLLM worker subprocess") + elif _is_modelopt_qat: + from verl.utils.modelopt import apply_modelopt_nvfp4_patches + + apply_modelopt_nvfp4_patches() + logger.info("Applied ModelOpt NVFP4 patches in vLLM worker subprocess") # TODO: For ascend NPU, when the corresponding vllm-ascend version is upgraded to v0.13.0, # please remove the VLLM_ASCEND_REQUIRED_ENV_VARS variable replacement action. @@ -183,6 +189,7 @@ def __new__(cls, **kwargs): instance = super().__new__(cls) instance._is_qat_model = _is_qat_model + instance._is_modelopt_qat = _is_modelopt_qat return instance def monkey_patch_model(self, vocab_size: int): @@ -226,11 +233,16 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False ) if self._is_qat_model: - # QAT: Prepare for weight loading BEFORE receiving any buckets + # QAT (compressed-tensors): Prepare for weight loading BEFORE receiving any buckets from verl.utils.qat import prepare_qat_for_load_weights prepare_qat_for_load_weights(self.model_runner.model, device=self.device) logger.info("QAT: prepare_qat_for_load_weights completed") + elif self._is_modelopt_qat: + from verl.utils.modelopt.vllm_modelopt_patch import prepare_modelopt_for_weight_reload + + prepare_modelopt_for_weight_reload(self.model_runner.model, device=self.device) + logger.info("ModelOpt: prepare_modelopt_for_weight_reload completed") elif use_standard_weight_load: # Re-apply here because async IPC weight sync can happen long after init and lose MoE weight_loader attrs. patch_vllm_moe_model_weight_loader(self.model_runner.model) @@ -259,11 +271,16 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False break if self._is_qat_model: - # QAT: call process_weights_after_loading AFTER all buckets are received + # QAT (compressed-tensors): call process_weights_after_loading AFTER all buckets are received from verl.utils.qat import manual_process_weights_after_loading manual_process_weights_after_loading(self.model_runner.model) logger.info("QAT: process_weights_after_loading completed") + elif self._is_modelopt_qat: + from verl.utils.modelopt.vllm_modelopt_patch import modelopt_process_weights_after_loading + + modelopt_process_weights_after_loading(self.model_runner.model) + logger.info("ModelOpt QAT: process_weights_after_loading completed") elif use_standard_weight_load: # Some post-load transforms are non-idempotent; run once after all buckets. from vllm.model_executor.model_loader.utils import process_weights_after_loading diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 3fe3aef3216..55c0a72a1db 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -240,18 +240,28 @@ async def launch_server(self, master_address: str = None, master_port: int = Non # Handle QAT (Quantization-Aware Training) configuration qat_config_dict = getattr(self.config, "qat", {}) or {} if qat_config_dict.get("enable", False): - # QAT uses compressed-tensors quantization, apply patches for dynamic weight loading - from verl.utils.qat import QATConfig, apply_qat_patches, load_quantization_config + from verl.utils.qat import QATConfig, load_quantization_config - apply_qat_patches() - - # Load quantization config from JSON file qat_config = QATConfig(**qat_config_dict) quantization_config_dict = load_quantization_config(qat_config) - hf_overrides["quantization_config"] = quantization_config_dict - quantization = "compressed-tensors" + quant_method = quantization_config_dict.get("quant_method", None) + + if quant_method == "modelopt": + from verl.utils.modelopt import apply_modelopt_nvfp4_patches + + apply_modelopt_nvfp4_patches() + quantization = "modelopt" + elif quant_method == "compressed-tensors": + from verl.utils.qat import apply_qat_patches - logger.info("QAT quantization config injected to vLLM async server") + apply_qat_patches() + quantization = "compressed-tensors" + + else: + raise ValueError(f"Unsupported quant_method: {quant_method}") + logger.info(f"QAT quantization config injected (quant_method={quant_method})") + hf_overrides["quantization_config"] = quantization_config_dict + print(f"quantization config: {quantization_config_dict}") elif quantization is not None: # Handle other quantization methods (fp8, torchao) _SUPPORTED_QUANTIZATION = ["fp8", "torchao"]