Skip to content

Commit fe83dad

Browse files
committed
Refactor QAT weight exporter
1 parent 7a1ebfa commit fe83dad

File tree

13 files changed

+510
-1161
lines changed

13 files changed

+510
-1161
lines changed

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,16 @@ actor_rollout_ref:
139139
mode: disabled
140140
record_file: null
141141
replay_file: null
142+
qat:
143+
enable: false
144+
mode: w4a16
145+
group_size: 16
146+
ignore_patterns:
147+
- lm_head
148+
- embed_tokens
149+
- re:.*mlp.gate$
150+
activation_observer: static_minmax
151+
quantization_config_path: null
142152
load_weight: true
143153
ref:
144154
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,6 @@ actor_rollout_ref:
120120
mode: disabled
121121
record_file: null
122122
replay_file: null
123-
grad_clip: 1.0
124-
ulysses_sequence_parallel_size: 1
125-
entropy_from_logits_with_chunking: false
126-
entropy_checkpointing: false
127-
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
128-
calculate_sum_pi_squared: false
129-
sum_pi_squared_checkpointing: false
130123
qat:
131124
enable: false
132125
mode: w4a16
@@ -137,6 +130,13 @@ actor_rollout_ref:
137130
- re:.*mlp.gate$
138131
activation_observer: static_minmax
139132
quantization_config_path: null
133+
grad_clip: 1.0
134+
ulysses_sequence_parallel_size: 1
135+
entropy_from_logits_with_chunking: false
136+
entropy_checkpointing: false
137+
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
138+
calculate_sum_pi_squared: false
139+
sum_pi_squared_checkpointing: false
140140
ref:
141141
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
142142
strategy: ${actor_rollout_ref.actor.strategy}

verl/trainer/config/_generated_ppo_veomni_trainer.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ actor_rollout_ref:
120120
mode: disabled
121121
record_file: null
122122
replay_file: null
123+
qat:
124+
enable: false
125+
mode: w4a16
126+
group_size: 16
127+
ignore_patterns:
128+
- lm_head
129+
- embed_tokens
130+
- re:.*mlp.gate$
131+
activation_observer: static_minmax
132+
quantization_config_path: null
123133
ref:
124134
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
125135
strategy: veomni

verl/utils/modelopt/__init__.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515

1616
"""ModelOpt integration for NVFP4 quantization with Megatron QAT training and vLLM inference."""
1717

18+
from verl.utils.modelopt.megatron_qat_patch import (
19+
apply_qat_patch,
20+
revert_qat_patch,
21+
)
22+
from verl.utils.modelopt.qat_weight_exporter import QATWeightExporter
1823
from verl.utils.modelopt.quantize import (
19-
QuantizationMetadata,
2024
apply_qat,
2125
build_quantize_config,
2226
)
@@ -25,18 +29,11 @@
2529
modelopt_process_weights_after_loading,
2630
prepare_modelopt_for_weight_reload,
2731
)
28-
from verl.utils.modelopt.weight_processor import QATWeightPostProcessor
29-
from verl.utils.modelopt.megatron_qat_patch import (
30-
apply_qat_patch,
31-
revert_qat_patch,
32-
)
33-
3432

