Skip to content

Commit 6e815f8

Browse files
committed
Refactor modelopt utils and unify QAT config under actor
1 parent 964ccac commit 6e815f8

File tree

12 files changed

+221
-230
lines changed

12 files changed

+221
-230
lines changed

verl/trainer/config/actor/actor.yaml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,35 @@ router_replay:
259259
# Required when mode is 'replay'
260260
replay_file: null
261261

262+
# QAT (Quantization-Aware Training) configuration
263+
# When enabled:
264+
# - QAT is automatically applied to actor model during training
265+
# - Fused scales (QKV/GateUp) are automatically enabled for training-inference consistency
266+
# - Fast quantization is used when syncing weights to vLLM rollout
267+
# Supported modes: "w4a16" (NVFP4 weight-only)
268+
# Note: "w4a4" mode is included in the code but currently has KL divergence issues and is NOT recommended for use.
269+
# For usage examples, see: https://github.com/verl-project/verl-recipe/blob/main/qat/README.md
270+
qat:
271+
272+
# Whether to enable QAT
273+
enable: false
274+
275+
# Quantization mode: "w4a16" (weight-only). "w4a4" is experimental and not recommended.
276+
mode: "w4a16"
277+
278+
# Quantization group size (NVFP4 requires 16)
279+
group_size: 16
280+
281+
# Patterns to ignore (e.g., lm_head, embed_tokens)
282+
ignore_patterns:
283+
284+
- "lm_head"
285+
- "embed_tokens"
286+
- "re:.*mlp.gate$"
287+
288+
# Activation observer for W4A4 mode: "static_minmax", "memoryless_minmax", or "minmax"
289+
activation_observer: "static_minmax"
290+
291+
# Path to vLLM quantization config JSON file
292+
quantization_config_path: null
293+

verl/trainer/config/actor/dp_actor.yaml

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -48,35 +48,3 @@ calculate_sum_pi_squared: False
4848

4949
# Enable gradient checkpointing for sum_pi_squared computation (saves memory)
5050
sum_pi_squared_checkpointing: False
51-
52-
# QAT (Quantization-Aware Training) configuration
53-
# When enabled:
54-
# - QAT is automatically applied to actor model during training
55-
# - Fused scales (QKV/GateUp) are automatically enabled for training-inference consistency
56-
# - Fast quantization is used when syncing weights to vLLM rollout
57-
# Supported modes: "w4a16" (NVFP4 weight-only)
58-
# Note: "w4a4" mode is included in the code but currently has KL divergence issues and is NOT recommended for use.
59-
# For usage examples, see: https://github.com/verl-project/verl-recipe/blob/main/qat/README.md
60-
qat:
61-
62-
# Whether to enable QAT
63-
enable: false
64-
65-
# Quantization mode: "w4a16" (weight-only). "w4a4" is experimental and not recommended.
66-
mode: "w4a16"
67-
68-
# Quantization group size (NVFP4 requires 16)
69-
group_size: 16
70-
71-
# Patterns to ignore (e.g., lm_head, embed_tokens)
72-
ignore_patterns:
73-
74-
- "lm_head"
75-
- "embed_tokens"
76-
- "re:.*mlp.gate$"
77-
78-
# Activation observer for W4A4 mode: "static_minmax", "memoryless_minmax", or "minmax"
79-
activation_observer: "static_minmax"
80-
81-
# Path to vLLM quantization config JSON file
82-
quantization_config_path: null

verl/trainer/config/engine/megatron.yaml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,6 @@ override_transformer_config:
7979
# Attention backend to use (flash,fused,unfused,local,auto). Defaults to auto in mcore, flash in verl
8080
attention_backend: flash
8181

82-
# # Quantization method. None for no quantization, "nvfp4" for NVFP4 quantization
83-
quantization: null
84-
85-
# Whether to enable Quantization-Aware Training (QAT). Default False.
86-
enable_qat: False
87-
8882
override_mcore_model_config: {}
8983

9084
# oc.select: default val for ref.megatron.use_mbridge

