From d4e479e690a09066c2e1c38df2dee7a064090494 Mon Sep 17 00:00:00 2001 From: larkz-nv Date: Thu, 12 Feb 2026 06:46:53 -0800 Subject: [PATCH 01/10] support nvfp4 qat training with modelopt and megatron --- verl/trainer/config/engine/megatron.yaml | 3 ++ verl/workers/config/engine.py | 2 + verl/workers/megatron_workers.py | 33 +++++++++++++++ verl/workers/rollout/vllm_rollout/utils.py | 10 ++++- .../rollout/vllm_rollout/vllm_async_server.py | 40 ++++++++++++++++++- .../rollout/vllm_rollout/vllm_rollout.py | 1 + 6 files changed, 87 insertions(+), 2 deletions(-) diff --git a/verl/trainer/config/engine/megatron.yaml b/verl/trainer/config/engine/megatron.yaml index b588a96c1b3..ec6933a4e55 100644 --- a/verl/trainer/config/engine/megatron.yaml +++ b/verl/trainer/config/engine/megatron.yaml @@ -72,6 +72,9 @@ override_transformer_config: # Attention backend to use (flash,fused,unfused,local,auto). Defaults to auto in mcore, flash in verl attention_backend: flash +# # Quantization method. None for no quantization, "nvfp4_qat" for Quantization-Aware Training +quantization: null + override_mcore_model_config: {} # oc.select: default val for ref.megatron.use_mbridge diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index e09dfb20a7f..afcc81bf142 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -102,6 +102,7 @@ class McoreEngineConfig(EngineConfig): override_transformer_config (dict[str, Any]): Override configuration for transformer. use_mbridge (bool): Whether to use MBridge for communication. dtype (str): Mixed precision training param dtype, default "bfloat16" + quantization (Optional[str]): Quantization method to use. None for no quantization, "nvfp4_qat" for QAT. """ # sequence_parallel is not listed as a frozen field for auto-correction purpose @@ -124,6 +125,7 @@ class McoreEngineConfig(EngineConfig): use_mbridge: bool = True vanilla_mbridge: bool = True strategy: str = "megatron" + quantization: Optional[str] = None def __post_init__(self) -> None: super().__post_init__() diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index aa7613fbc78..187beaea9ef 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -73,6 +73,7 @@ simple_timer, ) from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max +from verl.utils.qat_utils import QATConfig, apply_qat, is_qat_enabled from verl.utils.ray_utils import get_event_loop from verl.utils.torch_functional import use_original_torch_compile from verl.workers.actor.megatron_actor import MegatronPPOActor @@ -442,6 +443,16 @@ def _build_model_optimizer( if self.rank == 0: print_model_size(actor_module[0]) log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) + quantization = self.config.actor.megatron.get("quantization", None) + if quantization is not None: + if is_qat_enabled(quantization): + print(f"[lark]: Applying QAT with method: {quantization}") + qat_config = QATConfig(enabled=True, quant_method=quantization) + print("[lark]: length of actor_module:", len(actor_module)) + for i in range(len(actor_module)): + actor_module[i] = apply_qat(actor_module[i], qat_config) + print("[lark]: QAT applied to all actor model chunks") + elif self._is_ref: wrap_config = McoreModuleWrapperConfig( is_value_model=False, # ref is not value model @@ -717,6 +728,28 @@ async def rollout_mode(self): self.tf_config, self.layer_name_mapping, ) + if is_qat_enabled(self.config.actor.megatron.quantization): + print("[lark]: rollout mode: quantizing weights with QAT") + from verl.utils.qat_post_utils import QATWeightPostProcessor + + qat_weight_post_processor = QATWeightPostProcessor( + self.actor.actor_module, "nvfp4", self.dtype, use_calibrated_scale_2=True + ) + per_tensor_param = qat_weight_post_processor.process_weights_iterator(per_tensor_param) + + # per_tensor_param = list(per_tensor_param) + # rank = torch.distributed.get_rank() + # state_dict = {} + # for name, weight in per_tensor_param: + # state_dict[name] = weight.data.cpu() + # path = f"/apps/quant_models/qwen3_8b_nvfp4/model_rank_{rank}.pt" + # torch.save(state_dict, path) + # del state_dict + # print(f"[lark]: saved state_dict to {path}") + + # import time + # time.sleep(1000) + if self.config.rollout.free_cache_engine: await self.rollout.resume(tags=["weights"]) diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index cbfedb879fa..fc4ff5b915c 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -29,7 +29,7 @@ from verl.utils.vllm import TensorLoRARequest, VLLMHijack from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights - +from verl.utils.modelopt_utils import apply_vllm_modelopt_patches logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -138,6 +138,8 @@ def __new__(cls, **kwargs): # 2. patch online fp8 quant if os.environ.get("VERL_VLLM_FP8_QUANT_ENABLED", "0") == "1": apply_vllm_fp8_patches() + elif os.environ.get("VERL_VLLM_NVFP4_QUANT_ENABLED", "0") == "1": + apply_vllm_modelopt_patches() # TODO: For ascend NPU, when the corresponding vllm-ascend version is upgraded to v0.13.0, # please remove the VLLM_ASCEND_REQUIRED_ENV_VARS variable replacement action. @@ -225,6 +227,12 @@ def _update_weights(self, weights: list[tuple[str, torch.Tensor]], peft_config: logger.info("Loading standard weights (non-FP8, async)") self.model_runner.model.load_weights(weights) + from vllm.model_executor.model_loader.utils import process_weights_after_loading + model_config = self.model_runner.vllm_config.model_config + device = next(self.model_runner.model.parameters()).device + process_weights_after_loading(self.model_runner.model, model_config, device) + # from vllm.model_executor.layers.quantization.modelopt import ModelOptNvFp4LinearMethod + def _get_zmq_handle(self) -> str: """Get ZMQ handle for communication.""" if not hasattr(self, "device_uuid") or not self.device_uuid: diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 196c72bc378..08e2cb4de6a 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -225,7 +225,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non quantization = self.config.quantization if quantization is not None: - _SUPPORTED_QUANTIZATION = ["fp8", "torchao"] + _SUPPORTED_QUANTIZATION = ["fp8", "torchao", "nvfp4_qat"] if quantization not in _SUPPORTED_QUANTIZATION: raise ValueError(f"Currently only support {_SUPPORTED_QUANTIZATION} quantization, got: {quantization}") @@ -242,6 +242,41 @@ async def launch_server(self, master_address: str = None, master_port: int = Non apply_vllm_fp8_patches() # for subprocesses patching os.environ["VERL_VLLM_FP8_QUANT_ENABLED"] = "1" + elif quantization == "nvfp4_qat": + print("[lark]: vllm quantization is nvfp4_qat") + fp4_block_quant_kwargs = { + "config_groups": { + "group_0": { + "input_activations": { + "dynamic": "false", + "num_bits": 4, + "type": "float", + "group_size": 16 + }, + "weights": { + "dynamic": "false", + "num_bits": 4, + "type": "float", + "group_size": 16 + }, + "targets": [ + "Linear" + ] + } + }, + "ignore": [ + "lm_head" + ], + "quant_algo": "NVFP4", + "producer": { + "name": "modelopt", + "version": "0.40.0.dev89+g0ec5e200f.d20251127" + }, + "quant_method": "modelopt" + } + from verl.utils.modelopt_utils import apply_vllm_modelopt_patches + apply_vllm_modelopt_patches() + os.environ["VERL_VLLM_NVFP4_QUANT_ENABLED"] = "1" hf_overrides = {} if quantization is not None and self.config.quantization_config_file is not None: @@ -249,6 +284,9 @@ async def launch_server(self, master_address: str = None, master_port: int = Non if quantization == "fp8": hf_overrides["quantization_config"] = fp8_block_quant_kwargs + elif quantization == "nvfp4_qat": + hf_overrides["quantization_config"] = fp4_block_quant_kwargs + quantization = "modelopt" args = { "dtype": self.config.dtype, diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index ebbb6e19e48..3880d2756a2 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -169,6 +169,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None # model parameters are in fp32 full precision weight = weight.to(dtype, non_blocking=True) + # fill the tensor bucket if offset + weight.nbytes > bucket_size: get_torch_device().synchronize() From deb8e8f17b7e1dc08c5eaebf63720a05c9f731f1 Mon Sep 17 00:00:00 2001 From: larkz-nv Date: Thu, 12 Feb 2026 06:47:20 -0800 Subject: [PATCH 02/10] clean some logic --- verl/trainer/config/engine/megatron.yaml | 5 +- verl/utils/modelopt_qat_utils.py | 737 ++++++++++++++++ verl/utils/modelopt_vllm_utils.py | 823 ++++++++++++++++++ verl/workers/config/engine.py | 4 +- verl/workers/megatron_workers.py | 36 +- verl/workers/rollout/vllm_rollout/utils.py | 2 +- .../rollout/vllm_rollout/vllm_async_server.py | 41 +- 7 files changed, 1585 insertions(+), 63 deletions(-) create mode 100644 verl/utils/modelopt_qat_utils.py create mode 100644 verl/utils/modelopt_vllm_utils.py diff --git a/verl/trainer/config/engine/megatron.yaml b/verl/trainer/config/engine/megatron.yaml index ec6933a4e55..cbf2c53c733 100644 --- a/verl/trainer/config/engine/megatron.yaml +++ b/verl/trainer/config/engine/megatron.yaml @@ -72,9 +72,12 @@ override_transformer_config: # Attention backend to use (flash,fused,unfused,local,auto). Defaults to auto in mcore, flash in verl attention_backend: flash -# # Quantization method. None for no quantization, "nvfp4_qat" for Quantization-Aware Training +# # Quantization method. None for no quantization, "nvfp4" for NVFP4 quantization quantization: null +# Whether to enable Quantization-Aware Training (QAT). Default False. +enable_qat: False + override_mcore_model_config: {} # oc.select: default val for ref.megatron.use_mbridge diff --git a/verl/utils/modelopt_qat_utils.py b/verl/utils/modelopt_qat_utils.py new file mode 100644 index 00000000000..ed21398a5fc --- /dev/null +++ b/verl/utils/modelopt_qat_utils.py @@ -0,0 +1,737 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ModelOpt Quantization-Aware Training (QAT) utilities for Megatron models. + +Includes: +- QAT application via ModelOpt (apply_qat) +- QAT weight post-processing for exporting quantized weights to vLLM rollout (QATWeightPostProcessor) +""" + +import re +from dataclasses import dataclass +from typing import Any, Iterator + +import torch +import torch.nn as nn + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.quant_utils import ( + QUANTIZATION_NONE, + QUANTIZATION_NVFP4, + get_quantization_format, + get_weight_block_size, + to_quantized_weight, +) +from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor + +from verl.utils.megatron_utils import unwrap_model + +# --------------------------------------------------------------------------- +# NVFP4 quantization config +# --------------------------------------------------------------------------- + +NVFP4_WEIGHT_ONLY_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": {"enable": False}, + "nn.BatchNorm1d": {"*": {"enable": False}}, + "nn.BatchNorm2d": {"*": {"enable": False}}, + "nn.BatchNorm3d": {"*": {"enable": False}}, + "nn.LeakyReLU": {"*": {"enable": False}}, + "*lm_head*": {"enable": False}, + "*proj_out.*": {"enable": False}, # Whisper: lm_head has key name proj_out + "*block_sparse_moe.gate*": {"enable": False}, # Skip MOE router + "*router*": {"enable": False}, # Skip MOE router + "*mlp.gate.*": {"enable": False}, # Skip MOE router + "*mlp.shared_expert_gate.*": {"enable": False}, # Skip MOE router + "*linear_attn.conv1d*": {"enable": False}, + "*mixer.conv1d*": {"enable": False}, + "*output_layer*": {"enable": False}, + "output.*": {"enable": False}, + "default": {"enable": False}, + }, + "algorithm": "max", +} + +# --------------------------------------------------------------------------- +# QAT application +# --------------------------------------------------------------------------- + + +def apply_qat(model: nn.Module, quant_method: str): + """Apply Quantization-Aware Training to the model. + + Args: + model: The Megatron model to apply QAT to. + quant_method: Quantization method (currently only ``"nvfp4"`` is supported). + + Returns: + The quantized model. + """ + if quant_method != "nvfp4": + raise ValueError(f"Only 'nvfp4' is supported, got: {quant_method}") + + mtq.quantize(model, NVFP4_WEIGHT_ONLY_CFG) + return model + + +# --------------------------------------------------------------------------- +# QAT weight post-processing (for exporting quantized weights to rollout) +# --------------------------------------------------------------------------- + +@dataclass +class QuantizationMetadata: + """Metadata for a quantized module.""" + + qformat: str + weight_quantizer: Any + input_quantizer: Any + module: torch.nn.Module + vpp_idx: int + block_size: int = 16 # Default NVFP4 block size + + +class QATWeightPostProcessor: + """ + Post-processor for extracting quantization info from QAT trained modules + and converting bf16 weights to quantized formats (e.g., NVFP4). + + Key Design: + 1. Collect quantization metadata (quantizers, amax, block_size) from QAT modules + 2. Process all_gathered bf16 weights to compute quantized weights and scaling factors + 3. The scaling factors are computed on the merged (all_gathered) weights to ensure + correct block boundaries for per-block quantization (NVFP4) + + Note on TP (Tensor Parallelism): + - For NVFP4, weight_scale_2 (global scale) should ideally be computed from the full + (all_gathered) weight to ensure consistency across TP ranks. + - If use_calibrated_scale_2=True (default), we use the QAT calibrated amax which may + only reflect the local shard's statistics. + - If use_calibrated_scale_2=False, we recompute weight_scale_2 from the merged weight. + """ + + def __init__( + self, + actor_module: list, + quantization_method: str = "nvfp4", + dtype: torch.dtype = torch.bfloat16, + use_calibrated_scale_2: bool = True, + ): + """ + Initialize the QAT weight post-processor. + + Args: + actor_module: List of QAT trained model chunks (vpp chunks) + quantization_method: Quantization method (nvfp4, fp8, etc.) + dtype: Original data type (bf16) + use_calibrated_scale_2: If True, use QAT calibrated amax for weight_scale_2. + If False, recompute weight_scale_2 from merged weights. Recommended to set + False when using TP to ensure consistent global scale. + """ + self.actor_module = actor_module + self.quantization_method = quantization_method + self.dtype = dtype + self.use_calibrated_scale_2 = use_calibrated_scale_2 + self.quant_metadata: dict[str, QuantizationMetadata] = {} + + self._build_quantization_metadata() + self._log_initialization_info() + + def _build_quantization_metadata(self): + """ + Extract quantization metadata from all modules in actor_module. + Stores: {param_name: QuantizationMetadata} + + Supports both dense and MoE (Mixture of Experts) models: + - Dense: decoder.layers.X.mlp.linear_fc1, decoder.layers.X.mlp.linear_fc2 + - MoE SequentialMLP: decoder.layers.X.mlp.experts.local_experts.Y.linear_fc1/fc2 + - MoE shared_experts: decoder.layers.X.mlp.shared_experts.linear_fc1/fc2 + - MoE TEGroupedMLP: decoder.layers.X.mlp.experts.linear_fc1/fc2 (grouped) + """ + for vpp_idx, module in enumerate(self.actor_module): + model = unwrap_model(module) + + for name, submodule in model.named_modules(): + # Handle MoE SequentialMLP - need to iterate over local_experts + if self._is_sequential_mlp(submodule): + self._build_moe_sequential_mlp_metadata(name, submodule, vpp_idx) + continue + + # Handle MoE TEGroupedMLP - grouped experts with linear_fc1/fc2 + if self._is_te_grouped_mlp(submodule): + self._build_moe_te_grouped_mlp_metadata(name, submodule, vpp_idx) + continue + + # Handle regular quantized modules (dense layers, shared_experts, etc.) + qformat = get_quantization_format(submodule) + if qformat == QUANTIZATION_NONE: + continue + + block_size = get_weight_block_size(submodule) + if block_size == 0: + continue + + weight_quantizer = getattr(submodule, "weight_quantizer", None) + input_quantizer = getattr(submodule, "input_quantizer", None) + + metadata = QuantizationMetadata( + qformat=qformat, + weight_quantizer=weight_quantizer, + input_quantizer=input_quantizer, + module=submodule, + vpp_idx=vpp_idx, + block_size=block_size, + ) + + for param_name, _ in submodule.named_parameters(recurse=False): + full_name = f"{name}.{param_name}" if name else param_name + self.quant_metadata[full_name] = metadata + + def _is_sequential_mlp(self, module: torch.nn.Module) -> bool: + """Check if module is a MoE SequentialMLP.""" + module_type_name = type(module).__name__ + return "SequentialMLP" in module_type_name and hasattr(module, "local_experts") + + def _is_te_grouped_mlp(self, module: torch.nn.Module) -> bool: + """Check if module is a MoE TEGroupedMLP (Transformer Engine Grouped MLP).""" + module_type_name = type(module).__name__ + return "TEGroupedMLP" in module_type_name or "GroupedMLP" in module_type_name + + def _build_moe_sequential_mlp_metadata( + self, + base_name: str, + sequential_mlp: torch.nn.Module, + vpp_idx: int, + ): + """ + Build quantization metadata for MoE SequentialMLP. + + SequentialMLP structure: + - local_experts: list of MLP experts, each with linear_fc1 and linear_fc2 + - Each expert's linear layers may have quantizers attached + + Args: + base_name: Base module name (e.g., 'decoder.layers.0.mlp.experts') + sequential_mlp: The SequentialMLP module + vpp_idx: Virtual pipeline parallel index + """ + if not hasattr(sequential_mlp, "local_experts"): + return + + for expert_idx, expert in enumerate(sequential_mlp.local_experts): + # Process linear_fc1 and linear_fc2 for each expert + for linear_name in ["linear_fc1", "linear_fc2"]: + linear_module = getattr(expert, linear_name, None) + if linear_module is None: + continue + + qformat = get_quantization_format(linear_module) + if qformat == QUANTIZATION_NONE: + continue + + block_size = get_weight_block_size(linear_module) + if block_size == 0: + continue + + weight_quantizer = getattr(linear_module, "weight_quantizer", None) + input_quantizer = getattr(linear_module, "input_quantizer", None) + + metadata = QuantizationMetadata( + qformat=qformat, + weight_quantizer=weight_quantizer, + input_quantizer=input_quantizer, + module=linear_module, + vpp_idx=vpp_idx, + block_size=block_size, + ) + + # Build full parameter name + # Format: {base_name}.local_experts.{expert_idx}.{linear_name}.weight + for param_name, _ in linear_module.named_parameters(recurse=False): + full_name = f"{base_name}.local_experts.{expert_idx}.{linear_name}.{param_name}" + self.quant_metadata[full_name] = metadata + + def _build_moe_te_grouped_mlp_metadata( + self, + base_name: str, + te_grouped_mlp: torch.nn.Module, + vpp_idx: int, + ): + """ + Build quantization metadata for MoE TEGroupedMLP. + + TEGroupedMLP structure (Transformer Engine): + - linear_fc1: grouped linear layer for all experts + - linear_fc2: grouped linear layer for all experts + - Weights are stored as 3D tensors [num_experts, out_dim, in_dim] + + Args: + base_name: Base module name (e.g., 'decoder.layers.0.mlp.experts') + te_grouped_mlp: The TEGroupedMLP module + vpp_idx: Virtual pipeline parallel index + """ + for linear_name in ["linear_fc1", "linear_fc2"]: + linear_module = getattr(te_grouped_mlp, linear_name, None) + if linear_module is None: + continue + + qformat = get_quantization_format(linear_module) + if qformat == QUANTIZATION_NONE: + continue + + block_size = get_weight_block_size(linear_module) + if block_size == 0: + continue + + weight_quantizer = getattr(linear_module, "weight_quantizer", None) + input_quantizer = getattr(linear_module, "input_quantizer", None) + + metadata = QuantizationMetadata( + qformat=qformat, + weight_quantizer=weight_quantizer, + input_quantizer=input_quantizer, + module=linear_module, + vpp_idx=vpp_idx, + block_size=block_size, + ) + + # Build full parameter name + # Format: {base_name}.{linear_name}.weight + for param_name, _ in linear_module.named_parameters(recurse=False): + full_name = f"{base_name}.{linear_name}.{param_name}" + self.quant_metadata[full_name] = metadata + + def _log_initialization_info(self): + """Log initialization information for debugging.""" + print(f"[QAT PostProcessor] Initialized with quantization method: {self.quantization_method}") + print(f"[QAT PostProcessor] Found {len(self.quant_metadata)} quantized parameters") + + # Log sample parameters from layer 0 for debugging + for name, metadata in self.quant_metadata.items(): + if "layers.0" in name and "weight" in name: + print( + f"[QAT PostProcessor] Sample: {name}, qformat={metadata.qformat}, block_size={metadata.block_size}, module type: {type(metadata.module)}" + ) + + def _log_initialization_info(self): + """Log initialization information for debugging.""" + print(f"[QAT PostProcessor] Initialized with quantization method: {self.quantization_method}") + print(f"[QAT PostProcessor] Found {len(self.quant_metadata)} quantized parameters") + + # Log sample parameters from layer 0 for debugging (including MoE experts) + moe_expert_count = 0 + for name, metadata in self.quant_metadata.items(): + if "layers.0" in name and "weight" in name: + if "local_experts" in name: + moe_expert_count += 1 + if moe_expert_count <= 2: # Only log first 2 experts + print( + f"[QAT PostProcessor] MoE Expert Sample: {name}, qformat={metadata.qformat}, block_size={metadata.block_size}" + ) + elif "shared_experts" in name: + print( + f"[QAT PostProcessor] Shared Expert Sample: {name}, qformat={metadata.qformat}, block_size={metadata.block_size}" + ) + else: + print( + f"[QAT PostProcessor] Dense Sample: {name}, qformat={metadata.qformat}, block_size={metadata.block_size}, module type: {type(metadata.module)}" + ) + + if moe_expert_count > 0: + print(f"[QAT PostProcessor] Total MoE expert layers in layer 0: {moe_expert_count}") + + def _find_matching_metadata(self, param_name: str) -> QuantizationMetadata | None: + """ + Find matching quantization metadata for a parameter name. + Handles potential name variations between training and export. + """ + # Direct match + if param_name in self.quant_metadata: + return self.quant_metadata[param_name] + + # Try removing common prefixes/suffixes + variations = [ + param_name, + param_name.replace("module.", ""), + param_name.replace("model.", ""), + ] + + for var in variations: + if var in self.quant_metadata: + return self.quant_metadata[var] + + return None + + def _quantize_weight( + self, + name: str, + weight: torch.Tensor, + metadata: QuantizationMetadata, + ) -> Iterator[tuple[str, torch.Tensor]]: + """ + Quantize a single weight parameter. + + Args: + name: Parameter name + weight: The all_gathered bf16 weight tensor + metadata: Quantization metadata + + Yields: + (param_name, param_tensor) for quantized weight and scaling factors + """ + qformat = metadata.qformat + + if qformat == QUANTIZATION_NVFP4: + # print("[lark]: quantize_weight name:", name, "weight:", weight.shape, "metadata:", metadata) + yield from self._quantize_nvfp4(name, weight, metadata) + else: + # Unknown format, pass through with warning + print(f"[QAT PostProcessor] Warning: Unknown qformat {qformat} for {name}, passing through") + yield (name, weight) + + def _quantize_nvfp4( + self, + name: str, + weight: torch.Tensor, + metadata: QuantizationMetadata, + ) -> Iterator[tuple[str, torch.Tensor]]: + """ + NVFP4 quantization implementation. + + NVFP4 uses two-level scaling: + - weight_scale_2 (global): per-tensor scale = amax / (6.0 * 448.0) + - weight_scale (per-block): per-block scale in FP8 format + + The weight is packed into uint8 format (2 x FP4 values per byte). + + Yields: + (name, quantized_weight): Packed uint8 weight + (name + "_scale", weight_scale): Per-block FP8 scaling factors + (name + "_scale_2", weight_scale_2): Global scaling factor + (name + "_input_scale", input_scale): Input activation scale (if available) + """ + weight_quantizer = metadata.weight_quantizer + input_quantizer = metadata.input_quantizer + block_size = metadata.block_size + qformat = metadata.qformat + + # # Ensure weight is in float for quantization computation + # weight_float = weight.float() + + # Step 1: Compute weight_scale_2 (global scale) + # For TP sharding, we should recompute weight_scale_2 from merged weight + # to ensure consistent global scale across all TP ranks. + if self.use_calibrated_scale_2 and weight_quantizer is not None and hasattr(weight_quantizer, "_amax"): + # Use QAT calibrated amax (may only reflect local shard statistics) + # weight_scale_2 = amax / (6.0 * 448.0) + weight_scale_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) + else: + # Compute from all_gathered weight directly (recommended for TP) + # weight_scale_2 = max(abs(weight)) / (6.0 * 448.0) + weight_scale_2 = NVFP4QTensor.get_weights_scaling_factor_2(weight) + + # Step 2: Compute weight_scale (per-block scale) + # This MUST be computed on the all_gathered (merged) weight to ensure + # correct block boundaries + # weight_scale shape: [out_dim, in_dim / block_size], dtype: float8_e4m3fn + weight_scale = NVFP4QTensor.get_weights_scaling_factor( + weight, + block_size, + weights_scaling_factor_2=weight_scale_2.to(weight.device), + )[0] + + # Step 3: Quantize weight to NVFP4 packed format + quantized_weight = to_quantized_weight( + weight, + weight_scale, + qformat, + weight_scale_2, + block_size, + ) + + # Yield quantized weight + yield (name, quantized_weight) + + # Yield scaling factors + # Note: Use consistent naming convention with ModelOpt export + scale_name = name.replace(".weight", ".weight_scale") + if scale_name == name: + scale_name = name + "_scale" + yield (scale_name, weight_scale) + + scale_2_name = name.replace(".weight", ".weight_scale_2") + if scale_2_name == name: + scale_2_name = name + "_scale_2" + yield (scale_2_name, weight_scale_2) + + # Step 4: Export input_scale (activation quantization) if available + if input_quantizer is not None: + input_scale = self._get_input_scale(input_quantizer) + if input_scale is not None: + input_scale_name = name.replace(".weight", ".input_scale") + if input_scale_name == name: + input_scale_name = name + "_input_scale" + yield (input_scale_name, input_scale) + + def _get_input_scale(self, input_quantizer) -> torch.Tensor | None: + """ + Get input activation scaling factor from quantizer. + + Args: + input_quantizer: The input quantizer from the module + + Returns: + Input scaling factor tensor or None + """ + if input_quantizer is None: + return None + + if not hasattr(input_quantizer, "_amax"): + return None + + amax = input_quantizer._amax + if amax is None: + return None + + # For NVFP4, use the NVFP4QTensor method + if hasattr(NVFP4QTensor, "get_activation_scaling_factor"): + return NVFP4QTensor.get_activation_scaling_factor(input_quantizer) + + return amax.float() / (6.0 * 448.0) + + def process_weights_iterator( + self, + per_tensor_param: Iterator[tuple[str, torch.Tensor]], + ) -> Iterator[tuple[str, torch.Tensor]]: + """ + Process an iterator of weights and yield quantized results. + + This method wraps per_tensor_generator output and applies quantization + to each weight, yielding the quantized weights and scaling factors. + + Args: + per_tensor_param: Iterator of (name, bf16_weight) from per_tensor_generator + + Yields: + (name, tensor): Quantized weight and associated scaling factors + """ + for name, param in per_tensor_param: + # quantize_single_tensor returns a list of (name, tensor) tuples + # For NVFP4: [(name, quant_weight), (name_scale, scale), (name_scale_2, scale_2), ...] + # For non-quantized: [(name, original_weight)] + quantized_results = self.quantize_single_tensor(name, param) + for q_name, q_tensor in quantized_results: + yield (q_name, q_tensor) + + def quantize_single_tensor( + self, + name: str, + weight: torch.Tensor, + ) -> list[tuple[str, torch.Tensor]]: + """ + Quantize a single tensor and return all related tensors as a list. + + This method is designed to be called AFTER weight_converter.convert_param, + so the name should already be in HF format (e.g., 'model.layers.0.self_attn.q_proj.weight'). + + Args: + name: Parameter name in HF format + weight: Single tensor to quantize + + Returns: + List of (param_name, param_tensor) tuples: + - (name, quantized_weight) + - (name.replace('.weight', '.weight_scale'), weight_scale) # for NVFP4 + - (name.replace('.weight', '.weight_scale_2'), weight_scale_2) # for NVFP4 + """ + # Find matching metadata using the original mcore name pattern + # Since name is now in HF format, we need to check if this layer type should be quantized + metadata = self._find_matching_metadata_by_hf_name(name) + + if metadata is None: + # Not quantized, return original tensor + return [(name, weight)] + + # Quantize this tensor + return list(self._quantize_weight(name, weight, metadata)) + + def _find_matching_metadata_by_hf_name(self, hf_name: str) -> QuantizationMetadata | None: + """ + Find matching quantization metadata for an HF-format parameter name. + + This maps HF names back to the original mcore names to find metadata. + E.g., 'model.layers.0.self_attn.q_proj.weight' -> check if qkv layer is quantized + + The mapping logic: + - HF q_proj/k_proj/v_proj.weight -> mcore linear_qkv.weight + - HF o_proj.weight -> mcore linear_proj.weight + - HF gate_proj/up_proj.weight -> mcore linear_fc1.weight + - HF down_proj.weight -> mcore linear_fc2.weight + """ + import re + + # Only process weight parameters + if not hf_name.endswith(".weight") or hf_name.endswith("._amax") or "norm" in hf_name: + return None + + # Extract layer number from HF name + layer_match = re.search(r"layers?\.(\d+)\.", hf_name) + if not layer_match: + # Not a layer parameter (e.g., embed_tokens, lm_head, norm) + # Check for direct matches + return self._find_non_layer_metadata(hf_name) + + layer_num = layer_match.group(1) + + # Determine the mcore module name based on HF name pattern + mcore_patterns = [] + + if "self_attn" in hf_name: + if any(proj in hf_name for proj in ["q_proj", "k_proj", "v_proj"]): + mcore_patterns.append(f"decoder.layers.{layer_num}.self_attention.linear_qkv.weight") + elif "o_proj" in hf_name: + mcore_patterns.append(f"decoder.layers.{layer_num}.self_attention.linear_proj.weight") + elif "mlp" in hf_name: + if any(proj in hf_name for proj in ["gate_proj", "up_proj"]): + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc1.weight") + elif "down_proj" in hf_name: + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc2.weight") + + # Try to find matching metadata + for pattern in mcore_patterns: + if pattern in self.quant_metadata: + return self.quant_metadata[pattern] + + # If no exact match, try to find any metadata from the same layer + # This handles cases where the exact name might be slightly different + for mcore_name, metadata in self.quant_metadata.items(): + if f"layers.{layer_num}." in mcore_name: + # Found a quantized module in the same layer + # For QAT, if any module in the layer is quantized, all Linear layers should be + if ".weight" in mcore_name: + return metadata + + return None + + def _find_non_layer_metadata(self, hf_name: str) -> QuantizationMetadata | None: + """Find metadata for non-layer parameters (embed_tokens, lm_head, etc.).""" + # Map HF names to mcore names for non-layer parameters + name_mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "lm_head.weight": "output_layer.weight", + "model.norm.weight": "decoder.final_layernorm.weight", + } + + mcore_name = name_mapping.get(hf_name) + if mcore_name and mcore_name in self.quant_metadata: + return self.quant_metadata[mcore_name] + + return None + + def _find_matching_metadata_by_hf_name(self, hf_name: str) -> QuantizationMetadata | None: + """ + Find matching quantization metadata for an HF-format parameter name. + + This maps HF names back to the original mcore names to find metadata. + E.g., 'model.layers.0.self_attn.q_proj.weight' -> check if qkv layer is quantized + + The mapping logic: + - HF q_proj/k_proj/v_proj.weight -> mcore linear_qkv.weight + - HF o_proj.weight -> mcore linear_proj.weight + - HF gate_proj/up_proj.weight -> mcore linear_fc1.weight + - HF down_proj.weight -> mcore linear_fc2.weight + - HF experts.X.gate_proj/up_proj.weight -> mcore experts.local_experts.X.linear_fc1.weight + - HF experts.X.down_proj.weight -> mcore experts.local_experts.X.linear_fc2.weight + - HF shared_expert.gate_proj/up_proj.weight -> mcore shared_experts.linear_fc1.weight + - HF shared_expert.down_proj.weight -> mcore shared_experts.linear_fc2.weight + """ + import re + + # Only process weight parameters + if not hf_name.endswith(".weight") or hf_name.endswith("._amax") or "norm" in hf_name: + return None + + # Extract layer number from HF name + layer_match = re.search(r"layers?\.(\d+)\.", hf_name) + if not layer_match: + # Not a layer parameter (e.g., embed_tokens, lm_head, norm) + # Check for direct matches + return self._find_non_layer_metadata(hf_name) + + layer_num = layer_match.group(1) + + # Determine the mcore module name based on HF name pattern + mcore_patterns = [] + + if "self_attn" in hf_name: + if any(proj in hf_name for proj in ["q_proj", "k_proj", "v_proj"]): + mcore_patterns.append(f"decoder.layers.{layer_num}.self_attention.linear_qkv.weight") + elif "o_proj" in hf_name: + mcore_patterns.append(f"decoder.layers.{layer_num}.self_attention.linear_proj.weight") + elif "mlp" in hf_name: + # Check for MoE expert patterns first + expert_match = re.search(r"experts\.(\d+)\.", hf_name) + if expert_match: + expert_id = expert_match.group(1) + if any(proj in hf_name for proj in ["gate_proj", "up_proj"]): + # MoE expert gate_proj/up_proj -> local_experts.X.linear_fc1 + mcore_patterns.append( + f"decoder.layers.{layer_num}.mlp.experts.local_experts.{expert_id}.linear_fc1.weight" + ) + elif "down_proj" in hf_name: + # MoE expert down_proj -> local_experts.X.linear_fc2 + mcore_patterns.append( + f"decoder.layers.{layer_num}.mlp.experts.local_experts.{expert_id}.linear_fc2.weight" + ) + elif "shared_expert" in hf_name: + # Shared expert patterns (Qwen2Moe, DeepSeekV3, etc.) + if any(proj in hf_name for proj in ["gate_proj", "up_proj"]): + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.shared_experts.linear_fc1.weight") + elif "down_proj" in hf_name: + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.shared_experts.linear_fc2.weight") + elif "gate.weight" in hf_name: + # MoE router gate + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.router.weight") + elif any(proj in hf_name for proj in ["gate_proj", "up_proj"]): + # Dense MLP gate_proj/up_proj + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc1.weight") + elif "down_proj" in hf_name: + # Dense MLP down_proj + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc2.weight") + + # Try to find matching metadata + for pattern in mcore_patterns: + if pattern in self.quant_metadata: + return self.quant_metadata[pattern] + + # If no exact match, try to find any metadata from the same layer + # This handles cases where the exact name might be slightly different + for mcore_name, metadata in self.quant_metadata.items(): + if f"layers.{layer_num}." in mcore_name: + # For MoE, check if we're looking for expert weights + if "experts" in hf_name: + if "experts" in mcore_name and ".weight" in mcore_name: + return metadata + # For dense, check if any module in the layer is quantized + elif ".weight" in mcore_name: + return metadata + + return None \ No newline at end of file diff --git a/verl/utils/modelopt_vllm_utils.py b/verl/utils/modelopt_vllm_utils.py new file mode 100644 index 00000000000..417fc359365 --- /dev/null +++ b/verl/utils/modelopt_vllm_utils.py @@ -0,0 +1,823 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Callable, Optional +from unittest.mock import patch + +import torch + +logger = logging.getLogger(__name__) +from torch.nn import Parameter + + +NVFP4_BLOCK_QUANT_KWARGS = { + "config_groups": { + "group_0": { + "input_activations": { + "dynamic": "false", + "num_bits": 4, + "type": "float", + "group_size": 16 + }, + "weights": { + "dynamic": "false", + "num_bits": 4, + "type": "float", + "group_size": 16 + }, + "targets": [ + "Linear" + ] + } + }, + "ignore": [ + "lm_head" + ], + "quant_algo": "NVFP4", + "producer": { + "name": "modelopt", + "version": "0.40.0.dev89+g0ec5e200f.d20251127" + }, + "quant_method": "modelopt" +} + + + +def _create_param_from_subclass_attributes(custom_data: torch.Tensor, custom_weight) -> Parameter: + """ + Helper to preserve custom attributes from ModelWeightParameter and + PerTensorScaleParameter when creating new Parameters. + """ + param = Parameter(custom_data, requires_grad=False) + base_param_dir = dir(torch.nn.Parameter) + custom_weight_dir = dir(custom_weight) + # Find the attributes that are unique to the custom parameter + custom_attributes = [attr for attr in custom_weight_dir if attr not in base_param_dir and not attr.startswith("__")] + # Set the custom attributes into the base parameter object + for attr in custom_attributes: + setattr(param, attr, getattr(custom_weight, attr)) + return param + + +def process_weights_after_loading_modelopt(self, layer: torch.nn.Module) -> None: + if getattr(layer, "prefix", None) == "model.layers.27.mlp.gate_up_proj" or getattr(layer, "prefix", "").startswith( + "model.layers.27.self_attn" + ): + print( + f"##VLLM##: {getattr(layer, 'prefix', None)}: {layer.params_dtype} bias: {getattr(layer, 'bias', None)} {layer.weight.data[0, :4]}, scale: {layer.weight_scale.data[0, :4]}, scale_2: {layer.weight_scale_2.data[0]}" + ) + import vllm._custom_ops as ops + from torch.nn import Parameter + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_make_workspace_new, + marlin_permute_bias, + marlin_permute_scales, + ) + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + mxfp4_marlin_process_scales, + nvfp4_marlin_process_global_scale, + nvfp4_marlin_process_scales, + ) + from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale + + def _create_param_from_subclass_attributes(custom_data, custom_weight): + param = Parameter(custom_data, requires_grad=False) + base_param_dir = dir(torch.nn.Parameter) + custom_weight_dir = dir(custom_weight) + # Find the attributes that are unique to the custom parameter + custom_attributes = [ + attr for attr in custom_weight_dir if attr not in base_param_dir and not attr.startswith("__") + ] + # Set the custom attributes into the base parameter object + for attr in custom_attributes: + setattr(param, attr, getattr(custom_weight, attr)) + + return param + + def prepare_fp4_layer_for_marlin(layer: torch.nn.Module, weight_scale_2_max: torch.Tensor) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads." + ) + + is_nvfp4 = hasattr(layer, "weight_scale_2") + group_size = 16 if is_nvfp4 else 32 + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + if getattr(layer, "workspace", None) is None: + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4, + ) + layer.marlin_weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.contiguous() + + if not is_nvfp4: + weight_scale = weight_scale.view(torch.float8_e8m0fnu) + + weight_scale = weight_scale.to(param_dtype) + weight_scale = marlin_permute_scales( + s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size + ) + + if is_nvfp4: + weight_scale = nvfp4_marlin_process_scales(weight_scale) + layer.marlin_weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = weight_scale_2_max.to(param_dtype) + weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) + layer.marlin_weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False) + else: + weight_scale = mxfp4_marlin_process_scales(weight_scale) + layer.marlin_weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n,) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + + return + + # global scales: + input_scale_2 = layer.input_scale.data + layer.input_scale = _create_param_from_subclass_attributes(input_scale_2, layer.input_scale) + input_scale_2_max = input_scale_2.max().to(torch.float32) + + weight_scale_2 = layer.weight_scale_2.data + layer.weight_scale_2 = _create_param_from_subclass_attributes(weight_scale_2, layer.weight_scale_2) + weight_scale_2_max = weight_scale_2.max().to(torch.float32) + + layer.alpha = Parameter(input_scale_2_max * weight_scale_2_max, requires_grad=False) + + # Calculate `1 / input_scale` so that we don't need to do so at runtime + layer.input_scale_inv = Parameter((1 / layer.input_scale).to(torch.float32), requires_grad=False) + + # Swizzle the weight blockscale. + # contracting dimension is input dimension + # block_size = 16; + assert layer.weight_scale.dtype == torch.float8_e4m3fn, "Weight Block scale must be represented as FP8-E4M3" + + if self.backend == "marlin": + weight = layer.weight.data + weight_scale = layer.weight_scale.data + layer.weight = _create_param_from_subclass_attributes(weight, layer.weight) + layer.weight_scale = _create_param_from_subclass_attributes(weight_scale, layer.weight_scale) + prepare_fp4_layer_for_marlin(layer, weight_scale_2_max) + + if getattr(layer, "prefix", None) == "model.layers.27.mlp.gate_up_proj" or getattr( + layer, "prefix", "" + ).startswith("model.layers.27.self_attn"): + print( + f"##VLLM-MARLIN##: {getattr(layer, 'prefix', None)}: {layer.marlin_weight.data[0, :4]}, scale: {layer.marlin_weight_scale.data[0, :4]}, scale_2: {layer.marlin_weight_scale_2.data}" + ) + + del layer.alpha + # del layer.input_scale + elif self.backend == "flashinfer-trtllm": + # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. + # FlashInfer provides nvfp4_quantize to quantize + shuffle the + # layout but we use our own quantization so we have to call + # shuffles ourselves. + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a + + weight = layer.weight.data + weight_scale = layer.weight_scale.data + + epilogue_tile_m = 128 + weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) + weight_scale = ( + shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) + .reshape(weight_scale.shape) + .view(torch.float8_e4m3fn) + ) + + layer.weight_scale = _create_param_from_subclass_attributes(weight_scale, layer.weight_scale) + layer.weight = _create_param_from_subclass_attributes(weight, layer.weight) + else: + swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) + layer.weight_scale = _create_param_from_subclass_attributes(swizzled_weight_scale, layer.weight_scale) + layer.weight = _create_param_from_subclass_attributes(layer.weight.data, layer.weight) + +def apply_modelopt( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import apply_fp4_marlin_linear + from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm + + if self.backend == "marlin": + # if getattr(layer, "prefix", None) == "model.layers.27.mlp.gate_up_proj" or getattr(layer, "prefix", "").startswith("model.layers.27.self_attn"): + # print(f"##VLLM-MARLIN##: {getattr(layer, 'prefix', None)}: {layer.marlin_weight.data[0, :4]}, scale: {layer.marlin_weight_scale.data[0, :4]}, scale_2: {layer.marlin_weight_scale_2.data}") + return apply_fp4_marlin_linear( + input=x, + weight=layer.marlin_weight, + weight_scale=layer.marlin_weight_scale, + weight_scale_2=layer.marlin_weight_scale_2, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + + output_dtype = x.dtype + output_shape = [x.shape[0], layer.weight.shape[0]] + + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) + + # validate dtypes of quantized input, input block scale, + # weight and weight_blockscale + assert x_fp4.dtype == torch.uint8 + assert layer.weight.dtype == torch.uint8 + assert x_blockscale.dtype == torch.float8_e4m3fn + assert layer.weight_scale.dtype == torch.float8_e4m3fn + assert layer.alpha.dtype == torch.float32 + + mm_args = ( + x_fp4, + layer.weight, + x_blockscale, + layer.weight_scale, + layer.alpha, + output_dtype, + ) + if self.backend == "flashinfer-trtllm": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") + elif self.backend == "flashinfer-cutlass": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + else: + out = cutlass_scaled_fp4_mm(*mm_args) + + if bias is not None: + out = out + bias + return out.view(*output_shape) + + +# ============================================================================= +# ModelOptNvFp4FusedMoE Patches +# ============================================================================= + + +def process_weights_after_loading_moe(self, layer: torch.nn.Module) -> None: + """ + Patched process_weights_after_loading for ModelOptNvFp4FusedMoE. + + Key modifications compared to original: + 1. Preserves original weights in separate attributes (marlin_w13_weight, etc.) + 2. Uses _create_param_from_subclass_attributes to preserve parameter metadata + 3. Computes weight_scale_2_max before processing for Marlin + """ + import vllm._custom_ops as ops + from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( + prepare_static_weights_for_trtllm_fp4_moe, + reorder_w1w3_to_w3w1, + ) + from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + FlashinferMoeBackend, + is_flashinfer_supporting_global_sf, + ) + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_make_workspace_new, + marlin_permute_scales, + ) + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + nvfp4_marlin_process_global_scale, + nvfp4_marlin_process_scales, + ) + from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale + + def prepare_moe_fp4_layer_for_marlin_patched( + layer: torch.nn.Module, + w13_weight_scale_2_per_expert: torch.Tensor, + w2_weight_scale_2_per_expert: torch.Tensor, + ) -> None: + """ + Modified prepare_moe_fp4_layer_for_marlin that: + 1. Takes per-expert weight_scale_2 values (not max!) + 2. Saves to marlin_* attributes instead of overwriting originals + + Args: + w13_weight_scale_2_per_expert: shape (num_experts,) - per-expert scales + w2_weight_scale_2_per_expert: shape (num_experts,) - per-expert scales + """ + logger.warning("Using patched prepare_moe_fp4_layer_for_marlin for NVFP4 MoE") + + group_size = 16 # NVFP4 uses group_size=16 + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + device = layer.w13_weight.device + param_dtype = layer.params_dtype + + # WORKSPACE + if getattr(layer, "workspace", None) is None: + layer.workspace = marlin_make_workspace_new(device, 4) + + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT - Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2), ( + f"Weight shape mismatch for {name}: expected {(e, size_n, size_k // 2)}, got {weight.shape}" + ) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + tensor_list.append(marlin_qweight) + + marlin_weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + marlin_weight = Parameter(marlin_weight, requires_grad=False) + + # Save to marlin_* attribute instead of overwriting original + marlin_attr_name = "marlin_" + name + setattr(layer, marlin_attr_name, marlin_weight) + + # WEIGHT SCALES - Permute scales + for name, weight_scale_2_per_expert in [ + ("w13", w13_weight_scale_2_per_expert), + ("w2", w2_weight_scale_2_per_expert), + ]: + scales = getattr(layer, name + "_weight_scale") + scales = scales.to(param_dtype) + + # Convert per-expert global scale to param_dtype + global_scale = weight_scale_2_per_expert.to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + scale = scales[i].T + + marlin_scales = marlin_permute_scales( + s=scale, + size_k=size_k, + size_n=size_n, + group_size=group_size, + ) + marlin_scales = nvfp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + marlin_scales_combined = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + marlin_scales_combined = Parameter(marlin_scales_combined, requires_grad=False) + + # Save to marlin_* attribute + setattr(layer, "marlin_" + name + "_weight_scale", marlin_scales_combined) + + # Process per-expert global scale (shape: num_experts) + global_scale = nvfp4_marlin_process_global_scale(global_scale) + global_scale = Parameter(global_scale, requires_grad=False) + setattr(layer, "marlin_" + name + "_weight_scale_2", global_scale) + + # ========== Main processing logic ========== + + # GEMM 1 processing + gemm1_weight = layer.w13_weight.data + gemm1_weight_scale = layer.w13_weight_scale.data + + if ( + self.allow_flashinfer + and ( + self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ) + and self.moe.is_act_and_mul + ): + gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(gemm1_weight, gemm1_weight_scale, dim=-2) + + layer.w13_weight = _create_param_from_subclass_attributes(gemm1_weight, layer.w13_weight) + layer.w13_weight_scale = _create_param_from_subclass_attributes(gemm1_weight_scale, layer.w13_weight_scale) + + # Common processing for w13_weight_scale_2 + # IMPORTANT: Keep the original shape (num_experts, 2) for subsequent weight loading + # Only compute the max value for Marlin, but don't modify the original parameter shape + if self.moe.is_act_and_mul and not torch.allclose(layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]): + logger.warning("w1_weight_scale_2 must match w3_weight_scale_2. Accuracy may be affected.") + + # Keep original data and shape - DO NOT reduce dimension! + w13_weight_scale_2_data = layer.w13_weight_scale_2.data # Keep original shape: (num_experts, 2) + layer.w13_weight_scale_2 = _create_param_from_subclass_attributes(w13_weight_scale_2_data, layer.w13_weight_scale_2) + # Get per-expert scales (shape: num_experts) for Marlin - NOT the max! + # This is what the original code uses after reducing [:, 0] + w13_weight_scale_2_per_expert = layer.w13_weight_scale_2[:, 0].clone() + # Also keep a 1D version for g1_alphas calculation (following original logic) + w13_weight_scale_2_1d = layer.w13_weight_scale_2[:, 0] + + # Common processing for input scales and alphas + # IMPORTANT: Keep original input_scale shapes for subsequent weight loading + use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(self.flashinfer_moe_backend) + + # Keep original w13_input_scale data and shape + w13_input_scale_data = layer.w13_input_scale.data + layer.w13_input_scale = _create_param_from_subclass_attributes(w13_input_scale_data, layer.w13_input_scale) + + # Compute derived values for runtime use + if use_global_sf: + w13_input_scale_for_alpha = layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts) + else: + w13_input_scale_for_alpha = layer.w13_input_scale.max(dim=1).values.to(torch.float32) + + layer.g1_alphas = Parameter( + (w13_input_scale_for_alpha * w13_weight_scale_2_1d).to(torch.float32), + requires_grad=False, + ) + + # This is for quantization, so we need to invert it. + layer.w13_input_scale_quant = Parameter((1 / w13_input_scale_for_alpha).to(torch.float32), requires_grad=False) + + # GEMM 2 processing + # Keep original w2_weight_scale_2 data and shape + w2_weight_scale_2_data = layer.w2_weight_scale_2.data + layer.w2_weight_scale_2 = _create_param_from_subclass_attributes(w2_weight_scale_2_data, layer.w2_weight_scale_2) + # Get per-expert scales (shape: num_experts) for Marlin - NOT the max! + w2_weight_scale_2_per_expert = layer.w2_weight_scale_2.clone() + + # Keep original w2_input_scale data and shape + w2_input_scale_data = layer.w2_input_scale.data + layer.w2_input_scale = _create_param_from_subclass_attributes(w2_input_scale_data, layer.w2_input_scale) + + if use_global_sf: + w2_input_scale_for_alpha = layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts) + else: + w2_input_scale_for_alpha = layer.w2_input_scale + layer.g2_alphas = Parameter( + (w2_input_scale_for_alpha * layer.w2_weight_scale_2).to(torch.float32), + requires_grad=False, + ) + + # This is for quantization, so we need to invert it. + layer.w2_input_scale_quant = Parameter((1 / w2_input_scale_for_alpha).to(torch.float32), requires_grad=False) + + # ========== Backend-specific processing ========== + + if self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + # TensorRT-LLM specific processing + ( + gemm1_weights_fp4_shuffled, + gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, + gemm2_scales_fp4_shuffled, + ) = prepare_static_weights_for_trtllm_fp4_moe( + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + layer.w2_weight.size(-2), # hidden_size + layer.w13_weight.size(-2) // 2, # intermediate_size + layer.w13_weight.size(0), # num_experts + ) + logger.debug("Finished shuffling weights for TRT-LLM MOE") + + layer.gemm1_weights_fp4_shuffled = Parameter(gemm1_weights_fp4_shuffled, requires_grad=False) + layer.gemm2_weights_fp4_shuffled = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False) + layer.gemm1_scales_fp4_shuffled = Parameter(gemm1_scales_fp4_shuffled, requires_grad=False) + layer.gemm2_scales_fp4_shuffled = Parameter(gemm2_scales_fp4_shuffled, requires_grad=False) + + # Additional parameter needed for TRT-LLM + layer.g1_scale_c = Parameter( + (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), + requires_grad=False, + ) + + # Clean up weights that won't be used by TRT-LLM + del layer.w2_weight + del layer.w2_weight_scale + del layer.w13_weight + del layer.w13_weight_scale + + elif self.use_marlin: + # Marlin processing - use patched version + # Pass per-expert scales (shape: num_experts), NOT scalar max values! + prepare_moe_fp4_layer_for_marlin_patched(layer, w13_weight_scale_2_per_expert, w2_weight_scale_2_per_expert) + # Delete attributes not needed for Marlin + del layer.g1_alphas + del layer.g2_alphas + del layer.w13_input_scale_quant + del layer.w2_input_scale_quant + + else: + # Non-TRT-LLM processing (Cutlass or non-flashinfer) + w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) + layer.w13_weight_scale = Parameter(w13_blockscale_swizzled, requires_grad=False) + + w13_weight = layer.w13_weight + intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1) + if intermediate_size_pad: + # padding gated activations will require to split w1 and w3 + # and pad them individually + assert not self.moe.is_act_and_mul, ( + "The intermediate size required padding, but padding is not implemented for gated activations" + ) + + layer.w13_weight = Parameter( + torch.nn.functional.pad(w13_weight, (0, 0, 0, intermediate_size_pad)), + requires_grad=False, + ) + layer.w2_weight = Parameter( + torch.nn.functional.pad(layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)), + requires_grad=False, + ) + layer.w2_weight_scale = Parameter( + torch.nn.functional.pad(layer.w2_weight_scale, (0, intermediate_size_pad // 16)), + requires_grad=False, + ) + + w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) + layer.w2_weight_scale = Parameter(w2_blockscale_swizzled, requires_grad=False) + + +def apply_moe( + self, + layer, # FusedMoE + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Patched apply method for ModelOptNvFp4FusedMoE. + + Key modification for Marlin: Uses marlin_* attributes instead of originals. + """ + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe + from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( + flashinfer_trtllm_fp4_moe, + ) + from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + FlashinferMoeBackend, + ) + from vllm.scalar_type import scalar_types + + if not self.moe.is_act_and_mul: + assert self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS, ( + "Non-gated activations are only supported by the flashinfer CUTLASS backend for modelopt checkpoints" + ) + + if self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if enable_eplb: + raise NotImplementedError("EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") + return flashinfer_trtllm_fp4_moe( + layer=layer, + x=x, + router_logits=router_logits, + top_k=top_k, + global_num_experts=global_num_experts, + num_expert_group=num_expert_group, + topk_group=topk_group, + custom_routing_function=custom_routing_function, + e_score_correction_bias=e_score_correction_bias, + ) + + topk_weights, topk_ids, _ = layer.select_experts( + hidden_states=x, + router_logits=router_logits, + ) + + if self.use_marlin: + # Use marlin_* attributes instead of original attributes + return fused_marlin_moe( + x, + layer.marlin_w13_weight, + layer.marlin_w2_weight, + None, # bias1 + None, # bias2 + layer.marlin_w13_weight_scale, + layer.marlin_w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=scalar_types.float4_e2m1f.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + global_scale1=layer.marlin_w13_weight_scale_2, + global_scale2=layer.marlin_w2_weight_scale_2, + workspace=layer.workspace, + input_dtype=self.marlin_input_dtype, + ) + + elif self.allow_flashinfer: + assert self.flashinfer_moe_backend in ( + FlashinferMoeBackend.CUTLASS, + FlashinferMoeBackend.CUTEDSL, + ) + if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + flashinfer_cutlass_moe_fp4, + ) + + flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4 + else: + from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( + flashinfer_cutedsl_moe_fp4, + ) + + flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4 + + assert self.moe_quant_config is not None + return flashinfer_fn_moe_fp4( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + quant_config=self.moe_quant_config, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + # If no modular kernel is provided, use cutlass_moe_fp4 for TP case + # only (no EP). + from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 + + assert self.moe_quant_config is not None + return cutlass_moe_fp4( + a=x, + w1_fp4=layer.w13_weight, + w2_fp4=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + quant_config=self.moe_quant_config, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + ) + + +def process_weights_after_loading_kv(self, layer) -> None: + """Modified version of BaseKVCacheMethod.process_weights_after_loading. + + Doesn't delete k_scale, v_scale, q_scale, and prob_scale parameters to allow + for dynamic updates during refit. + """ + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + # No need to process kv scales after loading if we are going to + # calculate them on the fly. + from vllm.platforms import current_platform + + if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = 1.0 + v_scale = 1.0 + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError("Only support per-tensor scaling factor for fp8 KV cache") + + if layer.q_scale < 0.0: + layer._q_scale.copy_(k_scale) + layer._q_scale_float = k_scale + + # These are used in the final Attention.forward() + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale + + if layer.q_scale > 0.0: + q_scale = layer.q_scale + if current_platform.is_fp8_fnuz(): + q_scale *= 2 + layer.calculate_kv_scales = False + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale + if current_platform.is_fp8_fnuz(): + prob_scale *= 2 + else: + prob_scale = 1.0 + + is_singleton_float = ( + lambda x: isinstance(x, float) or isinstance(x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() + ) + if not is_singleton_float(q_scale) or not is_singleton_float(prob_scale): + raise ValueError("Only support per-tensor scaling factorfor fp8-quantized Q/prob") + + # These are used in the final Attention.forward() + layer._q_scale.copy_(q_scale) + layer._q_scale_float = q_scale.item() if isinstance(q_scale, torch.Tensor) else q_scale + + layer._prob_scale.copy_(prob_scale) + + +def apply_vllm_modelopt_patches(): + func1_path = ( + "vllm.model_executor.layers.quantization.modelopt.ModelOptNvFp4LinearMethod.process_weights_after_loading" + ) + patcher1 = patch(func1_path, process_weights_after_loading_modelopt) + patcher1.start() + func2_path = "vllm.model_executor.layers.quantization.modelopt.ModelOptNvFp4LinearMethod.apply" + patcher2 = patch(func2_path, apply_modelopt) + patcher2.start() + # Patch ModelOptNvFp4FusedMoE + func3_path = "vllm.model_executor.layers.quantization.modelopt.ModelOptNvFp4FusedMoE.process_weights_after_loading" + patcher3 = patch(func3_path, process_weights_after_loading_moe) + patcher3.start() + func4_path = "vllm.model_executor.layers.quantization.modelopt.ModelOptNvFp4FusedMoE.apply" + patcher4 = patch(func4_path, apply_moe) + patcher4.start() + # Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates + func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" + patcher5 = patch(func5_path, process_weights_after_loading_kv) + patcher5.start() \ No newline at end of file diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index afcc81bf142..86828322c81 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -102,7 +102,8 @@ class McoreEngineConfig(EngineConfig): override_transformer_config (dict[str, Any]): Override configuration for transformer. use_mbridge (bool): Whether to use MBridge for communication. dtype (str): Mixed precision training param dtype, default "bfloat16" - quantization (Optional[str]): Quantization method to use. None for no quantization, "nvfp4_qat" for QAT. + quantization (Optional[str]): Quantization method to use. None for no quantization, "nvfp4" for NVFP4 quantization. + enable_qat (bool): Whether to enable Quantization-Aware Training (QAT). Default False. """ # sequence_parallel is not listed as a frozen field for auto-correction purpose @@ -126,6 +127,7 @@ class McoreEngineConfig(EngineConfig): vanilla_mbridge: bool = True strategy: str = "megatron" quantization: Optional[str] = None + enable_qat: bool = False def __post_init__(self) -> None: super().__post_init__() diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 187beaea9ef..e319500a720 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -73,7 +73,7 @@ simple_timer, ) from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max -from verl.utils.qat_utils import QATConfig, apply_qat, is_qat_enabled +from verl.utils.modelopt_qat_utils import apply_qat from verl.utils.ray_utils import get_event_loop from verl.utils.torch_functional import use_original_torch_compile from verl.workers.actor.megatron_actor import MegatronPPOActor @@ -444,14 +444,13 @@ def _build_model_optimizer( print_model_size(actor_module[0]) log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) quantization = self.config.actor.megatron.get("quantization", None) - if quantization is not None: - if is_qat_enabled(quantization): - print(f"[lark]: Applying QAT with method: {quantization}") - qat_config = QATConfig(enabled=True, quant_method=quantization) - print("[lark]: length of actor_module:", len(actor_module)) - for i in range(len(actor_module)): - actor_module[i] = apply_qat(actor_module[i], qat_config) - print("[lark]: QAT applied to all actor model chunks") + enable_qat = self.config.actor.megatron.get("enable_qat", False) + if quantization is not None and enable_qat: + print(f"[lark]: Applying QAT with quantization: {quantization}") + print("[lark]: length of actor_module:", len(actor_module)) + for i in range(len(actor_module)): + actor_module[i] = apply_qat(actor_module[i], quantization) + print("[lark]: QAT applied to all actor model chunks") elif self._is_ref: wrap_config = McoreModuleWrapperConfig( @@ -728,28 +727,15 @@ async def rollout_mode(self): self.tf_config, self.layer_name_mapping, ) - if is_qat_enabled(self.config.actor.megatron.quantization): + if self.config.actor.megatron.get("enable_qat", False): print("[lark]: rollout mode: quantizing weights with QAT") - from verl.utils.qat_post_utils import QATWeightPostProcessor + from verl.utils.modelopt_qat_utils import QATWeightPostProcessor qat_weight_post_processor = QATWeightPostProcessor( - self.actor.actor_module, "nvfp4", self.dtype, use_calibrated_scale_2=True + self.actor.actor_module, "nvfp4" ) per_tensor_param = qat_weight_post_processor.process_weights_iterator(per_tensor_param) - # per_tensor_param = list(per_tensor_param) - # rank = torch.distributed.get_rank() - # state_dict = {} - # for name, weight in per_tensor_param: - # state_dict[name] = weight.data.cpu() - # path = f"/apps/quant_models/qwen3_8b_nvfp4/model_rank_{rank}.pt" - # torch.save(state_dict, path) - # del state_dict - # print(f"[lark]: saved state_dict to {path}") - - # import time - # time.sleep(1000) - if self.config.rollout.free_cache_engine: await self.rollout.resume(tags=["weights"]) diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index fc4ff5b915c..8f7d2173f27 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -29,7 +29,7 @@ from verl.utils.vllm import TensorLoRARequest, VLLMHijack from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights -from verl.utils.modelopt_utils import apply_vllm_modelopt_patches +from verl.utils.modelopt_vllm_utils import apply_vllm_modelopt_patches logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 08e2cb4de6a..e4a4416412a 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -225,7 +225,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non quantization = self.config.quantization if quantization is not None: - _SUPPORTED_QUANTIZATION = ["fp8", "torchao", "nvfp4_qat"] + _SUPPORTED_QUANTIZATION = ["fp8", "torchao", "nvfp4"] if quantization not in _SUPPORTED_QUANTIZATION: raise ValueError(f"Currently only support {_SUPPORTED_QUANTIZATION} quantization, got: {quantization}") @@ -242,39 +242,10 @@ async def launch_server(self, master_address: str = None, master_port: int = Non apply_vllm_fp8_patches() # for subprocesses patching os.environ["VERL_VLLM_FP8_QUANT_ENABLED"] = "1" - elif quantization == "nvfp4_qat": - print("[lark]: vllm quantization is nvfp4_qat") - fp4_block_quant_kwargs = { - "config_groups": { - "group_0": { - "input_activations": { - "dynamic": "false", - "num_bits": 4, - "type": "float", - "group_size": 16 - }, - "weights": { - "dynamic": "false", - "num_bits": 4, - "type": "float", - "group_size": 16 - }, - "targets": [ - "Linear" - ] - } - }, - "ignore": [ - "lm_head" - ], - "quant_algo": "NVFP4", - "producer": { - "name": "modelopt", - "version": "0.40.0.dev89+g0ec5e200f.d20251127" - }, - "quant_method": "modelopt" - } - from verl.utils.modelopt_utils import apply_vllm_modelopt_patches + elif quantization == "nvfp4": + print("[lark]: vllm quantization is nvfp4") + from verl.utils.modelopt_vllm_utils import apply_vllm_modelopt_patches, NVFP4_BLOCK_QUANT_KWARGS + fp4_block_quant_kwargs = dict(NVFP4_BLOCK_QUANT_KWARGS) apply_vllm_modelopt_patches() os.environ["VERL_VLLM_NVFP4_QUANT_ENABLED"] = "1" @@ -284,7 +255,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non if quantization == "fp8": hf_overrides["quantization_config"] = fp8_block_quant_kwargs - elif quantization == "nvfp4_qat": + elif quantization == "nvfp4": hf_overrides["quantization_config"] = fp4_block_quant_kwargs quantization = "modelopt" From 3dbb2d2ac099402eedcf424949de4a9c5829c3a9 Mon Sep 17 00:00:00 2001 From: larkz-nv Date: Thu, 12 Feb 2026 06:47:30 -0800 Subject: [PATCH 03/10] update MoE model logic --- verl/utils/modelopt_qat_utils.py | 822 +++++++++++++++++++++--------- verl/utils/modelopt_vllm_utils.py | 48 ++ verl/workers/megatron_workers.py | 15 +- 3 files changed, 634 insertions(+), 251 deletions(-) diff --git a/verl/utils/modelopt_qat_utils.py b/verl/utils/modelopt_qat_utils.py index ed21398a5fc..45b20e3f3da 100644 --- a/verl/utils/modelopt_qat_utils.py +++ b/verl/utils/modelopt_qat_utils.py @@ -13,16 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""ModelOpt Quantization-Aware Training (QAT) utilities for Megatron models. - -Includes: -- QAT application via ModelOpt (apply_qat) -- QAT weight post-processing for exporting quantized weights to vLLM rollout (QATWeightPostProcessor) -""" import re from dataclasses import dataclass -from typing import Any, Iterator +from typing import Any, Iterator, Optional import torch import torch.nn as nn @@ -93,10 +87,6 @@ def apply_qat(model: nn.Module, quant_method: str): return model -# --------------------------------------------------------------------------- -# QAT weight post-processing (for exporting quantized weights to rollout) -# --------------------------------------------------------------------------- - @dataclass class QuantizationMetadata: """Metadata for a quantized module.""" @@ -107,6 +97,12 @@ class QuantizationMetadata: module: torch.nn.Module vpp_idx: int block_size: int = 16 # Default NVFP4 block size + # Fields for EP synchronization - store amax values for non-local experts + weight_amax: Optional[torch.Tensor] = None + input_amax: Optional[torch.Tensor] = None + is_local: bool = True # Whether this expert is local to current EP rank + global_expert_idx: Optional[int] = None # Global expert index for MoE experts + local_expert_idx: Optional[int] = None # Local expert index on this EP rank class QATWeightPostProcessor: @@ -126,6 +122,10 @@ class QATWeightPostProcessor: - If use_calibrated_scale_2=True (default), we use the QAT calibrated amax which may only reflect the local shard's statistics. - If use_calibrated_scale_2=False, we recompute weight_scale_2 from the merged weight. + Note on EP (Expert Parallelism): + - When EP is enabled, each rank only holds a subset of experts (local_experts) + - We synchronize metadata across all EP ranks to ensure complete metadata for all experts + - Local expert indices are converted to global expert indices for proper mapping """ def __init__( @@ -133,7 +133,7 @@ def __init__( actor_module: list, quantization_method: str = "nvfp4", dtype: torch.dtype = torch.bfloat16, - use_calibrated_scale_2: bool = True, + use_calibrated_scale_2: bool = False, ): """ Initialize the QAT weight post-processor. @@ -151,36 +151,202 @@ def __init__( self.dtype = dtype self.use_calibrated_scale_2 = use_calibrated_scale_2 self.quant_metadata: dict[str, QuantizationMetadata] = {} + self.ep_size, self.ep_rank, self.ep_group = self._get_ep_info() + self.pp_size, self.pp_rank, self.pp_group = self._get_pp_info() + self.num_local_experts = 0 # Will be determined during metadata building self._build_quantization_metadata() + + global_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + # print(f"[QAT PostProcessor][Rank {global_rank}] After _build_quantization_metadata: " + # f"metadata_count={len(self.quant_metadata)}, ep_size={self.ep_size}, pp_size={self.pp_size}") + + # Synchronize metadata across EP ranks if EP is enabled + if self.ep_size > 1: + print(f"[QAT PostProcessor][Rank {global_rank}] Starting EP metadata sync...") + self._sync_quantization_metadata_across_ep() + print(f"[QAT PostProcessor][Rank {global_rank}] After EP sync: metadata_count={len(self.quant_metadata)}") + + # Synchronize metadata across PP ranks if PP is enabled + # This ensures all PP ranks have complete metadata for all layers + if self.pp_size > 1: + print(f"[QAT PostProcessor][Rank {global_rank}] Starting PP metadata sync...") + self._sync_quantization_metadata_across_pp() + print(f"[QAT PostProcessor][Rank {global_rank}] After PP sync: metadata_count={len(self.quant_metadata)}") + else: + print(f"[QAT PostProcessor][Rank {global_rank}] PP sync skipped: pp_size={self.pp_size}") + self._log_initialization_info() + def _get_ep_info(self) -> tuple[int, int, Any]: + """ + Get Expert Parallel information from Megatron parallel state. + + Returns: + (ep_size, ep_rank, ep_group): EP world size, rank, and process group + """ + try: + from megatron.core import parallel_state as mpu + + ep_size = mpu.get_expert_model_parallel_world_size() + if ep_size > 1: + ep_rank = mpu.get_expert_model_parallel_rank() + ep_group = mpu.get_expert_model_parallel_group() + return ep_size, ep_rank, ep_group + except Exception: + # EP not enabled or mpu not available + pass + return 1, 0, None + + def _get_pp_info(self) -> tuple[int, int, Any]: + """ + Get Pipeline Parallel information from Megatron parallel state. + + Returns: + (pp_size, pp_rank, pp_group): PP world size, rank, and process group + """ + try: + from megatron.core import parallel_state as mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_group = mpu.get_pipeline_model_parallel_group() + + if torch.distributed.get_rank() == 0: + print(f"[QAT PostProcessor] PP info: pp_size={pp_size}, pp_rank={pp_rank}, pp_group={pp_group}") + + if pp_size > 1: + return pp_size, pp_rank, pp_group + else: + return pp_size, pp_rank, None + except Exception as e: + if torch.distributed.get_rank() == 0: + print(f"[QAT PostProcessor] Warning: Failed to get PP info: {e}") + pass + return 1, 0, None + + def _extract_layer_index(self, name: str) -> Optional[int]: + """ + Extract layer index from parameter name. + + For mcore format: decoder.layers.{layer_idx}.xxx + + Returns: + Layer index or None if not a layer parameter + """ + match = re.search(r"layers\.(\d+)\.", name) + if match: + return int(match.group(1)) + return None + + def _get_num_layers_per_pp_stage(self) -> int: + """ + Get the number of layers per PP stage from local metadata. + + This is calculated as max(local_layer_indices) + 1 + """ + max_layer_idx = -1 + for name in self.quant_metadata.keys(): + layer_idx = self._extract_layer_index(name) + if layer_idx is not None and layer_idx > max_layer_idx: + max_layer_idx = layer_idx + return max_layer_idx + 1 if max_layer_idx >= 0 else 0 + + def _convert_local_to_global_layer_name(self, name: str, source_pp_rank: int, num_layers_per_stage: int) -> str: + """ + Convert parameter name from local layer index to global layer index. + + Args: + name: Parameter name with local layer index (e.g., decoder.layers.0.xxx) + source_pp_rank: The PP rank this name came from + num_layers_per_stage: Number of layers per PP stage + + Returns: + Parameter name with global layer index + """ + local_layer_idx = self._extract_layer_index(name) + if local_layer_idx is None: + return name + + global_layer_idx = source_pp_rank * num_layers_per_stage + local_layer_idx + return re.sub(r"layers\.(\d+)\.", f"layers.{global_layer_idx}.", name, count=1) + + def _extract_local_expert_index(self, name: str) -> Optional[int]: + """ + Extract local expert index from parameter name. + + For SequentialMLP structure, the pattern is: + decoder.layers.{layer}.mlp.experts.local_experts.{local_idx}.linear_fc1/fc2.weight + + Args: + name: Parameter name in mcore format + + Returns: + Local expert index or None if not an expert parameter + """ + match = re.search(r"local_experts\.(\d+)\.", name) + if match: + return int(match.group(1)) + return None + + def _local_to_global_expert_index(self, local_idx: int) -> int: + """ + Convert local expert index to global expert index. + + Global index = ep_rank * num_local_experts + local_idx + + Args: + local_idx: Local expert index on this EP rank + + Returns: + Global expert index + """ + return self.ep_rank * self.num_local_experts + local_idx + + def _convert_name_to_global_index(self, name: str, local_idx: int, global_idx: int) -> str: + """ + Convert parameter name from local to global expert index. + + Args: + name: Original parameter name with local index + local_idx: Local expert index + global_idx: Global expert index + + Returns: + Parameter name with global expert index + """ + return name.replace(f"local_experts.{local_idx}.", f"local_experts.{global_idx}.") + def _build_quantization_metadata(self): """ Extract quantization metadata from all modules in actor_module. Stores: {param_name: QuantizationMetadata} - Supports both dense and MoE (Mixture of Experts) models: - - Dense: decoder.layers.X.mlp.linear_fc1, decoder.layers.X.mlp.linear_fc2 - - MoE SequentialMLP: decoder.layers.X.mlp.experts.local_experts.Y.linear_fc1/fc2 - - MoE shared_experts: decoder.layers.X.mlp.shared_experts.linear_fc1/fc2 - - MoE TEGroupedMLP: decoder.layers.X.mlp.experts.linear_fc1/fc2 (grouped) + For EP training with SequentialMLP: + - Detects local expert indices and computes global indices + - Stores metadata with global expert indices as keys """ + # First pass: collect all local expert indices to determine num_local_experts + local_expert_indices = set() + for vpp_idx, module in enumerate(self.actor_module): model = unwrap_model(module) - for name, submodule in model.named_modules(): - # Handle MoE SequentialMLP - need to iterate over local_experts - if self._is_sequential_mlp(submodule): - self._build_moe_sequential_mlp_metadata(name, submodule, vpp_idx) - continue + local_idx = self._extract_local_expert_index(name) + if local_idx is not None: + local_expert_indices.add(local_idx) - # Handle MoE TEGroupedMLP - grouped experts with linear_fc1/fc2 - if self._is_te_grouped_mlp(submodule): - self._build_moe_te_grouped_mlp_metadata(name, submodule, vpp_idx) - continue + if local_expert_indices: + self.num_local_experts = max(local_expert_indices) + 1 + if torch.distributed.get_rank() == 0: + print(f"[QAT PostProcessor] Detected {self.num_local_experts} local experts per EP rank") - # Handle regular quantized modules (dense layers, shared_experts, etc.) + # Second pass: build metadata with global indices + for vpp_idx, module in enumerate(self.actor_module): + model = unwrap_model(module) + + for name, submodule in model.named_modules(): + # Check if this module is quantized qformat = get_quantization_format(submodule) if qformat == QUANTIZATION_NONE: continue @@ -192,6 +358,20 @@ def _build_quantization_metadata(self): weight_quantizer = getattr(submodule, "weight_quantizer", None) input_quantizer = getattr(submodule, "input_quantizer", None) + # Extract amax values for synchronization + weight_amax = None + input_amax = None + if weight_quantizer is not None and hasattr(weight_quantizer, "_amax"): + weight_amax = weight_quantizer._amax.clone().cpu() if weight_quantizer._amax is not None else None + if input_quantizer is not None and hasattr(input_quantizer, "_amax"): + input_amax = input_quantizer._amax.clone().cpu() if input_quantizer._amax is not None else None + + # Determine global expert index for MoE experts + local_expert_idx = self._extract_local_expert_index(name) + global_expert_idx = None + if local_expert_idx is not None and self.ep_size > 1: + global_expert_idx = self._local_to_global_expert_index(local_expert_idx) + metadata = QuantizationMetadata( qformat=qformat, weight_quantizer=weight_quantizer, @@ -199,164 +379,324 @@ def _build_quantization_metadata(self): module=submodule, vpp_idx=vpp_idx, block_size=block_size, + weight_amax=weight_amax, + input_amax=input_amax, + is_local=True, + global_expert_idx=global_expert_idx, + local_expert_idx=local_expert_idx, ) for param_name, _ in submodule.named_parameters(recurse=False): full_name = f"{name}.{param_name}" if name else param_name - self.quant_metadata[full_name] = metadata - def _is_sequential_mlp(self, module: torch.nn.Module) -> bool: - """Check if module is a MoE SequentialMLP.""" - module_type_name = type(module).__name__ - return "SequentialMLP" in module_type_name and hasattr(module, "local_experts") + # For EP training, store with global expert index as key + if local_expert_idx is not None and self.ep_size > 1: + global_name = self._convert_name_to_global_index(full_name, local_expert_idx, global_expert_idx) + self.quant_metadata[global_name] = metadata + else: + self.quant_metadata[full_name] = metadata - def _is_te_grouped_mlp(self, module: torch.nn.Module) -> bool: - """Check if module is a MoE TEGroupedMLP (Transformer Engine Grouped MLP).""" - module_type_name = type(module).__name__ - return "TEGroupedMLP" in module_type_name or "GroupedMLP" in module_type_name + def _log_initialization_info(self): + """Log initialization information for debugging.""" + global_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - def _build_moe_sequential_mlp_metadata( - self, - base_name: str, - sequential_mlp: torch.nn.Module, - vpp_idx: int, - ): + print( + f"[QAT PostProcessor][Rank {global_rank}] Initialized with quantization method: {self.quantization_method}" + ) + print(f"[QAT PostProcessor][Rank {global_rank}] Found {len(self.quant_metadata)} quantized parameters") + if self.ep_size > 1: + print( + f"[QAT PostProcessor][Rank {global_rank}] EP enabled: ep_size={self.ep_size}, ep_rank={self.ep_rank}, " + f"num_local_experts={self.num_local_experts}" + ) + if self.pp_size > 1: + local_count = sum(1 for m in self.quant_metadata.values() if m.is_local) + remote_count = sum(1 for m in self.quant_metadata.values() if not m.is_local) + print( + f"[QAT PostProcessor][Rank {global_rank}] PP enabled: pp_size={self.pp_size}, pp_rank={self.pp_rank}, " + f"local_params={local_count}, remote_params={remote_count}" + ) + + # Log all metadata entries for debugging + for name, metadata in self.quant_metadata.items(): + extra_info = "" + if metadata.global_expert_idx is not None: + extra_info = f", global_expert_idx={metadata.global_expert_idx}" + if not metadata.is_local: + extra_info += ", is_local=False" + print( + f"[QAT PostProcessor][Rank {global_rank}] Metadata: {name}, qformat={metadata.qformat}, " + f"block_size={metadata.block_size}{extra_info}" + ) + + def _sync_quantization_metadata_across_ep(self): """ - Build quantization metadata for MoE SequentialMLP. + Synchronize quantization metadata across all EP (Expert Parallel) ranks. - SequentialMLP structure: - - local_experts: list of MLP experts, each with linear_fc1 and linear_fc2 - - Each expert's linear layers may have quantizers attached + When EP is enabled, each rank only holds metadata for its local experts. + This method gathers metadata from all EP ranks and merges them so that + every rank has complete metadata for all experts. - Args: - base_name: Base module name (e.g., 'decoder.layers.0.mlp.experts') - sequential_mlp: The SequentialMLP module - vpp_idx: Virtual pipeline parallel index + For SequentialMLP structure: + - Local expert indices are converted to global indices + - Metadata is gathered and merged using global indices as keys + - Non-local experts have is_local=False and module/quantizers set to None """ - if not hasattr(sequential_mlp, "local_experts"): + if self.ep_size <= 1 or self.ep_group is None: return - for expert_idx, expert in enumerate(sequential_mlp.local_experts): - # Process linear_fc1 and linear_fc2 for each expert - for linear_name in ["linear_fc1", "linear_fc2"]: - linear_module = getattr(expert, linear_name, None) - if linear_module is None: - continue + # Prepare serializable metadata info for all_gather + # We can't send module/quantizer objects, so we extract necessary info + local_metadata_info = {} + for name, metadata in self.quant_metadata.items(): + # Only sync MoE expert metadata (containing "local_experts") + if "local_experts" not in name: + continue - qformat = get_quantization_format(linear_module) - if qformat == QUANTIZATION_NONE: - continue + local_metadata_info[name] = { + "qformat": metadata.qformat, + "block_size": metadata.block_size, + "vpp_idx": metadata.vpp_idx, + "weight_amax": metadata.weight_amax, + "input_amax": metadata.input_amax, + "global_expert_idx": metadata.global_expert_idx, + "local_expert_idx": metadata.local_expert_idx, + } + + # Also send num_local_experts for validation + sync_data = { + "metadata": local_metadata_info, + "num_local_experts": self.num_local_experts, + "ep_rank": self.ep_rank, + } - block_size = get_weight_block_size(linear_module) - if block_size == 0: - continue + # Gather metadata from all EP ranks + all_sync_data = [None] * self.ep_size + torch.distributed.all_gather_object(all_sync_data, sync_data, group=self.ep_group) + + # Validate that all ranks have the same num_local_experts + for rank_idx, data in enumerate(all_sync_data): + if data is not None and data["num_local_experts"] != self.num_local_experts: + print( + f"[QAT PostProcessor] Warning: EP rank {rank_idx} has " + f"{data['num_local_experts']} local experts, expected {self.num_local_experts}" + ) + + # Merge metadata from all ranks + for rank_idx, data in enumerate(all_sync_data): + if rank_idx == self.ep_rank: + # Skip local metadata (already have it) + continue + + if data is None: + continue - weight_quantizer = getattr(linear_module, "weight_quantizer", None) - input_quantizer = getattr(linear_module, "input_quantizer", None) + rank_metadata = data["metadata"] + for name, info in rank_metadata.items(): + if name in self.quant_metadata: + # Already have this metadata (shouldn't happen with proper global indices) + continue + # Create metadata entry for non-local experts + # Note: module and quantizers are not available for non-local experts metadata = QuantizationMetadata( - qformat=qformat, - weight_quantizer=weight_quantizer, - input_quantizer=input_quantizer, - module=linear_module, - vpp_idx=vpp_idx, - block_size=block_size, + qformat=info["qformat"], + weight_quantizer=None, # Not available for non-local + input_quantizer=None, # Not available for non-local + module=None, # Not available for non-local + vpp_idx=info["vpp_idx"], + block_size=info["block_size"], + weight_amax=info["weight_amax"], + input_amax=info["input_amax"], + is_local=False, # Mark as non-local + global_expert_idx=info["global_expert_idx"], + local_expert_idx=info["local_expert_idx"], ) + self.quant_metadata[name] = metadata - # Build full parameter name - # Format: {base_name}.local_experts.{expert_idx}.{linear_name}.weight - for param_name, _ in linear_module.named_parameters(recurse=False): - full_name = f"{base_name}.local_experts.{expert_idx}.{linear_name}.{param_name}" - self.quant_metadata[full_name] = metadata + # Count local vs non-local experts + num_local = sum(1 for m in self.quant_metadata.values() if m.is_local and m.global_expert_idx is not None) + num_remote = sum(1 for m in self.quant_metadata.values() if not m.is_local and m.global_expert_idx is not None) - def _build_moe_te_grouped_mlp_metadata( - self, - base_name: str, - te_grouped_mlp: torch.nn.Module, - vpp_idx: int, - ): + if torch.distributed.get_rank() == 0: + print( + f"[QAT PostProcessor] EP metadata sync complete. " + f"EP size: {self.ep_size}, Local expert params: {num_local}, " + f"Remote expert params: {num_remote}, Total metadata entries: {len(self.quant_metadata)}" + ) + + def _sync_quantization_metadata_across_pp(self): """ - Build quantization metadata for MoE TEGroupedMLP. + Synchronize quantization metadata across all PP (Pipeline Parallel) ranks. - TEGroupedMLP structure (Transformer Engine): - - linear_fc1: grouped linear layer for all experts - - linear_fc2: grouped linear layer for all experts - - Weights are stored as 3D tensors [num_experts, out_dim, in_dim] + When PP is enabled, each rank only holds layers for its pipeline stage. + This method gathers metadata from all PP ranks and merges them so that + every rank has complete metadata for all layers. - Args: - base_name: Base module name (e.g., 'decoder.layers.0.mlp.experts') - te_grouped_mlp: The TEGroupedMLP module - vpp_idx: Virtual pipeline parallel index + IMPORTANT: In Megatron's PP mode, each PP rank uses LOCAL layer indices + (starting from 0), not global layer indices. For example: + - PP rank 0 has decoder.layers.0 (globally layer 0) + - PP rank 1 has decoder.layers.0 (globally layer 1) + + This method converts local layer indices to global layer indices during sync. + + For MoE SequentialMLP structure with PP: + - Different PP ranks hold different decoder layers + - Each PP rank builds metadata only for its local layers + - We gather and merge metadata from all PP ranks + - Layer indices are converted from local to global during merge + - Non-local layers have is_local=False and module/quantizers set to None """ - for linear_name in ["linear_fc1", "linear_fc2"]: - linear_module = getattr(te_grouped_mlp, linear_name, None) - if linear_module is None: - continue + global_rank = torch.distributed.get_rank() - qformat = get_quantization_format(linear_module) - if qformat == QUANTIZATION_NONE: - continue + print( + f"[QAT PostProcessor][Rank {global_rank}] PP sync starting: " + f"pp_size={self.pp_size}, pp_rank={self.pp_rank}, pp_group={self.pp_group}, " + f"local_metadata_count={len(self.quant_metadata)}" + ) - block_size = get_weight_block_size(linear_module) - if block_size == 0: - continue + if self.pp_size <= 1: + print(f"[QAT PostProcessor][Rank {global_rank}] PP sync skipped: pp_size <= 1") + return - weight_quantizer = getattr(linear_module, "weight_quantizer", None) - input_quantizer = getattr(linear_module, "input_quantizer", None) + if self.pp_group is None: + print(f"[QAT PostProcessor][Rank {global_rank}] PP sync skipped: pp_group is None") + return - metadata = QuantizationMetadata( - qformat=qformat, - weight_quantizer=weight_quantizer, - input_quantizer=input_quantizer, - module=linear_module, - vpp_idx=vpp_idx, - block_size=block_size, - ) + # Verify PP group size matches expected pp_size + actual_pp_group_size = torch.distributed.get_world_size(group=self.pp_group) + print( + f"[QAT PostProcessor][Rank {global_rank}] PP group size verification: " + f"expected={self.pp_size}, actual={actual_pp_group_size}" + ) - # Build full parameter name - # Format: {base_name}.{linear_name}.weight - for param_name, _ in linear_module.named_parameters(recurse=False): - full_name = f"{base_name}.{linear_name}.{param_name}" - self.quant_metadata[full_name] = metadata + # Calculate number of layers per PP stage (needed for global layer index conversion) + num_layers_per_stage = self._get_num_layers_per_pp_stage() + print(f"[QAT PostProcessor][Rank {global_rank}] Detected {num_layers_per_stage} layers per PP stage") - def _log_initialization_info(self): - """Log initialization information for debugging.""" - print(f"[QAT PostProcessor] Initialized with quantization method: {self.quantization_method}") - print(f"[QAT PostProcessor] Found {len(self.quant_metadata)} quantized parameters") + # First, convert our local metadata to use global layer indices + # This is needed so we can properly merge with other PP ranks + local_metadata_with_global_indices = {} + for name, metadata in self.quant_metadata.items(): + global_name = self._convert_local_to_global_layer_name(name, self.pp_rank, num_layers_per_stage) + local_metadata_with_global_indices[global_name] = metadata - # Log sample parameters from layer 0 for debugging + # Update our metadata dict to use global layer indices + self.quant_metadata = local_metadata_with_global_indices + + # Prepare serializable metadata info for all_gather + # We can't send module/quantizer objects, so we extract necessary info + local_metadata_info = {} for name, metadata in self.quant_metadata.items(): - if "layers.0" in name and "weight" in name: + local_metadata_info[name] = { + "qformat": metadata.qformat, + "block_size": metadata.block_size, + "vpp_idx": metadata.vpp_idx, + "weight_amax": metadata.weight_amax, + "input_amax": metadata.input_amax, + "global_expert_idx": metadata.global_expert_idx, + "local_expert_idx": metadata.local_expert_idx, + "is_local": metadata.is_local, + } + + # Include PP rank info and num_layers_per_stage for global index conversion + sync_data = { + "metadata": local_metadata_info, + "pp_rank": self.pp_rank, + "num_local_experts": self.num_local_experts, + "num_layers_per_stage": num_layers_per_stage, + "global_rank": global_rank, + } + + print( + f"[QAT PostProcessor][Rank {global_rank}] Preparing to sync {len(local_metadata_info)} metadata entries, " + f"sample keys (global indices): {list(local_metadata_info.keys())[:3]}" + ) + + # Gather metadata from all PP ranks + all_sync_data = [None] * actual_pp_group_size + torch.distributed.all_gather_object(all_sync_data, sync_data, group=self.pp_group) + + # Debug: print what we received + print(f"[QAT PostProcessor][Rank {global_rank}] Received data from {len(all_sync_data)} PP ranks") + for i, data in enumerate(all_sync_data): + if data is not None: + sample_keys = list(data.get("metadata", {}).keys())[:2] print( - f"[QAT PostProcessor] Sample: {name}, qformat={metadata.qformat}, block_size={metadata.block_size}, module type: {type(metadata.module)}" + f"[QAT PostProcessor][Rank {global_rank}] PP rank {i}: " + f"received from global_rank={data.get('global_rank', 'unknown')}, " + f"pp_rank={data.get('pp_rank', 'unknown')}, " + f"metadata_count={len(data.get('metadata', {}))}, " + f"sample_keys={sample_keys}" ) - def _log_initialization_info(self): - """Log initialization information for debugging.""" - print(f"[QAT PostProcessor] Initialized with quantization method: {self.quantization_method}") - print(f"[QAT PostProcessor] Found {len(self.quant_metadata)} quantized parameters") + # Merge metadata from all PP ranks + local_metadata_before = len(self.quant_metadata) + for rank_idx, data in enumerate(all_sync_data): + if data is None: + print(f"[QAT PostProcessor][Rank {global_rank}] Skipping rank_idx={rank_idx}: data is None") + continue - # Log sample parameters from layer 0 for debugging (including MoE experts) - moe_expert_count = 0 - for name, metadata in self.quant_metadata.items(): - if "layers.0" in name and "weight" in name: - if "local_experts" in name: - moe_expert_count += 1 - if moe_expert_count <= 2: # Only log first 2 experts - print( - f"[QAT PostProcessor] MoE Expert Sample: {name}, qformat={metadata.qformat}, block_size={metadata.block_size}" - ) - elif "shared_experts" in name: - print( - f"[QAT PostProcessor] Shared Expert Sample: {name}, qformat={metadata.qformat}, block_size={metadata.block_size}" - ) - else: - print( - f"[QAT PostProcessor] Dense Sample: {name}, qformat={metadata.qformat}, block_size={metadata.block_size}, module type: {type(metadata.module)}" - ) + source_pp_rank = data.get("pp_rank") + + # Skip our own data - compare by pp_rank from the data, not by index + if source_pp_rank == self.pp_rank: + print( + f"[QAT PostProcessor][Rank {global_rank}] Skipping rank_idx={rank_idx}: same pp_rank={self.pp_rank}" + ) + continue + + rank_metadata = data["metadata"] + added_count = 0 + skipped_existing = 0 + + for name, info in rank_metadata.items(): + # The name already has global layer indices (converted by the sender) + if name in self.quant_metadata: + # Already have this metadata (shouldn't happen with correct global indices) + existing = self.quant_metadata[name] + if existing.is_local: + skipped_existing += 1 + continue + # If both are non-local, just keep existing + skipped_existing += 1 + continue - if moe_expert_count > 0: - print(f"[QAT PostProcessor] Total MoE expert layers in layer 0: {moe_expert_count}") + # Create metadata entry for layers from other PP ranks + # Note: module and quantizers are not available for non-local layers + metadata = QuantizationMetadata( + qformat=info["qformat"], + weight_quantizer=None, # Not available for non-local PP rank + input_quantizer=None, # Not available for non-local PP rank + module=None, # Not available for non-local PP rank + vpp_idx=info["vpp_idx"], + block_size=info["block_size"], + weight_amax=info["weight_amax"], + input_amax=info["input_amax"], + is_local=False, # Mark as non-local (from other PP rank) + global_expert_idx=info["global_expert_idx"], + local_expert_idx=info["local_expert_idx"], + ) + self.quant_metadata[name] = metadata + added_count += 1 + + print( + f"[QAT PostProcessor][Rank {global_rank}] From pp_rank={source_pp_rank}: " + f"added {added_count} metadata entries, skipped {skipped_existing} existing" + ) + + # Log statistics + metadata_added = len(self.quant_metadata) - local_metadata_before + local_count = sum(1 for m in self.quant_metadata.values() if m.is_local) + remote_count = sum(1 for m in self.quant_metadata.values() if not m.is_local) + + print( + f"[QAT PostProcessor][Rank {global_rank}] PP metadata sync complete. " + f"PP size: {self.pp_size}, PP rank: {self.pp_rank}, " + f"Local params: {local_count}, Remote params: {remote_count}, " + f"Metadata added from other PP ranks: {metadata_added}, " + f"Total metadata entries: {len(self.quant_metadata)}" + ) def _find_matching_metadata(self, param_name: str) -> QuantizationMetadata | None: """ @@ -443,6 +783,10 @@ def _quantize_nvfp4( # Use QAT calibrated amax (may only reflect local shard statistics) # weight_scale_2 = amax / (6.0 * 448.0) weight_scale_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) + elif metadata.weight_amax is not None: + # Non-local expert (EP): Use synchronized amax from metadata + weight_amax = metadata.weight_amax.to(weight.device) + weight_scale_2 = weight_amax.float() / (6.0 * 448.0) else: # Compute from all_gathered weight directly (recommended for TP) # weight_scale_2 = max(abs(weight)) / (6.0 * 448.0) @@ -585,13 +929,20 @@ def _find_matching_metadata_by_hf_name(self, hf_name: str) -> QuantizationMetada - HF o_proj.weight -> mcore linear_proj.weight - HF gate_proj/up_proj.weight -> mcore linear_fc1.weight - HF down_proj.weight -> mcore linear_fc2.weight + - MoE experts: model.layers.X.mlp.experts.Y.gate_proj/up_proj/down_proj.weight + - MoE router (gate): model.layers.X.mlp.gate.weight -> NOT quantized (returns None) """ - import re # Only process weight parameters if not hf_name.endswith(".weight") or hf_name.endswith("._amax") or "norm" in hf_name: return None + # Check for MoE router (gate) - should NOT be quantized + # HF formats: model.layers.X.mlp.gate.weight (Qwen) + # model.layers.X.block_sparse_moe.gate.weight (Mixtral) + if self._is_moe_router(hf_name): + return None + # Extract layer number from HF name layer_match = re.search(r"layers?\.(\d+)\.", hf_name) if not layer_match: @@ -610,27 +961,88 @@ def _find_matching_metadata_by_hf_name(self, hf_name: str) -> QuantizationMetada elif "o_proj" in hf_name: mcore_patterns.append(f"decoder.layers.{layer_num}.self_attention.linear_proj.weight") elif "mlp" in hf_name: - if any(proj in hf_name for proj in ["gate_proj", "up_proj"]): - mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc1.weight") - elif "down_proj" in hf_name: - mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc2.weight") + # Check for MoE experts first + # HF format: model.layers.X.mlp.experts.Y.gate_proj/up_proj/down_proj.weight + # HF Mixtral format: model.layers.X.block_sparse_moe.experts.Y.w1/w2/w3.weight + expert_match = re.search(r"\.experts\.(\d+)\.", hf_name) + if expert_match: + expert_id = expert_match.group(1) # This is the global expert ID in HF format + # MoE expert layers - use global expert ID for SequentialMLP + if any(proj in hf_name for proj in ["gate_proj", "up_proj", "w1", "w3"]): + # Try TEGroupedMLP pattern first (all experts share same linear layer) + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.experts.linear_fc1.weight") + # Try SequentialMLP pattern with global expert index + mcore_patterns.append( + f"decoder.layers.{layer_num}.mlp.experts.local_experts.{expert_id}.linear_fc1.weight" + ) + elif any(proj in hf_name for proj in ["down_proj", "w2"]): + # Try TEGroupedMLP pattern first + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.experts.linear_fc2.weight") + # Try SequentialMLP pattern with global expert index + mcore_patterns.append( + f"decoder.layers.{layer_num}.mlp.experts.local_experts.{expert_id}.linear_fc2.weight" + ) + # Check for shared_expert (Qwen2 MoE) + elif "shared_expert" in hf_name: + if any(proj in hf_name for proj in ["gate_proj", "up_proj"]): + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.shared_experts.linear_fc1.weight") + elif "down_proj" in hf_name: + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.shared_experts.linear_fc2.weight") + else: + # Dense MLP + if any(proj in hf_name for proj in ["gate_proj", "up_proj"]): + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc1.weight") + elif "down_proj" in hf_name: + mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc2.weight") # Try to find matching metadata for pattern in mcore_patterns: if pattern in self.quant_metadata: return self.quant_metadata[pattern] - # If no exact match, try to find any metadata from the same layer - # This handles cases where the exact name might be slightly different - for mcore_name, metadata in self.quant_metadata.items(): - if f"layers.{layer_num}." in mcore_name: - # Found a quantized module in the same layer - # For QAT, if any module in the layer is quantized, all Linear layers should be - if ".weight" in mcore_name: - return metadata + # # If no exact match, try to find any metadata from the same layer + # # This handles cases where the exact name might be slightly different + # for mcore_name, metadata in self.quant_metadata.items(): + # if f"layers.{layer_num}." in mcore_name: + # # Found a quantized module in the same layer + # # Skip router metadata - router should not be used for other layers + # if ".router." in mcore_name: + # continue + # # For QAT, if any module in the layer is quantized, all Linear layers should be + # if ".weight" in mcore_name: + # return metadata return None + def _is_moe_router(self, hf_name: str) -> bool: + """ + Check if the HF parameter name corresponds to a MoE router (gate). + + MoE router should NOT be quantized to maintain routing precision. + + Router naming patterns: + - Qwen/Qwen2/Qwen3 MoE: model.layers.X.mlp.gate.weight + - Mixtral: model.layers.X.block_sparse_moe.gate.weight + - Shared expert gate (Qwen2 MoE): model.layers.X.mlp.shared_expert_gate.weight + + Note: gate_proj is NOT the router, it's part of the MLP expert. + """ + + # Pattern 1: Qwen/Qwen3 MoE router - model.layers.X.mlp.gate.weight + # Must be exactly ".mlp.gate.weight" not ".mlp.gate_proj.weight" + if re.search(r"\.mlp\.gate\.weight$", hf_name): + return True + + # Pattern 2: Mixtral router - model.layers.X.block_sparse_moe.gate.weight + if re.search(r"\.block_sparse_moe\.gate\.weight$", hf_name): + return True + + # Pattern 3: Qwen2 MoE shared expert gate - model.layers.X.mlp.shared_expert_gate.weight + if re.search(r"\.mlp\.shared_expert_gate\.weight$", hf_name): + return True + + return False + def _find_non_layer_metadata(self, hf_name: str) -> QuantizationMetadata | None: """Find metadata for non-layer parameters (embed_tokens, lm_head, etc.).""" # Map HF names to mcore names for non-layer parameters @@ -644,94 +1056,4 @@ def _find_non_layer_metadata(self, hf_name: str) -> QuantizationMetadata | None: if mcore_name and mcore_name in self.quant_metadata: return self.quant_metadata[mcore_name] - return None - - def _find_matching_metadata_by_hf_name(self, hf_name: str) -> QuantizationMetadata | None: - """ - Find matching quantization metadata for an HF-format parameter name. - - This maps HF names back to the original mcore names to find metadata. - E.g., 'model.layers.0.self_attn.q_proj.weight' -> check if qkv layer is quantized - - The mapping logic: - - HF q_proj/k_proj/v_proj.weight -> mcore linear_qkv.weight - - HF o_proj.weight -> mcore linear_proj.weight - - HF gate_proj/up_proj.weight -> mcore linear_fc1.weight - - HF down_proj.weight -> mcore linear_fc2.weight - - HF experts.X.gate_proj/up_proj.weight -> mcore experts.local_experts.X.linear_fc1.weight - - HF experts.X.down_proj.weight -> mcore experts.local_experts.X.linear_fc2.weight - - HF shared_expert.gate_proj/up_proj.weight -> mcore shared_experts.linear_fc1.weight - - HF shared_expert.down_proj.weight -> mcore shared_experts.linear_fc2.weight - """ - import re - - # Only process weight parameters - if not hf_name.endswith(".weight") or hf_name.endswith("._amax") or "norm" in hf_name: - return None - - # Extract layer number from HF name - layer_match = re.search(r"layers?\.(\d+)\.", hf_name) - if not layer_match: - # Not a layer parameter (e.g., embed_tokens, lm_head, norm) - # Check for direct matches - return self._find_non_layer_metadata(hf_name) - - layer_num = layer_match.group(1) - - # Determine the mcore module name based on HF name pattern - mcore_patterns = [] - - if "self_attn" in hf_name: - if any(proj in hf_name for proj in ["q_proj", "k_proj", "v_proj"]): - mcore_patterns.append(f"decoder.layers.{layer_num}.self_attention.linear_qkv.weight") - elif "o_proj" in hf_name: - mcore_patterns.append(f"decoder.layers.{layer_num}.self_attention.linear_proj.weight") - elif "mlp" in hf_name: - # Check for MoE expert patterns first - expert_match = re.search(r"experts\.(\d+)\.", hf_name) - if expert_match: - expert_id = expert_match.group(1) - if any(proj in hf_name for proj in ["gate_proj", "up_proj"]): - # MoE expert gate_proj/up_proj -> local_experts.X.linear_fc1 - mcore_patterns.append( - f"decoder.layers.{layer_num}.mlp.experts.local_experts.{expert_id}.linear_fc1.weight" - ) - elif "down_proj" in hf_name: - # MoE expert down_proj -> local_experts.X.linear_fc2 - mcore_patterns.append( - f"decoder.layers.{layer_num}.mlp.experts.local_experts.{expert_id}.linear_fc2.weight" - ) - elif "shared_expert" in hf_name: - # Shared expert patterns (Qwen2Moe, DeepSeekV3, etc.) - if any(proj in hf_name for proj in ["gate_proj", "up_proj"]): - mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.shared_experts.linear_fc1.weight") - elif "down_proj" in hf_name: - mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.shared_experts.linear_fc2.weight") - elif "gate.weight" in hf_name: - # MoE router gate - mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.router.weight") - elif any(proj in hf_name for proj in ["gate_proj", "up_proj"]): - # Dense MLP gate_proj/up_proj - mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc1.weight") - elif "down_proj" in hf_name: - # Dense MLP down_proj - mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc2.weight") - - # Try to find matching metadata - for pattern in mcore_patterns: - if pattern in self.quant_metadata: - return self.quant_metadata[pattern] - - # If no exact match, try to find any metadata from the same layer - # This handles cases where the exact name might be slightly different - for mcore_name, metadata in self.quant_metadata.items(): - if f"layers.{layer_num}." in mcore_name: - # For MoE, check if we're looking for expert weights - if "experts" in hf_name: - if "experts" in mcore_name and ".weight" in mcore_name: - return metadata - # For dense, check if any module in the layer is quantized - elif ".weight" in mcore_name: - return metadata - return None \ No newline at end of file diff --git a/verl/utils/modelopt_vllm_utils.py b/verl/utils/modelopt_vllm_utils.py index 417fc359365..5942d37d7e4 100644 --- a/verl/utils/modelopt_vllm_utils.py +++ b/verl/utils/modelopt_vllm_utils.py @@ -44,6 +44,54 @@ } }, "ignore": [ + # "model.layers.0.mlp.gate", + # "model.layers.1.mlp.gate", + # "model.layers.10.mlp.gate", + # "model.layers.11.mlp.gate", + # "model.layers.12.mlp.gate", + # "model.layers.13.mlp.gate", + # "model.layers.14.mlp.gate", + # "model.layers.15.mlp.gate", + # "model.layers.16.mlp.gate", + # "model.layers.17.mlp.gate", + # "model.layers.18.mlp.gate", + # "model.layers.19.mlp.gate", + # "model.layers.2.mlp.gate", + # "model.layers.20.mlp.gate", + # "model.layers.21.mlp.gate", + # "model.layers.22.mlp.gate", + # "model.layers.23.mlp.gate", + # "model.layers.24.mlp.gate", + # "model.layers.25.mlp.gate", + # "model.layers.26.mlp.gate", + # "model.layers.27.mlp.gate", + # "model.layers.28.mlp.gate", + # "model.layers.29.mlp.gate", + # "model.layers.3.mlp.gate", + # "model.layers.30.mlp.gate", + # "model.layers.31.mlp.gate", + # "model.layers.32.mlp.gate", + # "model.layers.33.mlp.gate", + # "model.layers.34.mlp.gate", + # "model.layers.35.mlp.gate", + # "model.layers.36.mlp.gate", + # "model.layers.37.mlp.gate", + # "model.layers.38.mlp.gate", + # "model.layers.39.mlp.gate", + # "model.layers.4.mlp.gate", + # "model.layers.40.mlp.gate", + # "model.layers.41.mlp.gate", + # "model.layers.42.mlp.gate", + # "model.layers.43.mlp.gate", + # "model.layers.44.mlp.gate", + # "model.layers.45.mlp.gate", + # "model.layers.46.mlp.gate", + # "model.layers.47.mlp.gate", + # "model.layers.5.mlp.gate", + # "model.layers.6.mlp.gate", + # "model.layers.7.mlp.gate", + # "model.layers.8.mlp.gate", + # "model.layers.9.mlp.gate", "lm_head" ], "quant_algo": "NVFP4", diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index e319500a720..901beb7a3f6 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -219,6 +219,19 @@ def _init_hf_config_and_tf_config( provider.moe_token_dispatcher_type = "alltoall" provider.moe_router_load_balancing_type = "none" + def quantization_layer_spec(config): + from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec + return get_gpt_modelopt_spec( + config=config, + local_core_attention=False, + remap_te_layernorm=True, + real_quant_cfg="None", + use_arbitrary_attention_mask=False, + ) + + # from megatron.bridge.models.gpt_provider import quantization_layer_spec + provider.transformer_layer_spec = quantization_layer_spec + # Apply transformer config overrides for key, value in override_transformer_config.items(): setattr(provider, key, value) @@ -451,7 +464,7 @@ def _build_model_optimizer( for i in range(len(actor_module)): actor_module[i] = apply_qat(actor_module[i], quantization) print("[lark]: QAT applied to all actor model chunks") - + print(f"larkz module: {actor_module[0]}") elif self._is_ref: wrap_config = McoreModuleWrapperConfig( is_value_model=False, # ref is not value model From 74df2dc14cc718473c322fc215f65c7959d12919 Mon Sep 17 00:00:00 2001 From: larkz-nv Date: Thu, 12 Feb 2026 06:55:20 -0800 Subject: [PATCH 04/10] add megatron patch for qat training --- verl/models/mcore/qat_patch.py | 541 ++++++++++++++++++ verl/utils/modelopt_qat_utils.py | 3 - verl/utils/modelopt_vllm_utils.py | 156 ++--- verl/workers/megatron_workers.py | 28 +- .../rollout/vllm_rollout/vllm_async_server.py | 21 +- 5 files changed, 632 insertions(+), 117 deletions(-) create mode 100644 verl/models/mcore/qat_patch.py diff --git a/verl/models/mcore/qat_patch.py b/verl/models/mcore/qat_patch.py new file mode 100644 index 00000000000..ec381a0f5de --- /dev/null +++ b/verl/models/mcore/qat_patch.py @@ -0,0 +1,541 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runtime patches for QAT (Quantization-Aware Training) with Megatron-Core. + +This module provides four independent monkey-patches that fix issues in older +versions of megatron-core / megatron-bridge when running QAT workflows: + +1. **SwiGLU sharded-state-dict patch** (``apply_swiglu_sharded_factory_patch``) + Older megatron-core raises ``NotImplementedError`` inside + ``apply_swiglu_sharded_factory`` when ``singleton_local_shards=True``. + The patch adds correct handling by splitting the sharded tensor key into + separate ``{key}_w`` / ``{key}_v`` entries. + +2. **EP gather_from_ep_ranks patch** (``apply_ep_gather_patch``) + The original ``MegatronParamMapping.gather_from_ep_ranks`` only supports + the TEGroupedMLP naming pattern (``weight`` / ``bias``). The patch + additionally supports the SequentialMLP pattern (``local_experts.``) + and adds better error handling. + +3. **extract_sort_key patch** (``apply_extract_sort_key_patch``) + The original ``extract_sort_key`` in megatron-bridge utils only recognises + expert numbers in TEGroupedMLP format (``weight`` / ``bias``). The + patch adds fallback support for the SequentialMLP pattern + (``local_experts.``). + +4. **build_conversion_tasks patch** (``apply_build_conversion_tasks_patch``) + The original ``MegatronModelBridge.build_conversion_tasks`` may return + ``None`` entries in the task list (for PP ranks that don't own certain + parameters and have no mapping). The patch filters out ``None`` entries + before returning so that callers never need to guard against them. + +Convenience entry-point:: + + from verl.models.mcore.qat_patch import apply_qat_patch + apply_qat_patch() # applies all patches at once +""" + +import gc +import logging +import re +from typing import Dict, Iterable, List, Optional + +import torch + +logger = logging.getLogger(__name__) + +# ====================================================================== +# 1. SwiGLU sharded-state-dict patch +# ====================================================================== + + +def apply_swiglu_sharded_factory_patch(): + """Patch ``megatron.core.transformer.mlp.apply_swiglu_sharded_factory`` + to support ``singleton_local_shards`` for SwiGLU MLP tensors. + + Idempotent – safe to call multiple times. + """ + import megatron.core.transformer.mlp as mlp_module + from megatron.core.dist_checkpointing import ShardedTensor + from megatron.core.dist_checkpointing.mapping import ( + ReplicaId, + ShardedTensorFactory, + ) + + if getattr(mlp_module, "_swiglu_patched", False): + return + mlp_module._swiglu_patched = True + mlp_module._original_apply_swiglu_sharded_factory = mlp_module.apply_swiglu_sharded_factory + + def patched_apply_swiglu_sharded_factory( + original_sh_ten, sharded_offsets, singleton_local_shards: bool = False + ): + swiglu_shard_axis = 0 + prepend_axis_num = len(sharded_offsets) + original_shape = original_sh_ten.local_shape + local_axis_size = original_shape[swiglu_shard_axis] + assert ( + original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] + % local_axis_size + == 0 + ) + rank_offset = ( + original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] + // local_axis_size + ) + axis_frag = original_sh_ten.axis_fragmentations[ + swiglu_shard_axis + prepend_axis_num + ] + + @torch.no_grad() + def sh_ten_build_fn( + key: str, + t: torch.Tensor, + replica_id: ReplicaId, + flattened_range: Optional[slice], + ): + if singleton_local_shards: + offset_w = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag) + offset_v = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag) + w_key = f"{key}_w" + v_key = f"{key}_v" + else: + offset_w = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag * 2) + offset_v = ( + swiglu_shard_axis + prepend_axis_num, + rank_offset + axis_frag, + axis_frag * 2, + ) + w_key = key + v_key = key + + tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis) + return [ + ShardedTensor.from_rank_offsets( + w_key, tensor_w, *sharded_offsets, offset_w, + replica_id=replica_id, prepend_axis_num=prepend_axis_num, + ), + ShardedTensor.from_rank_offsets( + v_key, tensor_v, *sharded_offsets, offset_v, + replica_id=replica_id, prepend_axis_num=prepend_axis_num, + ), + ] + + def sh_ten_merge_fn(sub_state_dict): + with torch.no_grad(): + try: + return torch.cat(sub_state_dict) + except (RuntimeError, torch.cuda.OutOfMemoryError) as e: + logger.warning( + "CUDA OOM during tensor merge – falling back to CPU. (Error: %s)", e, + ) + merged = torch.cat([t.cpu() for t in sub_state_dict]) + gc.collect() + torch.cuda.empty_cache() + return merged + + return ShardedTensorFactory( + original_sh_ten.key, + original_sh_ten.data, + sh_ten_build_fn, + sh_ten_merge_fn, + original_sh_ten.replica_id, + flattened_range=original_sh_ten.flattened_range, + ) + + mlp_module.apply_swiglu_sharded_factory = patched_apply_swiglu_sharded_factory + logger.info("Applied QAT patch: apply_swiglu_sharded_factory now supports singleton_local_shards.") + + +def revert_swiglu_sharded_factory_patch(): + """Revert :func:`apply_swiglu_sharded_factory_patch`.""" + import megatron.core.transformer.mlp as mlp_module + + if not getattr(mlp_module, "_swiglu_patched", False): + return + mlp_module.apply_swiglu_sharded_factory = mlp_module._original_apply_swiglu_sharded_factory + mlp_module._swiglu_patched = False + logger.info("Reverted QAT patch: apply_swiglu_sharded_factory.") + + +# ====================================================================== +# 2. EP gather_from_ep_ranks patch +# ====================================================================== + + +def apply_ep_gather_patch(): + """Patch ``MegatronParamMapping.gather_from_ep_ranks`` in megatron-bridge + to support both SequentialMLP (``local_experts.``) and TEGroupedMLP + (``weight`` / ``bias``) naming patterns. + + Idempotent – safe to call multiple times. + """ + from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping + + if getattr(MegatronParamMapping, "_ep_gather_patched", False): + return + MegatronParamMapping._ep_gather_patched = True + MegatronParamMapping._original_gather_from_ep_ranks = MegatronParamMapping.gather_from_ep_ranks + + def _patched_gather_from_ep_ranks( + self, + megatron_weights: Optional[torch.Tensor], + megatron_module, # Optional[MegatronModule] + hf_param_name: Optional[str], + ) -> Dict[str, torch.Tensor]: + """Gather expert weights across EP ranks (supports SequentialMLP + TEGroupedMLP).""" + if megatron_module is None: + num_experts_per_rank = self.broadcast_obj_from_pp_rank(None, "num_experts_per_rank") + else: + model_config = self._get_config(megatron_module) + num_experts = model_config.num_moe_experts + num_experts_per_rank = num_experts // self.ep_size + num_experts_per_rank = self.broadcast_obj_from_pp_rank( + num_experts_per_rank, "num_experts_per_rank" + ) + + # --- Extract the local expert index from the Megatron param name --- + local_expert_number = None + + # Try SequentialMLP pattern first: local_experts. + local_experts_match = re.search(r"local_experts\.(\d+)", self.megatron_param) + if local_experts_match: + global_expert_number = int(local_experts_match.group(1)) + local_expert_number = global_expert_number % num_experts_per_rank + else: + # Fallback: TEGroupedMLP pattern – weight or bias + for key in (".weight", ".bias"): + if key in self.megatron_param: + suffix = self.megatron_param.split(key)[-1] + if suffix: # only if there is actually a number after the suffix + global_expert_number = int(suffix) + local_expert_number = global_expert_number % num_experts_per_rank + break + + if local_expert_number is None: + raise ValueError( + f"Could not extract expert number from parameter name: {self.megatron_param}. " + f"Expected either TEGroupedMLP pattern (weight/bias) or " + f"SequentialMLP pattern (local_experts.)." + ) + + # Build HF param names for every EP rank + gathered_expert_param_names = [ + re.sub( + r"experts\.(\d+)", + f"experts.{int(local_expert_number) + num_experts_per_rank * i}", + str(hf_param_name), + ) + for i in range(self.ep_size) + ] + assert str(hf_param_name) in gathered_expert_param_names, ( + f"hf_param_name {hf_param_name} not in gathered_expert_param_names " + f"{gathered_expert_param_names}" + ) + + # All-gather across the EP group + gathered_weights = [torch.empty_like(megatron_weights) for _ in range(self.ep_size)] + torch.distributed.all_gather(gathered_weights, megatron_weights, group=self.ep_group) + + # Assemble the result dict (handles duplicate names via concatenation) + weights_dict: Dict[str, torch.Tensor] = {} + for i, param_name in enumerate(gathered_expert_param_names): + if param_name in weights_dict: + weights_dict[param_name] = torch.cat( + [weights_dict[param_name], gathered_weights[i].unsqueeze(0)], dim=0 + ) + else: + weights_dict[param_name] = gathered_weights[i].unsqueeze(0) + for param_name in weights_dict: + weights_dict[param_name] = weights_dict[param_name].squeeze() + + return weights_dict + + MegatronParamMapping.gather_from_ep_ranks = _patched_gather_from_ep_ranks + logger.info( + "Applied QAT patch: MegatronParamMapping.gather_from_ep_ranks " + "now supports SequentialMLP pattern." + ) + + +def revert_ep_gather_patch(): + """Revert :func:`apply_ep_gather_patch`.""" + from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping + + if not getattr(MegatronParamMapping, "_ep_gather_patched", False): + return + MegatronParamMapping.gather_from_ep_ranks = MegatronParamMapping._original_gather_from_ep_ranks + MegatronParamMapping._ep_gather_patched = False + logger.info("Reverted QAT patch: MegatronParamMapping.gather_from_ep_ranks.") + + +# ====================================================================== +# 3. extract_sort_key patch +# ====================================================================== + + +def apply_extract_sort_key_patch(): + """Patch ``megatron.bridge.models.conversion.utils.extract_sort_key`` + to support the SequentialMLP naming pattern (``local_experts.``) in + addition to the original TEGroupedMLP pattern (``weight`` / ``bias``). + + Idempotent – safe to call multiple times. + """ + import megatron.bridge.models.conversion.utils as utils_module + + if getattr(utils_module, "_sort_key_patched", False): + return + utils_module._sort_key_patched = True + utils_module._original_extract_sort_key = utils_module.extract_sort_key + + def _patched_extract_sort_key(param_name: str): + """Extract sorting key based on layer and expert numbers.""" + numbers = [] + + # Find layer number + layer_match = re.search(r"layers\.(\d+)", param_name) + if layer_match: + numbers.append(int(layer_match.group(1))) + + # Find expert number – try multiple patterns + expert_number = None + + # Pattern 1: TEGroupedMLP format (e.g., weight15, bias15) + expert_match = re.search(r"(?:bias|weight)(\d+)", param_name) + if expert_match: + expert_number = int(expert_match.group(1)) + + # Pattern 2: SequentialMLP format (e.g., local_experts.15) + if expert_number is None: + local_experts_match = re.search(r"local_experts\.(\d+)", param_name) + if local_experts_match: + expert_number = int(local_experts_match.group(1)) + + if expert_number is not None: + numbers.append(expert_number) + + # Pad to ensure consistent comparison (max 2 numbers) + while len(numbers) < 2: + numbers.append(-1) + numbers = numbers[:2] + return numbers, param_name + + utils_module.extract_sort_key = _patched_extract_sort_key + logger.info( + "Applied QAT patch: extract_sort_key now supports SequentialMLP pattern." + ) + + +def revert_extract_sort_key_patch(): + """Revert :func:`apply_extract_sort_key_patch`.""" + import megatron.bridge.models.conversion.utils as utils_module + + if not getattr(utils_module, "_sort_key_patched", False): + return + utils_module.extract_sort_key = utils_module._original_extract_sort_key + utils_module._sort_key_patched = False + logger.info("Reverted QAT patch: extract_sort_key.") + + +# ====================================================================== +# 4. build_conversion_tasks patch +# ====================================================================== + + +def apply_build_conversion_tasks_patch(): + """Patch ``MegatronModelBridge.build_conversion_tasks`` to filter out + ``None`` entries before returning the task list. + + The original implementation can leave ``None`` slots for PP ranks that + don't own certain parameters and have no mapping. Downstream code that + iterates over the returned list may break on ``None``. This patch + ensures only valid :class:`WeightConversionTask` objects are returned. + + Idempotent – safe to call multiple times. + """ + import itertools + + from megatron.bridge.models.conversion.model_bridge import ( + MegatronModelBridge, + WeightConversionTask, + _megatron_local_name_to_global, + ) + from megatron.bridge.models.conversion.utils import ( + get_module_and_param_from_name, + persistent_buffers, + ) + from megatron.bridge.utils.common_utils import print_rank_0 + from megatron.core import parallel_state + from megatron.core.utils import unwrap_model + + if getattr(MegatronModelBridge, "_build_tasks_patched", False): + return + MegatronModelBridge._build_tasks_patched = True + MegatronModelBridge._original_build_conversion_tasks = ( + MegatronModelBridge.build_conversion_tasks + ) + + def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): + """Construct conversion tasks between HF and Megatron (``None``-free). + + Returns a list of :class:`WeightConversionTask` objects — ``None`` + entries are filtered out before the list is returned so that callers + never need to guard against them. + """ + # Ensure hf_pretrained has the required state structure + if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): + raise ValueError("hf_pretrained.state.source is required for weight ordering") + + hf_keys: Iterable[str] = hf_pretrained.state.source.get_all_keys() + + mapping_registry = self.mapping_registry() + unwrapped_model = unwrap_model(megatron_model)[0] + model_config = unwrapped_model.config + embeddings_are_tied = self._share_embeddings_and_output_weights(model_config, unwrapped_model) + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks( + megatron_model + ) + + # Filter out output_layer related parameters if embeddings are tied + if embeddings_are_tied: + sorted_global_param_names_all_pp_ranks = [ + name for name in sorted_global_param_names_all_pp_ranks if "output_layer" not in name + ] + + global_names_index_dict = { + name: idx for idx, name in enumerate(sorted_global_param_names_all_pp_ranks) + } + + tasks = [None] * len(sorted_global_param_names_all_pp_ranks) + for vp_stage, model in enumerate(megatron_model): + for local_name, _ in itertools.chain( + model.named_parameters(), persistent_buffers(model) + ): + if "_extra_state" in local_name or self._is_adapter_param_name(local_name): + continue + + local_name = self._unwrap_name(local_name) + global_name = _megatron_local_name_to_global( + megatron_model, model_config, local_name, vp_stage + ) + if global_name not in global_names_index_dict: + print_rank_0(f"WARNING: {global_name} not in global_names_index_dict") + continue + global_name_idx = global_names_index_dict[global_name] + mapping = mapping_registry.megatron_to_hf_lookup( + self._get_lora_unwrapped_name(global_name) + ) + + if not mapping: + logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}") + continue + + # Ensure HF weights exist + if not mapping.allow_hf_name_mismatch: + if isinstance(mapping.hf_param, str): + if mapping.hf_param not in hf_keys: + logger.warning(f"WARNING: Can't find {mapping.hf_param} in hf_keys") + continue + else: + missing_params = [ + hf_param + for hf_param in mapping.hf_param.values() + if hf_param not in hf_keys + ] + if missing_params: + logger.warning( + f"WARNING: Can't find the following HF parameters in hf_keys: " + f"{missing_params}" + ) + continue + + local_module, local_weights = get_module_and_param_from_name( + megatron_model, local_name, vp_stage + ) + if local_module is not None and not hasattr(local_module, "config"): + setattr(local_module, "config", model_config) + + tasks[global_name_idx] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=vp_stage, + param_name=local_name, + global_param_name=global_name, + megatron_module=local_module, + param_weight=local_weights, + mapping=mapping, + ) + + # Fill the remaining slots for PP communications + for idx, global_name in enumerate(sorted_global_param_names_all_pp_ranks): + if tasks[idx] is None: + mapping = mapping_registry.megatron_to_hf_lookup( + self._get_lora_unwrapped_name(global_name) + ) + if mapping is None: + continue + tasks[idx] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=None, + param_name=global_name, + global_param_name=global_name, + megatron_module=None, + param_weight=None, + mapping=mapping, + ) + + tasks = [task for task in tasks if task is not None] + return tasks + + MegatronModelBridge.build_conversion_tasks = _patched_build_conversion_tasks + logger.info( + "Applied QAT patch: MegatronModelBridge.build_conversion_tasks " + "now filters out None entries." + ) + + +def revert_build_conversion_tasks_patch(): + """Revert :func:`apply_build_conversion_tasks_patch`.""" + from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge + + if not getattr(MegatronModelBridge, "_build_tasks_patched", False): + return + MegatronModelBridge.build_conversion_tasks = ( + MegatronModelBridge._original_build_conversion_tasks + ) + MegatronModelBridge._build_tasks_patched = False + logger.info("Reverted QAT patch: MegatronModelBridge.build_conversion_tasks.") + + +# ====================================================================== +# Convenience: apply / revert all QAT patches at once +# ====================================================================== + + +def apply_qat_patch(): + """Apply **all** QAT-related patches. Idempotent.""" + apply_swiglu_sharded_factory_patch() + apply_ep_gather_patch() + apply_extract_sort_key_patch() + apply_build_conversion_tasks_patch() + + +def revert_qat_patch(): + """Revert **all** QAT-related patches.""" + revert_swiglu_sharded_factory_patch() + revert_ep_gather_patch() + revert_extract_sort_key_patch() + revert_build_conversion_tasks_patch() diff --git a/verl/utils/modelopt_qat_utils.py b/verl/utils/modelopt_qat_utils.py index 45b20e3f3da..7e8b63d401b 100644 --- a/verl/utils/modelopt_qat_utils.py +++ b/verl/utils/modelopt_qat_utils.py @@ -158,8 +158,6 @@ def __init__( self._build_quantization_metadata() global_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - # print(f"[QAT PostProcessor][Rank {global_rank}] After _build_quantization_metadata: " - # f"metadata_count={len(self.quant_metadata)}, ep_size={self.ep_size}, pp_size={self.pp_size}") # Synchronize metadata across EP ranks if EP is enabled if self.ep_size > 1: @@ -740,7 +738,6 @@ def _quantize_weight( qformat = metadata.qformat if qformat == QUANTIZATION_NVFP4: - # print("[lark]: quantize_weight name:", name, "weight:", weight.shape, "metadata:", metadata) yield from self._quantize_nvfp4(name, weight, metadata) else: # Unknown format, pass through with warning diff --git a/verl/utils/modelopt_vllm_utils.py b/verl/utils/modelopt_vllm_utils.py index 5942d37d7e4..18bbaca5ced 100644 --- a/verl/utils/modelopt_vllm_utils.py +++ b/verl/utils/modelopt_vllm_utils.py @@ -23,84 +23,69 @@ from torch.nn import Parameter -NVFP4_BLOCK_QUANT_KWARGS = { - "config_groups": { - "group_0": { - "input_activations": { - "dynamic": "false", - "num_bits": 4, - "type": "float", - "group_size": 16 - }, - "weights": { - "dynamic": "false", - "num_bits": 4, - "type": "float", - "group_size": 16 - }, - "targets": [ - "Linear" - ] - } - }, - "ignore": [ - # "model.layers.0.mlp.gate", - # "model.layers.1.mlp.gate", - # "model.layers.10.mlp.gate", - # "model.layers.11.mlp.gate", - # "model.layers.12.mlp.gate", - # "model.layers.13.mlp.gate", - # "model.layers.14.mlp.gate", - # "model.layers.15.mlp.gate", - # "model.layers.16.mlp.gate", - # "model.layers.17.mlp.gate", - # "model.layers.18.mlp.gate", - # "model.layers.19.mlp.gate", - # "model.layers.2.mlp.gate", - # "model.layers.20.mlp.gate", - # "model.layers.21.mlp.gate", - # "model.layers.22.mlp.gate", - # "model.layers.23.mlp.gate", - # "model.layers.24.mlp.gate", - # "model.layers.25.mlp.gate", - # "model.layers.26.mlp.gate", - # "model.layers.27.mlp.gate", - # "model.layers.28.mlp.gate", - # "model.layers.29.mlp.gate", - # "model.layers.3.mlp.gate", - # "model.layers.30.mlp.gate", - # "model.layers.31.mlp.gate", - # "model.layers.32.mlp.gate", - # "model.layers.33.mlp.gate", - # "model.layers.34.mlp.gate", - # "model.layers.35.mlp.gate", - # "model.layers.36.mlp.gate", - # "model.layers.37.mlp.gate", - # "model.layers.38.mlp.gate", - # "model.layers.39.mlp.gate", - # "model.layers.4.mlp.gate", - # "model.layers.40.mlp.gate", - # "model.layers.41.mlp.gate", - # "model.layers.42.mlp.gate", - # "model.layers.43.mlp.gate", - # "model.layers.44.mlp.gate", - # "model.layers.45.mlp.gate", - # "model.layers.46.mlp.gate", - # "model.layers.47.mlp.gate", - # "model.layers.5.mlp.gate", - # "model.layers.6.mlp.gate", - # "model.layers.7.mlp.gate", - # "model.layers.8.mlp.gate", - # "model.layers.9.mlp.gate", - "lm_head" - ], - "quant_algo": "NVFP4", - "producer": { - "name": "modelopt", - "version": "0.40.0.dev89+g0ec5e200f.d20251127" - }, - "quant_method": "modelopt" -} +def generate_nvfp4_ignore_list(num_layers: int, is_moe: bool) -> list[str]: + """ + Generate the ignore list for NVFP4 quantization based on model configuration. + + Args: + num_layers: Number of hidden layers in the model (from hf_config.num_hidden_layers) + is_moe: Whether the model is a Mixture of Experts model + + Returns: + List of layer names to ignore during quantization + """ + ignore_list = [] + + # For MoE models, ignore the gate layers (routing layers) + if is_moe: + for layer_idx in range(num_layers): + ignore_list.append(f"model.layers.{layer_idx}.mlp.gate") + + # Always ignore lm_head for stability + ignore_list.append("lm_head") + + return ignore_list + + +def get_nvfp4_block_quant_kwargs(num_layers: int, is_moe: bool) -> dict: + """ + Generate complete NVFP4 quantization configuration based on model properties. + Args: + num_layers: Number of hidden layers in the model (from hf_config.num_hidden_layers) + is_moe: Whether the model is a Mixture of Experts model + + Returns: + Complete quantization configuration dictionary compatible with ModelOpt + """ + ignore_list = generate_nvfp4_ignore_list(num_layers, is_moe) + + return { + "config_groups": { + "group_0": { + "input_activations": { + "dynamic": "false", + "num_bits": 4, + "type": "float", + "group_size": 16 + }, + "weights": { + "dynamic": "false", + "num_bits": 4, + "type": "float", + "group_size": 16 + }, + "targets": [ + "Linear" + ] + } + }, + "ignore": ignore_list, + "quant_algo": "NVFP4", + "producer": { + "name": "modelopt", + }, + "quant_method": "modelopt" + } @@ -121,12 +106,6 @@ def _create_param_from_subclass_attributes(custom_data: torch.Tensor, custom_wei def process_weights_after_loading_modelopt(self, layer: torch.nn.Module) -> None: - if getattr(layer, "prefix", None) == "model.layers.27.mlp.gate_up_proj" or getattr(layer, "prefix", "").startswith( - "model.layers.27.self_attn" - ): - print( - f"##VLLM##: {getattr(layer, 'prefix', None)}: {layer.params_dtype} bias: {getattr(layer, 'bias', None)} {layer.weight.data[0, :4]}, scale: {layer.weight_scale.data[0, :4]}, scale_2: {layer.weight_scale_2.data[0]}" - ) import vllm._custom_ops as ops from torch.nn import Parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( @@ -248,13 +227,6 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module, weight_scale_2_max: tor layer.weight_scale = _create_param_from_subclass_attributes(weight_scale, layer.weight_scale) prepare_fp4_layer_for_marlin(layer, weight_scale_2_max) - if getattr(layer, "prefix", None) == "model.layers.27.mlp.gate_up_proj" or getattr( - layer, "prefix", "" - ).startswith("model.layers.27.self_attn"): - print( - f"##VLLM-MARLIN##: {getattr(layer, 'prefix', None)}: {layer.marlin_weight.data[0, :4]}, scale: {layer.marlin_weight_scale.data[0, :4]}, scale_2: {layer.marlin_weight_scale_2.data}" - ) - del layer.alpha # del layer.input_scale elif self.backend == "flashinfer-trtllm": @@ -293,8 +265,6 @@ def apply_modelopt( from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm if self.backend == "marlin": - # if getattr(layer, "prefix", None) == "model.layers.27.mlp.gate_up_proj" or getattr(layer, "prefix", "").startswith("model.layers.27.self_attn"): - # print(f"##VLLM-MARLIN##: {getattr(layer, 'prefix', None)}: {layer.marlin_weight.data[0, :4]}, scale: {layer.marlin_weight_scale.data[0, :4]}, scale_2: {layer.marlin_weight_scale_2.data}") return apply_fp4_marlin_linear( input=x, weight=layer.marlin_weight, diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 901beb7a3f6..c6182f168c1 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -219,18 +219,19 @@ def _init_hf_config_and_tf_config( provider.moe_token_dispatcher_type = "alltoall" provider.moe_router_load_balancing_type = "none" - def quantization_layer_spec(config): - from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec - return get_gpt_modelopt_spec( - config=config, - local_core_attention=False, - remap_te_layernorm=True, - real_quant_cfg="None", - use_arbitrary_attention_mask=False, - ) + enable_qat = self.config.actor.megatron.get("enable_qat", False) + if enable_qat: + from megatron.bridge.models.gpt_provider import quantization_layer_spec + provider.transformer_layer_spec = quantization_layer_spec + + # Patch megatron-core MLP to support singleton_local_shards + # in SwiGLU sharded state dict (required for QAT checkpointing) + from verl.models.mcore.qat_patch import apply_qat_patch + apply_qat_patch() - # from megatron.bridge.models.gpt_provider import quantization_layer_spec - provider.transformer_layer_spec = quantization_layer_spec + from megatron.bridge.models.conversion.param_mapping import AutoMapping + AutoMapping.register_module_type('QuantColumnParallelLinear', 'column') + AutoMapping.register_module_type('QuantRowParallelLinear', 'row') # Apply transformer config overrides for key, value in override_transformer_config.items(): @@ -459,12 +460,8 @@ def _build_model_optimizer( quantization = self.config.actor.megatron.get("quantization", None) enable_qat = self.config.actor.megatron.get("enable_qat", False) if quantization is not None and enable_qat: - print(f"[lark]: Applying QAT with quantization: {quantization}") - print("[lark]: length of actor_module:", len(actor_module)) for i in range(len(actor_module)): actor_module[i] = apply_qat(actor_module[i], quantization) - print("[lark]: QAT applied to all actor model chunks") - print(f"larkz module: {actor_module[0]}") elif self._is_ref: wrap_config = McoreModuleWrapperConfig( is_value_model=False, # ref is not value model @@ -741,7 +738,6 @@ async def rollout_mode(self): self.layer_name_mapping, ) if self.config.actor.megatron.get("enable_qat", False): - print("[lark]: rollout mode: quantizing weights with QAT") from verl.utils.modelopt_qat_utils import QATWeightPostProcessor qat_weight_post_processor = QATWeightPostProcessor( diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index e4a4416412a..e32019d562b 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -21,6 +21,7 @@ from typing import Any, Callable, Optional import numpy as np +from numpy.random import f import ray import vllm.entrypoints.cli.serve from packaging import version @@ -60,9 +61,10 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser if _VLLM_VERSION == version.parse("0.12.0"): - from vllm.entrypoints.harmony_utils import get_encoding + pass + # from vllm.entrypoints.harmony_utils import get_encoding - get_encoding() + # get_encoding() elif _VLLM_VERSION >= version.parse("0.13.0"): from vllm.entrypoints.openai.parser.harmony_utils import get_encoding @@ -243,9 +245,18 @@ async def launch_server(self, master_address: str = None, master_port: int = Non # for subprocesses patching os.environ["VERL_VLLM_FP8_QUANT_ENABLED"] = "1" elif quantization == "nvfp4": - print("[lark]: vllm quantization is nvfp4") - from verl.utils.modelopt_vllm_utils import apply_vllm_modelopt_patches, NVFP4_BLOCK_QUANT_KWARGS - fp4_block_quant_kwargs = dict(NVFP4_BLOCK_QUANT_KWARGS) + from verl.utils.modelopt_vllm_utils import apply_vllm_modelopt_patches, get_nvfp4_block_quant_kwargs + + num_layers = getattr(self.model_config.hf_config, "num_hidden_layers") + + is_moe = ( + hasattr(self.model_config.hf_config, "num_experts") or + hasattr(self.model_config.hf_config, "num_local_experts") or + hasattr(self.model_config.hf_config, "moe_intermediate_size") + ) + + fp4_block_quant_kwargs = get_nvfp4_block_quant_kwargs(num_layers, is_moe) + apply_vllm_modelopt_patches() os.environ["VERL_VLLM_NVFP4_QUANT_ENABLED"] = "1" From 8772ae8f7fc562aff8cd6d2c9b03b8312745af42 Mon Sep 17 00:00:00 2001 From: larkz-nv Date: Mon, 23 Feb 2026 23:39:38 -0800 Subject: [PATCH 05/10] Add patch for SequentialMLP expert numbering under EP --- verl/models/mcore/qat_patch.py | 161 ++++++++++++++++++++++++++++++++- 1 file changed, 157 insertions(+), 4 deletions(-) diff --git a/verl/models/mcore/qat_patch.py b/verl/models/mcore/qat_patch.py index ec381a0f5de..ad6475400ae 100644 --- a/verl/models/mcore/qat_patch.py +++ b/verl/models/mcore/qat_patch.py @@ -36,12 +36,31 @@ patch adds fallback support for the SequentialMLP pattern (``local_experts.``). -4. **build_conversion_tasks patch** (``apply_build_conversion_tasks_patch``) +4. **_megatron_local_name_to_global patch** + (``apply_local_name_to_global_patch``) + The original ``_megatron_local_name_to_global`` only converts local + expert numbers to global for the TEGroupedMLP pattern + (``mlp.experts.linear_fc`` + ``weight``/``bias``). The patch + adds support for the SequentialMLP pattern + (``mlp.experts.local_experts.``). Without this, expert numbers + remain local (e.g. 0-15 for 128 experts with EP=8) instead of being + mapped to global indices (0-127). + +5. **build_conversion_tasks patch** (``apply_build_conversion_tasks_patch``) The original ``MegatronModelBridge.build_conversion_tasks`` may return ``None`` entries in the task list (for PP ranks that don't own certain parameters and have no mapping). The patch filters out ``None`` entries before returning so that callers never need to guard against them. +6. **AutoMapping._detect_parallelism_type patch** + (``apply_detect_parallelism_type_patch``) + The original ``_detect_parallelism_type`` only matches + ``module_type == "TELayerNormColumnParallelLinear"`` exactly. ModelOpt + quantised wrappers produce class names like + ``QuantTELayerNormColumnParallelLinear`` that contain the substring but + don't match exactly. The patch broadens the check to + ``"LayerNormColumnParallelLinear" in module_type``. + Convenience entry-point:: from verl.models.mcore.qat_patch import apply_qat_patch @@ -295,11 +314,14 @@ def apply_extract_sort_key_patch(): Idempotent – safe to call multiple times. """ import megatron.bridge.models.conversion.utils as utils_module + import megatron.bridge.models.conversion.model_bridge as bridge_module if getattr(utils_module, "_sort_key_patched", False): return utils_module._sort_key_patched = True + bridge_module._sort_key_patched = True utils_module._original_extract_sort_key = utils_module.extract_sort_key + bridge_module._original_extract_sort_key = bridge_module.extract_sort_key def _patched_extract_sort_key(param_name: str): """Extract sorting key based on layer and expert numbers.""" @@ -334,6 +356,7 @@ def _patched_extract_sort_key(param_name: str): return numbers, param_name utils_module.extract_sort_key = _patched_extract_sort_key + bridge_module.extract_sort_key = _patched_extract_sort_key logger.info( "Applied QAT patch: extract_sort_key now supports SequentialMLP pattern." ) @@ -342,16 +365,88 @@ def _patched_extract_sort_key(param_name: str): def revert_extract_sort_key_patch(): """Revert :func:`apply_extract_sort_key_patch`.""" import megatron.bridge.models.conversion.utils as utils_module + import megatron.bridge.models.conversion.model_bridge as bridge_module + if not getattr(utils_module, "_sort_key_patched", False): return utils_module.extract_sort_key = utils_module._original_extract_sort_key + bridge_module.extract_sort_key = bridge_module._original_extract_sort_key utils_module._sort_key_patched = False + bridge_module._sort_key_patched = False logger.info("Reverted QAT patch: extract_sort_key.") # ====================================================================== -# 4. build_conversion_tasks patch +# 4. _megatron_local_name_to_global patch +# ====================================================================== + + +def apply_local_name_to_global_patch(): + """Patch ``_megatron_local_name_to_global`` in megatron-bridge + to support the SequentialMLP naming pattern (``local_experts.``) + for local-to-global expert number conversion under EP > 1. + + The original function only handles the TEGroupedMLP pattern + (``mlp.experts.linear_fc`` with ``weight``/``bias``). The + patch adds an ``elif`` branch for SequentialMLP parameters whose + names contain ``mlp.experts.local_experts.``. + + Idempotent – safe to call multiple times. + """ + import megatron.bridge.models.conversion.model_bridge as bridge_module + from megatron.core import parallel_state + from megatron.core.utils import get_pg_size + + if getattr(bridge_module, "_local_name_to_global_patched", False): + return + bridge_module._local_name_to_global_patched = True + bridge_module._original_megatron_local_name_to_global = bridge_module._megatron_local_name_to_global + + _orig_fn = bridge_module._megatron_local_name_to_global + + def _patched_megatron_local_name_to_global(models, config, param_name, vp_stage=None): + param_name = _orig_fn(models, config, param_name, vp_stage) + + ep_group = parallel_state.get_expert_model_parallel_group() + if ( + ".mlp.experts.local_experts." in param_name + and get_pg_size(ep_group) > 1 + and ".adapter." not in param_name + ): + num_experts = config.num_moe_experts + num_experts_per_rank = num_experts // ep_group.size() + local_experts_match = re.search(r"\.local_experts\.(\d+)\.", param_name) + if local_experts_match: + local_expert_number = int(local_experts_match.group(1)) + global_expert_number = num_experts_per_rank * ep_group.rank() + local_expert_number + param_name = param_name.replace( + f".local_experts.{local_expert_number}.", + f".local_experts.{global_expert_number}.", + ) + + return param_name + + bridge_module._megatron_local_name_to_global = _patched_megatron_local_name_to_global + logger.info( + "Applied QAT patch: _megatron_local_name_to_global " + "now supports SequentialMLP pattern." + ) + + +def revert_local_name_to_global_patch(): + """Revert :func:`apply_local_name_to_global_patch`.""" + import megatron.bridge.models.conversion.model_bridge as bridge_module + + if not getattr(bridge_module, "_local_name_to_global_patched", False): + return + bridge_module._megatron_local_name_to_global = bridge_module._original_megatron_local_name_to_global + bridge_module._local_name_to_global_patched = False + logger.info("Reverted QAT patch: _megatron_local_name_to_global.") + + +# ====================================================================== +# 5. build_conversion_tasks patch # ====================================================================== @@ -368,10 +463,10 @@ def apply_build_conversion_tasks_patch(): """ import itertools + import megatron.bridge.models.conversion.model_bridge as bridge_module from megatron.bridge.models.conversion.model_bridge import ( MegatronModelBridge, WeightConversionTask, - _megatron_local_name_to_global, ) from megatron.bridge.models.conversion.utils import ( get_module_and_param_from_name, @@ -429,7 +524,7 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): continue local_name = self._unwrap_name(local_name) - global_name = _megatron_local_name_to_global( + global_name = bridge_module._megatron_local_name_to_global( megatron_model, model_config, local_name, vp_stage ) if global_name not in global_names_index_dict: @@ -520,6 +615,60 @@ def revert_build_conversion_tasks_patch(): logger.info("Reverted QAT patch: MegatronModelBridge.build_conversion_tasks.") +# ====================================================================== +# 5. AutoMapping._detect_parallelism_type patch +# ====================================================================== + + +def apply_detect_parallelism_type_patch(): + """Patch ``AutoMapping._detect_parallelism_type`` to recognise quantised + ``LayerNormColumnParallelLinear`` variants (e.g. + ``QuantTELayerNormColumnParallelLinear``). + + The original code only checks + ``module_type == "TELayerNormColumnParallelLinear"``. ModelOpt wraps this + into classes whose names still *contain* ``LayerNormColumnParallelLinear`` + but do not match exactly. The patch broadens the check to + ``"LayerNormColumnParallelLinear" in module_type``. + + Idempotent – safe to call multiple times. + """ + from megatron.bridge.models.conversion.param_mapping import AutoMapping + + if getattr(AutoMapping, "_detect_parallelism_patched", False): + return + AutoMapping._detect_parallelism_patched = True + AutoMapping._original_detect_parallelism_type = AutoMapping._detect_parallelism_type + + def _patched_detect_parallelism_type(self, module): + module_type = type(module).__name__ + if "LayerNormColumnParallelLinear" in module_type: + if self.megatron_param and ( + self.megatron_param.endswith("layer_norm_weight") + or self.megatron_param.endswith("layer_norm_bias") + ): + return "replicated" + return "column" + return AutoMapping._original_detect_parallelism_type(self, module) + + AutoMapping._detect_parallelism_type = _patched_detect_parallelism_type + logger.info( + "Applied QAT patch: AutoMapping._detect_parallelism_type " + "now supports quantised LayerNormColumnParallelLinear variants." + ) + + +def revert_detect_parallelism_type_patch(): + """Revert :func:`apply_detect_parallelism_type_patch`.""" + from megatron.bridge.models.conversion.param_mapping import AutoMapping + + if not getattr(AutoMapping, "_detect_parallelism_patched", False): + return + AutoMapping._detect_parallelism_type = AutoMapping._original_detect_parallelism_type + AutoMapping._detect_parallelism_patched = False + logger.info("Reverted QAT patch: AutoMapping._detect_parallelism_type.") + + # ====================================================================== # Convenience: apply / revert all QAT patches at once # ====================================================================== @@ -530,7 +679,9 @@ def apply_qat_patch(): apply_swiglu_sharded_factory_patch() apply_ep_gather_patch() apply_extract_sort_key_patch() + apply_local_name_to_global_patch() apply_build_conversion_tasks_patch() + apply_detect_parallelism_type_patch() def revert_qat_patch(): @@ -538,4 +689,6 @@ def revert_qat_patch(): revert_swiglu_sharded_factory_patch() revert_ep_gather_patch() revert_extract_sort_key_patch() + revert_local_name_to_global_patch() revert_build_conversion_tasks_patch() + revert_detect_parallelism_type_patch() From 6e815f8d3150546c445fba066a9dc63df5675642 Mon Sep 17 00:00:00 2001 From: larkzhang-nv Date: Wed, 25 Feb 2026 02:39:28 -0800 Subject: [PATCH 06/10] Refactor modelopt utils and unify QAT config under actor --- verl/trainer/config/actor/actor.yaml | 32 +++++++ verl/trainer/config/actor/dp_actor.yaml | 32 ------- verl/trainer/config/engine/megatron.yaml | 6 -- verl/utils/modelopt/__init__.py | 44 +++++++++ verl/utils/modelopt/qat.py | 81 ++++++++++++++++ .../vllm_patch.py} | 68 +------------ .../weight_processor.py} | 96 ++----------------- verl/workers/config/actor.py | 2 +- verl/workers/config/engine.py | 4 - verl/workers/megatron_workers.py | 24 ++--- verl/workers/rollout/vllm_rollout/utils.py | 32 +++++-- .../rollout/vllm_rollout/vllm_async_server.py | 30 ++++-- 12 files changed, 221 insertions(+), 230 deletions(-) create mode 100644 verl/utils/modelopt/__init__.py create mode 100644 verl/utils/modelopt/qat.py rename verl/utils/{modelopt_vllm_utils.py => modelopt/vllm_patch.py} (94%) rename verl/utils/{modelopt_qat_utils.py => modelopt/weight_processor.py} (92%) diff --git a/verl/trainer/config/actor/actor.yaml b/verl/trainer/config/actor/actor.yaml index bffe8aec484..07cad10391b 100644 --- a/verl/trainer/config/actor/actor.yaml +++ b/verl/trainer/config/actor/actor.yaml @@ -259,3 +259,35 @@ router_replay: # Required when mode is 'replay' replay_file: null +# QAT (Quantization-Aware Training) configuration +# When enabled: +# - QAT is automatically applied to actor model during training +# - Fused scales (QKV/GateUp) are automatically enabled for training-inference consistency +# - Fast quantization is used when syncing weights to vLLM rollout +# Supported modes: "w4a16" (NVFP4 weight-only) +# Note: "w4a4" mode is included in the code but currently has KL divergence issues and is NOT recommended for use. +# For usage examples, see: https://github.com/verl-project/verl-recipe/blob/main/qat/README.md +qat: + + # Whether to enable QAT + enable: false + + # Quantization mode: "w4a16" (weight-only). "w4a4" is experimental and not recommended. + mode: "w4a16" + + # Quantization group size (NVFP4 requires 16) + group_size: 16 + + # Patterns to ignore (e.g., lm_head, embed_tokens) + ignore_patterns: + + - "lm_head" + - "embed_tokens" + - "re:.*mlp.gate$" + + # Activation observer for W4A4 mode: "static_minmax", "memoryless_minmax", or "minmax" + activation_observer: "static_minmax" + + # Path to vLLM quantization config JSON file + quantization_config_path: null + diff --git a/verl/trainer/config/actor/dp_actor.yaml b/verl/trainer/config/actor/dp_actor.yaml index 7fbe49c019e..fc0a16be609 100644 --- a/verl/trainer/config/actor/dp_actor.yaml +++ b/verl/trainer/config/actor/dp_actor.yaml @@ -48,35 +48,3 @@ calculate_sum_pi_squared: False # Enable gradient checkpointing for sum_pi_squared computation (saves memory) sum_pi_squared_checkpointing: False - -# QAT (Quantization-Aware Training) configuration -# When enabled: -# - QAT is automatically applied to actor model during training -# - Fused scales (QKV/GateUp) are automatically enabled for training-inference consistency -# - Fast quantization is used when syncing weights to vLLM rollout -# Supported modes: "w4a16" (NVFP4 weight-only) -# Note: "w4a4" mode is included in the code but currently has KL divergence issues and is NOT recommended for use. -# For usage examples, see: https://github.com/verl-project/verl-recipe/blob/main/qat/README.md -qat: - - # Whether to enable QAT - enable: false - - # Quantization mode: "w4a16" (weight-only). "w4a4" is experimental and not recommended. - mode: "w4a16" - - # Quantization group size (NVFP4 requires 16) - group_size: 16 - - # Patterns to ignore (e.g., lm_head, embed_tokens) - ignore_patterns: - - - "lm_head" - - "embed_tokens" - - "re:.*mlp.gate$" - - # Activation observer for W4A4 mode: "static_minmax", "memoryless_minmax", or "minmax" - activation_observer: "static_minmax" - - # Path to vLLM quantization config JSON file - quantization_config_path: null diff --git a/verl/trainer/config/engine/megatron.yaml b/verl/trainer/config/engine/megatron.yaml index 23c40178397..f9b9f903f14 100644 --- a/verl/trainer/config/engine/megatron.yaml +++ b/verl/trainer/config/engine/megatron.yaml @@ -79,12 +79,6 @@ override_transformer_config: # Attention backend to use (flash,fused,unfused,local,auto). Defaults to auto in mcore, flash in verl attention_backend: flash -# # Quantization method. None for no quantization, "nvfp4" for NVFP4 quantization -quantization: null - -# Whether to enable Quantization-Aware Training (QAT). Default False. -enable_qat: False - override_mcore_model_config: {} # oc.select: default val for ref.megatron.use_mbridge diff --git a/verl/utils/modelopt/__init__.py b/verl/utils/modelopt/__init__.py new file mode 100644 index 00000000000..ebf63d1e0bf --- /dev/null +++ b/verl/utils/modelopt/__init__.py @@ -0,0 +1,44 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ModelOpt integration for verl. + +Supports NVFP4 quantization with Megatron QAT training + vLLM low-precision inference. + +Module Structure: +- qat.py: QAT quantization config, apply_qat, QuantizationMetadata +- weight_processor.py: QATWeightPostProcessor for converting QAT weights to quantized format +- vllm_patch.py: vLLM monkey patches for NVFP4 inference (Linear, MoE, KV Cache) + +Usage: + # Training side + from verl.utils.modelopt import apply_qat, QATWeightPostProcessor + + # Inference side + from verl.utils.modelopt import apply_vllm_modelopt_patches +""" + +from verl.utils.modelopt.qat import NVFP4_WEIGHT_ONLY_CFG, QuantizationMetadata, apply_qat +from verl.utils.modelopt.vllm_patch import apply_vllm_modelopt_patches +from verl.utils.modelopt.weight_processor import QATWeightPostProcessor + +__all__ = [ + "NVFP4_WEIGHT_ONLY_CFG", + "apply_qat", + "QuantizationMetadata", + "QATWeightPostProcessor", + "apply_vllm_modelopt_patches", +] diff --git a/verl/utils/modelopt/qat.py b/verl/utils/modelopt/qat.py new file mode 100644 index 00000000000..3b0e65001ec --- /dev/null +++ b/verl/utils/modelopt/qat.py @@ -0,0 +1,81 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Any, Optional + +import torch +import torch.nn as nn + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg + +# --------------------------------------------------------------------------- +# NVFP4 quantization config +# --------------------------------------------------------------------------- + +NVFP4_WEIGHT_ONLY_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": {"enable": False}, + **_default_disabled_quantizer_cfg, + }, + "algorithm": "max", +} + +# --------------------------------------------------------------------------- +# QAT application +# --------------------------------------------------------------------------- + + +def apply_qat(model: nn.Module, qat_mode: str): + """Apply Quantization-Aware Training to the model. + + Args: + model: The Megatron model to apply QAT to. + qat_mode: QAT mode, now only support "w4a16" for weight-only quantization. + + Returns: + The quantized model. + """ + if qat_mode != "w4a16": + raise ValueError(f"Only 'w4a16' is supported, got: {qat_mode}") + + mtq.quantize(model, NVFP4_WEIGHT_ONLY_CFG) + return model + + +@dataclass +class QuantizationMetadata: + """Metadata for a quantized module.""" + + qformat: str + weight_quantizer: Any + input_quantizer: Any + module: torch.nn.Module + vpp_idx: int + block_size: int = 16 # Default NVFP4 block size + # Fields for EP synchronization - store amax values for non-local experts + weight_amax: Optional[torch.Tensor] = None + input_amax: Optional[torch.Tensor] = None + is_local: bool = True # Whether this expert is local to current EP rank + global_expert_idx: Optional[int] = None # Global expert index for MoE experts + local_expert_idx: Optional[int] = None # Local expert index on this EP rank diff --git a/verl/utils/modelopt_vllm_utils.py b/verl/utils/modelopt/vllm_patch.py similarity index 94% rename from verl/utils/modelopt_vllm_utils.py rename to verl/utils/modelopt/vllm_patch.py index 18bbaca5ced..7b2d695af45 100644 --- a/verl/utils/modelopt_vllm_utils.py +++ b/verl/utils/modelopt/vllm_patch.py @@ -23,72 +23,6 @@ from torch.nn import Parameter -def generate_nvfp4_ignore_list(num_layers: int, is_moe: bool) -> list[str]: - """ - Generate the ignore list for NVFP4 quantization based on model configuration. - - Args: - num_layers: Number of hidden layers in the model (from hf_config.num_hidden_layers) - is_moe: Whether the model is a Mixture of Experts model - - Returns: - List of layer names to ignore during quantization - """ - ignore_list = [] - - # For MoE models, ignore the gate layers (routing layers) - if is_moe: - for layer_idx in range(num_layers): - ignore_list.append(f"model.layers.{layer_idx}.mlp.gate") - - # Always ignore lm_head for stability - ignore_list.append("lm_head") - - return ignore_list - - -def get_nvfp4_block_quant_kwargs(num_layers: int, is_moe: bool) -> dict: - """ - Generate complete NVFP4 quantization configuration based on model properties. - Args: - num_layers: Number of hidden layers in the model (from hf_config.num_hidden_layers) - is_moe: Whether the model is a Mixture of Experts model - - Returns: - Complete quantization configuration dictionary compatible with ModelOpt - """ - ignore_list = generate_nvfp4_ignore_list(num_layers, is_moe) - - return { - "config_groups": { - "group_0": { - "input_activations": { - "dynamic": "false", - "num_bits": 4, - "type": "float", - "group_size": 16 - }, - "weights": { - "dynamic": "false", - "num_bits": 4, - "type": "float", - "group_size": 16 - }, - "targets": [ - "Linear" - ] - } - }, - "ignore": ignore_list, - "quant_algo": "NVFP4", - "producer": { - "name": "modelopt", - }, - "quant_method": "modelopt" - } - - - def _create_param_from_subclass_attributes(custom_data: torch.Tensor, custom_weight) -> Parameter: """ Helper to preserve custom attributes from ModelWeightParameter and @@ -838,4 +772,4 @@ def apply_vllm_modelopt_patches(): # Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" patcher5 = patch(func5_path, process_weights_after_loading_kv) - patcher5.start() \ No newline at end of file + patcher5.start() diff --git a/verl/utils/modelopt_qat_utils.py b/verl/utils/modelopt/weight_processor.py similarity index 92% rename from verl/utils/modelopt_qat_utils.py rename to verl/utils/modelopt/weight_processor.py index 7e8b63d401b..1b1380a3d6b 100644 --- a/verl/utils/modelopt_qat_utils.py +++ b/verl/utils/modelopt/weight_processor.py @@ -15,13 +15,10 @@ import re -from dataclasses import dataclass from typing import Any, Iterator, Optional import torch -import torch.nn as nn -import modelopt.torch.quantization as mtq from modelopt.torch.export.quant_utils import ( QUANTIZATION_NONE, QUANTIZATION_NVFP4, @@ -32,77 +29,7 @@ from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor from verl.utils.megatron_utils import unwrap_model - -# --------------------------------------------------------------------------- -# NVFP4 quantization config -# --------------------------------------------------------------------------- - -NVFP4_WEIGHT_ONLY_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": {"enable": False}, - "nn.BatchNorm1d": {"*": {"enable": False}}, - "nn.BatchNorm2d": {"*": {"enable": False}}, - "nn.BatchNorm3d": {"*": {"enable": False}}, - "nn.LeakyReLU": {"*": {"enable": False}}, - "*lm_head*": {"enable": False}, - "*proj_out.*": {"enable": False}, # Whisper: lm_head has key name proj_out - "*block_sparse_moe.gate*": {"enable": False}, # Skip MOE router - "*router*": {"enable": False}, # Skip MOE router - "*mlp.gate.*": {"enable": False}, # Skip MOE router - "*mlp.shared_expert_gate.*": {"enable": False}, # Skip MOE router - "*linear_attn.conv1d*": {"enable": False}, - "*mixer.conv1d*": {"enable": False}, - "*output_layer*": {"enable": False}, - "output.*": {"enable": False}, - "default": {"enable": False}, - }, - "algorithm": "max", -} - -# --------------------------------------------------------------------------- -# QAT application -# --------------------------------------------------------------------------- - - -def apply_qat(model: nn.Module, quant_method: str): - """Apply Quantization-Aware Training to the model. - - Args: - model: The Megatron model to apply QAT to. - quant_method: Quantization method (currently only ``"nvfp4"`` is supported). - - Returns: - The quantized model. - """ - if quant_method != "nvfp4": - raise ValueError(f"Only 'nvfp4' is supported, got: {quant_method}") - - mtq.quantize(model, NVFP4_WEIGHT_ONLY_CFG) - return model - - -@dataclass -class QuantizationMetadata: - """Metadata for a quantized module.""" - - qformat: str - weight_quantizer: Any - input_quantizer: Any - module: torch.nn.Module - vpp_idx: int - block_size: int = 16 # Default NVFP4 block size - # Fields for EP synchronization - store amax values for non-local experts - weight_amax: Optional[torch.Tensor] = None - input_amax: Optional[torch.Tensor] = None - is_local: bool = True # Whether this expert is local to current EP rank - global_expert_idx: Optional[int] = None # Global expert index for MoE experts - local_expert_idx: Optional[int] = None # Local expert index on this EP rank +from verl.utils.modelopt.qat import QuantizationMetadata class QATWeightPostProcessor: @@ -131,7 +58,7 @@ class QATWeightPostProcessor: def __init__( self, actor_module: list, - quantization_method: str = "nvfp4", + qat_mode: str = "w4a16", dtype: torch.dtype = torch.bfloat16, use_calibrated_scale_2: bool = False, ): @@ -140,14 +67,14 @@ def __init__( Args: actor_module: List of QAT trained model chunks (vpp chunks) - quantization_method: Quantization method (nvfp4, fp8, etc.) + qat_mode: QAT mode, e.g. "w4a16" or "w4a4". dtype: Original data type (bf16) use_calibrated_scale_2: If True, use QAT calibrated amax for weight_scale_2. If False, recompute weight_scale_2 from merged weights. Recommended to set False when using TP to ensure consistent global scale. """ self.actor_module = actor_module - self.quantization_method = quantization_method + self.qat_mode = qat_mode self.dtype = dtype self.use_calibrated_scale_2 = use_calibrated_scale_2 self.quant_metadata: dict[str, QuantizationMetadata] = {} @@ -399,7 +326,7 @@ def _log_initialization_info(self): global_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 print( - f"[QAT PostProcessor][Rank {global_rank}] Initialized with quantization method: {self.quantization_method}" + f"[QAT PostProcessor][Rank {global_rank}] Initialized with qat_mode: {self.qat_mode}" ) print(f"[QAT PostProcessor][Rank {global_rank}] Found {len(self.quant_metadata)} quantized parameters") if self.ep_size > 1: @@ -997,17 +924,6 @@ def _find_matching_metadata_by_hf_name(self, hf_name: str) -> QuantizationMetada if pattern in self.quant_metadata: return self.quant_metadata[pattern] - # # If no exact match, try to find any metadata from the same layer - # # This handles cases where the exact name might be slightly different - # for mcore_name, metadata in self.quant_metadata.items(): - # if f"layers.{layer_num}." in mcore_name: - # # Found a quantized module in the same layer - # # Skip router metadata - router should not be used for other layers - # if ".router." in mcore_name: - # continue - # # For QAT, if any module in the layer is quantized, all Linear layers should be - # if ".weight" in mcore_name: - # return metadata return None @@ -1053,4 +969,4 @@ def _find_non_layer_metadata(self, hf_name: str) -> QuantizationMetadata | None: if mcore_name and mcore_name in self.quant_metadata: return self.quant_metadata[mcore_name] - return None \ No newline at end of file + return None diff --git a/verl/workers/config/actor.py b/verl/workers/config/actor.py index 255acc7bdc1..b062342def8 100644 --- a/verl/workers/config/actor.py +++ b/verl/workers/config/actor.py @@ -179,6 +179,7 @@ class ActorConfig(BaseConfig): # batch_num_tokens: number of valid tokens in global batch # global_batch_size: global batch size global_batch_info: dict = field(default_factory=dict) + qat: QATConfig = field(default_factory=QATConfig) def __post_init__(self): """Validate actor configuration parameters.""" @@ -296,7 +297,6 @@ class FSDPActorConfig(ActorConfig): use_rollout_log_probs: bool = False calculate_sum_pi_squared: bool = False sum_pi_squared_checkpointing: bool = False - qat: QATConfig = field(default_factory=QATConfig) def __post_init__(self): """Validate FSDP actor configuration parameters.""" diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index b999971b895..08d7fa293aa 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -142,8 +142,6 @@ class McoreEngineConfig(EngineConfig): override_transformer_config (dict[str, Any]): Override configuration for transformer. use_mbridge (bool): Whether to use MBridge for communication. dtype (str): Mixed precision training param dtype, default "bfloat16" - quantization (Optional[str]): Quantization method to use. None for no quantization, "nvfp4" for NVFP4 quantization. - enable_qat (bool): Whether to enable Quantization-Aware Training (QAT). Default False. """ # sequence_parallel is not listed as a frozen field for auto-correction purpose @@ -168,8 +166,6 @@ class McoreEngineConfig(EngineConfig): use_mbridge: bool = True vanilla_mbridge: bool = True strategy: str = "megatron" - quantization: Optional[str] = None - enable_qat: bool = False def __post_init__(self) -> None: super().__post_init__() diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index fea93003f51..22154c432b5 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -73,7 +73,7 @@ simple_timer, ) from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max -from verl.utils.modelopt_qat_utils import apply_qat +from verl.utils.modelopt import apply_qat from verl.utils.ray_utils import get_event_loop from verl.utils.torch_functional import use_original_torch_compile from verl.workers.actor.megatron_actor import MegatronPPOActor @@ -221,13 +221,11 @@ def _init_hf_config_and_tf_config( provider.moe_token_dispatcher_type = "alltoall" provider.moe_router_load_balancing_type = "none" - enable_qat = self.config.actor.megatron.get("enable_qat", False) - if enable_qat: + qat_enabled = self.config.actor.get("qat", {}).get("enable", False) + if qat_enabled: from megatron.bridge.models.gpt_provider import quantization_layer_spec provider.transformer_layer_spec = quantization_layer_spec - # Patch megatron-core MLP to support singleton_local_shards - # in SwiGLU sharded state dict (required for QAT checkpointing) from verl.models.mcore.qat_patch import apply_qat_patch apply_qat_patch() @@ -459,11 +457,11 @@ def _build_model_optimizer( if self.rank == 0: print_model_size(actor_module[0]) log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) - quantization = self.config.actor.megatron.get("quantization", None) - enable_qat = self.config.actor.megatron.get("enable_qat", False) - if quantization is not None and enable_qat: + qat_config = self.config.actor.get("qat", {}) + if qat_config.get("enable", False): + qat_mode = qat_config.get("mode", "w4a16") for i in range(len(actor_module)): - actor_module[i] = apply_qat(actor_module[i], quantization) + actor_module[i] = apply_qat(actor_module[i], qat_mode) elif self._is_ref: wrap_config = McoreModuleWrapperConfig( is_value_model=False, # ref is not value model @@ -736,11 +734,13 @@ async def rollout_mode(self): self.tf_config, self.layer_name_mapping, ) - if self.config.actor.megatron.get("enable_qat", False): - from verl.utils.modelopt_qat_utils import QATWeightPostProcessor + qat_config = self.config.actor.get("qat", {}) + if qat_config.get("enable", False): + from verl.utils.modelopt import QATWeightPostProcessor + qat_mode = qat_config.get("mode", "w4a16") qat_weight_post_processor = QATWeightPostProcessor( - self.actor.actor_module, "nvfp4" + self.actor.actor_module, qat_mode ) per_tensor_param = qat_weight_post_processor.process_weights_iterator(per_tensor_param) diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index db134b318a8..ac280e1a261 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -30,7 +30,7 @@ from verl.utils.vllm import TensorLoRARequest, VLLMHijack from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights -from verl.utils.modelopt_vllm_utils import apply_vllm_modelopt_patches +from verl.utils.modelopt import apply_vllm_modelopt_patches logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -167,11 +167,16 @@ def __new__(cls, **kwargs): vllm_config = kwargs.get("vllm_config") quant_config = getattr(vllm_config, "quant_config", None) if vllm_config else None _is_qat_model = getattr(quant_config, "quant_format", None) == "nvfp4-pack-quantized" + _is_modelopt_qat = type(quant_config).__name__ == "ModelOptNvFp4Config" + print(f"type(quant_config).__name__: {type(quant_config).__name__} _is_modelopt_qat: {_is_modelopt_qat}") if _is_qat_model: from verl.utils.qat import apply_qat_patches apply_qat_patches() - logger.info("Applied QAT patches in vLLM worker subprocess") + logger.info("Applied QAT (compressed-tensors) patches in vLLM worker subprocess") + elif _is_modelopt_qat: + apply_vllm_modelopt_patches() + logger.info("Applied QAT (modelopt) patches in vLLM worker subprocess") # TODO: For ascend NPU, when the corresponding vllm-ascend version is upgraded to v0.13.0, # please remove the VLLM_ASCEND_REQUIRED_ENV_VARS variable replacement action. @@ -183,6 +188,7 @@ def __new__(cls, **kwargs): instance = super().__new__(cls) instance._is_qat_model = _is_qat_model + instance._is_modelopt_qat = _is_modelopt_qat return instance def monkey_patch_model(self, vocab_size: int): @@ -259,11 +265,18 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False break if self._is_qat_model: - # QAT: call process_weights_after_loading AFTER all buckets are received + # QAT (compressed-tensors): call process_weights_after_loading AFTER all buckets are received from verl.utils.qat import manual_process_weights_after_loading manual_process_weights_after_loading(self.model_runner.model) logger.info("QAT: process_weights_after_loading completed") + elif self._is_modelopt_qat: + from vllm.model_executor.model_loader.utils import process_weights_after_loading + + model = self.model_runner.model + model_config = self.model_runner.vllm_config.model_config + process_weights_after_loading(model, model_config, self.device) + logger.info("ModelOpt QAT: process_weights_after_loading completed") elif use_standard_weight_load: # Some post-load transforms are non-idempotent; run once after all buckets. from vllm.model_executor.model_loader.utils import process_weights_after_loading @@ -307,11 +320,14 @@ def _update_weights(self, weights: list[tuple[str, torch.Tensor]], peft_config: logger.info("Loading standard weights (non-FP8, async)") self.model_runner.model.load_weights(weights) - from vllm.model_executor.model_loader.utils import process_weights_after_loading - model_config = self.model_runner.vllm_config.model_config - device = next(self.model_runner.model.parameters()).device - process_weights_after_loading(self.model_runner.model, model_config, device) - # from vllm.model_executor.layers.quantization.modelopt import ModelOptNvFp4LinearMethod + if not getattr(self, '_is_modelopt_qat', False): + # Skip per-bucket process_weights_after_loading for modelopt QAT + # because the patched version is not idempotent (swizzle, etc.). + # It will be called once after all buckets in update_weights_from_ipc. + from vllm.model_executor.model_loader.utils import process_weights_after_loading + model_config = self.model_runner.vllm_config.model_config + device = next(self.model_runner.model.parameters()).device + process_weights_after_loading(self.model_runner.model, model_config, device) def _get_zmq_handle(self) -> str: """Get ZMQ handle for communication.""" diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 0d762116235..7b50e5ef23a 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -61,8 +61,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser if _VLLM_VERSION == version.parse("0.12.0"): - pass - # from vllm.entrypoints.harmony_utils import get_encoding + from vllm.entrypoints.harmony_utils import get_encoding elif _VLLM_VERSION >= version.parse("0.13.0"): from vllm.entrypoints.openai.parser.harmony_utils import get_encoding @@ -238,18 +237,29 @@ async def launch_server(self, master_address: str = None, master_port: int = Non # Handle QAT (Quantization-Aware Training) configuration qat_config_dict = getattr(self.config, "qat", {}) or {} if qat_config_dict.get("enable", False): - # QAT uses compressed-tensors quantization, apply patches for dynamic weight loading - from verl.utils.qat import QATConfig, apply_qat_patches, load_quantization_config + from verl.utils.qat import QATConfig, load_quantization_config - apply_qat_patches() - - # Load quantization config from JSON file qat_config = QATConfig(**qat_config_dict) quantization_config_dict = load_quantization_config(qat_config) - hf_overrides["quantization_config"] = quantization_config_dict - quantization = "compressed-tensors" + quant_method = quantization_config_dict.get("quant_method", None) + + if quant_method == "modelopt": + + from verl.utils.modelopt import apply_vllm_modelopt_patches - logger.info("QAT quantization config injected to vLLM async server") + apply_vllm_modelopt_patches() + quantization = "modelopt" + elif quant_method == "compressed-tensors": + from verl.utils.qat import apply_qat_patches + + apply_qat_patches() + quantization = "compressed-tensors" + + else: + raise ValueError(f"Unsupported quant_method: {quant_method}") + logger.info(f"QAT quantization config injected (quant_method={quant_method})") + hf_overrides["quantization_config"] = quantization_config_dict + print(f"quantization config: {quantization_config_dict}") elif quantization is not None: # Handle other quantization methods (fp8, torchao) _SUPPORTED_QUANTIZATION = ["fp8", "torchao"] From ce4d53ddf8a4bc7b1783cfa12c42d03394c950b6 Mon Sep 17 00:00:00 2001 From: larkz-nv Date: Wed, 25 Feb 2026 17:34:34 -0800 Subject: [PATCH 07/10] Rename modelopt modules and modify vllm patch logic --- verl/utils/modelopt/__init__.py | 28 +- verl/utils/modelopt/qat.py | 81 -- verl/utils/modelopt/vllm_patch.py | 775 ------------------ verl/utils/modelopt/weight_processor.py | 2 +- verl/workers/megatron_workers.py | 3 +- verl/workers/rollout/vllm_rollout/utils.py | 21 +- .../rollout/vllm_rollout/vllm_async_server.py | 5 +- 7 files changed, 37 insertions(+), 878 deletions(-) delete mode 100644 verl/utils/modelopt/qat.py delete mode 100644 verl/utils/modelopt/vllm_patch.py diff --git a/verl/utils/modelopt/__init__.py b/verl/utils/modelopt/__init__.py index ebf63d1e0bf..a250acb18ae 100644 --- a/verl/utils/modelopt/__init__.py +++ b/verl/utils/modelopt/__init__.py @@ -19,26 +19,38 @@ Supports NVFP4 quantization with Megatron QAT training + vLLM low-precision inference. Module Structure: -- qat.py: QAT quantization config, apply_qat, QuantizationMetadata +- quantize.py: Quantization config builder, apply_qat, QuantizationMetadata - weight_processor.py: QATWeightPostProcessor for converting QAT weights to quantized format -- vllm_patch.py: vLLM monkey patches for NVFP4 inference (Linear, MoE, KV Cache) +- vllm_modelopt_patch.py: vLLM monkey patches for ModelOpt NVFP4 inference (Linear, MoE, KV Cache) Usage: # Training side from verl.utils.modelopt import apply_qat, QATWeightPostProcessor - # Inference side - from verl.utils.modelopt import apply_vllm_modelopt_patches + # Inference side (dynamic weight reload lifecycle) + from verl.utils.modelopt import apply_modelopt_nvfp4_patches, prepare_modelopt_for_weight_reload, modelopt_process_weights_after_loading """ -from verl.utils.modelopt.qat import NVFP4_WEIGHT_ONLY_CFG, QuantizationMetadata, apply_qat -from verl.utils.modelopt.vllm_patch import apply_vllm_modelopt_patches +from verl.utils.modelopt.quantize import ( + # DEFAULT_IGNORE_PATTERNS, + QuantizationMetadata, + apply_qat, + build_quantize_config, +) +from verl.utils.modelopt.vllm_modelopt_patch import ( + apply_modelopt_nvfp4_patches, + modelopt_process_weights_after_loading, + prepare_modelopt_for_weight_reload, +) from verl.utils.modelopt.weight_processor import QATWeightPostProcessor + __all__ = [ - "NVFP4_WEIGHT_ONLY_CFG", + "build_quantize_config", "apply_qat", "QuantizationMetadata", "QATWeightPostProcessor", - "apply_vllm_modelopt_patches", + "apply_modelopt_nvfp4_patches", + "prepare_modelopt_for_weight_reload", + "modelopt_process_weights_after_loading", ] diff --git a/verl/utils/modelopt/qat.py b/verl/utils/modelopt/qat.py deleted file mode 100644 index 3b0e65001ec..00000000000 --- a/verl/utils/modelopt/qat.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from dataclasses import dataclass -from typing import Any, Optional - -import torch -import torch.nn as nn - -import modelopt.torch.quantization as mtq -from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg - -# --------------------------------------------------------------------------- -# NVFP4 quantization config -# --------------------------------------------------------------------------- - -NVFP4_WEIGHT_ONLY_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, - "algorithm": "max", -} - -# --------------------------------------------------------------------------- -# QAT application -# --------------------------------------------------------------------------- - - -def apply_qat(model: nn.Module, qat_mode: str): - """Apply Quantization-Aware Training to the model. - - Args: - model: The Megatron model to apply QAT to. - qat_mode: QAT mode, now only support "w4a16" for weight-only quantization. - - Returns: - The quantized model. - """ - if qat_mode != "w4a16": - raise ValueError(f"Only 'w4a16' is supported, got: {qat_mode}") - - mtq.quantize(model, NVFP4_WEIGHT_ONLY_CFG) - return model - - -@dataclass -class QuantizationMetadata: - """Metadata for a quantized module.""" - - qformat: str - weight_quantizer: Any - input_quantizer: Any - module: torch.nn.Module - vpp_idx: int - block_size: int = 16 # Default NVFP4 block size - # Fields for EP synchronization - store amax values for non-local experts - weight_amax: Optional[torch.Tensor] = None - input_amax: Optional[torch.Tensor] = None - is_local: bool = True # Whether this expert is local to current EP rank - global_expert_idx: Optional[int] = None # Global expert index for MoE experts - local_expert_idx: Optional[int] = None # Local expert index on this EP rank diff --git a/verl/utils/modelopt/vllm_patch.py b/verl/utils/modelopt/vllm_patch.py deleted file mode 100644 index 7b2d695af45..00000000000 --- a/verl/utils/modelopt/vllm_patch.py +++ /dev/null @@ -1,775 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Callable, Optional -from unittest.mock import patch - -import torch - -logger = logging.getLogger(__name__) -from torch.nn import Parameter - - -def _create_param_from_subclass_attributes(custom_data: torch.Tensor, custom_weight) -> Parameter: - """ - Helper to preserve custom attributes from ModelWeightParameter and - PerTensorScaleParameter when creating new Parameters. - """ - param = Parameter(custom_data, requires_grad=False) - base_param_dir = dir(torch.nn.Parameter) - custom_weight_dir = dir(custom_weight) - # Find the attributes that are unique to the custom parameter - custom_attributes = [attr for attr in custom_weight_dir if attr not in base_param_dir and not attr.startswith("__")] - # Set the custom attributes into the base parameter object - for attr in custom_attributes: - setattr(param, attr, getattr(custom_weight, attr)) - return param - - -def process_weights_after_loading_modelopt(self, layer: torch.nn.Module) -> None: - import vllm._custom_ops as ops - from torch.nn import Parameter - from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_make_workspace_new, - marlin_permute_bias, - marlin_permute_scales, - ) - from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - mxfp4_marlin_process_scales, - nvfp4_marlin_process_global_scale, - nvfp4_marlin_process_scales, - ) - from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale - - def _create_param_from_subclass_attributes(custom_data, custom_weight): - param = Parameter(custom_data, requires_grad=False) - base_param_dir = dir(torch.nn.Parameter) - custom_weight_dir = dir(custom_weight) - # Find the attributes that are unique to the custom parameter - custom_attributes = [ - attr for attr in custom_weight_dir if attr not in base_param_dir and not attr.startswith("__") - ] - # Set the custom attributes into the base parameter object - for attr in custom_attributes: - setattr(param, attr, getattr(custom_weight, attr)) - - return param - - def prepare_fp4_layer_for_marlin(layer: torch.nn.Module, weight_scale_2_max: torch.Tensor) -> None: - logger.warning_once( - "Your GPU does not have native support for FP4 computation but " - "FP4 quantization is being used. Weight-only FP4 compression will " - "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads." - ) - - is_nvfp4 = hasattr(layer, "weight_scale_2") - group_size = 16 if is_nvfp4 else 32 - - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - param_dtype = layer.params_dtype - - assert layer.weight.shape == (part_size_n, part_size_k // 2) - - device = layer.weight.device - - # WORKSPACE - if getattr(layer, "workspace", None) is None: - layer.workspace = marlin_make_workspace_new(device) - - # WEIGHT - # Repack weights to marlin format - perm = torch.empty(0, dtype=torch.int, device=device) - qweight = layer.weight.view(torch.int32).T.contiguous() - - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=qweight, - perm=perm, - size_k=part_size_k, - size_n=part_size_n, - num_bits=4, - ) - layer.marlin_weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - # Permute scales - weight_scale = layer.weight_scale.T.contiguous() - - if not is_nvfp4: - weight_scale = weight_scale.view(torch.float8_e8m0fnu) - - weight_scale = weight_scale.to(param_dtype) - weight_scale = marlin_permute_scales( - s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size - ) - - if is_nvfp4: - weight_scale = nvfp4_marlin_process_scales(weight_scale) - layer.marlin_weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - - weight_scale_2 = weight_scale_2_max.to(param_dtype) - weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) - layer.marlin_weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False) - else: - weight_scale = mxfp4_marlin_process_scales(weight_scale) - layer.marlin_weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - - if hasattr(layer, "bias") and layer.bias is not None: - assert layer.bias.shape == (part_size_n,) - bias = marlin_permute_bias(layer.bias) - layer.bias = torch.nn.Parameter(bias, requires_grad=False) - - return - - # global scales: - input_scale_2 = layer.input_scale.data - layer.input_scale = _create_param_from_subclass_attributes(input_scale_2, layer.input_scale) - input_scale_2_max = input_scale_2.max().to(torch.float32) - - weight_scale_2 = layer.weight_scale_2.data - layer.weight_scale_2 = _create_param_from_subclass_attributes(weight_scale_2, layer.weight_scale_2) - weight_scale_2_max = weight_scale_2.max().to(torch.float32) - - layer.alpha = Parameter(input_scale_2_max * weight_scale_2_max, requires_grad=False) - - # Calculate `1 / input_scale` so that we don't need to do so at runtime - layer.input_scale_inv = Parameter((1 / layer.input_scale).to(torch.float32), requires_grad=False) - - # Swizzle the weight blockscale. - # contracting dimension is input dimension - # block_size = 16; - assert layer.weight_scale.dtype == torch.float8_e4m3fn, "Weight Block scale must be represented as FP8-E4M3" - - if self.backend == "marlin": - weight = layer.weight.data - weight_scale = layer.weight_scale.data - layer.weight = _create_param_from_subclass_attributes(weight, layer.weight) - layer.weight_scale = _create_param_from_subclass_attributes(weight_scale, layer.weight_scale) - prepare_fp4_layer_for_marlin(layer, weight_scale_2_max) - - del layer.alpha - # del layer.input_scale - elif self.backend == "flashinfer-trtllm": - # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. - # FlashInfer provides nvfp4_quantize to quantize + shuffle the - # layout but we use our own quantization so we have to call - # shuffles ourselves. - from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a - - weight = layer.weight.data - weight_scale = layer.weight_scale.data - - epilogue_tile_m = 128 - weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) - weight_scale = ( - shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) - .reshape(weight_scale.shape) - .view(torch.float8_e4m3fn) - ) - - layer.weight_scale = _create_param_from_subclass_attributes(weight_scale, layer.weight_scale) - layer.weight = _create_param_from_subclass_attributes(weight, layer.weight) - else: - swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - layer.weight_scale = _create_param_from_subclass_attributes(swizzled_weight_scale, layer.weight_scale) - layer.weight = _create_param_from_subclass_attributes(layer.weight.data, layer.weight) - -def apply_modelopt( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant - from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import apply_fp4_marlin_linear - from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm - - if self.backend == "marlin": - return apply_fp4_marlin_linear( - input=x, - weight=layer.marlin_weight, - weight_scale=layer.marlin_weight_scale, - weight_scale_2=layer.marlin_weight_scale_2, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias, - ) - - output_dtype = x.dtype - output_shape = [x.shape[0], layer.weight.shape[0]] - - # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) - - # validate dtypes of quantized input, input block scale, - # weight and weight_blockscale - assert x_fp4.dtype == torch.uint8 - assert layer.weight.dtype == torch.uint8 - assert x_blockscale.dtype == torch.float8_e4m3fn - assert layer.weight_scale.dtype == torch.float8_e4m3fn - assert layer.alpha.dtype == torch.float32 - - mm_args = ( - x_fp4, - layer.weight, - x_blockscale, - layer.weight_scale, - layer.alpha, - output_dtype, - ) - if self.backend == "flashinfer-trtllm": - out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") - elif self.backend == "flashinfer-cutlass": - out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") - else: - out = cutlass_scaled_fp4_mm(*mm_args) - - if bias is not None: - out = out + bias - return out.view(*output_shape) - - -# ============================================================================= -# ModelOptNvFp4FusedMoE Patches -# ============================================================================= - - -def process_weights_after_loading_moe(self, layer: torch.nn.Module) -> None: - """ - Patched process_weights_after_loading for ModelOptNvFp4FusedMoE. - - Key modifications compared to original: - 1. Preserves original weights in separate attributes (marlin_w13_weight, etc.) - 2. Uses _create_param_from_subclass_attributes to preserve parameter metadata - 3. Computes weight_scale_2_max before processing for Marlin - """ - import vllm._custom_ops as ops - from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - prepare_static_weights_for_trtllm_fp4_moe, - reorder_w1w3_to_w3w1, - ) - from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - FlashinferMoeBackend, - is_flashinfer_supporting_global_sf, - ) - from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_make_workspace_new, - marlin_permute_scales, - ) - from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - nvfp4_marlin_process_global_scale, - nvfp4_marlin_process_scales, - ) - from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale - - def prepare_moe_fp4_layer_for_marlin_patched( - layer: torch.nn.Module, - w13_weight_scale_2_per_expert: torch.Tensor, - w2_weight_scale_2_per_expert: torch.Tensor, - ) -> None: - """ - Modified prepare_moe_fp4_layer_for_marlin that: - 1. Takes per-expert weight_scale_2 values (not max!) - 2. Saves to marlin_* attributes instead of overwriting originals - - Args: - w13_weight_scale_2_per_expert: shape (num_experts,) - per-expert scales - w2_weight_scale_2_per_expert: shape (num_experts,) - per-expert scales - """ - logger.warning("Using patched prepare_moe_fp4_layer_for_marlin for NVFP4 MoE") - - group_size = 16 # NVFP4 uses group_size=16 - - e = layer.num_experts - k = layer.hidden_size - n = layer.intermediate_size_per_partition - - device = layer.w13_weight.device - param_dtype = layer.params_dtype - - # WORKSPACE - if getattr(layer, "workspace", None) is None: - layer.workspace = marlin_make_workspace_new(device, 4) - - perm = torch.empty(0, dtype=torch.int, device=device) - - # WEIGHT - Repack weights to marlin format - for name in ["w13_weight", "w2_weight"]: - weight = getattr(layer, name) - tensor_list = [] - if "w13" in name: - size_n, size_k = n * 2, k - else: - size_n, size_k = k, n - - assert weight.shape == (e, size_n, size_k // 2), ( - f"Weight shape mismatch for {name}: expected {(e, size_n, size_k // 2)}, got {weight.shape}" - ) - - for i in range(e): - qweight = weight[i].view(torch.int32).T.contiguous() - - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=qweight, - perm=perm, - size_k=size_k, - size_n=size_n, - num_bits=4, - ) - tensor_list.append(marlin_qweight) - - marlin_weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - marlin_weight = Parameter(marlin_weight, requires_grad=False) - - # Save to marlin_* attribute instead of overwriting original - marlin_attr_name = "marlin_" + name - setattr(layer, marlin_attr_name, marlin_weight) - - # WEIGHT SCALES - Permute scales - for name, weight_scale_2_per_expert in [ - ("w13", w13_weight_scale_2_per_expert), - ("w2", w2_weight_scale_2_per_expert), - ]: - scales = getattr(layer, name + "_weight_scale") - scales = scales.to(param_dtype) - - # Convert per-expert global scale to param_dtype - global_scale = weight_scale_2_per_expert.to(param_dtype) - - tensor_list = [] - if "w13" in name: - size_n, size_k = n * 2, k - else: - size_n, size_k = k, n - - for i in range(e): - scale = scales[i].T - - marlin_scales = marlin_permute_scales( - s=scale, - size_k=size_k, - size_n=size_n, - group_size=group_size, - ) - marlin_scales = nvfp4_marlin_process_scales(marlin_scales) - tensor_list.append(marlin_scales) - - marlin_scales_combined = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - marlin_scales_combined = Parameter(marlin_scales_combined, requires_grad=False) - - # Save to marlin_* attribute - setattr(layer, "marlin_" + name + "_weight_scale", marlin_scales_combined) - - # Process per-expert global scale (shape: num_experts) - global_scale = nvfp4_marlin_process_global_scale(global_scale) - global_scale = Parameter(global_scale, requires_grad=False) - setattr(layer, "marlin_" + name + "_weight_scale_2", global_scale) - - # ========== Main processing logic ========== - - # GEMM 1 processing - gemm1_weight = layer.w13_weight.data - gemm1_weight_scale = layer.w13_weight_scale.data - - if ( - self.allow_flashinfer - and ( - self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS - or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - ) - and self.moe.is_act_and_mul - ): - gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(gemm1_weight, gemm1_weight_scale, dim=-2) - - layer.w13_weight = _create_param_from_subclass_attributes(gemm1_weight, layer.w13_weight) - layer.w13_weight_scale = _create_param_from_subclass_attributes(gemm1_weight_scale, layer.w13_weight_scale) - - # Common processing for w13_weight_scale_2 - # IMPORTANT: Keep the original shape (num_experts, 2) for subsequent weight loading - # Only compute the max value for Marlin, but don't modify the original parameter shape - if self.moe.is_act_and_mul and not torch.allclose(layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]): - logger.warning("w1_weight_scale_2 must match w3_weight_scale_2. Accuracy may be affected.") - - # Keep original data and shape - DO NOT reduce dimension! - w13_weight_scale_2_data = layer.w13_weight_scale_2.data # Keep original shape: (num_experts, 2) - layer.w13_weight_scale_2 = _create_param_from_subclass_attributes(w13_weight_scale_2_data, layer.w13_weight_scale_2) - # Get per-expert scales (shape: num_experts) for Marlin - NOT the max! - # This is what the original code uses after reducing [:, 0] - w13_weight_scale_2_per_expert = layer.w13_weight_scale_2[:, 0].clone() - # Also keep a 1D version for g1_alphas calculation (following original logic) - w13_weight_scale_2_1d = layer.w13_weight_scale_2[:, 0] - - # Common processing for input scales and alphas - # IMPORTANT: Keep original input_scale shapes for subsequent weight loading - use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(self.flashinfer_moe_backend) - - # Keep original w13_input_scale data and shape - w13_input_scale_data = layer.w13_input_scale.data - layer.w13_input_scale = _create_param_from_subclass_attributes(w13_input_scale_data, layer.w13_input_scale) - - # Compute derived values for runtime use - if use_global_sf: - w13_input_scale_for_alpha = layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts) - else: - w13_input_scale_for_alpha = layer.w13_input_scale.max(dim=1).values.to(torch.float32) - - layer.g1_alphas = Parameter( - (w13_input_scale_for_alpha * w13_weight_scale_2_1d).to(torch.float32), - requires_grad=False, - ) - - # This is for quantization, so we need to invert it. - layer.w13_input_scale_quant = Parameter((1 / w13_input_scale_for_alpha).to(torch.float32), requires_grad=False) - - # GEMM 2 processing - # Keep original w2_weight_scale_2 data and shape - w2_weight_scale_2_data = layer.w2_weight_scale_2.data - layer.w2_weight_scale_2 = _create_param_from_subclass_attributes(w2_weight_scale_2_data, layer.w2_weight_scale_2) - # Get per-expert scales (shape: num_experts) for Marlin - NOT the max! - w2_weight_scale_2_per_expert = layer.w2_weight_scale_2.clone() - - # Keep original w2_input_scale data and shape - w2_input_scale_data = layer.w2_input_scale.data - layer.w2_input_scale = _create_param_from_subclass_attributes(w2_input_scale_data, layer.w2_input_scale) - - if use_global_sf: - w2_input_scale_for_alpha = layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts) - else: - w2_input_scale_for_alpha = layer.w2_input_scale - layer.g2_alphas = Parameter( - (w2_input_scale_for_alpha * layer.w2_weight_scale_2).to(torch.float32), - requires_grad=False, - ) - - # This is for quantization, so we need to invert it. - layer.w2_input_scale_quant = Parameter((1 / w2_input_scale_for_alpha).to(torch.float32), requires_grad=False) - - # ========== Backend-specific processing ========== - - if self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: - # TensorRT-LLM specific processing - ( - gemm1_weights_fp4_shuffled, - gemm1_scales_fp4_shuffled, - gemm2_weights_fp4_shuffled, - gemm2_scales_fp4_shuffled, - ) = prepare_static_weights_for_trtllm_fp4_moe( - layer.w13_weight, - layer.w2_weight, - layer.w13_weight_scale, - layer.w2_weight_scale, - layer.w2_weight.size(-2), # hidden_size - layer.w13_weight.size(-2) // 2, # intermediate_size - layer.w13_weight.size(0), # num_experts - ) - logger.debug("Finished shuffling weights for TRT-LLM MOE") - - layer.gemm1_weights_fp4_shuffled = Parameter(gemm1_weights_fp4_shuffled, requires_grad=False) - layer.gemm2_weights_fp4_shuffled = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False) - layer.gemm1_scales_fp4_shuffled = Parameter(gemm1_scales_fp4_shuffled, requires_grad=False) - layer.gemm2_scales_fp4_shuffled = Parameter(gemm2_scales_fp4_shuffled, requires_grad=False) - - # Additional parameter needed for TRT-LLM - layer.g1_scale_c = Parameter( - (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), - requires_grad=False, - ) - - # Clean up weights that won't be used by TRT-LLM - del layer.w2_weight - del layer.w2_weight_scale - del layer.w13_weight - del layer.w13_weight_scale - - elif self.use_marlin: - # Marlin processing - use patched version - # Pass per-expert scales (shape: num_experts), NOT scalar max values! - prepare_moe_fp4_layer_for_marlin_patched(layer, w13_weight_scale_2_per_expert, w2_weight_scale_2_per_expert) - # Delete attributes not needed for Marlin - del layer.g1_alphas - del layer.g2_alphas - del layer.w13_input_scale_quant - del layer.w2_input_scale_quant - - else: - # Non-TRT-LLM processing (Cutlass or non-flashinfer) - w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) - layer.w13_weight_scale = Parameter(w13_blockscale_swizzled, requires_grad=False) - - w13_weight = layer.w13_weight - intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1) - if intermediate_size_pad: - # padding gated activations will require to split w1 and w3 - # and pad them individually - assert not self.moe.is_act_and_mul, ( - "The intermediate size required padding, but padding is not implemented for gated activations" - ) - - layer.w13_weight = Parameter( - torch.nn.functional.pad(w13_weight, (0, 0, 0, intermediate_size_pad)), - requires_grad=False, - ) - layer.w2_weight = Parameter( - torch.nn.functional.pad(layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)), - requires_grad=False, - ) - layer.w2_weight_scale = Parameter( - torch.nn.functional.pad(layer.w2_weight_scale, (0, intermediate_size_pad // 16)), - requires_grad=False, - ) - - w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) - layer.w2_weight_scale = Parameter(w2_blockscale_swizzled, requires_grad=False) - - -def apply_moe( - self, - layer, # FusedMoE - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, -) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - """ - Patched apply method for ModelOptNvFp4FusedMoE. - - Key modification for Marlin: Uses marlin_* attributes instead of originals. - """ - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe - from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - flashinfer_trtllm_fp4_moe, - ) - from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - FlashinferMoeBackend, - ) - from vllm.scalar_type import scalar_types - - if not self.moe.is_act_and_mul: - assert self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS, ( - "Non-gated activations are only supported by the flashinfer CUTLASS backend for modelopt checkpoints" - ) - - if self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: - if enable_eplb: - raise NotImplementedError("EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") - return flashinfer_trtllm_fp4_moe( - layer=layer, - x=x, - router_logits=router_logits, - top_k=top_k, - global_num_experts=global_num_experts, - num_expert_group=num_expert_group, - topk_group=topk_group, - custom_routing_function=custom_routing_function, - e_score_correction_bias=e_score_correction_bias, - ) - - topk_weights, topk_ids, _ = layer.select_experts( - hidden_states=x, - router_logits=router_logits, - ) - - if self.use_marlin: - # Use marlin_* attributes instead of original attributes - return fused_marlin_moe( - x, - layer.marlin_w13_weight, - layer.marlin_w2_weight, - None, # bias1 - None, # bias2 - layer.marlin_w13_weight_scale, - layer.marlin_w2_weight_scale, - router_logits, - topk_weights, - topk_ids, - quant_type_id=scalar_types.float4_e2m1f.id, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, - global_scale1=layer.marlin_w13_weight_scale_2, - global_scale2=layer.marlin_w2_weight_scale_2, - workspace=layer.workspace, - input_dtype=self.marlin_input_dtype, - ) - - elif self.allow_flashinfer: - assert self.flashinfer_moe_backend in ( - FlashinferMoeBackend.CUTLASS, - FlashinferMoeBackend.CUTEDSL, - ) - if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - flashinfer_cutlass_moe_fp4, - ) - - flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4 - else: - from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( - flashinfer_cutedsl_moe_fp4, - ) - - flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4 - - assert self.moe_quant_config is not None - return flashinfer_fn_moe_fp4( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - quant_config=self.moe_quant_config, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - else: - # If no modular kernel is provided, use cutlass_moe_fp4 for TP case - # only (no EP). - from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 - - assert self.moe_quant_config is not None - return cutlass_moe_fp4( - a=x, - w1_fp4=layer.w13_weight, - w2_fp4=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - quant_config=self.moe_quant_config, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - m=x.shape[0], - n=layer.w2_weight.shape[2] * 2, - k=x.shape[1], - e=layer.w13_weight.shape[0], - ) - - -def process_weights_after_loading_kv(self, layer) -> None: - """Modified version of BaseKVCacheMethod.process_weights_after_loading. - - Doesn't delete k_scale, v_scale, q_scale, and prob_scale parameters to allow - for dynamic updates during refit. - """ - # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 - # regardless whether the kv-scale is available in the checkpoint. - # No need to process kv scales after loading if we are going to - # calculate them on the fly. - from vllm.platforms import current_platform - - if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: - if layer.k_scale > 0.0 and layer.v_scale > 0.0: - # We prefer to use separate k_scale and v_scale if present - k_scale = layer.k_scale.to("cpu").tolist() - v_scale = layer.v_scale.to("cpu").tolist() - if current_platform.is_fp8_fnuz(): - k_scale *= 2 - v_scale *= 2 - elif layer.k_scale < 0.0 and layer.v_scale < 0.0: - # If no scales were loaded (both scales are invalid negative - # values), use the default value of 1.0 - k_scale = 1.0 - v_scale = 1.0 - else: - # If we find a single kv_scale in the checkpoint, we remap - # kv_scale to k_scale during weight loading, and duplicate - # k_scale to v_scale here - assert layer.k_scale > 0.0 - scale_to_duplicate = max(layer.k_scale, layer.v_scale) - k_scale = scale_to_duplicate.to("cpu").tolist() - v_scale = scale_to_duplicate.to("cpu").tolist() - if current_platform.is_fp8_fnuz(): - k_scale *= 2 - v_scale *= 2 - - if not isinstance(k_scale, float) or not isinstance(v_scale, float): - raise ValueError("Only support per-tensor scaling factor for fp8 KV cache") - - if layer.q_scale < 0.0: - layer._q_scale.copy_(k_scale) - layer._q_scale_float = k_scale - - # These are used in the final Attention.forward() - layer._k_scale.copy_(k_scale) - layer._v_scale.copy_(v_scale) - layer._k_scale_float = k_scale - layer._v_scale_float = v_scale - - if layer.q_scale > 0.0: - q_scale = layer.q_scale - if current_platform.is_fp8_fnuz(): - q_scale *= 2 - layer.calculate_kv_scales = False - else: - q_scale = 1.0 - if layer.prob_scale > 0.0: - prob_scale = layer.prob_scale - if current_platform.is_fp8_fnuz(): - prob_scale *= 2 - else: - prob_scale = 1.0 - - is_singleton_float = ( - lambda x: isinstance(x, float) or isinstance(x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() - ) - if not is_singleton_float(q_scale) or not is_singleton_float(prob_scale): - raise ValueError("Only support per-tensor scaling factorfor fp8-quantized Q/prob") - - # These are used in the final Attention.forward() - layer._q_scale.copy_(q_scale) - layer._q_scale_float = q_scale.item() if isinstance(q_scale, torch.Tensor) else q_scale - - layer._prob_scale.copy_(prob_scale) - - -def apply_vllm_modelopt_patches(): - func1_path = ( - "vllm.model_executor.layers.quantization.modelopt.ModelOptNvFp4LinearMethod.process_weights_after_loading" - ) - patcher1 = patch(func1_path, process_weights_after_loading_modelopt) - patcher1.start() - func2_path = "vllm.model_executor.layers.quantization.modelopt.ModelOptNvFp4LinearMethod.apply" - patcher2 = patch(func2_path, apply_modelopt) - patcher2.start() - # Patch ModelOptNvFp4FusedMoE - func3_path = "vllm.model_executor.layers.quantization.modelopt.ModelOptNvFp4FusedMoE.process_weights_after_loading" - patcher3 = patch(func3_path, process_weights_after_loading_moe) - patcher3.start() - func4_path = "vllm.model_executor.layers.quantization.modelopt.ModelOptNvFp4FusedMoE.apply" - patcher4 = patch(func4_path, apply_moe) - patcher4.start() - # Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates - func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" - patcher5 = patch(func5_path, process_weights_after_loading_kv) - patcher5.start() diff --git a/verl/utils/modelopt/weight_processor.py b/verl/utils/modelopt/weight_processor.py index 1b1380a3d6b..8c126216da8 100644 --- a/verl/utils/modelopt/weight_processor.py +++ b/verl/utils/modelopt/weight_processor.py @@ -29,7 +29,7 @@ from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor from verl.utils.megatron_utils import unwrap_model -from verl.utils.modelopt.qat import QuantizationMetadata +from verl.utils.modelopt.quantize import QuantizationMetadata class QATWeightPostProcessor: diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 22154c432b5..c6226f8c8a8 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -460,8 +460,9 @@ def _build_model_optimizer( qat_config = self.config.actor.get("qat", {}) if qat_config.get("enable", False): qat_mode = qat_config.get("mode", "w4a16") + ignore_patterns = qat_config.get("ignore_patterns", None) for i in range(len(actor_module)): - actor_module[i] = apply_qat(actor_module[i], qat_mode) + actor_module[i] = apply_qat(actor_module[i], qat_mode, ignore_patterns=ignore_patterns) elif self._is_ref: wrap_config = McoreModuleWrapperConfig( is_value_model=False, # ref is not value model diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index ac280e1a261..df3e74e11b8 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -30,7 +30,6 @@ from verl.utils.vllm import TensorLoRARequest, VLLMHijack from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights -from verl.utils.modelopt import apply_vllm_modelopt_patches logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -168,15 +167,16 @@ def __new__(cls, **kwargs): quant_config = getattr(vllm_config, "quant_config", None) if vllm_config else None _is_qat_model = getattr(quant_config, "quant_format", None) == "nvfp4-pack-quantized" _is_modelopt_qat = type(quant_config).__name__ == "ModelOptNvFp4Config" - print(f"type(quant_config).__name__: {type(quant_config).__name__} _is_modelopt_qat: {_is_modelopt_qat}") if _is_qat_model: from verl.utils.qat import apply_qat_patches apply_qat_patches() logger.info("Applied QAT (compressed-tensors) patches in vLLM worker subprocess") elif _is_modelopt_qat: - apply_vllm_modelopt_patches() - logger.info("Applied QAT (modelopt) patches in vLLM worker subprocess") + from verl.utils.modelopt import apply_modelopt_nvfp4_patches + + apply_modelopt_nvfp4_patches() + logger.info("Applied ModelOpt NVFP4 patches in vLLM worker subprocess") # TODO: For ascend NPU, when the corresponding vllm-ascend version is upgraded to v0.13.0, # please remove the VLLM_ASCEND_REQUIRED_ENV_VARS variable replacement action. @@ -232,11 +232,16 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False ) if self._is_qat_model: - # QAT: Prepare for weight loading BEFORE receiving any buckets + # QAT (compressed-tensors): Prepare for weight loading BEFORE receiving any buckets from verl.utils.qat import prepare_qat_for_load_weights prepare_qat_for_load_weights(self.model_runner.model, device=self.device) logger.info("QAT: prepare_qat_for_load_weights completed") + elif self._is_modelopt_qat: + from verl.utils.modelopt.vllm_modelopt_patch import prepare_modelopt_for_weight_reload + + prepare_modelopt_for_weight_reload(self.model_runner.model, device=self.device) + logger.info("ModelOpt: prepare_modelopt_for_weight_reload completed") elif use_standard_weight_load: # Re-apply here because async IPC weight sync can happen long after init and lose MoE weight_loader attrs. patch_vllm_moe_model_weight_loader(self.model_runner.model) @@ -271,11 +276,9 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False manual_process_weights_after_loading(self.model_runner.model) logger.info("QAT: process_weights_after_loading completed") elif self._is_modelopt_qat: - from vllm.model_executor.model_loader.utils import process_weights_after_loading + from verl.utils.modelopt.vllm_modelopt_patch import modelopt_process_weights_after_loading - model = self.model_runner.model - model_config = self.model_runner.vllm_config.model_config - process_weights_after_loading(model, model_config, self.device) + modelopt_process_weights_after_loading(self.model_runner.model) logger.info("ModelOpt QAT: process_weights_after_loading completed") elif use_standard_weight_load: # Some post-load transforms are non-idempotent; run once after all buckets. diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 7b50e5ef23a..302fb50068f 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -244,10 +244,9 @@ async def launch_server(self, master_address: str = None, master_port: int = Non quant_method = quantization_config_dict.get("quant_method", None) if quant_method == "modelopt": + from verl.utils.modelopt import apply_modelopt_nvfp4_patches - from verl.utils.modelopt import apply_vllm_modelopt_patches - - apply_vllm_modelopt_patches() + apply_modelopt_nvfp4_patches() quantization = "modelopt" elif quant_method == "compressed-tensors": from verl.utils.qat import apply_qat_patches From 7a1ebfa4fe05e58cbd5662e34d9f503118924f6f Mon Sep 17 00:00:00 2001 From: larkz-nv Date: Wed, 25 Feb 2026 19:29:05 -0800 Subject: [PATCH 08/10] Consolidate ModelOpt modules --- verl/utils/modelopt/__init__.py | 25 +- .../modelopt/megatron_qat_patch.py} | 196 +----- verl/utils/modelopt/quantize.py | 131 ++++ verl/utils/modelopt/vllm_modelopt_patch.py | 628 ++++++++++++++++++ verl/workers/megatron_workers.py | 2 +- 5 files changed, 792 insertions(+), 190 deletions(-) rename verl/{models/mcore/qat_patch.py => utils/modelopt/megatron_qat_patch.py} (71%) create mode 100644 verl/utils/modelopt/quantize.py create mode 100644 verl/utils/modelopt/vllm_modelopt_patch.py diff --git a/verl/utils/modelopt/__init__.py b/verl/utils/modelopt/__init__.py index a250acb18ae..5b1750e0378 100644 --- a/verl/utils/modelopt/__init__.py +++ b/verl/utils/modelopt/__init__.py @@ -13,26 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -ModelOpt integration for verl. - -Supports NVFP4 quantization with Megatron QAT training + vLLM low-precision inference. - -Module Structure: -- quantize.py: Quantization config builder, apply_qat, QuantizationMetadata -- weight_processor.py: QATWeightPostProcessor for converting QAT weights to quantized format -- vllm_modelopt_patch.py: vLLM monkey patches for ModelOpt NVFP4 inference (Linear, MoE, KV Cache) - -Usage: - # Training side - from verl.utils.modelopt import apply_qat, QATWeightPostProcessor - - # Inference side (dynamic weight reload lifecycle) - from verl.utils.modelopt import apply_modelopt_nvfp4_patches, prepare_modelopt_for_weight_reload, modelopt_process_weights_after_loading -""" +"""ModelOpt integration for NVFP4 quantization with Megatron QAT training and vLLM inference.""" from verl.utils.modelopt.quantize import ( - # DEFAULT_IGNORE_PATTERNS, QuantizationMetadata, apply_qat, build_quantize_config, @@ -43,6 +26,10 @@ prepare_modelopt_for_weight_reload, ) from verl.utils.modelopt.weight_processor import QATWeightPostProcessor +from verl.utils.modelopt.megatron_qat_patch import ( + apply_qat_patch, + revert_qat_patch, +) __all__ = [ @@ -53,4 +40,6 @@ "apply_modelopt_nvfp4_patches", "prepare_modelopt_for_weight_reload", "modelopt_process_weights_after_loading", + "apply_qat_patch", + "revert_qat_patch", ] diff --git a/verl/models/mcore/qat_patch.py b/verl/utils/modelopt/megatron_qat_patch.py similarity index 71% rename from verl/models/mcore/qat_patch.py rename to verl/utils/modelopt/megatron_qat_patch.py index ad6475400ae..025686394db 100644 --- a/verl/models/mcore/qat_patch.py +++ b/verl/utils/modelopt/megatron_qat_patch.py @@ -13,80 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Runtime patches for QAT (Quantization-Aware Training) with Megatron-Core. - -This module provides four independent monkey-patches that fix issues in older -versions of megatron-core / megatron-bridge when running QAT workflows: - -1. **SwiGLU sharded-state-dict patch** (``apply_swiglu_sharded_factory_patch``) - Older megatron-core raises ``NotImplementedError`` inside - ``apply_swiglu_sharded_factory`` when ``singleton_local_shards=True``. - The patch adds correct handling by splitting the sharded tensor key into - separate ``{key}_w`` / ``{key}_v`` entries. - -2. **EP gather_from_ep_ranks patch** (``apply_ep_gather_patch``) - The original ``MegatronParamMapping.gather_from_ep_ranks`` only supports - the TEGroupedMLP naming pattern (``weight`` / ``bias``). The patch - additionally supports the SequentialMLP pattern (``local_experts.``) - and adds better error handling. - -3. **extract_sort_key patch** (``apply_extract_sort_key_patch``) - The original ``extract_sort_key`` in megatron-bridge utils only recognises - expert numbers in TEGroupedMLP format (``weight`` / ``bias``). The - patch adds fallback support for the SequentialMLP pattern - (``local_experts.``). - -4. **_megatron_local_name_to_global patch** - (``apply_local_name_to_global_patch``) - The original ``_megatron_local_name_to_global`` only converts local - expert numbers to global for the TEGroupedMLP pattern - (``mlp.experts.linear_fc`` + ``weight``/``bias``). The patch - adds support for the SequentialMLP pattern - (``mlp.experts.local_experts.``). Without this, expert numbers - remain local (e.g. 0-15 for 128 experts with EP=8) instead of being - mapped to global indices (0-127). - -5. **build_conversion_tasks patch** (``apply_build_conversion_tasks_patch``) - The original ``MegatronModelBridge.build_conversion_tasks`` may return - ``None`` entries in the task list (for PP ranks that don't own certain - parameters and have no mapping). The patch filters out ``None`` entries - before returning so that callers never need to guard against them. - -6. **AutoMapping._detect_parallelism_type patch** - (``apply_detect_parallelism_type_patch``) - The original ``_detect_parallelism_type`` only matches - ``module_type == "TELayerNormColumnParallelLinear"`` exactly. ModelOpt - quantised wrappers produce class names like - ``QuantTELayerNormColumnParallelLinear`` that contain the substring but - don't match exactly. The patch broadens the check to - ``"LayerNormColumnParallelLinear" in module_type``. - -Convenience entry-point:: - - from verl.models.mcore.qat_patch import apply_qat_patch - apply_qat_patch() # applies all patches at once +"""Megatron-Core / megatron-bridge monkey patches for QAT workflows. + +Patches SwiGLU sharded state-dict, EP gather, extract_sort_key, local-to-global +name mapping, build_conversion_tasks, and parallelism type detection to support +SequentialMLP and quantised wrappers. """ import gc import logging import re -from typing import Dict, Iterable, List, Optional +from typing import Iterable, Optional import torch logger = logging.getLogger(__name__) -# ====================================================================== -# 1. SwiGLU sharded-state-dict patch -# ====================================================================== - def apply_swiglu_sharded_factory_patch(): - """Patch ``megatron.core.transformer.mlp.apply_swiglu_sharded_factory`` - to support ``singleton_local_shards`` for SwiGLU MLP tensors. - - Idempotent – safe to call multiple times. - """ + """Patch ``apply_swiglu_sharded_factory`` to support ``singleton_local_shards``.""" import megatron.core.transformer.mlp as mlp_module from megatron.core.dist_checkpointing import ShardedTensor from megatron.core.dist_checkpointing.mapping import ( @@ -190,18 +135,8 @@ def revert_swiglu_sharded_factory_patch(): logger.info("Reverted QAT patch: apply_swiglu_sharded_factory.") -# ====================================================================== -# 2. EP gather_from_ep_ranks patch -# ====================================================================== - - def apply_ep_gather_patch(): - """Patch ``MegatronParamMapping.gather_from_ep_ranks`` in megatron-bridge - to support both SequentialMLP (``local_experts.``) and TEGroupedMLP - (``weight`` / ``bias``) naming patterns. - - Idempotent – safe to call multiple times. - """ + """Patch ``gather_from_ep_ranks`` to support SequentialMLP and TEGroupedMLP naming.""" from megatron.bridge.models.conversion.param_mapping import MegatronParamMapping if getattr(MegatronParamMapping, "_ep_gather_patched", False): @@ -214,8 +149,7 @@ def _patched_gather_from_ep_ranks( megatron_weights: Optional[torch.Tensor], megatron_module, # Optional[MegatronModule] hf_param_name: Optional[str], - ) -> Dict[str, torch.Tensor]: - """Gather expert weights across EP ranks (supports SequentialMLP + TEGroupedMLP).""" + ) -> dict[str, torch.Tensor]: if megatron_module is None: num_experts_per_rank = self.broadcast_obj_from_pp_rank(None, "num_experts_per_rank") else: @@ -226,16 +160,15 @@ def _patched_gather_from_ep_ranks( num_experts_per_rank, "num_experts_per_rank" ) - # --- Extract the local expert index from the Megatron param name --- local_expert_number = None - # Try SequentialMLP pattern first: local_experts. + # SequentialMLP pattern: local_experts. local_experts_match = re.search(r"local_experts\.(\d+)", self.megatron_param) if local_experts_match: global_expert_number = int(local_experts_match.group(1)) local_expert_number = global_expert_number % num_experts_per_rank else: - # Fallback: TEGroupedMLP pattern – weight or bias + # TEGroupedMLP pattern: weight or bias for key in (".weight", ".bias"): if key in self.megatron_param: suffix = self.megatron_param.split(key)[-1] @@ -246,12 +179,10 @@ def _patched_gather_from_ep_ranks( if local_expert_number is None: raise ValueError( - f"Could not extract expert number from parameter name: {self.megatron_param}. " - f"Expected either TEGroupedMLP pattern (weight/bias) or " - f"SequentialMLP pattern (local_experts.)." + f"Cannot extract expert number from: {self.megatron_param}. " + f"Expected TEGroupedMLP (weight/bias) or SequentialMLP (local_experts.)." ) - # Build HF param names for every EP rank gathered_expert_param_names = [ re.sub( r"experts\.(\d+)", @@ -261,16 +192,13 @@ def _patched_gather_from_ep_ranks( for i in range(self.ep_size) ] assert str(hf_param_name) in gathered_expert_param_names, ( - f"hf_param_name {hf_param_name} not in gathered_expert_param_names " - f"{gathered_expert_param_names}" + f"hf_param_name {hf_param_name} not in {gathered_expert_param_names}" ) - # All-gather across the EP group gathered_weights = [torch.empty_like(megatron_weights) for _ in range(self.ep_size)] torch.distributed.all_gather(gathered_weights, megatron_weights, group=self.ep_group) - # Assemble the result dict (handles duplicate names via concatenation) - weights_dict: Dict[str, torch.Tensor] = {} + weights_dict: dict[str, torch.Tensor] = {} for i, param_name in enumerate(gathered_expert_param_names): if param_name in weights_dict: weights_dict[param_name] = torch.cat( @@ -301,18 +229,8 @@ def revert_ep_gather_patch(): logger.info("Reverted QAT patch: MegatronParamMapping.gather_from_ep_ranks.") -# ====================================================================== -# 3. extract_sort_key patch -# ====================================================================== - - def apply_extract_sort_key_patch(): - """Patch ``megatron.bridge.models.conversion.utils.extract_sort_key`` - to support the SequentialMLP naming pattern (``local_experts.``) in - addition to the original TEGroupedMLP pattern (``weight`` / ``bias``). - - Idempotent – safe to call multiple times. - """ + """Patch ``extract_sort_key`` to support SequentialMLP naming pattern.""" import megatron.bridge.models.conversion.utils as utils_module import megatron.bridge.models.conversion.model_bridge as bridge_module @@ -324,23 +242,19 @@ def apply_extract_sort_key_patch(): bridge_module._original_extract_sort_key = bridge_module.extract_sort_key def _patched_extract_sort_key(param_name: str): - """Extract sorting key based on layer and expert numbers.""" numbers = [] - - # Find layer number layer_match = re.search(r"layers\.(\d+)", param_name) if layer_match: numbers.append(int(layer_match.group(1))) - # Find expert number – try multiple patterns expert_number = None - # Pattern 1: TEGroupedMLP format (e.g., weight15, bias15) + # TEGroupedMLP: weight, bias expert_match = re.search(r"(?:bias|weight)(\d+)", param_name) if expert_match: expert_number = int(expert_match.group(1)) - # Pattern 2: SequentialMLP format (e.g., local_experts.15) + # SequentialMLP: local_experts. if expert_number is None: local_experts_match = re.search(r"local_experts\.(\d+)", param_name) if local_experts_match: @@ -349,7 +263,6 @@ def _patched_extract_sort_key(param_name: str): if expert_number is not None: numbers.append(expert_number) - # Pad to ensure consistent comparison (max 2 numbers) while len(numbers) < 2: numbers.append(-1) numbers = numbers[:2] @@ -366,7 +279,6 @@ def revert_extract_sort_key_patch(): """Revert :func:`apply_extract_sort_key_patch`.""" import megatron.bridge.models.conversion.utils as utils_module import megatron.bridge.models.conversion.model_bridge as bridge_module - if not getattr(utils_module, "_sort_key_patched", False): return @@ -377,23 +289,9 @@ def revert_extract_sort_key_patch(): logger.info("Reverted QAT patch: extract_sort_key.") -# ====================================================================== -# 4. _megatron_local_name_to_global patch -# ====================================================================== - - def apply_local_name_to_global_patch(): - """Patch ``_megatron_local_name_to_global`` in megatron-bridge - to support the SequentialMLP naming pattern (``local_experts.``) - for local-to-global expert number conversion under EP > 1. - - The original function only handles the TEGroupedMLP pattern - (``mlp.experts.linear_fc`` with ``weight``/``bias``). The - patch adds an ``elif`` branch for SequentialMLP parameters whose - names contain ``mlp.experts.local_experts.``. - - Idempotent – safe to call multiple times. - """ + """Patch ``_megatron_local_name_to_global`` to support SequentialMLP + local-to-global expert number conversion under EP.""" import megatron.bridge.models.conversion.model_bridge as bridge_module from megatron.core import parallel_state from megatron.core.utils import get_pg_size @@ -445,22 +343,8 @@ def revert_local_name_to_global_patch(): logger.info("Reverted QAT patch: _megatron_local_name_to_global.") -# ====================================================================== -# 5. build_conversion_tasks patch -# ====================================================================== - - def apply_build_conversion_tasks_patch(): - """Patch ``MegatronModelBridge.build_conversion_tasks`` to filter out - ``None`` entries before returning the task list. - - The original implementation can leave ``None`` slots for PP ranks that - don't own certain parameters and have no mapping. Downstream code that - iterates over the returned list may break on ``None``. This patch - ensures only valid :class:`WeightConversionTask` objects are returned. - - Idempotent – safe to call multiple times. - """ + """Patch ``build_conversion_tasks`` to filter out ``None`` entries.""" import itertools import megatron.bridge.models.conversion.model_bridge as bridge_module @@ -484,13 +368,6 @@ def apply_build_conversion_tasks_patch(): ) def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): - """Construct conversion tasks between HF and Megatron (``None``-free). - - Returns a list of :class:`WeightConversionTask` objects — ``None`` - entries are filtered out before the list is returned so that callers - never need to guard against them. - """ - # Ensure hf_pretrained has the required state structure if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): raise ValueError("hf_pretrained.state.source is required for weight ordering") @@ -505,7 +382,6 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): megatron_model ) - # Filter out output_layer related parameters if embeddings are tied if embeddings_are_tied: sorted_global_param_names_all_pp_ranks = [ name for name in sorted_global_param_names_all_pp_ranks if "output_layer" not in name @@ -539,7 +415,6 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}") continue - # Ensure HF weights exist if not mapping.allow_hf_name_mismatch: if isinstance(mapping.hf_param, str): if mapping.hf_param not in hf_keys: @@ -574,7 +449,6 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): mapping=mapping, ) - # Fill the remaining slots for PP communications for idx, global_name in enumerate(sorted_global_param_names_all_pp_ranks): if tasks[idx] is None: mapping = mapping_registry.megatron_to_hf_lookup( @@ -615,24 +489,9 @@ def revert_build_conversion_tasks_patch(): logger.info("Reverted QAT patch: MegatronModelBridge.build_conversion_tasks.") -# ====================================================================== -# 5. AutoMapping._detect_parallelism_type patch -# ====================================================================== - - def apply_detect_parallelism_type_patch(): - """Patch ``AutoMapping._detect_parallelism_type`` to recognise quantised - ``LayerNormColumnParallelLinear`` variants (e.g. - ``QuantTELayerNormColumnParallelLinear``). - - The original code only checks - ``module_type == "TELayerNormColumnParallelLinear"``. ModelOpt wraps this - into classes whose names still *contain* ``LayerNormColumnParallelLinear`` - but do not match exactly. The patch broadens the check to - ``"LayerNormColumnParallelLinear" in module_type``. - - Idempotent – safe to call multiple times. - """ + """Patch ``_detect_parallelism_type`` to recognise quantised + ``LayerNormColumnParallelLinear`` variants via substring matching.""" from megatron.bridge.models.conversion.param_mapping import AutoMapping if getattr(AutoMapping, "_detect_parallelism_patched", False): @@ -669,13 +528,8 @@ def revert_detect_parallelism_type_patch(): logger.info("Reverted QAT patch: AutoMapping._detect_parallelism_type.") -# ====================================================================== -# Convenience: apply / revert all QAT patches at once -# ====================================================================== - - def apply_qat_patch(): - """Apply **all** QAT-related patches. Idempotent.""" + """Apply all QAT-related patches.""" apply_swiglu_sharded_factory_patch() apply_ep_gather_patch() apply_extract_sort_key_patch() @@ -685,7 +539,7 @@ def apply_qat_patch(): def revert_qat_patch(): - """Revert **all** QAT-related patches.""" + """Revert all QAT-related patches.""" revert_swiglu_sharded_factory_patch() revert_ep_gather_patch() revert_extract_sort_key_patch() diff --git a/verl/utils/modelopt/quantize.py b/verl/utils/modelopt/quantize.py new file mode 100644 index 00000000000..12fb811d292 --- /dev/null +++ b/verl/utils/modelopt/quantize.py @@ -0,0 +1,131 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ModelOpt NVFP4 quantization config and application for Megatron QAT.""" + +import logging +from dataclasses import dataclass +from typing import Any, Optional + +import torch +import torch.nn as nn + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg + + +logger = logging.getLogger(__name__) + + +_NVFP4_W4A16_QUANTIZER_CFG = { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": {"enable": False}, +} + + +def _ignore_patterns_to_quant_cfg(ignore_patterns: list[str]) -> dict: + """Convert user-provided ignore patterns to ModelOpt ``quant_cfg`` entries. + + Each pattern is wrapped with ``*`` on both ends (if not already present) + so that it performs glob-style substring matching against module names. + For example, ``"lm_head"`` becomes ``"*lm_head*"`` and ``"mlp.gate."`` + becomes ``"*mlp.gate.*"`` (the trailing dot prevents matching + ``mlp.gate_proj``). + """ + cfg = {} + for pattern in ignore_patterns: + key = pattern + if not key.startswith("*"): + key = f"*{key}" + if not key.endswith("*"): + key = f"{key}*" + cfg[key] = {"enable": False} + return cfg + + +def build_quantize_config( + qat_mode: str, + ignore_patterns: list[str] | None = None, +) -> dict: + """Build a complete ModelOpt quantization config for ``mtq.quantize``. + + Args: + qat_mode: Quantization mode. Currently only ``"w4a16"`` is supported. + ignore_patterns: Layer name patterns to skip quantization for. + Uses glob-style matching (e.g. ``"lm_head"`` matches ``*lm_head*``). + If *None*, uses :data:`DEFAULT_IGNORE_PATTERNS`. + + Returns: + A config dict suitable for ``mtq.quantize()``. + """ + if qat_mode != "w4a16": + raise ValueError(f"Only 'w4a16' is supported, got: {qat_mode}") + + if ignore_patterns is None: + ignore_patterns = [] + + ignore_cfg = _ignore_patterns_to_quant_cfg(ignore_patterns) + + quant_cfg = { + **_NVFP4_W4A16_QUANTIZER_CFG, + **_default_disabled_quantizer_cfg, + **ignore_cfg, + } + logger.info("Built NVFP4 %s quantize config, ignore_patterns=%s", qat_mode, ignore_patterns) + + return {"quant_cfg": quant_cfg, "algorithm": "max"} + + +def apply_qat( + model: nn.Module, + qat_mode: str, + ignore_patterns: list[str] | None = None, +) -> nn.Module: + """Apply Quantization-Aware Training to a Megatron model. + + Args: + model: The Megatron model to quantize. + qat_mode: Quantization mode. Currently only ``"w4a16"`` is supported. + ignore_patterns: Layer name patterns to skip quantization for. + If *None*, uses :data:`DEFAULT_IGNORE_PATTERNS`. + + Returns: + The quantized model (modified in-place). + """ + config = build_quantize_config(qat_mode, ignore_patterns) + mtq.quantize(model, config) + return model + + +@dataclass +class QuantizationMetadata: + """Metadata for a quantized module.""" + + qformat: str + weight_quantizer: Any + input_quantizer: Any + module: torch.nn.Module + vpp_idx: int + block_size: int = 16 # Default NVFP4 block size + weight_amax: Optional[torch.Tensor] = None + input_amax: Optional[torch.Tensor] = None + is_local: bool = True + global_expert_idx: Optional[int] = None + local_expert_idx: Optional[int] = None diff --git a/verl/utils/modelopt/vllm_modelopt_patch.py b/verl/utils/modelopt/vllm_modelopt_patch.py new file mode 100644 index 00000000000..ca7f99abd96 --- /dev/null +++ b/verl/utils/modelopt/vllm_modelopt_patch.py @@ -0,0 +1,628 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +vLLM ModelOpt NVFP4 Patches for Dynamic Weight Updates (Marlin Backend). + +Enables dynamic weight reloading for NVFP4 quantized models in vLLM +using the ModelOpt quantization path with the Marlin kernel backend. + +Saves parameter metadata on first load and deletes HF parameters. Before +reload, HF parameters are rebuilt from metadata, loaded, then re-converted +to Marlin format in-place via copy_ (preserving CUDA Graph tensor addresses). + +Supported schemes: +- Dense: ModelOptNvFp4LinearMethod (Marlin backend) +- MoE: ModelOptNvFp4FusedMoE (Marlin backend) +- KV: BaseKVCacheMethod (preserves scales for reload) +""" + +import logging +import os +from typing import Optional + +import torch +from torch.nn import Parameter + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +# ============================================================================ +# Utility Functions +# ============================================================================ + +def save_param_meta(layer: torch.nn.Module, param_name: str): + """Save parameter metadata (shape, dtype, param_class, dims) for later rebuild.""" + if not hasattr(layer, "_hf_param_meta"): + layer._hf_param_meta = {} + + param = getattr(layer, param_name, None) + if param is None: + return + + meta = { + "shape": tuple(param.shape), + "dtype": param.dtype, + "device": str(param.device), + "param_class": type(param), + } + + if hasattr(param, "_input_dim"): + meta["input_dim"] = param._input_dim + if hasattr(param, "_output_dim"): + meta["output_dim"] = param._output_dim + + layer._hf_param_meta[param_name] = meta + + +def _create_param_from_meta( + module: torch.nn.Module, + param_name: str, + meta: dict, + device: Optional[torch.device] = None, +) -> Parameter: + """Create a Parameter from saved metadata. Used by rebuild and tensor swap.""" + shape = meta["shape"] + dtype = meta["dtype"] + dev = device or meta.get("device", "cuda") + param_class = meta.get("param_class", Parameter) + + weight_loaders = getattr(module, "_weight_loaders", {}) + weight_loader = weight_loaders.get(param_name) + + data = torch.empty(shape, dtype=dtype, device=dev) + + try: + if param_class is not Parameter and weight_loader is not None: + kwargs = {"data": data, "weight_loader": weight_loader} + if "input_dim" in meta: + kwargs["input_dim"] = meta["input_dim"] + if "output_dim" in meta: + kwargs["output_dim"] = meta["output_dim"] + new_param = param_class(**kwargs) + else: + new_param = Parameter(data, requires_grad=False) + if weight_loader is not None: + new_param.weight_loader = weight_loader + except Exception as e: + logger.warning(f"Failed to create param {param_name} with class {param_class}: {e}, using Parameter") + new_param = Parameter(data, requires_grad=False) + if weight_loader is not None: + new_param.weight_loader = weight_loader + + return new_param + + +def _check_first_call(layer: torch.nn.Module) -> bool: + """Check if this is the first process_weights call, and increment counter.""" + count = getattr(layer, "_process_weights_call_count", 0) + layer._process_weights_call_count = count + 1 + return count == 0 + + +def _save_weight_loaders(layer: torch.nn.Module, param_names: list[str]): + """Save weight_loader references from parameters before they are overwritten.""" + if not hasattr(layer, "_weight_loaders"): + layer._weight_loaders = {} + for pname in param_names: + param = getattr(layer, pname, None) + if param is not None and hasattr(param, "weight_loader"): + layer._weight_loaders[pname] = param.weight_loader + + +def _update_ref_or_create(layer, ref_name, new_data): + """Copy new_data into existing tensor ref (CUDA Graph safe), or create new Parameter.""" + refs = getattr(layer, "_marlin_tensor_refs", {}) + ref = refs.get(ref_name) + if ref is not None: + ref.copy_(new_data) + setattr(layer, ref_name, Parameter(ref, requires_grad=False)) + else: + logger.warning(f"_marlin_tensor_refs['{ref_name}'] not found, creating new Parameter") + t = new_data.clone() if isinstance(new_data, torch.Tensor) else torch.tensor(new_data) + setattr(layer, ref_name, Parameter(t, requires_grad=False)) + + +# ============================================================================ +# ModelOptParamMetaDict +# ============================================================================ + +class ModelOptParamMetaDict(dict): + """ + Dict-like class for parameter management with metadata-based rebuild + and tensor swap. Supports: + - Rebuild of deleted parameters from saved metadata + - Tensor swap for parameters with shape changes (address stability for CUDA Graph) + """ + + def __init__(self, model: torch.nn.Module, device: Optional[torch.device] = None): + super().__init__() + self.device = device + + actual_model = model + if hasattr(model, "model"): + actual_model = model.model + self._model = actual_model + + self._layer_meta_cache: dict[str, dict] = {} + self._tensor_swap_layers: dict[str, dict] = {} + + self._build_mappings() + + for name, param in actual_model.named_parameters(): + self[name] = param + + def _build_mappings(self): + """Build layer metadata cache for rebuild and tensor swap.""" + for layer_name, module in self._model.named_modules(): + if not hasattr(module, "_hf_param_meta"): + continue + + self._layer_meta_cache[layer_name] = { + "module": module, + "meta": module._hf_param_meta, + } + + marlin_refs = getattr(module, "_marlin_tensor_refs", {}) + for param_name, meta in module._hf_param_meta.items(): + if param_name in marlin_refs: + key = f"{layer_name}.{param_name}" if layer_name else param_name + self._tensor_swap_layers[key] = { + "module": module, + "param_name": param_name, + "marlin_ref": marlin_refs[param_name], + "hf_meta": meta, + } + + def _try_rebuild(self, key: str) -> Optional[Parameter]: + parts = key.rsplit(".", 1) + if len(parts) != 2: + return None + layer_name, param_name = parts + if layer_name not in self._layer_meta_cache: + return None + cache_entry = self._layer_meta_cache[layer_name] + module = cache_entry["module"] + meta = cache_entry["meta"] + if param_name not in meta: + return None + if hasattr(module, param_name): + param = getattr(module, param_name) + if param is not None: + return param + new_param = _create_param_from_meta(module, param_name, meta[param_name], self.device) + module.register_parameter(param_name, new_param) + return new_param + + def prepare_for_reload(self) -> None: + """Replace kernel-format tensors with HF-shape tensors for reload.""" + for _key, swap_info in self._tensor_swap_layers.items(): + module = swap_info["module"] + param_name = swap_info["param_name"] + hf_meta = swap_info["hf_meta"] + if hasattr(module, param_name): + new_param = _create_param_from_meta(module, param_name, hf_meta, self.device) + setattr(module, param_name, new_param) + + def __getitem__(self, key: str) -> Parameter: + if key in dict.keys(self): + return super().__getitem__(key) + param = self._try_rebuild(key) + if param is not None: + self[key] = param + return param + raise KeyError(f"Parameter not found: {key}") + + def __contains__(self, key: str) -> bool: + if super().__contains__(key): + return True + parts = key.rsplit(".", 1) + if len(parts) == 2: + layer_name, param_name = parts + if layer_name in self._layer_meta_cache: + if param_name in self._layer_meta_cache[layer_name]["meta"]: + return True + return False + + def get(self, key: str, default=None): + try: + return self[key] + except KeyError: + return default + + +# ============================================================================ +# Dense Linear Patch (Marlin) +# ============================================================================ + +_DENSE_HF_PARAMS = ["weight", "weight_scale", "input_scale", "weight_scale_2"] + + +def _modelopt_dense_process_weights(self, layer: torch.nn.Module) -> None: + """ + Replacement for ModelOptNvFp4LinearMethod.process_weights_after_loading. + + First call: save metadata + weight_loaders, convert HF→Marlin format, + save _marlin_tensor_refs for CUDA Graph stability. + Subsequent: read reloaded HF data, convert, copy_ into saved refs. + """ + import vllm._custom_ops as ops + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_make_workspace_new, + marlin_permute_scales, + ) + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + nvfp4_marlin_process_global_scale, + nvfp4_marlin_process_scales, + ) + + is_first_call = _check_first_call(layer) + + if is_first_call: + for pname in _DENSE_HF_PARAMS: + save_param_meta(layer, pname) + _save_weight_loaders(layer, _DENSE_HF_PARAMS) + + weight_data = layer.weight.data + weight_scale_data = layer.weight_scale.data + weight_scale_2_data = layer.weight_scale_2.data + + assert weight_scale_data.dtype == torch.float8_e4m3fn + + device = weight_data.device + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + group_size = 16 + weight_scale_2_max = weight_scale_2_data.max().to(torch.float32) + + if is_first_call: + layer.workspace = marlin_make_workspace_new(device) + + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = weight_data.view(torch.int32).T.contiguous() + marlin_weight = ops.gptq_marlin_repack( + b_q_weight=qweight, perm=perm, + size_k=part_size_k, size_n=part_size_n, num_bits=4, + ) + + weight_scale = weight_scale_data.T.contiguous().to(param_dtype) + weight_scale = marlin_permute_scales( + s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size, + ) + marlin_weight_scale = nvfp4_marlin_process_scales(weight_scale) + marlin_weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2_max.to(param_dtype)) + + if is_first_call: + layer.weight = Parameter(marlin_weight, requires_grad=False) + layer.weight_scale = Parameter(marlin_weight_scale, requires_grad=False) + layer.weight_scale_2 = Parameter(marlin_weight_scale_2, requires_grad=False) + layer._marlin_tensor_refs = { + "weight": layer.weight.data, + "weight_scale": layer.weight_scale.data, + "weight_scale_2": layer.weight_scale_2.data, + } + else: + _update_ref_or_create(layer, "weight", marlin_weight) + _update_ref_or_create(layer, "weight_scale", marlin_weight_scale) + _update_ref_or_create(layer, "weight_scale_2", marlin_weight_scale_2) + + for attr in ["input_scale", "alpha", "input_scale_inv"]: + if hasattr(layer, attr): + delattr(layer, attr) + + +# ============================================================================ +# MoE Helpers (Marlin) +# ============================================================================ + +def _marlin_repack_experts(packed, perm, size_k, size_n, num_experts): + """Repack weight for each expert into Marlin format and stack.""" + import vllm._custom_ops as ops + + result = [] + for i in range(num_experts): + qweight = packed[i].view(torch.int32).T.contiguous() + result.append( + ops.gptq_marlin_repack( + b_q_weight=qweight, perm=perm, + size_k=size_k, size_n=size_n, num_bits=4, + ) + ) + return torch.stack(result) + + +def _marlin_process_scales_experts(scale_hf, param_dtype, size_k, size_n, group_size, num_experts): + """Process scales for each expert into Marlin format and stack.""" + from vllm.model_executor.layers.quantization.utils.marlin_utils import marlin_permute_scales + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import nvfp4_marlin_process_scales + + result = [] + scales = scale_hf.to(param_dtype) + for i in range(num_experts): + s = marlin_permute_scales(s=scales[i].T, size_k=size_k, size_n=size_n, group_size=group_size) + result.append(nvfp4_marlin_process_scales(s)) + return torch.stack(result) + + +# ============================================================================ +# MoE Patch (Marlin) +# ============================================================================ + +_MOE_HF_PARAMS = [ + "w13_weight", "w2_weight", "w13_weight_scale", "w2_weight_scale", + "w13_weight_scale_2", "w2_weight_scale_2", "w13_input_scale", "w2_input_scale", +] + + +def _modelopt_moe_marlin_convert(self, layer: torch.nn.Module, is_first_call: bool) -> None: + """Convert MoE layer weights between HF and Marlin format.""" + from vllm.model_executor.layers.quantization.utils.marlin_utils import marlin_make_workspace_new + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import nvfp4_marlin_process_global_scale + + group_size = 16 + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + device = layer.w13_weight.device + param_dtype = layer.params_dtype + + if is_first_call: + layer.workspace = marlin_make_workspace_new(device, 4) + + perm = torch.empty(0, dtype=torch.int, device=device) + size_n_w13, size_k_w13 = n * 2, k + size_n_w2, size_k_w2 = k, n + + # Repack weights + w13_weight_marlin = _marlin_repack_experts(layer.w13_weight.data, perm, size_k_w13, size_n_w13, e) + w2_weight_marlin = _marlin_repack_experts(layer.w2_weight.data, perm, size_k_w2, size_n_w2, e) + + # Process scales + w13_weight_scale_marlin = _marlin_process_scales_experts( + layer.w13_weight_scale.data, param_dtype, size_k_w13, size_n_w13, group_size, e, + ) + w2_weight_scale_marlin = _marlin_process_scales_experts( + layer.w2_weight_scale.data, param_dtype, size_k_w2, size_n_w2, group_size, e, + ) + + # Process global scales (w13_weight_scale_2 is already (E,) after common processing) + w13_scale_2_processed = nvfp4_marlin_process_global_scale(layer.w13_weight_scale_2.data.to(param_dtype)) + w2_scale_2_processed = nvfp4_marlin_process_global_scale(layer.w2_weight_scale_2.data.to(param_dtype)) + + if is_first_call: + layer.w13_weight = Parameter(w13_weight_marlin, requires_grad=False) + layer.w2_weight = Parameter(w2_weight_marlin, requires_grad=False) + layer.w13_weight_scale = Parameter(w13_weight_scale_marlin, requires_grad=False) + layer.w2_weight_scale = Parameter(w2_weight_scale_marlin, requires_grad=False) + layer.w13_weight_scale_2 = Parameter(w13_scale_2_processed, requires_grad=False) + layer.w2_weight_scale_2 = Parameter(w2_scale_2_processed, requires_grad=False) + if not hasattr(layer, "_marlin_tensor_refs"): + layer._marlin_tensor_refs = {} + for rn in ["w13_weight", "w2_weight", "w13_weight_scale", "w2_weight_scale", + "w13_weight_scale_2", "w2_weight_scale_2"]: + layer._marlin_tensor_refs[rn] = getattr(layer, rn).data + else: + for rn, nd in [ + ("w13_weight", w13_weight_marlin), ("w2_weight", w2_weight_marlin), + ("w13_weight_scale", w13_weight_scale_marlin), ("w2_weight_scale", w2_weight_scale_marlin), + ("w13_weight_scale_2", w13_scale_2_processed), ("w2_weight_scale_2", w2_scale_2_processed), + ]: + _update_ref_or_create(layer, rn, nd) + + for attr in ["w13_input_scale", "w2_input_scale"]: + if hasattr(layer, attr): + delattr(layer, attr) + + +def _modelopt_moe_process_weights(self, layer: torch.nn.Module) -> None: + """ + Replacement for ModelOptNvFp4FusedMoE.process_weights_after_loading (Marlin). + + First call: save metadata + weight_loaders, convert HF→Marlin format, + save _marlin_tensor_refs for CUDA Graph stability. + Subsequent: read reloaded HF data, convert, copy_ into saved refs. + """ + is_first_call = _check_first_call(layer) + + if is_first_call: + for pname in _MOE_HF_PARAMS: + save_param_meta(layer, pname) + _save_weight_loaders(layer, _MOE_HF_PARAMS) + + # ---- w13_weight_scale_2: reduce (E, 2) → (E,) ---- + if self.moe.is_act_and_mul and not torch.allclose( + layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] + ): + logger.warning("w1_weight_scale_2 must match w3_weight_scale_2. Accuracy may be affected.") + + w13_weight_scale_2 = layer.w13_weight_scale_2.data + if w13_weight_scale_2.dim() == 2: + w13_weight_scale_2 = w13_weight_scale_2[:, 0] + layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) + + _modelopt_moe_marlin_convert(self, layer, is_first_call) + + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + + +# ============================================================================ +# KV Cache Patch +# ============================================================================ + +def _modelopt_kv_process_weights(self, layer) -> None: + """ + Replacement for BaseKVCacheMethod.process_weights_after_loading. + Doesn't delete k_scale, v_scale, q_scale, prob_scale to allow + for dynamic updates during refit. + """ + from vllm.platforms import current_platform + + if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + k_scale = 1.0 + v_scale = 1.0 + else: + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError("Only support per-tensor scaling factor for fp8 KV cache") + + if layer.q_scale < 0.0: + layer._q_scale.copy_(k_scale) + layer._q_scale_float = k_scale + + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale + + if layer.q_scale > 0.0: + q_scale = layer.q_scale + if current_platform.is_fp8_fnuz(): + q_scale *= 2 + layer.calculate_kv_scales = False + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale + if current_platform.is_fp8_fnuz(): + prob_scale *= 2 + else: + prob_scale = 1.0 + + is_singleton_float = ( + lambda x: isinstance(x, float) or isinstance(x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() + ) + if not is_singleton_float(q_scale) or not is_singleton_float(prob_scale): + raise ValueError("Only support per-tensor scaling factor for fp8-quantized Q/prob") + + layer._q_scale.copy_(q_scale) + layer._q_scale_float = q_scale.item() if isinstance(q_scale, torch.Tensor) else q_scale + layer._prob_scale.copy_(prob_scale) + + +# ============================================================================ +# Patch Application & Entry Points +# ============================================================================ + +_patched = False + + +def prepare_modelopt_for_weight_reload(model, device=None): + """ + Prepare ModelOpt model for weight reloading. Call ONCE before each reload cycle. + + 1. Builds ModelOptParamMetaDict from saved metadata + 2. Swaps kernel-format tensors back to HF-shape for weight_loader compatibility + 3. Rebuilds any deleted parameters from metadata + + Args: + model: vLLM model + device: Device for created parameters + """ + inner_model = model + if hasattr(model, "model"): + inner_model = model.model + + param_meta = ModelOptParamMetaDict(inner_model, device=device) + + param_meta.prepare_for_reload() + logger.info(f"[prepare_modelopt] Tensor swap prepared for {len(param_meta._tensor_swap_layers)} layers") + + rebuilt_count = 0 + for layer_name, cache_entry in param_meta._layer_meta_cache.items(): + module = cache_entry["module"] + for param_name, pm in cache_entry["meta"].items(): + existing = getattr(module, param_name, None) + if existing is not None: + hf_shape = tuple(pm["shape"]) + hf_dtype = pm["dtype"] + if ( + tuple(existing.shape) == hf_shape + and existing.dtype == hf_dtype + and hasattr(existing, "weight_loader") + ): + continue + new_param = _create_param_from_meta(module, param_name, pm, device) + module.register_parameter(param_name, new_param) + rebuilt_count += 1 + + logger.info(f"[prepare_modelopt] Rebuilt {rebuilt_count} parameters") + inner_model._param_meta_for_restore = param_meta + return param_meta + + +def modelopt_process_weights_after_loading(model): + """Trigger weight post-processing for all quantized layers after load_weights.""" + dense_count = 0 + moe_count = 0 + + actual_model = model + if hasattr(model, "model"): + actual_model = model.model + + for module in actual_model.modules(): + if hasattr(module, "scheme"): + module.scheme.process_weights_after_loading(module) + dense_count += 1 + + quant_method = getattr(module, "quant_method", None) + if quant_method is not None and not hasattr(module, "scheme"): + if hasattr(quant_method, "process_weights_after_loading"): + if "KVCache" in quant_method.__class__.__name__: + continue + quant_method.process_weights_after_loading(module) + moe_count += 1 + + logger.debug(f"Processed {dense_count} dense layers, {moe_count} MoE layers") + return dense_count + moe_count + + +def apply_modelopt_nvfp4_patches(): + """Apply ModelOpt NVFP4 patches to support dynamic weight updates. Call before model loading.""" + global _patched + + if _patched: + logger.warning("ModelOpt NVFP4 patches already applied, skipping") + return + + logger.info("Applying ModelOpt NVFP4 patches for dynamic weight loading...") + + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4LinearMethod, + ModelOptNvFp4FusedMoE, + ) + from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod + + ModelOptNvFp4LinearMethod.process_weights_after_loading = _modelopt_dense_process_weights + ModelOptNvFp4FusedMoE.process_weights_after_loading = _modelopt_moe_process_weights + BaseKVCacheMethod.process_weights_after_loading = _modelopt_kv_process_weights + + _patched = True + logger.info("Applied 3 ModelOpt NVFP4 patches (Dense, MoE, KV)") \ No newline at end of file diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index c6226f8c8a8..452e76af9fe 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -226,7 +226,7 @@ def _init_hf_config_and_tf_config( from megatron.bridge.models.gpt_provider import quantization_layer_spec provider.transformer_layer_spec = quantization_layer_spec - from verl.models.mcore.qat_patch import apply_qat_patch + from verl.utils.modelopt.megatron_qat_patch import apply_qat_patch apply_qat_patch() from megatron.bridge.models.conversion.param_mapping import AutoMapping From fe83dadb7348451cb0af2e59773c4e9ce8e2b92b Mon Sep 17 00:00:00 2001 From: larkz-nv Date: Thu, 26 Feb 2026 18:58:04 -0800 Subject: [PATCH 09/10] Refactor QAT weight exporter --- .../_generated_ppo_megatron_trainer.yaml | 10 + .../config/_generated_ppo_trainer.yaml | 14 +- .../config/_generated_ppo_veomni_trainer.yaml | 10 + verl/utils/modelopt/__init__.py | 15 +- verl/utils/modelopt/megatron_qat_patch.py | 117 +-- verl/utils/modelopt/qat_weight_exporter.py | 365 +++++++ verl/utils/modelopt/quantize.py | 57 +- verl/utils/modelopt/vllm_modelopt_patch.py | 76 +- verl/utils/modelopt/weight_processor.py | 972 ------------------ verl/workers/megatron_workers.py | 21 +- verl/workers/rollout/vllm_rollout/utils.py | 10 +- .../rollout/vllm_rollout/vllm_async_server.py | 3 +- .../rollout/vllm_rollout/vllm_rollout.py | 1 - 13 files changed, 510 insertions(+), 1161 deletions(-) create mode 100644 verl/utils/modelopt/qat_weight_exporter.py delete mode 100644 verl/utils/modelopt/weight_processor.py diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index ea60c881619..05dfd008f97 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -139,6 +139,16 @@ actor_rollout_ref: mode: disabled record_file: null replay_file: null + qat: + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null load_weight: true ref: rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 6b97103ae9f..4787cf25eed 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -120,13 +120,6 @@ actor_rollout_ref: mode: disabled record_file: null replay_file: null - grad_clip: 1.0 - ulysses_sequence_parallel_size: 1 - entropy_from_logits_with_chunking: false - entropy_checkpointing: false - use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} - calculate_sum_pi_squared: false - sum_pi_squared_checkpointing: false qat: enable: false mode: w4a16 @@ -137,6 +130,13 @@ actor_rollout_ref: - re:.*mlp.gate$ activation_observer: static_minmax quantization_config_path: null + grad_clip: 1.0 + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} + calculate_sum_pi_squared: false + sum_pi_squared_checkpointing: false ref: rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} strategy: ${actor_rollout_ref.actor.strategy} diff --git a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml index 4528e0d667d..956c725f433 100644 --- a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml @@ -120,6 +120,16 @@ actor_rollout_ref: mode: disabled record_file: null replay_file: null + qat: + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null ref: rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} strategy: veomni diff --git a/verl/utils/modelopt/__init__.py b/verl/utils/modelopt/__init__.py index 5b1750e0378..30ddbd1a860 100644 --- a/verl/utils/modelopt/__init__.py +++ b/verl/utils/modelopt/__init__.py @@ -15,8 +15,12 @@ """ModelOpt integration for NVFP4 quantization with Megatron QAT training and vLLM inference.""" +from verl.utils.modelopt.megatron_qat_patch import ( + apply_qat_patch, + revert_qat_patch, +) +from verl.utils.modelopt.qat_weight_exporter import QATWeightExporter from verl.utils.modelopt.quantize import ( - QuantizationMetadata, apply_qat, build_quantize_config, ) @@ -25,18 +29,11 @@ modelopt_process_weights_after_loading, prepare_modelopt_for_weight_reload, ) -from verl.utils.modelopt.weight_processor import QATWeightPostProcessor -from verl.utils.modelopt.megatron_qat_patch import ( - apply_qat_patch, - revert_qat_patch, -) - __all__ = [ "build_quantize_config", "apply_qat", - "QuantizationMetadata", - "QATWeightPostProcessor", + "QATWeightExporter", "apply_modelopt_nvfp4_patches", "prepare_modelopt_for_weight_reload", "modelopt_process_weights_after_loading", diff --git a/verl/utils/modelopt/megatron_qat_patch.py b/verl/utils/modelopt/megatron_qat_patch.py index 025686394db..c5a4f52f5dc 100644 --- a/verl/utils/modelopt/megatron_qat_patch.py +++ b/verl/utils/modelopt/megatron_qat_patch.py @@ -44,25 +44,14 @@ def apply_swiglu_sharded_factory_patch(): mlp_module._swiglu_patched = True mlp_module._original_apply_swiglu_sharded_factory = mlp_module.apply_swiglu_sharded_factory - def patched_apply_swiglu_sharded_factory( - original_sh_ten, sharded_offsets, singleton_local_shards: bool = False - ): + def patched_apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets, singleton_local_shards: bool = False): swiglu_shard_axis = 0 prepend_axis_num = len(sharded_offsets) original_shape = original_sh_ten.local_shape local_axis_size = original_shape[swiglu_shard_axis] - assert ( - original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] - % local_axis_size - == 0 - ) - rank_offset = ( - original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] - // local_axis_size - ) - axis_frag = original_sh_ten.axis_fragmentations[ - swiglu_shard_axis + prepend_axis_num - ] + assert original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] % local_axis_size == 0 + rank_offset = original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] // local_axis_size + axis_frag = original_sh_ten.axis_fragmentations[swiglu_shard_axis + prepend_axis_num] @torch.no_grad() def sh_ten_build_fn( @@ -89,12 +78,20 @@ def sh_ten_build_fn( tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis) return [ ShardedTensor.from_rank_offsets( - w_key, tensor_w, *sharded_offsets, offset_w, - replica_id=replica_id, prepend_axis_num=prepend_axis_num, + w_key, + tensor_w, + *sharded_offsets, + offset_w, + replica_id=replica_id, + prepend_axis_num=prepend_axis_num, ), ShardedTensor.from_rank_offsets( - v_key, tensor_v, *sharded_offsets, offset_v, - replica_id=replica_id, prepend_axis_num=prepend_axis_num, + v_key, + tensor_v, + *sharded_offsets, + offset_v, + replica_id=replica_id, + prepend_axis_num=prepend_axis_num, ), ] @@ -104,7 +101,8 @@ def sh_ten_merge_fn(sub_state_dict): return torch.cat(sub_state_dict) except (RuntimeError, torch.cuda.OutOfMemoryError) as e: logger.warning( - "CUDA OOM during tensor merge – falling back to CPU. (Error: %s)", e, + "CUDA OOM during tensor merge – falling back to CPU. (Error: %s)", + e, ) merged = torch.cat([t.cpu() for t in sub_state_dict]) gc.collect() @@ -156,9 +154,7 @@ def _patched_gather_from_ep_ranks( model_config = self._get_config(megatron_module) num_experts = model_config.num_moe_experts num_experts_per_rank = num_experts // self.ep_size - num_experts_per_rank = self.broadcast_obj_from_pp_rank( - num_experts_per_rank, "num_experts_per_rank" - ) + num_experts_per_rank = self.broadcast_obj_from_pp_rank(num_experts_per_rank, "num_experts_per_rank") local_expert_number = None @@ -212,10 +208,7 @@ def _patched_gather_from_ep_ranks( return weights_dict MegatronParamMapping.gather_from_ep_ranks = _patched_gather_from_ep_ranks - logger.info( - "Applied QAT patch: MegatronParamMapping.gather_from_ep_ranks " - "now supports SequentialMLP pattern." - ) + logger.info("Applied QAT patch: MegatronParamMapping.gather_from_ep_ranks now supports SequentialMLP pattern.") def revert_ep_gather_patch(): @@ -231,8 +224,8 @@ def revert_ep_gather_patch(): def apply_extract_sort_key_patch(): """Patch ``extract_sort_key`` to support SequentialMLP naming pattern.""" - import megatron.bridge.models.conversion.utils as utils_module import megatron.bridge.models.conversion.model_bridge as bridge_module + import megatron.bridge.models.conversion.utils as utils_module if getattr(utils_module, "_sort_key_patched", False): return @@ -270,15 +263,13 @@ def _patched_extract_sort_key(param_name: str): utils_module.extract_sort_key = _patched_extract_sort_key bridge_module.extract_sort_key = _patched_extract_sort_key - logger.info( - "Applied QAT patch: extract_sort_key now supports SequentialMLP pattern." - ) + logger.info("Applied QAT patch: extract_sort_key now supports SequentialMLP pattern.") def revert_extract_sort_key_patch(): """Revert :func:`apply_extract_sort_key_patch`.""" - import megatron.bridge.models.conversion.utils as utils_module import megatron.bridge.models.conversion.model_bridge as bridge_module + import megatron.bridge.models.conversion.utils as utils_module if not getattr(utils_module, "_sort_key_patched", False): return @@ -307,11 +298,7 @@ def _patched_megatron_local_name_to_global(models, config, param_name, vp_stage= param_name = _orig_fn(models, config, param_name, vp_stage) ep_group = parallel_state.get_expert_model_parallel_group() - if ( - ".mlp.experts.local_experts." in param_name - and get_pg_size(ep_group) > 1 - and ".adapter." not in param_name - ): + if ".mlp.experts.local_experts." in param_name and get_pg_size(ep_group) > 1 and ".adapter." not in param_name: num_experts = config.num_moe_experts num_experts_per_rank = num_experts // ep_group.size() local_experts_match = re.search(r"\.local_experts\.(\d+)\.", param_name) @@ -326,10 +313,7 @@ def _patched_megatron_local_name_to_global(models, config, param_name, vp_stage= return param_name bridge_module._megatron_local_name_to_global = _patched_megatron_local_name_to_global - logger.info( - "Applied QAT patch: _megatron_local_name_to_global " - "now supports SequentialMLP pattern." - ) + logger.info("Applied QAT patch: _megatron_local_name_to_global now supports SequentialMLP pattern.") def revert_local_name_to_global_patch(): @@ -363,9 +347,7 @@ def apply_build_conversion_tasks_patch(): if getattr(MegatronModelBridge, "_build_tasks_patched", False): return MegatronModelBridge._build_tasks_patched = True - MegatronModelBridge._original_build_conversion_tasks = ( - MegatronModelBridge.build_conversion_tasks - ) + MegatronModelBridge._original_build_conversion_tasks = MegatronModelBridge.build_conversion_tasks def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): @@ -378,24 +360,18 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): model_config = unwrapped_model.config embeddings_are_tied = self._share_embeddings_and_output_weights(model_config, unwrapped_model) pp_rank = parallel_state.get_pipeline_model_parallel_rank() - sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks( - megatron_model - ) + sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks(megatron_model) if embeddings_are_tied: sorted_global_param_names_all_pp_ranks = [ name for name in sorted_global_param_names_all_pp_ranks if "output_layer" not in name ] - global_names_index_dict = { - name: idx for idx, name in enumerate(sorted_global_param_names_all_pp_ranks) - } + global_names_index_dict = {name: idx for idx, name in enumerate(sorted_global_param_names_all_pp_ranks)} tasks = [None] * len(sorted_global_param_names_all_pp_ranks) for vp_stage, model in enumerate(megatron_model): - for local_name, _ in itertools.chain( - model.named_parameters(), persistent_buffers(model) - ): + for local_name, _ in itertools.chain(model.named_parameters(), persistent_buffers(model)): if "_extra_state" in local_name or self._is_adapter_param_name(local_name): continue @@ -407,9 +383,7 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): print_rank_0(f"WARNING: {global_name} not in global_names_index_dict") continue global_name_idx = global_names_index_dict[global_name] - mapping = mapping_registry.megatron_to_hf_lookup( - self._get_lora_unwrapped_name(global_name) - ) + mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name)) if not mapping: logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}") @@ -421,23 +395,16 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): logger.warning(f"WARNING: Can't find {mapping.hf_param} in hf_keys") continue else: - missing_params = [ - hf_param - for hf_param in mapping.hf_param.values() - if hf_param not in hf_keys - ] + missing_params = [hf_param for hf_param in mapping.hf_param.values() if hf_param not in hf_keys] if missing_params: logger.warning( - f"WARNING: Can't find the following HF parameters in hf_keys: " - f"{missing_params}" + f"WARNING: Can't find the following HF parameters in hf_keys: {missing_params}" ) continue - local_module, local_weights = get_module_and_param_from_name( - megatron_model, local_name, vp_stage - ) + local_module, local_weights = get_module_and_param_from_name(megatron_model, local_name, vp_stage) if local_module is not None and not hasattr(local_module, "config"): - setattr(local_module, "config", model_config) + local_module.config = model_config tasks[global_name_idx] = WeightConversionTask( pp_rank=pp_rank, @@ -451,9 +418,7 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): for idx, global_name in enumerate(sorted_global_param_names_all_pp_ranks): if tasks[idx] is None: - mapping = mapping_registry.megatron_to_hf_lookup( - self._get_lora_unwrapped_name(global_name) - ) + mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name)) if mapping is None: continue tasks[idx] = WeightConversionTask( @@ -470,10 +435,7 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): return tasks MegatronModelBridge.build_conversion_tasks = _patched_build_conversion_tasks - logger.info( - "Applied QAT patch: MegatronModelBridge.build_conversion_tasks " - "now filters out None entries." - ) + logger.info("Applied QAT patch: MegatronModelBridge.build_conversion_tasks now filters out None entries.") def revert_build_conversion_tasks_patch(): @@ -482,9 +444,7 @@ def revert_build_conversion_tasks_patch(): if not getattr(MegatronModelBridge, "_build_tasks_patched", False): return - MegatronModelBridge.build_conversion_tasks = ( - MegatronModelBridge._original_build_conversion_tasks - ) + MegatronModelBridge.build_conversion_tasks = MegatronModelBridge._original_build_conversion_tasks MegatronModelBridge._build_tasks_patched = False logger.info("Reverted QAT patch: MegatronModelBridge.build_conversion_tasks.") @@ -503,8 +463,7 @@ def _patched_detect_parallelism_type(self, module): module_type = type(module).__name__ if "LayerNormColumnParallelLinear" in module_type: if self.megatron_param and ( - self.megatron_param.endswith("layer_norm_weight") - or self.megatron_param.endswith("layer_norm_bias") + self.megatron_param.endswith("layer_norm_weight") or self.megatron_param.endswith("layer_norm_bias") ): return "replicated" return "column" diff --git a/verl/utils/modelopt/qat_weight_exporter.py b/verl/utils/modelopt/qat_weight_exporter.py new file mode 100644 index 00000000000..8b4e3787ef8 --- /dev/null +++ b/verl/utils/modelopt/qat_weight_exporter.py @@ -0,0 +1,365 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from dataclasses import dataclass +from typing import Any, Iterator, Optional + +import torch +from modelopt.torch.export.quant_utils import ( + QUANTIZATION_NONE, + QUANTIZATION_NVFP4, + get_quantization_format, + get_weight_block_size, + to_quantized_weight, +) +from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor + +logger = logging.getLogger(__name__) + +# NVFP4 two-level scaling denominator: FP4_MAX (6.0) * FP8_MAX (448.0). +_NVFP4_AMAX_DENOMINATOR = 6.0 * 448.0 + + +@dataclass +class _QuantMeta: + """Quantization metadata for a single parameter.""" + + qformat: str + block_size: int + weight_amax: Optional[torch.Tensor] + input_amax: Optional[torch.Tensor] = None + input_quantizer: Any = None + + +class QATWeightExporter: + """Export QAT-trained bf16 weights as quantized weights (e.g. NVFP4).""" + + def __init__( + self, + actor_module: list, + qat_mode: str = "w4a16", + bridge: Any = None, + ): + self.qat_mode = qat_mode + self._actor_module = actor_module + + self._registry = self._get_mapping_registry(bridge) + if self._registry is None: + raise ValueError( + "QATWeightExporter requires a bridge with a valid MappingRegistry. " + "Ensure use_mbridge=True and vanilla_mbridge=False." + ) + + self._pp_size, self._pp_rank, self._pp_group = _get_parallel_info("pp") + self._ep_size, self._ep_rank, self._ep_group = _get_parallel_info("ep") + + self._config = self._get_model_config(actor_module) + self._num_local_experts = self._count_local_experts(actor_module) + + self._metadata: dict[str, _QuantMeta] = {} + self._collect_metadata(actor_module) + + if self._pp_size > 1 and self._pp_group is not None: + self._sync_metadata(self._pp_group) + if self._ep_size > 1 and self._ep_group is not None: + self._sync_metadata(self._ep_group) + + self._log_init_summary() + + def process_weights_iterator( + self, + per_tensor_param: Iterator[tuple[str, torch.Tensor]], + ) -> Iterator[tuple[str, torch.Tensor]]: + """Wrap a weight iterator to apply quantization. + + For each ``(hf_name, bf16_weight)`` from the iterator, yields the + quantized weight plus its scaling factors when the parameter is + quantized, or the original tensor unchanged otherwise. + """ + for hf_name, weight in per_tensor_param: + meta = self._resolve_quant_metadata(hf_name) + if meta is None: + yield (hf_name, weight) + else: + yield from self._quantize_weight(hf_name, weight, meta) + + @staticmethod + def _get_mapping_registry(bridge) -> Any: + """Extract the ``MappingRegistry`` from *bridge*, or return ``None``.""" + if bridge is None: + return None + try: + return bridge._model_bridge.mapping_registry() + except Exception as exc: + logger.warning("Failed to get mapping registry from bridge: %s", exc) + return None + + @staticmethod + def _get_model_config(actor_module): + """Return the ``TransformerConfig`` from the first model chunk.""" + try: + from verl.utils.megatron_utils import unwrap_model + + model = unwrap_model(actor_module[0]) + return getattr(model, "config", None) + except Exception: + return None + + @staticmethod + def _count_local_experts(actor_module) -> int: + """Count distinct ``local_experts.`` indices across all model chunks.""" + from verl.utils.megatron_utils import unwrap_model + + indices: set[int] = set() + for module in actor_module: + model = unwrap_model(module) + for name, _ in model.named_modules(): + m = re.search(r"local_experts\.(\d+)", name) + if m: + indices.add(int(m.group(1))) + return max(indices) + 1 if indices else 0 + + def _collect_metadata(self, actor_module: list) -> None: + """Walk all QAT modules and populate ``self._metadata``.""" + from verl.utils.megatron_utils import unwrap_model + + for vpp_idx, module in enumerate(actor_module): + model = unwrap_model(module) + for name, submodule in model.named_modules(): + qformat = get_quantization_format(submodule) + if qformat == QUANTIZATION_NONE: + continue + block_size = get_weight_block_size(submodule) + if block_size == 0: + continue + + w_q = getattr(submodule, "weight_quantizer", None) + i_q = getattr(submodule, "input_quantizer", None) + w_amax = w_q._amax.clone().cpu() if w_q and getattr(w_q, "_amax", None) is not None else None + i_amax = i_q._amax.clone().cpu() if i_q and getattr(i_q, "_amax", None) is not None else None + + meta = _QuantMeta( + qformat=qformat, + block_size=block_size, + weight_amax=w_amax, + input_amax=i_amax, + input_quantizer=i_q, + ) + + for pname, _ in submodule.named_parameters(recurse=False): + full_name = f"{name}.{pname}" if name else pname + global_name = self._local_to_global_param_name(full_name, vpp_idx) + self._metadata[global_name] = meta + + def _local_to_global_param_name(self, name: str, vpp_idx: int) -> str: + """Convert a local parameter name to global (PP layers + EP experts).""" + if self._pp_size > 1 and "layers." in name and self._config is not None: + from megatron.bridge.models.conversion.model_bridge import ( + _megatron_local_name_to_global, + ) + + name = _megatron_local_name_to_global(self._actor_module, self._config, name, vpp_idx) + + # SequentialMLP ``local_experts.{idx}`` needs manual global conversion; + # TEGroupedMLP is already handled by ``_megatron_local_name_to_global``. + if self._ep_size > 1 and self._num_local_experts > 0: + m = re.search(r"local_experts\.(\d+)\.", name) + if m: + local_idx = int(m.group(1)) + global_idx = self._ep_rank * self._num_local_experts + local_idx + name = name.replace( + f"local_experts.{local_idx}.", + f"local_experts.{global_idx}.", + 1, + ) + + return name + + def _sync_metadata(self, group) -> None: + """Gather and merge metadata across the given process group.""" + world_size = torch.distributed.get_world_size(group=group) + + local_info = { + name: { + "qformat": m.qformat, + "block_size": m.block_size, + "weight_amax": m.weight_amax, + "input_amax": m.input_amax, + } + for name, m in self._metadata.items() + } + + gathered: list[dict | None] = [None] * world_size + torch.distributed.all_gather_object(gathered, local_info, group=group) + + for rank_info in gathered: + if rank_info is None: + continue + for name, info in rank_info.items(): + if name in self._metadata: + continue + self._metadata[name] = _QuantMeta( + qformat=info["qformat"], + block_size=info["block_size"], + weight_amax=info["weight_amax"], + input_amax=info["input_amax"], + input_quantizer=None, + ) + + def _resolve_quant_metadata(self, hf_name: str) -> Optional[_QuantMeta]: + """Resolve *hf_name* -> Megatron param name -> quantisation metadata. + + Returns ``None`` for parameters that are not quantised (norms, + embeddings, MoE routers, etc.). + """ + if not hf_name.endswith(".weight") or "norm" in hf_name: + return None + + for resolved in _iter_hf_to_megatron_matches(self._registry, hf_name): + meta = self._metadata.get(resolved.megatron_param) + if meta is not None: + return meta + + return None + + def _quantize_weight( + self, + name: str, + weight: torch.Tensor, + meta: _QuantMeta, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Dispatch to the format-specific quantiser.""" + if meta.qformat == QUANTIZATION_NVFP4: + yield from self._quantize_nvfp4(name, weight, meta) + else: + logger.warning("Unsupported qformat %s for %s; passing through", meta.qformat, name) + yield (name, weight) + + def _quantize_nvfp4( + self, + name: str, + weight: torch.Tensor, + meta: _QuantMeta, + ) -> Iterator[tuple[str, torch.Tensor]]: + """NVFP4 two-level quantization. + + Produces up to four tensors: + ``(name, packed_uint8_weight)`` + ``(weight_scale, per_block_fp8_scale)`` + ``(weight_scale_2, global_scale_from_amax)`` + ``(input_scale, activation_scale)`` -- only when available + """ + w_amax = meta.weight_amax.to(weight.device) + w_scale_2 = w_amax.float() / _NVFP4_AMAX_DENOMINATOR + + w_scale = NVFP4QTensor.get_weights_scaling_factor( + weight, + meta.block_size, + weights_scaling_factor_2=w_scale_2.to(weight.device), + )[0] + + quantized = to_quantized_weight(weight, w_scale, meta.qformat, w_scale_2, meta.block_size) + + yield (name, quantized) + yield (_derive_scale_name(name, "weight_scale"), w_scale) + yield (_derive_scale_name(name, "weight_scale_2"), w_scale_2) + + input_scale = _compute_input_scale(meta) + if input_scale is not None: + yield (_derive_scale_name(name, "input_scale"), input_scale) + + def _log_init_summary(self) -> None: + """Log a one-line initialisation summary.""" + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + logger.info( + "[QAT Exporter][Rank %d] mode=%s, metadata_count=%d, pp=%d/%d, ep=%d/%d", + rank, + self.qat_mode, + len(self._metadata), + self._pp_rank, + self._pp_size, + self._ep_rank, + self._ep_size, + ) + + +def _iter_hf_to_megatron_matches(registry, hf_name: str): + """Yield all resolved mappings whose HF pattern matches *hf_name*.""" + for pattern_info, mapping in registry._reverse_patterns: + if isinstance(mapping.hf_param, str): + pattern = pattern_info + if pattern is None: + if mapping.hf_param == hf_name: + yield mapping + else: + match = pattern.match(hf_name) + if match: + yield mapping.resolve(match.groups()) + else: + patterns_dict = pattern_info + for key, pattern in patterns_dict.items(): + if pattern is None: + if mapping.hf_param[key] == hf_name: + yield mapping.resolve(()) + else: + match = pattern.match(hf_name) + if match: + yield mapping.resolve(match.groups()) + + +def _get_parallel_info(kind: str) -> tuple[int, int, Any]: + """Return ``(world_size, rank, process_group)`` for *kind* in {pp, ep}.""" + try: + from megatron.core import parallel_state as mpu + + if kind == "pp": + size = mpu.get_pipeline_model_parallel_world_size() + rank = mpu.get_pipeline_model_parallel_rank() + group = mpu.get_pipeline_model_parallel_group() if size > 1 else None + elif kind == "ep": + size = mpu.get_expert_model_parallel_world_size() + rank = mpu.get_expert_model_parallel_rank() if size > 1 else 0 + group = mpu.get_expert_model_parallel_group() if size > 1 else None + else: + return 1, 0, None + return size, rank, group + except Exception: + return 1, 0, None + + +def _derive_scale_name(weight_name: str, suffix: str) -> str: + """Derive a scale parameter name from a weight parameter name. + + ``"model.layers.0.self_attn.q_proj.weight"`` + -> ``"model.layers.0.self_attn.q_proj.weight_scale"`` + """ + result = weight_name.replace(".weight", f".{suffix}") + return result if result != weight_name else f"{weight_name}_{suffix}" + + +def _compute_input_scale(meta: _QuantMeta) -> Optional[torch.Tensor]: + """Derive the activation scale from the quantizer or synced amax.""" + if meta.input_quantizer is not None: + if hasattr(NVFP4QTensor, "get_activation_scaling_factor"): + return NVFP4QTensor.get_activation_scaling_factor(meta.input_quantizer) + if hasattr(meta.input_quantizer, "_amax") and meta.input_quantizer._amax is not None: + return meta.input_quantizer._amax.float() / _NVFP4_AMAX_DENOMINATOR + + if meta.input_amax is not None: + return meta.input_amax.float() / _NVFP4_AMAX_DENOMINATOR + + return None diff --git a/verl/utils/modelopt/quantize.py b/verl/utils/modelopt/quantize.py index 12fb811d292..20307259249 100644 --- a/verl/utils/modelopt/quantize.py +++ b/verl/utils/modelopt/quantize.py @@ -16,16 +16,11 @@ """ModelOpt NVFP4 quantization config and application for Megatron QAT.""" import logging -from dataclasses import dataclass -from typing import Any, Optional - -import torch -import torch.nn as nn import modelopt.torch.quantization as mtq +import torch.nn as nn from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg - logger = logging.getLogger(__name__) @@ -41,14 +36,7 @@ def _ignore_patterns_to_quant_cfg(ignore_patterns: list[str]) -> dict: - """Convert user-provided ignore patterns to ModelOpt ``quant_cfg`` entries. - - Each pattern is wrapped with ``*`` on both ends (if not already present) - so that it performs glob-style substring matching against module names. - For example, ``"lm_head"`` becomes ``"*lm_head*"`` and ``"mlp.gate."`` - becomes ``"*mlp.gate.*"`` (the trailing dot prevents matching - ``mlp.gate_proj``). - """ + """Convert user-provided ignore patterns to ModelOpt ``quant_cfg`` entries.""" cfg = {} for pattern in ignore_patterns: key = pattern @@ -64,17 +52,7 @@ def build_quantize_config( qat_mode: str, ignore_patterns: list[str] | None = None, ) -> dict: - """Build a complete ModelOpt quantization config for ``mtq.quantize``. - - Args: - qat_mode: Quantization mode. Currently only ``"w4a16"`` is supported. - ignore_patterns: Layer name patterns to skip quantization for. - Uses glob-style matching (e.g. ``"lm_head"`` matches ``*lm_head*``). - If *None*, uses :data:`DEFAULT_IGNORE_PATTERNS`. - - Returns: - A config dict suitable for ``mtq.quantize()``. - """ + """Build a complete ModelOpt quantization config for ``mtq.quantize``.""" if qat_mode != "w4a16": raise ValueError(f"Only 'w4a16' is supported, got: {qat_mode}") @@ -98,34 +76,7 @@ def apply_qat( qat_mode: str, ignore_patterns: list[str] | None = None, ) -> nn.Module: - """Apply Quantization-Aware Training to a Megatron model. - - Args: - model: The Megatron model to quantize. - qat_mode: Quantization mode. Currently only ``"w4a16"`` is supported. - ignore_patterns: Layer name patterns to skip quantization for. - If *None*, uses :data:`DEFAULT_IGNORE_PATTERNS`. - - Returns: - The quantized model (modified in-place). - """ + """Apply Quantization-Aware Training to a Megatron model.""" config = build_quantize_config(qat_mode, ignore_patterns) mtq.quantize(model, config) return model - - -@dataclass -class QuantizationMetadata: - """Metadata for a quantized module.""" - - qformat: str - weight_quantizer: Any - input_quantizer: Any - module: torch.nn.Module - vpp_idx: int - block_size: int = 16 # Default NVFP4 block size - weight_amax: Optional[torch.Tensor] = None - input_amax: Optional[torch.Tensor] = None - is_local: bool = True - global_expert_idx: Optional[int] = None - local_expert_idx: Optional[int] = None diff --git a/verl/utils/modelopt/vllm_modelopt_patch.py b/verl/utils/modelopt/vllm_modelopt_patch.py index ca7f99abd96..484685ac0a8 100644 --- a/verl/utils/modelopt/vllm_modelopt_patch.py +++ b/verl/utils/modelopt/vllm_modelopt_patch.py @@ -44,6 +44,7 @@ # Utility Functions # ============================================================================ + def save_param_meta(layer: torch.nn.Module, param_name: str): """Save parameter metadata (shape, dtype, param_class, dims) for later rebuild.""" if not hasattr(layer, "_hf_param_meta"): @@ -140,6 +141,7 @@ def _update_ref_or_create(layer, ref_name, new_data): # ModelOptParamMetaDict # ============================================================================ + class ModelOptParamMetaDict(dict): """ Dict-like class for parameter management with metadata-based rebuild @@ -295,13 +297,19 @@ def _modelopt_dense_process_weights(self, layer: torch.nn.Module) -> None: perm = torch.empty(0, dtype=torch.int, device=device) qweight = weight_data.view(torch.int32).T.contiguous() marlin_weight = ops.gptq_marlin_repack( - b_q_weight=qweight, perm=perm, - size_k=part_size_k, size_n=part_size_n, num_bits=4, + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4, ) weight_scale = weight_scale_data.T.contiguous().to(param_dtype) weight_scale = marlin_permute_scales( - s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size, + s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=group_size, ) marlin_weight_scale = nvfp4_marlin_process_scales(weight_scale) marlin_weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2_max.to(param_dtype)) @@ -329,6 +337,7 @@ def _modelopt_dense_process_weights(self, layer: torch.nn.Module) -> None: # MoE Helpers (Marlin) # ============================================================================ + def _marlin_repack_experts(packed, perm, size_k, size_n, num_experts): """Repack weight for each expert into Marlin format and stack.""" import vllm._custom_ops as ops @@ -338,8 +347,11 @@ def _marlin_repack_experts(packed, perm, size_k, size_n, num_experts): qweight = packed[i].view(torch.int32).T.contiguous() result.append( ops.gptq_marlin_repack( - b_q_weight=qweight, perm=perm, - size_k=size_k, size_n=size_n, num_bits=4, + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, ) ) return torch.stack(result) @@ -363,8 +375,14 @@ def _marlin_process_scales_experts(scale_hf, param_dtype, size_k, size_n, group_ # ============================================================================ _MOE_HF_PARAMS = [ - "w13_weight", "w2_weight", "w13_weight_scale", "w2_weight_scale", - "w13_weight_scale_2", "w2_weight_scale_2", "w13_input_scale", "w2_input_scale", + "w13_weight", + "w2_weight", + "w13_weight_scale", + "w2_weight_scale", + "w13_weight_scale_2", + "w2_weight_scale_2", + "w13_input_scale", + "w2_input_scale", ] @@ -393,10 +411,20 @@ def _modelopt_moe_marlin_convert(self, layer: torch.nn.Module, is_first_call: bo # Process scales w13_weight_scale_marlin = _marlin_process_scales_experts( - layer.w13_weight_scale.data, param_dtype, size_k_w13, size_n_w13, group_size, e, + layer.w13_weight_scale.data, + param_dtype, + size_k_w13, + size_n_w13, + group_size, + e, ) w2_weight_scale_marlin = _marlin_process_scales_experts( - layer.w2_weight_scale.data, param_dtype, size_k_w2, size_n_w2, group_size, e, + layer.w2_weight_scale.data, + param_dtype, + size_k_w2, + size_n_w2, + group_size, + e, ) # Process global scales (w13_weight_scale_2 is already (E,) after common processing) @@ -412,14 +440,23 @@ def _modelopt_moe_marlin_convert(self, layer: torch.nn.Module, is_first_call: bo layer.w2_weight_scale_2 = Parameter(w2_scale_2_processed, requires_grad=False) if not hasattr(layer, "_marlin_tensor_refs"): layer._marlin_tensor_refs = {} - for rn in ["w13_weight", "w2_weight", "w13_weight_scale", "w2_weight_scale", - "w13_weight_scale_2", "w2_weight_scale_2"]: + for rn in [ + "w13_weight", + "w2_weight", + "w13_weight_scale", + "w2_weight_scale", + "w13_weight_scale_2", + "w2_weight_scale_2", + ]: layer._marlin_tensor_refs[rn] = getattr(layer, rn).data else: for rn, nd in [ - ("w13_weight", w13_weight_marlin), ("w2_weight", w2_weight_marlin), - ("w13_weight_scale", w13_weight_scale_marlin), ("w2_weight_scale", w2_weight_scale_marlin), - ("w13_weight_scale_2", w13_scale_2_processed), ("w2_weight_scale_2", w2_scale_2_processed), + ("w13_weight", w13_weight_marlin), + ("w2_weight", w2_weight_marlin), + ("w13_weight_scale", w13_weight_scale_marlin), + ("w2_weight_scale", w2_weight_scale_marlin), + ("w13_weight_scale_2", w13_scale_2_processed), + ("w2_weight_scale_2", w2_scale_2_processed), ]: _update_ref_or_create(layer, rn, nd) @@ -444,9 +481,7 @@ def _modelopt_moe_process_weights(self, layer: torch.nn.Module) -> None: _save_weight_loaders(layer, _MOE_HF_PARAMS) # ---- w13_weight_scale_2: reduce (E, 2) → (E,) ---- - if self.moe.is_act_and_mul and not torch.allclose( - layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] - ): + if self.moe.is_act_and_mul and not torch.allclose(layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]): logger.warning("w1_weight_scale_2 must match w3_weight_scale_2. Accuracy may be affected.") w13_weight_scale_2 = layer.w13_weight_scale_2.data @@ -463,6 +498,7 @@ def _modelopt_moe_process_weights(self, layer: torch.nn.Module) -> None: # KV Cache Patch # ============================================================================ + def _modelopt_kv_process_weights(self, layer) -> None: """ Replacement for BaseKVCacheMethod.process_weights_after_loading. @@ -614,15 +650,15 @@ def apply_modelopt_nvfp4_patches(): logger.info("Applying ModelOpt NVFP4 patches for dynamic weight loading...") + from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.modelopt import ( - ModelOptNvFp4LinearMethod, ModelOptNvFp4FusedMoE, + ModelOptNvFp4LinearMethod, ) - from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod ModelOptNvFp4LinearMethod.process_weights_after_loading = _modelopt_dense_process_weights ModelOptNvFp4FusedMoE.process_weights_after_loading = _modelopt_moe_process_weights BaseKVCacheMethod.process_weights_after_loading = _modelopt_kv_process_weights _patched = True - logger.info("Applied 3 ModelOpt NVFP4 patches (Dense, MoE, KV)") \ No newline at end of file + logger.info("Applied 3 ModelOpt NVFP4 patches (Dense, MoE, KV)") diff --git a/verl/utils/modelopt/weight_processor.py b/verl/utils/modelopt/weight_processor.py deleted file mode 100644 index 8c126216da8..00000000000 --- a/verl/utils/modelopt/weight_processor.py +++ /dev/null @@ -1,972 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import re -from typing import Any, Iterator, Optional - -import torch - -from modelopt.torch.export.quant_utils import ( - QUANTIZATION_NONE, - QUANTIZATION_NVFP4, - get_quantization_format, - get_weight_block_size, - to_quantized_weight, -) -from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor - -from verl.utils.megatron_utils import unwrap_model -from verl.utils.modelopt.quantize import QuantizationMetadata - - -class QATWeightPostProcessor: - """ - Post-processor for extracting quantization info from QAT trained modules - and converting bf16 weights to quantized formats (e.g., NVFP4). - - Key Design: - 1. Collect quantization metadata (quantizers, amax, block_size) from QAT modules - 2. Process all_gathered bf16 weights to compute quantized weights and scaling factors - 3. The scaling factors are computed on the merged (all_gathered) weights to ensure - correct block boundaries for per-block quantization (NVFP4) - - Note on TP (Tensor Parallelism): - - For NVFP4, weight_scale_2 (global scale) should ideally be computed from the full - (all_gathered) weight to ensure consistency across TP ranks. - - If use_calibrated_scale_2=True (default), we use the QAT calibrated amax which may - only reflect the local shard's statistics. - - If use_calibrated_scale_2=False, we recompute weight_scale_2 from the merged weight. - Note on EP (Expert Parallelism): - - When EP is enabled, each rank only holds a subset of experts (local_experts) - - We synchronize metadata across all EP ranks to ensure complete metadata for all experts - - Local expert indices are converted to global expert indices for proper mapping - """ - - def __init__( - self, - actor_module: list, - qat_mode: str = "w4a16", - dtype: torch.dtype = torch.bfloat16, - use_calibrated_scale_2: bool = False, - ): - """ - Initialize the QAT weight post-processor. - - Args: - actor_module: List of QAT trained model chunks (vpp chunks) - qat_mode: QAT mode, e.g. "w4a16" or "w4a4". - dtype: Original data type (bf16) - use_calibrated_scale_2: If True, use QAT calibrated amax for weight_scale_2. - If False, recompute weight_scale_2 from merged weights. Recommended to set - False when using TP to ensure consistent global scale. - """ - self.actor_module = actor_module - self.qat_mode = qat_mode - self.dtype = dtype - self.use_calibrated_scale_2 = use_calibrated_scale_2 - self.quant_metadata: dict[str, QuantizationMetadata] = {} - self.ep_size, self.ep_rank, self.ep_group = self._get_ep_info() - self.pp_size, self.pp_rank, self.pp_group = self._get_pp_info() - self.num_local_experts = 0 # Will be determined during metadata building - - self._build_quantization_metadata() - - global_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - - # Synchronize metadata across EP ranks if EP is enabled - if self.ep_size > 1: - print(f"[QAT PostProcessor][Rank {global_rank}] Starting EP metadata sync...") - self._sync_quantization_metadata_across_ep() - print(f"[QAT PostProcessor][Rank {global_rank}] After EP sync: metadata_count={len(self.quant_metadata)}") - - # Synchronize metadata across PP ranks if PP is enabled - # This ensures all PP ranks have complete metadata for all layers - if self.pp_size > 1: - print(f"[QAT PostProcessor][Rank {global_rank}] Starting PP metadata sync...") - self._sync_quantization_metadata_across_pp() - print(f"[QAT PostProcessor][Rank {global_rank}] After PP sync: metadata_count={len(self.quant_metadata)}") - else: - print(f"[QAT PostProcessor][Rank {global_rank}] PP sync skipped: pp_size={self.pp_size}") - - self._log_initialization_info() - - def _get_ep_info(self) -> tuple[int, int, Any]: - """ - Get Expert Parallel information from Megatron parallel state. - - Returns: - (ep_size, ep_rank, ep_group): EP world size, rank, and process group - """ - try: - from megatron.core import parallel_state as mpu - - ep_size = mpu.get_expert_model_parallel_world_size() - if ep_size > 1: - ep_rank = mpu.get_expert_model_parallel_rank() - ep_group = mpu.get_expert_model_parallel_group() - return ep_size, ep_rank, ep_group - except Exception: - # EP not enabled or mpu not available - pass - return 1, 0, None - - def _get_pp_info(self) -> tuple[int, int, Any]: - """ - Get Pipeline Parallel information from Megatron parallel state. - - Returns: - (pp_size, pp_rank, pp_group): PP world size, rank, and process group - """ - try: - from megatron.core import parallel_state as mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_group = mpu.get_pipeline_model_parallel_group() - - if torch.distributed.get_rank() == 0: - print(f"[QAT PostProcessor] PP info: pp_size={pp_size}, pp_rank={pp_rank}, pp_group={pp_group}") - - if pp_size > 1: - return pp_size, pp_rank, pp_group - else: - return pp_size, pp_rank, None - except Exception as e: - if torch.distributed.get_rank() == 0: - print(f"[QAT PostProcessor] Warning: Failed to get PP info: {e}") - pass - return 1, 0, None - - def _extract_layer_index(self, name: str) -> Optional[int]: - """ - Extract layer index from parameter name. - - For mcore format: decoder.layers.{layer_idx}.xxx - - Returns: - Layer index or None if not a layer parameter - """ - match = re.search(r"layers\.(\d+)\.", name) - if match: - return int(match.group(1)) - return None - - def _get_num_layers_per_pp_stage(self) -> int: - """ - Get the number of layers per PP stage from local metadata. - - This is calculated as max(local_layer_indices) + 1 - """ - max_layer_idx = -1 - for name in self.quant_metadata.keys(): - layer_idx = self._extract_layer_index(name) - if layer_idx is not None and layer_idx > max_layer_idx: - max_layer_idx = layer_idx - return max_layer_idx + 1 if max_layer_idx >= 0 else 0 - - def _convert_local_to_global_layer_name(self, name: str, source_pp_rank: int, num_layers_per_stage: int) -> str: - """ - Convert parameter name from local layer index to global layer index. - - Args: - name: Parameter name with local layer index (e.g., decoder.layers.0.xxx) - source_pp_rank: The PP rank this name came from - num_layers_per_stage: Number of layers per PP stage - - Returns: - Parameter name with global layer index - """ - local_layer_idx = self._extract_layer_index(name) - if local_layer_idx is None: - return name - - global_layer_idx = source_pp_rank * num_layers_per_stage + local_layer_idx - return re.sub(r"layers\.(\d+)\.", f"layers.{global_layer_idx}.", name, count=1) - - def _extract_local_expert_index(self, name: str) -> Optional[int]: - """ - Extract local expert index from parameter name. - - For SequentialMLP structure, the pattern is: - decoder.layers.{layer}.mlp.experts.local_experts.{local_idx}.linear_fc1/fc2.weight - - Args: - name: Parameter name in mcore format - - Returns: - Local expert index or None if not an expert parameter - """ - match = re.search(r"local_experts\.(\d+)\.", name) - if match: - return int(match.group(1)) - return None - - def _local_to_global_expert_index(self, local_idx: int) -> int: - """ - Convert local expert index to global expert index. - - Global index = ep_rank * num_local_experts + local_idx - - Args: - local_idx: Local expert index on this EP rank - - Returns: - Global expert index - """ - return self.ep_rank * self.num_local_experts + local_idx - - def _convert_name_to_global_index(self, name: str, local_idx: int, global_idx: int) -> str: - """ - Convert parameter name from local to global expert index. - - Args: - name: Original parameter name with local index - local_idx: Local expert index - global_idx: Global expert index - - Returns: - Parameter name with global expert index - """ - return name.replace(f"local_experts.{local_idx}.", f"local_experts.{global_idx}.") - - def _build_quantization_metadata(self): - """ - Extract quantization metadata from all modules in actor_module. - Stores: {param_name: QuantizationMetadata} - - For EP training with SequentialMLP: - - Detects local expert indices and computes global indices - - Stores metadata with global expert indices as keys - """ - # First pass: collect all local expert indices to determine num_local_experts - local_expert_indices = set() - - for vpp_idx, module in enumerate(self.actor_module): - model = unwrap_model(module) - for name, submodule in model.named_modules(): - local_idx = self._extract_local_expert_index(name) - if local_idx is not None: - local_expert_indices.add(local_idx) - - if local_expert_indices: - self.num_local_experts = max(local_expert_indices) + 1 - if torch.distributed.get_rank() == 0: - print(f"[QAT PostProcessor] Detected {self.num_local_experts} local experts per EP rank") - - # Second pass: build metadata with global indices - for vpp_idx, module in enumerate(self.actor_module): - model = unwrap_model(module) - - for name, submodule in model.named_modules(): - # Check if this module is quantized - qformat = get_quantization_format(submodule) - if qformat == QUANTIZATION_NONE: - continue - - block_size = get_weight_block_size(submodule) - if block_size == 0: - continue - - weight_quantizer = getattr(submodule, "weight_quantizer", None) - input_quantizer = getattr(submodule, "input_quantizer", None) - - # Extract amax values for synchronization - weight_amax = None - input_amax = None - if weight_quantizer is not None and hasattr(weight_quantizer, "_amax"): - weight_amax = weight_quantizer._amax.clone().cpu() if weight_quantizer._amax is not None else None - if input_quantizer is not None and hasattr(input_quantizer, "_amax"): - input_amax = input_quantizer._amax.clone().cpu() if input_quantizer._amax is not None else None - - # Determine global expert index for MoE experts - local_expert_idx = self._extract_local_expert_index(name) - global_expert_idx = None - if local_expert_idx is not None and self.ep_size > 1: - global_expert_idx = self._local_to_global_expert_index(local_expert_idx) - - metadata = QuantizationMetadata( - qformat=qformat, - weight_quantizer=weight_quantizer, - input_quantizer=input_quantizer, - module=submodule, - vpp_idx=vpp_idx, - block_size=block_size, - weight_amax=weight_amax, - input_amax=input_amax, - is_local=True, - global_expert_idx=global_expert_idx, - local_expert_idx=local_expert_idx, - ) - - for param_name, _ in submodule.named_parameters(recurse=False): - full_name = f"{name}.{param_name}" if name else param_name - - # For EP training, store with global expert index as key - if local_expert_idx is not None and self.ep_size > 1: - global_name = self._convert_name_to_global_index(full_name, local_expert_idx, global_expert_idx) - self.quant_metadata[global_name] = metadata - else: - self.quant_metadata[full_name] = metadata - - def _log_initialization_info(self): - """Log initialization information for debugging.""" - global_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - - print( - f"[QAT PostProcessor][Rank {global_rank}] Initialized with qat_mode: {self.qat_mode}" - ) - print(f"[QAT PostProcessor][Rank {global_rank}] Found {len(self.quant_metadata)} quantized parameters") - if self.ep_size > 1: - print( - f"[QAT PostProcessor][Rank {global_rank}] EP enabled: ep_size={self.ep_size}, ep_rank={self.ep_rank}, " - f"num_local_experts={self.num_local_experts}" - ) - if self.pp_size > 1: - local_count = sum(1 for m in self.quant_metadata.values() if m.is_local) - remote_count = sum(1 for m in self.quant_metadata.values() if not m.is_local) - print( - f"[QAT PostProcessor][Rank {global_rank}] PP enabled: pp_size={self.pp_size}, pp_rank={self.pp_rank}, " - f"local_params={local_count}, remote_params={remote_count}" - ) - - # Log all metadata entries for debugging - for name, metadata in self.quant_metadata.items(): - extra_info = "" - if metadata.global_expert_idx is not None: - extra_info = f", global_expert_idx={metadata.global_expert_idx}" - if not metadata.is_local: - extra_info += ", is_local=False" - print( - f"[QAT PostProcessor][Rank {global_rank}] Metadata: {name}, qformat={metadata.qformat}, " - f"block_size={metadata.block_size}{extra_info}" - ) - - def _sync_quantization_metadata_across_ep(self): - """ - Synchronize quantization metadata across all EP (Expert Parallel) ranks. - - When EP is enabled, each rank only holds metadata for its local experts. - This method gathers metadata from all EP ranks and merges them so that - every rank has complete metadata for all experts. - - For SequentialMLP structure: - - Local expert indices are converted to global indices - - Metadata is gathered and merged using global indices as keys - - Non-local experts have is_local=False and module/quantizers set to None - """ - if self.ep_size <= 1 or self.ep_group is None: - return - - # Prepare serializable metadata info for all_gather - # We can't send module/quantizer objects, so we extract necessary info - local_metadata_info = {} - for name, metadata in self.quant_metadata.items(): - # Only sync MoE expert metadata (containing "local_experts") - if "local_experts" not in name: - continue - - local_metadata_info[name] = { - "qformat": metadata.qformat, - "block_size": metadata.block_size, - "vpp_idx": metadata.vpp_idx, - "weight_amax": metadata.weight_amax, - "input_amax": metadata.input_amax, - "global_expert_idx": metadata.global_expert_idx, - "local_expert_idx": metadata.local_expert_idx, - } - - # Also send num_local_experts for validation - sync_data = { - "metadata": local_metadata_info, - "num_local_experts": self.num_local_experts, - "ep_rank": self.ep_rank, - } - - # Gather metadata from all EP ranks - all_sync_data = [None] * self.ep_size - torch.distributed.all_gather_object(all_sync_data, sync_data, group=self.ep_group) - - # Validate that all ranks have the same num_local_experts - for rank_idx, data in enumerate(all_sync_data): - if data is not None and data["num_local_experts"] != self.num_local_experts: - print( - f"[QAT PostProcessor] Warning: EP rank {rank_idx} has " - f"{data['num_local_experts']} local experts, expected {self.num_local_experts}" - ) - - # Merge metadata from all ranks - for rank_idx, data in enumerate(all_sync_data): - if rank_idx == self.ep_rank: - # Skip local metadata (already have it) - continue - - if data is None: - continue - - rank_metadata = data["metadata"] - for name, info in rank_metadata.items(): - if name in self.quant_metadata: - # Already have this metadata (shouldn't happen with proper global indices) - continue - - # Create metadata entry for non-local experts - # Note: module and quantizers are not available for non-local experts - metadata = QuantizationMetadata( - qformat=info["qformat"], - weight_quantizer=None, # Not available for non-local - input_quantizer=None, # Not available for non-local - module=None, # Not available for non-local - vpp_idx=info["vpp_idx"], - block_size=info["block_size"], - weight_amax=info["weight_amax"], - input_amax=info["input_amax"], - is_local=False, # Mark as non-local - global_expert_idx=info["global_expert_idx"], - local_expert_idx=info["local_expert_idx"], - ) - self.quant_metadata[name] = metadata - - # Count local vs non-local experts - num_local = sum(1 for m in self.quant_metadata.values() if m.is_local and m.global_expert_idx is not None) - num_remote = sum(1 for m in self.quant_metadata.values() if not m.is_local and m.global_expert_idx is not None) - - if torch.distributed.get_rank() == 0: - print( - f"[QAT PostProcessor] EP metadata sync complete. " - f"EP size: {self.ep_size}, Local expert params: {num_local}, " - f"Remote expert params: {num_remote}, Total metadata entries: {len(self.quant_metadata)}" - ) - - def _sync_quantization_metadata_across_pp(self): - """ - Synchronize quantization metadata across all PP (Pipeline Parallel) ranks. - - When PP is enabled, each rank only holds layers for its pipeline stage. - This method gathers metadata from all PP ranks and merges them so that - every rank has complete metadata for all layers. - - IMPORTANT: In Megatron's PP mode, each PP rank uses LOCAL layer indices - (starting from 0), not global layer indices. For example: - - PP rank 0 has decoder.layers.0 (globally layer 0) - - PP rank 1 has decoder.layers.0 (globally layer 1) - - This method converts local layer indices to global layer indices during sync. - - For MoE SequentialMLP structure with PP: - - Different PP ranks hold different decoder layers - - Each PP rank builds metadata only for its local layers - - We gather and merge metadata from all PP ranks - - Layer indices are converted from local to global during merge - - Non-local layers have is_local=False and module/quantizers set to None - """ - global_rank = torch.distributed.get_rank() - - print( - f"[QAT PostProcessor][Rank {global_rank}] PP sync starting: " - f"pp_size={self.pp_size}, pp_rank={self.pp_rank}, pp_group={self.pp_group}, " - f"local_metadata_count={len(self.quant_metadata)}" - ) - - if self.pp_size <= 1: - print(f"[QAT PostProcessor][Rank {global_rank}] PP sync skipped: pp_size <= 1") - return - - if self.pp_group is None: - print(f"[QAT PostProcessor][Rank {global_rank}] PP sync skipped: pp_group is None") - return - - # Verify PP group size matches expected pp_size - actual_pp_group_size = torch.distributed.get_world_size(group=self.pp_group) - print( - f"[QAT PostProcessor][Rank {global_rank}] PP group size verification: " - f"expected={self.pp_size}, actual={actual_pp_group_size}" - ) - - # Calculate number of layers per PP stage (needed for global layer index conversion) - num_layers_per_stage = self._get_num_layers_per_pp_stage() - print(f"[QAT PostProcessor][Rank {global_rank}] Detected {num_layers_per_stage} layers per PP stage") - - # First, convert our local metadata to use global layer indices - # This is needed so we can properly merge with other PP ranks - local_metadata_with_global_indices = {} - for name, metadata in self.quant_metadata.items(): - global_name = self._convert_local_to_global_layer_name(name, self.pp_rank, num_layers_per_stage) - local_metadata_with_global_indices[global_name] = metadata - - # Update our metadata dict to use global layer indices - self.quant_metadata = local_metadata_with_global_indices - - # Prepare serializable metadata info for all_gather - # We can't send module/quantizer objects, so we extract necessary info - local_metadata_info = {} - for name, metadata in self.quant_metadata.items(): - local_metadata_info[name] = { - "qformat": metadata.qformat, - "block_size": metadata.block_size, - "vpp_idx": metadata.vpp_idx, - "weight_amax": metadata.weight_amax, - "input_amax": metadata.input_amax, - "global_expert_idx": metadata.global_expert_idx, - "local_expert_idx": metadata.local_expert_idx, - "is_local": metadata.is_local, - } - - # Include PP rank info and num_layers_per_stage for global index conversion - sync_data = { - "metadata": local_metadata_info, - "pp_rank": self.pp_rank, - "num_local_experts": self.num_local_experts, - "num_layers_per_stage": num_layers_per_stage, - "global_rank": global_rank, - } - - print( - f"[QAT PostProcessor][Rank {global_rank}] Preparing to sync {len(local_metadata_info)} metadata entries, " - f"sample keys (global indices): {list(local_metadata_info.keys())[:3]}" - ) - - # Gather metadata from all PP ranks - all_sync_data = [None] * actual_pp_group_size - torch.distributed.all_gather_object(all_sync_data, sync_data, group=self.pp_group) - - # Debug: print what we received - print(f"[QAT PostProcessor][Rank {global_rank}] Received data from {len(all_sync_data)} PP ranks") - for i, data in enumerate(all_sync_data): - if data is not None: - sample_keys = list(data.get("metadata", {}).keys())[:2] - print( - f"[QAT PostProcessor][Rank {global_rank}] PP rank {i}: " - f"received from global_rank={data.get('global_rank', 'unknown')}, " - f"pp_rank={data.get('pp_rank', 'unknown')}, " - f"metadata_count={len(data.get('metadata', {}))}, " - f"sample_keys={sample_keys}" - ) - - # Merge metadata from all PP ranks - local_metadata_before = len(self.quant_metadata) - for rank_idx, data in enumerate(all_sync_data): - if data is None: - print(f"[QAT PostProcessor][Rank {global_rank}] Skipping rank_idx={rank_idx}: data is None") - continue - - source_pp_rank = data.get("pp_rank") - - # Skip our own data - compare by pp_rank from the data, not by index - if source_pp_rank == self.pp_rank: - print( - f"[QAT PostProcessor][Rank {global_rank}] Skipping rank_idx={rank_idx}: same pp_rank={self.pp_rank}" - ) - continue - - rank_metadata = data["metadata"] - added_count = 0 - skipped_existing = 0 - - for name, info in rank_metadata.items(): - # The name already has global layer indices (converted by the sender) - if name in self.quant_metadata: - # Already have this metadata (shouldn't happen with correct global indices) - existing = self.quant_metadata[name] - if existing.is_local: - skipped_existing += 1 - continue - # If both are non-local, just keep existing - skipped_existing += 1 - continue - - # Create metadata entry for layers from other PP ranks - # Note: module and quantizers are not available for non-local layers - metadata = QuantizationMetadata( - qformat=info["qformat"], - weight_quantizer=None, # Not available for non-local PP rank - input_quantizer=None, # Not available for non-local PP rank - module=None, # Not available for non-local PP rank - vpp_idx=info["vpp_idx"], - block_size=info["block_size"], - weight_amax=info["weight_amax"], - input_amax=info["input_amax"], - is_local=False, # Mark as non-local (from other PP rank) - global_expert_idx=info["global_expert_idx"], - local_expert_idx=info["local_expert_idx"], - ) - self.quant_metadata[name] = metadata - added_count += 1 - - print( - f"[QAT PostProcessor][Rank {global_rank}] From pp_rank={source_pp_rank}: " - f"added {added_count} metadata entries, skipped {skipped_existing} existing" - ) - - # Log statistics - metadata_added = len(self.quant_metadata) - local_metadata_before - local_count = sum(1 for m in self.quant_metadata.values() if m.is_local) - remote_count = sum(1 for m in self.quant_metadata.values() if not m.is_local) - - print( - f"[QAT PostProcessor][Rank {global_rank}] PP metadata sync complete. " - f"PP size: {self.pp_size}, PP rank: {self.pp_rank}, " - f"Local params: {local_count}, Remote params: {remote_count}, " - f"Metadata added from other PP ranks: {metadata_added}, " - f"Total metadata entries: {len(self.quant_metadata)}" - ) - - def _find_matching_metadata(self, param_name: str) -> QuantizationMetadata | None: - """ - Find matching quantization metadata for a parameter name. - Handles potential name variations between training and export. - """ - # Direct match - if param_name in self.quant_metadata: - return self.quant_metadata[param_name] - - # Try removing common prefixes/suffixes - variations = [ - param_name, - param_name.replace("module.", ""), - param_name.replace("model.", ""), - ] - - for var in variations: - if var in self.quant_metadata: - return self.quant_metadata[var] - - return None - - def _quantize_weight( - self, - name: str, - weight: torch.Tensor, - metadata: QuantizationMetadata, - ) -> Iterator[tuple[str, torch.Tensor]]: - """ - Quantize a single weight parameter. - - Args: - name: Parameter name - weight: The all_gathered bf16 weight tensor - metadata: Quantization metadata - - Yields: - (param_name, param_tensor) for quantized weight and scaling factors - """ - qformat = metadata.qformat - - if qformat == QUANTIZATION_NVFP4: - yield from self._quantize_nvfp4(name, weight, metadata) - else: - # Unknown format, pass through with warning - print(f"[QAT PostProcessor] Warning: Unknown qformat {qformat} for {name}, passing through") - yield (name, weight) - - def _quantize_nvfp4( - self, - name: str, - weight: torch.Tensor, - metadata: QuantizationMetadata, - ) -> Iterator[tuple[str, torch.Tensor]]: - """ - NVFP4 quantization implementation. - - NVFP4 uses two-level scaling: - - weight_scale_2 (global): per-tensor scale = amax / (6.0 * 448.0) - - weight_scale (per-block): per-block scale in FP8 format - - The weight is packed into uint8 format (2 x FP4 values per byte). - - Yields: - (name, quantized_weight): Packed uint8 weight - (name + "_scale", weight_scale): Per-block FP8 scaling factors - (name + "_scale_2", weight_scale_2): Global scaling factor - (name + "_input_scale", input_scale): Input activation scale (if available) - """ - weight_quantizer = metadata.weight_quantizer - input_quantizer = metadata.input_quantizer - block_size = metadata.block_size - qformat = metadata.qformat - - # # Ensure weight is in float for quantization computation - # weight_float = weight.float() - - # Step 1: Compute weight_scale_2 (global scale) - # For TP sharding, we should recompute weight_scale_2 from merged weight - # to ensure consistent global scale across all TP ranks. - if self.use_calibrated_scale_2 and weight_quantizer is not None and hasattr(weight_quantizer, "_amax"): - # Use QAT calibrated amax (may only reflect local shard statistics) - # weight_scale_2 = amax / (6.0 * 448.0) - weight_scale_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) - elif metadata.weight_amax is not None: - # Non-local expert (EP): Use synchronized amax from metadata - weight_amax = metadata.weight_amax.to(weight.device) - weight_scale_2 = weight_amax.float() / (6.0 * 448.0) - else: - # Compute from all_gathered weight directly (recommended for TP) - # weight_scale_2 = max(abs(weight)) / (6.0 * 448.0) - weight_scale_2 = NVFP4QTensor.get_weights_scaling_factor_2(weight) - - # Step 2: Compute weight_scale (per-block scale) - # This MUST be computed on the all_gathered (merged) weight to ensure - # correct block boundaries - # weight_scale shape: [out_dim, in_dim / block_size], dtype: float8_e4m3fn - weight_scale = NVFP4QTensor.get_weights_scaling_factor( - weight, - block_size, - weights_scaling_factor_2=weight_scale_2.to(weight.device), - )[0] - - # Step 3: Quantize weight to NVFP4 packed format - quantized_weight = to_quantized_weight( - weight, - weight_scale, - qformat, - weight_scale_2, - block_size, - ) - - # Yield quantized weight - yield (name, quantized_weight) - - # Yield scaling factors - # Note: Use consistent naming convention with ModelOpt export - scale_name = name.replace(".weight", ".weight_scale") - if scale_name == name: - scale_name = name + "_scale" - yield (scale_name, weight_scale) - - scale_2_name = name.replace(".weight", ".weight_scale_2") - if scale_2_name == name: - scale_2_name = name + "_scale_2" - yield (scale_2_name, weight_scale_2) - - # Step 4: Export input_scale (activation quantization) if available - if input_quantizer is not None: - input_scale = self._get_input_scale(input_quantizer) - if input_scale is not None: - input_scale_name = name.replace(".weight", ".input_scale") - if input_scale_name == name: - input_scale_name = name + "_input_scale" - yield (input_scale_name, input_scale) - - def _get_input_scale(self, input_quantizer) -> torch.Tensor | None: - """ - Get input activation scaling factor from quantizer. - - Args: - input_quantizer: The input quantizer from the module - - Returns: - Input scaling factor tensor or None - """ - if input_quantizer is None: - return None - - if not hasattr(input_quantizer, "_amax"): - return None - - amax = input_quantizer._amax - if amax is None: - return None - - # For NVFP4, use the NVFP4QTensor method - if hasattr(NVFP4QTensor, "get_activation_scaling_factor"): - return NVFP4QTensor.get_activation_scaling_factor(input_quantizer) - - return amax.float() / (6.0 * 448.0) - - def process_weights_iterator( - self, - per_tensor_param: Iterator[tuple[str, torch.Tensor]], - ) -> Iterator[tuple[str, torch.Tensor]]: - """ - Process an iterator of weights and yield quantized results. - - This method wraps per_tensor_generator output and applies quantization - to each weight, yielding the quantized weights and scaling factors. - - Args: - per_tensor_param: Iterator of (name, bf16_weight) from per_tensor_generator - - Yields: - (name, tensor): Quantized weight and associated scaling factors - """ - for name, param in per_tensor_param: - # quantize_single_tensor returns a list of (name, tensor) tuples - # For NVFP4: [(name, quant_weight), (name_scale, scale), (name_scale_2, scale_2), ...] - # For non-quantized: [(name, original_weight)] - quantized_results = self.quantize_single_tensor(name, param) - for q_name, q_tensor in quantized_results: - yield (q_name, q_tensor) - - def quantize_single_tensor( - self, - name: str, - weight: torch.Tensor, - ) -> list[tuple[str, torch.Tensor]]: - """ - Quantize a single tensor and return all related tensors as a list. - - This method is designed to be called AFTER weight_converter.convert_param, - so the name should already be in HF format (e.g., 'model.layers.0.self_attn.q_proj.weight'). - - Args: - name: Parameter name in HF format - weight: Single tensor to quantize - - Returns: - List of (param_name, param_tensor) tuples: - - (name, quantized_weight) - - (name.replace('.weight', '.weight_scale'), weight_scale) # for NVFP4 - - (name.replace('.weight', '.weight_scale_2'), weight_scale_2) # for NVFP4 - """ - # Find matching metadata using the original mcore name pattern - # Since name is now in HF format, we need to check if this layer type should be quantized - metadata = self._find_matching_metadata_by_hf_name(name) - - if metadata is None: - # Not quantized, return original tensor - return [(name, weight)] - - # Quantize this tensor - return list(self._quantize_weight(name, weight, metadata)) - - def _find_matching_metadata_by_hf_name(self, hf_name: str) -> QuantizationMetadata | None: - """ - Find matching quantization metadata for an HF-format parameter name. - - This maps HF names back to the original mcore names to find metadata. - E.g., 'model.layers.0.self_attn.q_proj.weight' -> check if qkv layer is quantized - - The mapping logic: - - HF q_proj/k_proj/v_proj.weight -> mcore linear_qkv.weight - - HF o_proj.weight -> mcore linear_proj.weight - - HF gate_proj/up_proj.weight -> mcore linear_fc1.weight - - HF down_proj.weight -> mcore linear_fc2.weight - - MoE experts: model.layers.X.mlp.experts.Y.gate_proj/up_proj/down_proj.weight - - MoE router (gate): model.layers.X.mlp.gate.weight -> NOT quantized (returns None) - """ - - # Only process weight parameters - if not hf_name.endswith(".weight") or hf_name.endswith("._amax") or "norm" in hf_name: - return None - - # Check for MoE router (gate) - should NOT be quantized - # HF formats: model.layers.X.mlp.gate.weight (Qwen) - # model.layers.X.block_sparse_moe.gate.weight (Mixtral) - if self._is_moe_router(hf_name): - return None - - # Extract layer number from HF name - layer_match = re.search(r"layers?\.(\d+)\.", hf_name) - if not layer_match: - # Not a layer parameter (e.g., embed_tokens, lm_head, norm) - # Check for direct matches - return self._find_non_layer_metadata(hf_name) - - layer_num = layer_match.group(1) - - # Determine the mcore module name based on HF name pattern - mcore_patterns = [] - - if "self_attn" in hf_name: - if any(proj in hf_name for proj in ["q_proj", "k_proj", "v_proj"]): - mcore_patterns.append(f"decoder.layers.{layer_num}.self_attention.linear_qkv.weight") - elif "o_proj" in hf_name: - mcore_patterns.append(f"decoder.layers.{layer_num}.self_attention.linear_proj.weight") - elif "mlp" in hf_name: - # Check for MoE experts first - # HF format: model.layers.X.mlp.experts.Y.gate_proj/up_proj/down_proj.weight - # HF Mixtral format: model.layers.X.block_sparse_moe.experts.Y.w1/w2/w3.weight - expert_match = re.search(r"\.experts\.(\d+)\.", hf_name) - if expert_match: - expert_id = expert_match.group(1) # This is the global expert ID in HF format - # MoE expert layers - use global expert ID for SequentialMLP - if any(proj in hf_name for proj in ["gate_proj", "up_proj", "w1", "w3"]): - # Try TEGroupedMLP pattern first (all experts share same linear layer) - mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.experts.linear_fc1.weight") - # Try SequentialMLP pattern with global expert index - mcore_patterns.append( - f"decoder.layers.{layer_num}.mlp.experts.local_experts.{expert_id}.linear_fc1.weight" - ) - elif any(proj in hf_name for proj in ["down_proj", "w2"]): - # Try TEGroupedMLP pattern first - mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.experts.linear_fc2.weight") - # Try SequentialMLP pattern with global expert index - mcore_patterns.append( - f"decoder.layers.{layer_num}.mlp.experts.local_experts.{expert_id}.linear_fc2.weight" - ) - # Check for shared_expert (Qwen2 MoE) - elif "shared_expert" in hf_name: - if any(proj in hf_name for proj in ["gate_proj", "up_proj"]): - mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.shared_experts.linear_fc1.weight") - elif "down_proj" in hf_name: - mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.shared_experts.linear_fc2.weight") - else: - # Dense MLP - if any(proj in hf_name for proj in ["gate_proj", "up_proj"]): - mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc1.weight") - elif "down_proj" in hf_name: - mcore_patterns.append(f"decoder.layers.{layer_num}.mlp.linear_fc2.weight") - - # Try to find matching metadata - for pattern in mcore_patterns: - if pattern in self.quant_metadata: - return self.quant_metadata[pattern] - - - return None - - def _is_moe_router(self, hf_name: str) -> bool: - """ - Check if the HF parameter name corresponds to a MoE router (gate). - - MoE router should NOT be quantized to maintain routing precision. - - Router naming patterns: - - Qwen/Qwen2/Qwen3 MoE: model.layers.X.mlp.gate.weight - - Mixtral: model.layers.X.block_sparse_moe.gate.weight - - Shared expert gate (Qwen2 MoE): model.layers.X.mlp.shared_expert_gate.weight - - Note: gate_proj is NOT the router, it's part of the MLP expert. - """ - - # Pattern 1: Qwen/Qwen3 MoE router - model.layers.X.mlp.gate.weight - # Must be exactly ".mlp.gate.weight" not ".mlp.gate_proj.weight" - if re.search(r"\.mlp\.gate\.weight$", hf_name): - return True - - # Pattern 2: Mixtral router - model.layers.X.block_sparse_moe.gate.weight - if re.search(r"\.block_sparse_moe\.gate\.weight$", hf_name): - return True - - # Pattern 3: Qwen2 MoE shared expert gate - model.layers.X.mlp.shared_expert_gate.weight - if re.search(r"\.mlp\.shared_expert_gate\.weight$", hf_name): - return True - - return False - - def _find_non_layer_metadata(self, hf_name: str) -> QuantizationMetadata | None: - """Find metadata for non-layer parameters (embed_tokens, lm_head, etc.).""" - # Map HF names to mcore names for non-layer parameters - name_mapping = { - "model.embed_tokens.weight": "embedding.word_embeddings.weight", - "lm_head.weight": "output_layer.weight", - "model.norm.weight": "decoder.final_layernorm.weight", - } - - mcore_name = name_mapping.get(hf_name) - if mcore_name and mcore_name in self.quant_metadata: - return self.quant_metadata[mcore_name] - - return None diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 452e76af9fe..e99b877b1d7 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -64,6 +64,7 @@ ) from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights +from verl.utils.modelopt import apply_qat from verl.utils.profiler import ( DistProfiler, DistProfilerExtension, @@ -73,7 +74,6 @@ simple_timer, ) from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max -from verl.utils.modelopt import apply_qat from verl.utils.ray_utils import get_event_loop from verl.utils.torch_functional import use_original_torch_compile from verl.workers.actor.megatron_actor import MegatronPPOActor @@ -224,14 +224,17 @@ def _init_hf_config_and_tf_config( qat_enabled = self.config.actor.get("qat", {}).get("enable", False) if qat_enabled: from megatron.bridge.models.gpt_provider import quantization_layer_spec + provider.transformer_layer_spec = quantization_layer_spec from verl.utils.modelopt.megatron_qat_patch import apply_qat_patch + apply_qat_patch() from megatron.bridge.models.conversion.param_mapping import AutoMapping - AutoMapping.register_module_type('QuantColumnParallelLinear', 'column') - AutoMapping.register_module_type('QuantRowParallelLinear', 'row') + + AutoMapping.register_module_type("QuantColumnParallelLinear", "column") + AutoMapping.register_module_type("QuantRowParallelLinear", "row") # Apply transformer config overrides for key, value in override_transformer_config.items(): @@ -737,14 +740,14 @@ async def rollout_mode(self): ) qat_config = self.config.actor.get("qat", {}) if qat_config.get("enable", False): - from verl.utils.modelopt import QATWeightPostProcessor + from verl.utils.modelopt import QATWeightExporter qat_mode = qat_config.get("mode", "w4a16") - qat_weight_post_processor = QATWeightPostProcessor( - self.actor.actor_module, qat_mode - ) - per_tensor_param = qat_weight_post_processor.process_weights_iterator(per_tensor_param) - + qat_weight_exporter = QATWeightExporter(self.actor.actor_module, qat_mode, bridge=self.bridge) + # qat_weight_exporter = QATWeightExporter( + # self.actor.actor_module, qat_mode + # ) + per_tensor_param = qat_weight_exporter.process_weights_iterator(per_tensor_param) if self.config.rollout.free_cache_engine: await self.rollout.resume(tags=["weights"]) diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index df3e74e11b8..d964aef5e1d 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -30,6 +30,7 @@ from verl.utils.vllm import TensorLoRARequest, VLLMHijack from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights + logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -323,15 +324,6 @@ def _update_weights(self, weights: list[tuple[str, torch.Tensor]], peft_config: logger.info("Loading standard weights (non-FP8, async)") self.model_runner.model.load_weights(weights) - if not getattr(self, '_is_modelopt_qat', False): - # Skip per-bucket process_weights_after_loading for modelopt QAT - # because the patched version is not idempotent (swizzle, etc.). - # It will be called once after all buckets in update_weights_from_ipc. - from vllm.model_executor.model_loader.utils import process_weights_after_loading - model_config = self.model_runner.vllm_config.model_config - device = next(self.model_runner.model.parameters()).device - process_weights_after_loading(self.model_runner.model, model_config, device) - def _get_zmq_handle(self) -> str: """Get ZMQ handle for communication.""" if not hasattr(self, "device_uuid") or not self.device_uuid: diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 302fb50068f..1571fc6e639 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -21,7 +21,6 @@ from typing import Any, Callable, Optional import numpy as np -from numpy.random import f import ray import vllm.entrypoints.cli.serve from packaging import version @@ -253,7 +252,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non apply_qat_patches() quantization = "compressed-tensors" - + else: raise ValueError(f"Unsupported quant_method: {quant_method}") logger.info(f"QAT quantization config injected (quant_method={quant_method})") diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 8a014c7ecdf..75efb81d892 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -198,7 +198,6 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None # transfer volume. # weight = weight.to(dtype, non_blocking=True) - # fill the tensor bucket if offset + weight.nbytes > bucket_size: get_torch_device().synchronize() From 407b8f016e53a0f31acda53da3ed2f2eb773c8a3 Mon Sep 17 00:00:00 2001 From: larkz-nv Date: Fri, 27 Feb 2026 02:23:08 -0800 Subject: [PATCH 10/10] Refactor QAT patch to skip quantizer params --- recipe | 2 +- verl/utils/modelopt/megatron_qat_patch.py | 141 +++++----------------- 2 files changed, 30 insertions(+), 113 deletions(-) diff --git a/recipe b/recipe index 3490a22a0a3..066d77a6333 160000 --- a/recipe +++ b/recipe @@ -1 +1 @@ -Subproject commit 3490a22a0a3adeb7e4787fe70b1060b642efbae4 +Subproject commit 066d77a6333c42f24df5e76d31bfeeda43795af4 diff --git a/verl/utils/modelopt/megatron_qat_patch.py b/verl/utils/modelopt/megatron_qat_patch.py index c5a4f52f5dc..5fdf1d727c1 100644 --- a/verl/utils/modelopt/megatron_qat_patch.py +++ b/verl/utils/modelopt/megatron_qat_patch.py @@ -23,7 +23,7 @@ import gc import logging import re -from typing import Iterable, Optional +from typing import Optional import torch @@ -327,126 +327,43 @@ def revert_local_name_to_global_patch(): logger.info("Reverted QAT patch: _megatron_local_name_to_global.") -def apply_build_conversion_tasks_patch(): - """Patch ``build_conversion_tasks`` to filter out ``None`` entries.""" - import itertools +def apply_skip_quantizer_params_patch(): + """Extend ``_is_adapter_param_name`` to also skip ModelOpt quantizer parameters. - import megatron.bridge.models.conversion.model_bridge as bridge_module - from megatron.bridge.models.conversion.model_bridge import ( - MegatronModelBridge, - WeightConversionTask, - ) - from megatron.bridge.models.conversion.utils import ( - get_module_and_param_from_name, - persistent_buffers, - ) - from megatron.bridge.utils.common_utils import print_rank_0 - from megatron.core import parallel_state - from megatron.core.utils import unwrap_model + After ``mtq.quantize()``, quantizer sub-modules (``weight_quantizer``, + ``input_quantizer``) are registered in the model tree. Their internal + parameters (e.g. ``_amax``) have no HF counterpart and must not enter + the Bridge's conversion pipeline. + """ + from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge - if getattr(MegatronModelBridge, "_build_tasks_patched", False): + if getattr(MegatronModelBridge, "_quantizer_filter_patched", False): return - MegatronModelBridge._build_tasks_patched = True - MegatronModelBridge._original_build_conversion_tasks = MegatronModelBridge.build_conversion_tasks - - def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model): - if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): - raise ValueError("hf_pretrained.state.source is required for weight ordering") - - hf_keys: Iterable[str] = hf_pretrained.state.source.get_all_keys() - - mapping_registry = self.mapping_registry() - unwrapped_model = unwrap_model(megatron_model)[0] - model_config = unwrapped_model.config - embeddings_are_tied = self._share_embeddings_and_output_weights(model_config, unwrapped_model) - pp_rank = parallel_state.get_pipeline_model_parallel_rank() - sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks(megatron_model) - - if embeddings_are_tied: - sorted_global_param_names_all_pp_ranks = [ - name for name in sorted_global_param_names_all_pp_ranks if "output_layer" not in name - ] + MegatronModelBridge._quantizer_filter_patched = True + MegatronModelBridge._original_is_adapter_param_name = MegatronModelBridge._is_adapter_param_name - global_names_index_dict = {name: idx for idx, name in enumerate(sorted_global_param_names_all_pp_ranks)} + _orig = MegatronModelBridge._is_adapter_param_name - tasks = [None] * len(sorted_global_param_names_all_pp_ranks) - for vp_stage, model in enumerate(megatron_model): - for local_name, _ in itertools.chain(model.named_parameters(), persistent_buffers(model)): - if "_extra_state" in local_name or self._is_adapter_param_name(local_name): - continue + def _patched_is_adapter_param_name(self, param_name: str) -> bool: + if _orig(self, param_name): + return True + return "_quantizer" in param_name - local_name = self._unwrap_name(local_name) - global_name = bridge_module._megatron_local_name_to_global( - megatron_model, model_config, local_name, vp_stage - ) - if global_name not in global_names_index_dict: - print_rank_0(f"WARNING: {global_name} not in global_names_index_dict") - continue - global_name_idx = global_names_index_dict[global_name] - mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name)) - - if not mapping: - logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}") - continue - - if not mapping.allow_hf_name_mismatch: - if isinstance(mapping.hf_param, str): - if mapping.hf_param not in hf_keys: - logger.warning(f"WARNING: Can't find {mapping.hf_param} in hf_keys") - continue - else: - missing_params = [hf_param for hf_param in mapping.hf_param.values() if hf_param not in hf_keys] - if missing_params: - logger.warning( - f"WARNING: Can't find the following HF parameters in hf_keys: {missing_params}" - ) - continue - - local_module, local_weights = get_module_and_param_from_name(megatron_model, local_name, vp_stage) - if local_module is not None and not hasattr(local_module, "config"): - local_module.config = model_config - - tasks[global_name_idx] = WeightConversionTask( - pp_rank=pp_rank, - vp_stage=vp_stage, - param_name=local_name, - global_param_name=global_name, - megatron_module=local_module, - param_weight=local_weights, - mapping=mapping, - ) - - for idx, global_name in enumerate(sorted_global_param_names_all_pp_ranks): - if tasks[idx] is None: - mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name)) - if mapping is None: - continue - tasks[idx] = WeightConversionTask( - pp_rank=pp_rank, - vp_stage=None, - param_name=global_name, - global_param_name=global_name, - megatron_module=None, - param_weight=None, - mapping=mapping, - ) - - tasks = [task for task in tasks if task is not None] - return tasks - - MegatronModelBridge.build_conversion_tasks = _patched_build_conversion_tasks - logger.info("Applied QAT patch: MegatronModelBridge.build_conversion_tasks now filters out None entries.") + MegatronModelBridge._is_adapter_param_name = _patched_is_adapter_param_name + logger.info( + "Applied QAT patch: _is_adapter_param_name now also skips ModelOpt quantizer parameters (*_quantizer*)." + ) -def revert_build_conversion_tasks_patch(): - """Revert :func:`apply_build_conversion_tasks_patch`.""" +def revert_skip_quantizer_params_patch(): + """Revert :func:`apply_skip_quantizer_params_patch`.""" from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge - if not getattr(MegatronModelBridge, "_build_tasks_patched", False): + if not getattr(MegatronModelBridge, "_quantizer_filter_patched", False): return - MegatronModelBridge.build_conversion_tasks = MegatronModelBridge._original_build_conversion_tasks - MegatronModelBridge._build_tasks_patched = False - logger.info("Reverted QAT patch: MegatronModelBridge.build_conversion_tasks.") + MegatronModelBridge._is_adapter_param_name = MegatronModelBridge._original_is_adapter_param_name + MegatronModelBridge._quantizer_filter_patched = False + logger.info("Reverted QAT patch: _is_adapter_param_name (quantizer filter).") def apply_detect_parallelism_type_patch(): @@ -493,7 +410,7 @@ def apply_qat_patch(): apply_ep_gather_patch() apply_extract_sort_key_patch() apply_local_name_to_global_patch() - apply_build_conversion_tasks_patch() + apply_skip_quantizer_params_patch() apply_detect_parallelism_type_patch() @@ -503,5 +420,5 @@ def revert_qat_patch(): revert_ep_gather_patch() revert_extract_sort_key_patch() revert_local_name_to_global_patch() - revert_build_conversion_tasks_patch() + revert_skip_quantizer_params_patch() revert_detect_parallelism_type_patch()