3533
__all__ = [
3634
"build_quantize_config",
3735
"apply_qat",
38-
"QuantizationMetadata",
39-
"QATWeightPostProcessor",
36+
"QATWeightExporter",
4037
"apply_modelopt_nvfp4_patches",
4138
"prepare_modelopt_for_weight_reload",
4239
"modelopt_process_weights_after_loading",

verl/utils/modelopt/megatron_qat_patch.py

Lines changed: 38 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,14 @@ def apply_swiglu_sharded_factory_patch():
4444
mlp_module._swiglu_patched = True
4545
mlp_module._original_apply_swiglu_sharded_factory = mlp_module.apply_swiglu_sharded_factory
4646

47-
def patched_apply_swiglu_sharded_factory(
48-
original_sh_ten, sharded_offsets, singleton_local_shards: bool = False
49-
):
47+
def patched_apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets, singleton_local_shards: bool = False):
5048
swiglu_shard_axis = 0
5149
prepend_axis_num = len(sharded_offsets)
5250
original_shape = original_sh_ten.local_shape
5351
local_axis_size = original_shape[swiglu_shard_axis]
54-
assert (
55-
original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num]
56-
% local_axis_size
57-
== 0
58-
)
59-
rank_offset = (
60-
original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num]
61-
// local_axis_size
62-
)
63-
axis_frag = original_sh_ten.axis_fragmentations[
64-
swiglu_shard_axis + prepend_axis_num
65-
]
52+
assert original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] % local_axis_size == 0
53+
rank_offset = original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] // local_axis_size
54+
axis_frag = original_sh_ten.axis_fragmentations[swiglu_shard_axis + prepend_axis_num]
6655