verl/utils/modelopt/__init__.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
ModelOpt integration for verl.
18+
19+
Supports NVFP4 quantization with Megatron QAT training + vLLM low-precision inference.
20+
21+
Module Structure:
22+
- qat.py: QAT quantization config, apply_qat, QuantizationMetadata
23+
- weight_processor.py: QATWeightPostProcessor for converting QAT weights to quantized format
24+
- vllm_patch.py: vLLM monkey patches for NVFP4 inference (Linear, MoE, KV Cache)
25+
26+
Usage:
27+
# Training side
28+
from verl.utils.modelopt import apply_qat, QATWeightPostProcessor
29+
30+
# Inference side
31+
from verl.utils.modelopt import apply_vllm_modelopt_patches
32+
"""
33+
34+
from verl.utils.modelopt.qat import NVFP4_WEIGHT_ONLY_CFG, QuantizationMetadata, apply_qat
35+
from verl.utils.modelopt.vllm_patch import apply_vllm_modelopt_patches
36+
from verl.utils.modelopt.weight_processor import QATWeightPostProcessor
37+
38+
__all__ = [
39+
"NVFP4_WEIGHT_ONLY_CFG",
40+
"apply_qat",
41+
"QuantizationMetadata",
42+
"QATWeightPostProcessor",
43+
"apply_vllm_modelopt_patches",
44+
]

verl/utils/modelopt/qat.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
from dataclasses import dataclass
18+
from typing import Any, Optional
19+
20+
import torch
21+
import torch.nn as nn
22+
23+
import modelopt.torch.quantization as mtq
24+
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg
25+
26+
# ---------------------------------------------------------------------------
27+
# NVFP4 quantization config
28+
# ---------------------------------------------------------------------------
29+
30+
NVFP4_WEIGHT_ONLY_CFG = {
31+
"quant_cfg": {
32+
"*weight_quantizer": {
33+
"num_bits": (2, 1),
34+
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
35+
"axis": None,
36+
"enable": True,
37+
},
38+
"*input_quantizer": {"enable": False},
39+
**_default_disabled_quantizer_cfg,
40+
},
41+
"algorithm": "max",
42+
}
43+
44+
# ---------------------------------------------------------------------------
45+
# QAT application
46+
# ---------------------------------------------------------------------------
47+
48+
49+
def apply_qat(model: nn.Module, qat_mode: str):
50+
"""Apply Quantization-Aware Training to the model.
51+
52+
Args:
53+
model: The Megatron model to apply QAT to.
54+
qat_mode: QAT mode, now only support "w4a16" for weight-only quantization.
55+
56+
Returns:
57+
The quantized model.
58+
"""
59+
if qat_mode != "w4a16":
60+
raise ValueError(f"Only 'w4a16' is supported, got: {qat_mode}")
61+
62+
mtq.quantize(model, NVFP4_WEIGHT_ONLY_CFG)
63+
return model
64+
65+
66+
@dataclass
67+
class QuantizationMetadata:
68+
"""Metadata for a quantized module."""
69+
70+
qformat: str
71+
weight_quantizer: Any
72+
input_quantizer: Any
73+
module: torch.nn.Module
74+
vpp_idx: int
75+
block_size: int = 16 # Default NVFP4 block size
76+
# Fields for EP synchronization - store amax values for non-local experts
77+
weight_amax: Optional[torch.Tensor] = None
78+
input_amax: Optional[torch.Tensor] = None
79+
is_local: bool = True # Whether this expert is local to current EP rank
80+
global_expert_idx: Optional[int] = None # Global expert index for MoE experts
81+
local_expert_idx: Optional[int] = None # Local expert index on this EP rank
Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -23,72 +23,6 @@
2323
from torch.nn import Parameter
2424

2525

26-
def generate_nvfp4_ignore_list(num_layers: int, is_moe: bool) -> list[str]:
27-
"""
28-
Generate the ignore list for NVFP4 quantization based on model configuration.
29-
30-
Args:
31-
num_layers: Number of hidden layers in the model (from hf_config.num_hidden_layers)
32-
is_moe: Whether the model is a Mixture of Experts model
33-
34-
Returns:
35-
List of layer names to ignore during quantization
36-
"""
37-
ignore_list = []
38-
39-
# For MoE models, ignore the gate layers (routing layers)
40-
if is_moe:
41-
for layer_idx in range(num_layers):
42-
ignore_list.append(f"model.layers.{layer_idx}.mlp.gate")
43-
44-
# Always ignore lm_head for stability
45-
ignore_list.append("lm_head")
46-
47-
return ignore_list
48-
49-
50-
def get_nvfp4_block_quant_kwargs(num_layers: int, is_moe: bool) -> dict:
51-
"""
52-
Generate complete NVFP4 quantization configuration based on model properties.
53-
Args:
54-
num_layers: Number of hidden layers in the model (from hf_config.num_hidden_layers)
55-
is_moe: Whether the model is a Mixture of Experts model
56-
57-
Returns:
58-
Complete quantization configuration dictionary compatible with ModelOpt
59-
"""
60-
ignore_list = generate_nvfp4_ignore_list(num_layers, is_moe)
61-
62-
return {
63-
"config_groups": {
64-
"group_0": {
65-
"input_activations": {
66-
"dynamic": "false",
67-
"num_bits": 4,
68-
"type": "float",
69-
"group_size": 16
70-
},
71-
"weights": {
72-
"dynamic": "false",
73-
"num_bits": 4,
74-
"type": "float",
75-
"group_size": 16
76-
},
77-
"targets": [
78-
"Linear"
79-
]
80-
}
81-
},
82-
"ignore": ignore_list,
83-
"quant_algo": "NVFP4",
84-
"producer": {
85-
"name": "modelopt",
86-
},
87-
"quant_method": "modelopt"
88-
}
89-
90-
91-
9226
def _create_param_from_subclass_attributes(custom_data: torch.Tensor, custom_weight) -> Parameter:
9327
"""
9428
Helper to preserve custom attributes from ModelWeightParameter and
@@ -838,4 +772,4 @@ def apply_vllm_modelopt_patches():
838772
# Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates
839773
func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading"
840774
patcher5 = patch(func5_path, process_weights_after_loading_kv)
841-
patcher5.start()
775+
patcher5.start()

0 commit comments

Comments
 (0)