Skip to content

Commit 0442d6f

Browse files
committed
feat(qat): support QAT in FSDPEngine for the new unified engine_workers architecture
1 parent 6f4942b commit 0442d6f

File tree

9 files changed

+183
-9
lines changed

9 files changed

+183
-9
lines changed

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ actor_rollout_ref:
324324
quantization: null
325325
quantization_config_file: null
326326
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
327-
qat: ${oc.select:actor_rollout_ref.actor.qat,null}
327+
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,null}
328328
layer_name_map:
329329
qkv_layer_name: qkv
330330
gate_proj_layer_name: gate_up

verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ actor_rollout_ref:
313313
quantization: null
314314
quantization_config_file: null
315315
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
316-
qat: ${oc.select:actor_rollout_ref.actor.qat,null}
316+
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,null}
317317
layered_summon: false
318318
model:
319319
_target_: verl.workers.config.HFModelConfig

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,17 @@ actor_rollout_ref:
4545
forward_only: false
4646
strategy: fsdp
4747
dtype: bfloat16
48+
qat:
49+
_target_: verl.workers.config.QATEngineConfig
50+
enable: false
51+
mode: w4a16
52+
group_size: 16
53+
ignore_patterns:
54+
- lm_head
55+
- embed_tokens
56+
- re:.*mlp.gate$
57+
activation_observer: static_minmax
58+
quantization_config_path: null
4859
_target_: verl.workers.config.FSDPActorConfig
4960
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
5061
strategy: fsdp
@@ -196,6 +207,17 @@ actor_rollout_ref:
196207
forward_only: true
197208
strategy: fsdp
198209
dtype: bfloat16
210+
qat:
211+
_target_: verl.workers.config.QATEngineConfig
212+
enable: false
213+
mode: w4a16
214+
group_size: 16
215+
ignore_patterns:
216+
- lm_head
217+
- embed_tokens
218+
- re:.*mlp.gate$
219+
activation_observer: static_minmax
220+
quantization_config_path: null
199221
_target_: verl.workers.config.FSDPActorConfig
200222
ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}
201223
entropy_from_logits_with_chunking: false
@@ -312,7 +334,7 @@ actor_rollout_ref:
312334
quantization: null
313335
quantization_config_file: null
314336
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
315-
qat: ${oc.select:actor_rollout_ref.actor.qat,null}
337+
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,null}
316338
layered_summon: false
317339
model:
318340
_target_: verl.workers.config.HFModelConfig
@@ -436,6 +458,17 @@ critic:
436458
forward_only: false
437459
strategy: fsdp
438460
dtype: bfloat16
461+
qat:
462+
_target_: verl.workers.config.QATEngineConfig
463+
enable: false
464+
mode: w4a16
465+
group_size: 16
466+
ignore_patterns:
467+
- lm_head
468+
- embed_tokens
469+
- re:.*mlp.gate$
470+
activation_observer: static_minmax
471+
quantization_config_path: null
439472
path: ~/models/deepseek-llm-7b-chat
440473
tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"}
441474
override_config: {}

verl/trainer/config/_generated_ppo_veomni_trainer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ actor_rollout_ref:
294294
quantization: null
295295
quantization_config_file: null
296296
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
297-
qat: ${oc.select:actor_rollout_ref.actor.qat,null}
297+
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,null}
298298
layered_summon: false
299299
model:
300300
_target_: verl.workers.config.HFModelConfig

verl/trainer/config/engine/fsdp.yaml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,31 @@ strategy: fsdp
6161

6262
# Mixed precision training param dtype
6363
dtype: bfloat16 # ["bfloat16", "float16"]
64+
65+
# QAT (Quantization-Aware Training) configuration
66+
qat:
67+
68+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
69+
_target_: verl.workers.config.QATEngineConfig
70+
71+
# Whether to enable QAT
72+
enable: false
73+
74+
# Quantization mode: "w4a16" (weight-only). "w4a4" is experimental and not recommended.
75+
mode: "w4a16"
76+
77+
# Quantization group size (NVFP4 requires 16)
78+
group_size: 16
79+
80+
# Patterns to ignore (e.g., lm_head, embed_tokens)
81+
ignore_patterns:
82+
83+
- "lm_head"
84+
- "embed_tokens"
85+
- "re:.*mlp.gate$"
86+
87+
# Activation observer for W4A4 mode
88+
activation_observer: "static_minmax"
89+
90+
# Path to vLLM quantization config JSON file
91+
quantization_config_path: null