6756
@torch.no_grad()
6857
def sh_ten_build_fn(
@@ -89,12 +78,20 @@ def sh_ten_build_fn(
8978
tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis)
9079
return [
9180
ShardedTensor.from_rank_offsets(
92-
w_key, tensor_w, *sharded_offsets, offset_w,
93-
replica_id=replica_id, prepend_axis_num=prepend_axis_num,
81+
w_key,
82+
tensor_w,
83+
*sharded_offsets,
84+
offset_w,
85+
replica_id=replica_id,
86+
prepend_axis_num=prepend_axis_num,
9487
),
9588
ShardedTensor.from_rank_offsets(
96-
v_key, tensor_v, *sharded_offsets, offset_v,
97-
replica_id=replica_id, prepend_axis_num=prepend_axis_num,
89+
v_key,
90+
tensor_v,
91+
*sharded_offsets,
92+
offset_v,
93+
replica_id=replica_id,
94+
prepend_axis_num=prepend_axis_num,
9895
),
9996
]
10097

@@ -104,7 +101,8 @@ def sh_ten_merge_fn(sub_state_dict):
104101
return torch.cat(sub_state_dict)
105102
except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
106103
logger.warning(
107-
"CUDA OOM during tensor merge – falling back to CPU. (Error: %s)", e,
104+
"CUDA OOM during tensor merge – falling back to CPU. (Error: %s)",
105+
e,
108106
)
109107
merged = torch.cat([t.cpu() for t in sub_state_dict])
110108
gc.collect()
@@ -156,9 +154,7 @@ def _patched_gather_from_ep_ranks(
156154
model_config = self._get_config(megatron_module)
157155
num_experts = model_config.num_moe_experts
158156
num_experts_per_rank = num_experts // self.ep_size
159-
num_experts_per_rank = self.broadcast_obj_from_pp_rank(
160-
num_experts_per_rank, "num_experts_per_rank"
161-
)
157+
num_experts_per_rank = self.broadcast_obj_from_pp_rank(num_experts_per_rank, "num_experts_per_rank")
162158

163159
local_expert_number = None
164160

@@ -212,10 +208,7 @@ def _patched_gather_from_ep_ranks(
212208
return weights_dict
213209

214210
MegatronParamMapping.gather_from_ep_ranks = _patched_gather_from_ep_ranks
215-
logger.info(
216-
"Applied QAT patch: MegatronParamMapping.gather_from_ep_ranks "
217-
"now supports SequentialMLP pattern."
218-
)
211+
logger.info("Applied QAT patch: MegatronParamMapping.gather_from_ep_ranks now supports SequentialMLP pattern.")
219212

220213

221214
def revert_ep_gather_patch():
@@ -231,8 +224,8 @@ def revert_ep_gather_patch():
231224

232225
def apply_extract_sort_key_patch():
233226
"""Patch ``extract_sort_key`` to support SequentialMLP naming pattern."""
234-
import megatron.bridge.models.conversion.utils as utils_module
235227
import megatron.bridge.models.conversion.model_bridge as bridge_module
228+
import megatron.bridge.models.conversion.utils as utils_module
236229

237230
if getattr(utils_module, "_sort_key_patched", False):
238231
return
@@ -270,15 +263,13 @@ def _patched_extract_sort_key(param_name: str):
270263

271264
utils_module.extract_sort_key = _patched_extract_sort_key
272265
bridge_module.extract_sort_key = _patched_extract_sort_key
273-
logger.info(
274-
"Applied QAT patch: extract_sort_key now supports SequentialMLP pattern."
275-
)
266+
logger.info("Applied QAT patch: extract_sort_key now supports SequentialMLP pattern.")
276267

277268

278269
def revert_extract_sort_key_patch():
279270
"""Revert :func:`apply_extract_sort_key_patch`."""
280-
import megatron.bridge.models.conversion.utils as utils_module
281271
import megatron.bridge.models.conversion.model_bridge as bridge_module
272+
import megatron.bridge.models.conversion.utils as utils_module
282273

283274
if not getattr(utils_module, "_sort_key_patched", False):
284275
return
@@ -307,11 +298,7 @@ def _patched_megatron_local_name_to_global(models, config, param_name, vp_stage=
307298
param_name = _orig_fn(models, config, param_name, vp_stage)
308299

309300
ep_group = parallel_state.get_expert_model_parallel_group()
310-
if (
311-
".mlp.experts.local_experts." in param_name
312-
and get_pg_size(ep_group) > 1
313-
and ".adapter." not in param_name
314-
):
301+
if ".mlp.experts.local_experts." in param_name and get_pg_size(ep_group) > 1 and ".adapter." not in param_name:
315302
num_experts = config.num_moe_experts
316303
num_experts_per_rank = num_experts // ep_group.size()
317304
local_experts_match = re.search(r"\.local_experts\.(\d+)\.", param_name)
@@ -326,10 +313,7 @@ def _patched_megatron_local_name_to_global(models, config, param_name, vp_stage=
326313
return param_name
327314

328315
bridge_module._megatron_local_name_to_global = _patched_megatron_local_name_to_global
329-
logger.info(
330-
"Applied QAT patch: _megatron_local_name_to_global "
331-
"now supports SequentialMLP pattern."
332-
)
316+
logger.info("Applied QAT patch: _megatron_local_name_to_global now supports SequentialMLP pattern.")
333317

334318

335319
def revert_local_name_to_global_patch():
@@ -363,9 +347,7 @@ def apply_build_conversion_tasks_patch():
363347
if getattr(MegatronModelBridge, "_build_tasks_patched", False):
364348
return
365349
MegatronModelBridge._build_tasks_patched = True
366-
MegatronModelBridge._original_build_conversion_tasks = (
367-
MegatronModelBridge.build_conversion_tasks
368-
)
350+
MegatronModelBridge._original_build_conversion_tasks = MegatronModelBridge.build_conversion_tasks
369351

370352
def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model):
371353
if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")):
@@ -378,24 +360,18 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model):
378360
model_config = unwrapped_model.config
379361
embeddings_are_tied = self._share_embeddings_and_output_weights(model_config, unwrapped_model)
380362
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
381-
sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks(
382-
megatron_model
383-
)
363+
sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks(megatron_model)
384364

385365
if embeddings_are_tied:
386366
sorted_global_param_names_all_pp_ranks = [
387367
name for name in sorted_global_param_names_all_pp_ranks if "output_layer" not in name
388368
]
389369

390-
global_names_index_dict = {
391-
name: idx for idx, name in enumerate(sorted_global_param_names_all_pp_ranks)
392-
}
370+
global_names_index_dict = {name: idx for idx, name in enumerate(sorted_global_param_names_all_pp_ranks)}
393371

394372
tasks = [None] * len(sorted_global_param_names_all_pp_ranks)
395373
for vp_stage, model in enumerate(megatron_model):
396-
for local_name, _ in itertools.chain(
397-
model.named_parameters(), persistent_buffers(model)
398-
):
374+
for local_name, _ in itertools.chain(model.named_parameters(), persistent_buffers(model)):
399375
if "_extra_state" in local_name or self._is_adapter_param_name(local_name):
400376
continue
401377

@@ -407,9 +383,7 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model):
407383
print_rank_0(f"WARNING: {global_name} not in global_names_index_dict")
408384
continue
409385
global_name_idx = global_names_index_dict[global_name]
410-
mapping = mapping_registry.megatron_to_hf_lookup(
411-
self._get_lora_unwrapped_name(global_name)
412-
)
386+
mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name))
413387

