@@ -133,6 +133,12 @@ def __init__(
133133 self ._is_offload_optimizer = self .engine_config .optimizer_offload
134134 self ._is_lora = self .model_config .lora_rank > 0
135135
136+ # QAT (Quantization-Aware Training)
137+ self ._qat_config = getattr (self .engine_config , "qat" , None )
138+ self ._qat_enabled = self ._qat_config is not None and getattr (self ._qat_config , "enable" , False )
139+ if self ._qat_enabled :
140+ logger .info (f"QAT enabled: mode={ self ._qat_config .mode } , group_size={ self ._qat_config .group_size } " )
141+
136142 if self .engine_config .entropy_from_logits_with_chunking :
137143 entropy_from_logits = verl_F .entropy_from_logits_with_chunking
138144 else :
@@ -435,6 +441,58 @@ def _build_lr_scheduler(self, optimizer):
435441 raise NotImplementedError (f"LR scheduler type { lr_scheduler_type } is not supported" )
436442 return lr_scheduler
437443
444+ def _apply_qat (self , module ):
445+ """Apply QAT transformations to the model before FSDP wrapping."""
446+ from verl .utils .qat .core import apply_qat , enable_qat_fuse
447+
448+ module = apply_qat (
449+ module ,
450+ {
451+ "enable" : self ._qat_config .enable ,
452+ "mode" : self ._qat_config .mode ,
453+ "group_size" : self ._qat_config .group_size ,
454+ "ignore_patterns" : list (self ._qat_config .ignore_patterns ),
455+ "activation_observer" : self ._qat_config .activation_observer ,
456+ },
457+ )
458+ enable_qat_fuse (module )
459+
460+ if self ._qat_config .mode == "w4a4" :
461+ self ._restore_w4a4_input_scales (module , self .model_config .local_path )
462+
463+ return module
464+
465+ def _restore_w4a4_input_scales (self , model , model_path ):
466+ """Restore input_global_scale and input_amax from checkpoint for W4A4 mode."""
467+ import glob
468+
469+ from safetensors import safe_open
470+
471+ safetensor_files = glob .glob (f"{ model_path } /model*.safetensors" )
472+ loaded_count = 0
473+
474+ for sf_path in safetensor_files :
475+ with safe_open (sf_path , framework = "pt" ) as f :
476+ for key in f .keys ():
477+ if "input_global_scale" in key :
478+ module_path = key .replace (".input_global_scale" , "" )
479+ amax_key = f"{ module_path } .input_amax"
480+
481+ module = model
482+ for part in module_path .split ("." ):
483+ module = getattr (module , part )
484+
485+ scale_val = f .get_tensor (key )
486+ val = scale_val .item () if scale_val .numel () == 1 else scale_val .max ().item ()
487+ module .input_global_scale .fill_ (val )
488+
489+ amax_val = f .get_tensor (amax_key )
490+ amax = amax_val .item () if amax_val .numel () == 1 else amax_val .max ().item ()
491+ module .input_amax .fill_ (amax )
492+ loaded_count += 1
493+
494+ logger .info (f"[QAT W4A4] Restored { loaded_count } input_global_scale/input_amax from { model_path } " )
495+
438496 def _build_model_optimizer (self ):
439497 from verl .utils .model import print_model_size
440498
@@ -444,6 +502,10 @@ def _build_model_optimizer(self):
444502 if self ._is_lora :
445503 module = self ._build_lora_module (module )
446504
505+ # Apply QAT before FSDP wrapping (training only)
506+ if self ._qat_enabled and not self .engine_config .forward_only :
507+ module = self ._apply_qat (module )
508+
447509 # Synchronize all distributed processes before proceeding
448510 torch .distributed .barrier ()
449511 if self .rank == 0 :
@@ -567,6 +629,12 @@ def optimizer_step(self):
567629 self .optimizer .zero_grad ()
568630 else :
569631 self .optimizer .step ()
632+
633+ if self ._qat_enabled :
634+ from verl .utils .qat .core import invalidate_all_scales
635+
636+ invalidate_all_scales (self .module )
637+
570638 return grad_norm .item ()
571639
572640 def lr_scheduler_step (self ):
@@ -699,8 +767,29 @@ def get_per_tensor_param(self, layered_summon=False, base_sync_done=False, **kwa
699767 )
700768 for name , param in params .items ()
701769 )
702- # return per_tensor_param, peft_config
703- # Convert peft_config to dict for vLLM compatibility (PEFTHelper.from_dict expects dict)
770+
771+ if self ._qat_enabled :
772+ from verl .utils .qat .quantizer import QATQuantizer
773+ from verl .utils .torch_dtypes import PrecisionType
774+
775+ mixed_precision_config = self .engine_config .mixed_precision
776+ if mixed_precision_config is not None :
777+ param_dtype = PrecisionType .to_dtype (mixed_precision_config .get ("param_dtype" , "bf16" ))
778+ else :
779+ param_dtype = torch .bfloat16
780+
781+ quantizer = QATQuantizer (
782+ mode = self ._qat_config .mode ,
783+ group_size = self ._qat_config .group_size ,
784+ ignore_patterns = list (self ._qat_config .ignore_patterns ),
785+ device = torch .device (get_device_id ()),
786+ param_dtype = param_dtype ,
787+ )
788+ per_tensor_param = quantizer .quantize_with_fusion (
789+ per_tensor_param ,
790+ target_device = torch .device ("cpu" ),
791+ )
792+
704793 peft_config_dict = peft_config .to_dict () if peft_config is not None else None
705794 return per_tensor_param , peft_config_dict
706795
0 commit comments