verl/trainer/config/rollout/rollout.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,5 +387,5 @@ quantization_config_file: null
387387
# MTP configuration, reuse model configuration
388388
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}
389389

390-
# QAT configuration (inherited from actor.qat)
391-
qat: ${oc.select:actor_rollout_ref.actor.qat,null}
390+
# QAT configuration (inherited from actor's engine config)
391+
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,null}

verl/workers/config/engine.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"VeOmniEngineConfig",
3232
"EngineConfig",
3333
"EngineRouterReplayConfig",
34+
"QATEngineConfig",
3435
]
3536

3637

@@ -177,6 +178,27 @@ def __post_init__(self) -> None:
177178
self.sequence_parallel = False
178179

179180

181+
@dataclass
182+
class QATEngineConfig(BaseConfig):
183+
"""Configuration for QAT (Quantization-Aware Training) within an engine.
184+
185+
Args:
186+
enable (bool): Whether to enable QAT, default False
187+
mode (str): Quantization mode, "w4a16" or "w4a4", default "w4a16"
188+
group_size (int): Group size for blockwise quantization, default 16
189+
ignore_patterns (list[str]): Module name patterns to exclude from quantization
190+
activation_observer (str): Observer strategy for activation global_scale (W4A4 only)
191+
quantization_config_path (Optional[str]): Path to quantization config JSON for vLLM
192+
"""
193+
194+
enable: bool = False
195+
mode: str = "w4a16"
196+
group_size: int = 16
197+
ignore_patterns: list[str] = field(default_factory=lambda: ["lm_head", "embed_tokens", "re:.*mlp.gate$"])
198+
activation_observer: str = "static_minmax"
199+
quantization_config_path: Optional[str] = None
200+
201+
180202
@dataclass
181203
class FSDPEngineConfig(EngineConfig):
182204
"""Configuration for FSDP (Fully Sharded Data Parallel).
@@ -199,6 +221,7 @@ class FSDPEngineConfig(EngineConfig):
199221
debugging.
200222
mixed_precision (Optional[dict[str, Any]]): Mixed precision configuration for FSDP, default None
201223
dtype (str): Mixed precision training param dtype, default "bfloat16"
224+
qat (QATEngineConfig): QAT configuration, default disabled
202225
"""
203226

204227
# ulysses_sequence_parallel_size is mutable for backward compatibility
@@ -218,6 +241,7 @@ class FSDPEngineConfig(EngineConfig):
218241
use_torch_compile: bool = True
219242
entropy_checkpointing: bool = False
220243
strategy: str = "fsdp"
244+
qat: QATEngineConfig = field(default_factory=QATEngineConfig)
221245

222246
def __post_init__(self):
223247
super().__post_init__()