414388
if not mapping:
415389
logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}")
@@ -421,23 +395,16 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model):
421395
logger.warning(f"WARNING: Can't find {mapping.hf_param} in hf_keys")
422396
continue
423397
else:
424-
missing_params = [
425-
hf_param
426-
for hf_param in mapping.hf_param.values()
427-
if hf_param not in hf_keys
428-
]
398+
missing_params = [hf_param for hf_param in mapping.hf_param.values() if hf_param not in hf_keys]
429399
if missing_params:
430400
logger.warning(
431-
f"WARNING: Can't find the following HF parameters in hf_keys: "
432-
f"{missing_params}"
401+
f"WARNING: Can't find the following HF parameters in hf_keys: {missing_params}"
433402
)
434403
continue
435404

436-
local_module, local_weights = get_module_and_param_from_name(
437-
megatron_model, local_name, vp_stage
438-
)
405+
local_module, local_weights = get_module_and_param_from_name(megatron_model, local_name, vp_stage)
439406
if local_module is not None and not hasattr(local_module, "config"):
440-
setattr(local_module, "config", model_config)
407+
local_module.config = model_config
441408

442409
tasks[global_name_idx] = WeightConversionTask(
443410
pp_rank=pp_rank,
@@ -451,9 +418,7 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model):
451418

452419
for idx, global_name in enumerate(sorted_global_param_names_all_pp_ranks):
453420
if tasks[idx] is None:
454-
mapping = mapping_registry.megatron_to_hf_lookup(
455-
self._get_lora_unwrapped_name(global_name)
456-
)
421+
mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name))
457422
if mapping is None:
458423
continue
459424
tasks[idx] = WeightConversionTask(
@@ -470,10 +435,7 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model):
470435
return tasks
471436

472437
MegatronModelBridge.build_conversion_tasks = _patched_build_conversion_tasks
473-
logger.info(
474-
"Applied QAT patch: MegatronModelBridge.build_conversion_tasks "
475-
"now filters out None entries."
476-
)
438+
logger.info("Applied QAT patch: MegatronModelBridge.build_conversion_tasks now filters out None entries.")
477439

478440

479441
def revert_build_conversion_tasks_patch():
@@ -482,9 +444,7 @@ def revert_build_conversion_tasks_patch():
482444

483445
if not getattr(MegatronModelBridge, "_build_tasks_patched", False):
484446
return
485-
MegatronModelBridge.build_conversion_tasks = (
486-
MegatronModelBridge._original_build_conversion_tasks
487-
)
447+
MegatronModelBridge.build_conversion_tasks = MegatronModelBridge._original_build_conversion_tasks
488448
MegatronModelBridge._build_tasks_patched = False
489449
logger.info("Reverted QAT patch: MegatronModelBridge.build_conversion_tasks.")
490450

@@ -503,8 +463,7 @@ def _patched_detect_parallelism_type(self, module):
503463
module_type = type(module).__name__
504464
if "LayerNormColumnParallelLinear" in module_type:
505465
if self.megatron_param and (
506-
self.megatron_param.endswith("layer_norm_weight")
507-
or self.megatron_param.endswith("layer_norm_bias")
466+
self.megatron_param.endswith("layer_norm_weight") or self.megatron_param.endswith("layer_norm_bias")
508467
):
509468
return "replicated"
510469
return "column"

0 commit comments

Comments
 (0)