verl/workers/engine/fsdp/transformer_impl.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,12 @@ def __init__(
133133
self._is_offload_optimizer = self.engine_config.optimizer_offload
134134
self._is_lora = self.model_config.lora_rank > 0
135135

136+
# QAT (Quantization-Aware Training)
137+
self._qat_config = getattr(self.engine_config, "qat", None)
138+
self._qat_enabled = self._qat_config is not None and getattr(self._qat_config, "enable", False)
139+
if self._qat_enabled:
140+
logger.info(f"QAT enabled: mode={self._qat_config.mode}, group_size={self._qat_config.group_size}")
141+
136142
if self.engine_config.entropy_from_logits_with_chunking:
137143
entropy_from_logits = verl_F.entropy_from_logits_with_chunking
138144
else:
@@ -435,6 +441,58 @@ def _build_lr_scheduler(self, optimizer):
435441
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
436442
return lr_scheduler
437443

444+
def _apply_qat(self, module):
445+
"""Apply QAT transformations to the model before FSDP wrapping."""
446+
from verl.utils.qat.core import apply_qat, enable_qat_fuse
447+
448+
module = apply_qat(
449+
module,
450+
{
451+
"enable": self._qat_config.enable,
452+
"mode": self._qat_config.mode,
453+
"group_size": self._qat_config.group_size,
454+
"ignore_patterns": list(self._qat_config.ignore_patterns),
455+
"activation_observer": self._qat_config.activation_observer,
456+
},
457+
)
458+
enable_qat_fuse(module)
459+
460+
if self._qat_config.mode == "w4a4":
461+
self._restore_w4a4_input_scales(module, self.model_config.local_path)
462+
463+
return module
464+
465+
def _restore_w4a4_input_scales(self, model, model_path):
466+
"""Restore input_global_scale and input_amax from checkpoint for W4A4 mode."""
467+
import glob
468+
469+
from safetensors import safe_open
470+
471+
safetensor_files = glob.glob(f"{model_path}/model*.safetensors")
472+
loaded_count = 0
473+
474+
for sf_path in safetensor_files:
475+
with safe_open(sf_path, framework="pt") as f:
476+
for key in f.keys():
477+
if "input_global_scale" in key:
478+
module_path = key.replace(".input_global_scale", "")
479+
amax_key = f"{module_path}.input_amax"
480+
481+
module = model
482+
for part in module_path.split("."):
483+
module = getattr(module, part)
484+
485+
scale_val = f.get_tensor(key)
486+
val = scale_val.item() if scale_val.numel() == 1 else scale_val.max().item()
487+
module.input_global_scale.fill_(val)
488+
489+
amax_val = f.get_tensor(amax_key)
490+
amax = amax_val.item() if amax_val.numel() == 1 else amax_val.max().item()
491+
module.input_amax.fill_(amax)
492+
loaded_count += 1
493+
494+
logger.info(f"[QAT W4A4] Restored {loaded_count} input_global_scale/input_amax from {model_path}")
495+
438496
def _build_model_optimizer(self):
439497
from verl.utils.model import print_model_size
440498

@@ -444,6 +502,10 @@ def _build_model_optimizer(self):
444502
if self._is_lora:
445503
module = self._build_lora_module(module)
446504

505+
# Apply QAT before FSDP wrapping (training only)
506+
if self._qat_enabled and not self.engine_config.forward_only:
507+
module = self._apply_qat(module)
508+
447509
# Synchronize all distributed processes before proceeding
448510
torch.distributed.barrier()
449511
if self.rank == 0:
@@ -567,6 +629,12 @@ def optimizer_step(self):
567629
self.optimizer.zero_grad()
568630
else:
569631
self.optimizer.step()
632+
633+
if self._qat_enabled:
634+
from verl.utils.qat.core import invalidate_all_scales
635+
636+
invalidate_all_scales(self.module)
637+
570638
return grad_norm.item()
571639

572640
def lr_scheduler_step(self):
@@ -699,8 +767,29 @@ def get_per_tensor_param(self, layered_summon=False, base_sync_done=False, **kwa
699767
)
700768
for name, param in params.items()
701769
)
702-
# return per_tensor_param, peft_config
703-
# Convert peft_config to dict for vLLM compatibility (PEFTHelper.from_dict expects dict)
770+
771+
if self._qat_enabled:
772+
from verl.utils.qat.quantizer import QATQuantizer
773+
from verl.utils.torch_dtypes import PrecisionType
774+
775+
mixed_precision_config = self.engine_config.mixed_precision
776+
if mixed_precision_config is not None:
777+
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
778+
else:
779+
param_dtype = torch.bfloat16
780+
781+
quantizer = QATQuantizer(
782+
mode=self._qat_config.mode,
783+
group_size=self._qat_config.group_size,
784+
ignore_patterns=list(self._qat_config.ignore_patterns),
785+
device=torch.device(get_device_id()),
786+
param_dtype=param_dtype,
787+
)
788+
per_tensor_param = quantizer.quantize_with_fusion(
789+
per_tensor_param,
790+
target_device=torch.device("cpu"),
791+
)
792+
704793
peft_config_dict = peft_config.to_dict() if peft_config is not None else None
705794
return per_tensor_param, peft_config_dict
706795

0 commit comments

Comments
 (0)