6060 compute_throughout_metrics ,
6161 compute_timing_metrics ,
6262 compute_variance_proxy_metrics ,
63+ process_validation_metrics ,
6364)
6465from verl .trainer .ppo .ray_trainer import apply_kl_penalty , compute_advantage
6566from verl .trainer .ppo .rollout_corr_helper import compute_rollout_correction_and_add_to_batch
7879from verl .utils .py_functional import rename_dict
7980from verl .utils .seqlen_balancing import calculate_workload , get_seqlen_balanced_partitions , log_seqlen_unbalance
8081from verl .utils .tensordict_utils import list_of_dict_to_tensordict
81- from verl .utils .tracking import Tracking
82+ from verl .utils .tracking import Tracking , ValidationGenerationsLogger
8283from verl .workers .config import CriticConfig
8384from verl .workers .engine_workers import ActorRolloutRefWorker , TrainingWorker , TrainingWorkerConfig
8485from verl .workers .utils .losses import value_loss
@@ -159,7 +160,7 @@ def sample(self, partition_id: str, global_steps: int = None, batch_size: int =
159160 Returns:
160161 KVBatchMeta: A batch of data.
161162 """
162- assert (global_steps or batch_size ) and (not (global_steps and batch_size )), (
163+ assert (global_steps is not None or batch_size ) and (not (global_steps is not None and batch_size )), (
163164 "Either global_steps or batch_size must be specified, but not both."
164165 )
165166
@@ -344,15 +345,15 @@ def generate_sequences(self, prompts: TensorDict) -> None:
344345 """
345346 # mark prompts as pending in replay buffer
346347 global_steps = prompts ["global_steps" ]
347- partition_id = "train" if not prompts . get ( "validate" , False ) else "val"
348+ partition_id = "train" if "validate" not in prompts else "val"
348349 items = {uid : {"global_steps" : global_steps , "status" : "running" } for uid in prompts ["uid" ]}
349350 self .replay_buffer .add (partition_id , items )
350351
351352 chunkes = prompts .chunk (len (self .agent_loop_workers ))
352353 ray .get (
353354 [
354355 worker .generate_sequences .remote (chunk )
355- for worker , chunk in zip (self .agent_loop_workers , chunkes , strict = True )
356+ for worker , chunk in zip (self .agent_loop_workers , chunkes , strict = False )
356357 ]
357358 )
358359
@@ -634,8 +635,131 @@ def _load_checkpoint(self):
634635 def _save_checkpoint (self ):
635636 raise NotImplementedError
636637
637- def _validate (self ) -> dict :
638- raise NotImplementedError
638+ def _validate (self ) -> dict [str , float ]:
639+ # Lists to collect samples for the table
640+ sample_uids = []
641+ sample_inputs = []
642+ sample_outputs = []
643+ sample_gts = []
644+ sample_scores = []
645+ sample_turns = []
646+ data_sources = []
647+ reward_extra_infos_dict : dict [str , list ] = defaultdict (list )
648+
649+ for batch_dict in self .val_dataloader :
650+ # 1. put batch to agent loop manager
651+ batch_dict ["uid" ] = np .array (
652+ [str (uuid .uuid4 ()) for _ in range (len (batch_dict ["raw_prompt" ]))], dtype = object
653+ )
654+ batch = tu .get_tensordict (batch_dict )
655+ tu .assign_non_tensor_data (batch , "global_steps" , self .global_steps )
656+ tu .assign_non_tensor_data (batch , "validate" , True )
657+ self .agent_loop_manager .generate_sequences (batch )
658+
659+ # 2. sample batch from replay buffer
660+ batch = self .replay_buffer .sample (partition_id = "val" , global_steps = self .global_steps )
661+
662+ # 3. [OPTIONAL] compute reward score with colocated reward model
663+ if self .reward_loop_manager .reward_loop_worker_handles is None :
664+ self .checkpoint_manager .sleep_replicas ()
665+ batch = self ._compute_reward_colocate (batch )
666+ self .checkpoint_manager .update_weights ()
667+
668+ # 4. collect necessary data for logging
669+ fields = ["uid" , "prompts" , "responses" , "rm_scores" , "num_turns" , "reward_model" , "data_source" ]
670+ data = tq .kv_batch_get (keys = batch .keys , partition_id = batch .partition_id , fields = fields )
671+ data ["prompts" ] = data ["prompts" ].to_padded_tensor (padding = self .tokenizer .pad_token_id )
672+ data ["responses" ] = data ["responses" ].to_padded_tensor (padding = self .tokenizer .pad_token_id )
673+
674+ sample_uids .extend (data .pop ("uid" ).tolist ())
675+ output_texts = [self .tokenizer .decode (ids , skip_special_tokens = True ) for ids in data ["responses" ]]
676+ sample_outputs .extend (output_texts )
677+ input_texts = [self .tokenizer .decode (ids , skip_special_tokens = True ) for ids in data ["prompts" ]]
678+ sample_inputs .extend (input_texts )
679+ scores = data ["rm_scores" ].sum (dim = 1 ).tolist ()
680+ sample_scores .extend (scores )
681+ sample_turns .extend (data .pop ("num_turns" ).tolist ())
682+ reward_extra_infos_dict ["reward" ].extend (scores )
683+
684+ reward_model = data .pop ("reward_model" , None )
685+ if reward_model is not None :
686+ sample_gts .extend ([item .get ("ground_truth" , None ) for item in reward_model .tolist ()])
687+ else :
688+ sample_gts .extend ([None ] * len (batch ))
689+
690+ data_source = data .pop ("data_source" , None )
691+ if data_source is not None :
692+ data_sources .extend (data_source .tolist ())
693+ else :
694+ data_sources .extend (["unknown" ] * len (batch ))
695+
696+ # 5. cleanup transfer queue and replay buffer
697+ tq .kv_clear (keys = batch .keys , partition_id = batch .partition_id )
698+ self .replay_buffer .remove (batch .partition_id , batch .keys )
699+
700+ # logger to wandb
701+ self ._maybe_log_val_generations (inputs = sample_inputs , outputs = sample_outputs , scores = sample_scores )
702+
703+ # dump to local dir
704+ val_data_dir = self .config .trainer .get ("validation_data_dir" , None )
705+ if val_data_dir :
706+ self ._dump_generations (
707+ inputs = sample_inputs ,
708+ outputs = sample_outputs ,
709+ gts = sample_gts ,
710+ scores = sample_scores ,
711+ reward_extra_infos_dict = reward_extra_infos_dict ,
712+ dump_path = val_data_dir ,
713+ )
714+
715+ return self ._val_metrics_update (data_sources , sample_uids , reward_extra_infos_dict , sample_turns )
716+
717+ def _maybe_log_val_generations (self , inputs , outputs , scores ):
718+ """Log a table of validation samples to the configured logger (wandb or swanlab)"""
719+ generations_to_log = self .config .trainer .log_val_generations
720+ if generations_to_log == 0 :
721+ return
722+
723+ # Create tuples of (input, output, score) and sort by input text
724+ samples = list (zip (inputs , outputs , scores , strict = True ))
725+ samples .sort (key = lambda x : x [0 ]) # Sort by input text
726+
727+ # Use fixed random seed for deterministic shuffling
728+ rng = np .random .RandomState (42 )
729+ rng .shuffle (samples )
730+
731+ # Take first N samples after shuffling
732+ samples = samples [:generations_to_log ]
733+
734+ # Log to each configured logger
735+ self .validation_generations_logger .log (self .config .trainer .logger , samples , self .global_steps )
736+
737+ def _val_metrics_update (self , data_sources , sample_uids , reward_extra_infos_dict , sample_turns ) -> dict [str , float ]:
738+ data_src2var2metric2val = process_validation_metrics (data_sources , sample_uids , reward_extra_infos_dict )
739+ metric_dict = {}
740+ for data_source , var2metric2val in data_src2var2metric2val .items ():
741+ core_var = "acc" if "acc" in var2metric2val else "reward"
742+ for var_name , metric2val in var2metric2val .items ():
743+ n_max = max ([int (name .split ("@" )[- 1 ].split ("/" )[0 ]) for name in metric2val .keys ()])
744+ for metric_name , metric_val in metric2val .items ():
745+ if (
746+ (var_name == core_var )
747+ and any (metric_name .startswith (pfx ) for pfx in ["mean" , "maj" , "best" ])
748+ and (f"@{ n_max } " in metric_name )
749+ ):
750+ metric_sec = "val-core"
751+ else :
752+ metric_sec = "val-aux"
753+ pfx = f"{ metric_sec } /{ data_source } /{ var_name } /{ metric_name } "
754+ metric_dict [pfx ] = metric_val
755+
756+ if len (sample_turns ) > 0 :
757+ sample_turns = np .array (sample_turns )
758+ metric_dict ["val-aux/num_turns/min" ] = sample_turns .min ()
759+ metric_dict ["val-aux/num_turns/max" ] = sample_turns .max ()
760+ metric_dict ["val-aux/num_turns/mean" ] = sample_turns .mean ()
761+
762+ return metric_dict
639763
640764 def _start_profiling (self ) -> None :
641765 """Start profiling for all worker groups if profiling is enabled."""
@@ -930,6 +1054,7 @@ def _compute_metrics(self, batch: KVBatchMeta, metrics, timing_raw, global_steps
9301054 "returns" ,
9311055 "rm_scores" ,
9321056 "token_level_rewards" ,
1057+ "num_turns" ,
9331058 ]
9341059 data = tq .kv_batch_get (keys = batch .keys , partition_id = batch .partition_id , fields = fields )
9351060 data = data .to_padded_tensor ()
@@ -948,19 +1073,40 @@ def _compute_metrics(self, batch: KVBatchMeta, metrics, timing_raw, global_steps
9481073 gradient_norm = metrics .get ("actor/grad_norm" , None )
9491074 metrics .update (compute_variance_proxy_metrics (batch = batch , gradient_norm = gradient_norm ))
9501075
1076+ # 3. other auxiliary metrics
1077+ num_turns = np .array (data .pop ("num_turns" ).tolist ())
1078+ metrics .update (
1079+ {
1080+ "training/num_turns/mean" : num_turns .mean (),
1081+ "training/num_turns/max" : num_turns .max (),
1082+ "training/num_turns/min" : num_turns .min (),
1083+ }
1084+ )
1085+
9511086 def fit (self ):
9521087 self .logger = Tracking (
9531088 project_name = self .config .trainer .project_name ,
9541089 experiment_name = self .config .trainer .experiment_name ,
9551090 default_backend = self .config .trainer .logger ,
9561091 config = OmegaConf .to_container (self .config , resolve = True ),
9571092 )
1093+ self .validation_generations_logger = ValidationGenerationsLogger (
1094+ project_name = self .config .trainer .project_name ,
1095+ experiment_name = self .config .trainer .experiment_name ,
1096+ )
9581097
9591098 # load checkpoint and update weights before doing anything
9601099 self ._load_checkpoint ()
9611100 self .checkpoint_manager .update_weights ()
9621101
963- # TODO(wuxibin): validate before train
1102+ # perform validation before training
1103+ if self .config .trainer .get ("val_before_train" , True ):
1104+ val_metrics = self ._validate ()
1105+ assert val_metrics , f"{ val_metrics = } "
1106+ pprint (f"Initial validation metrics: { val_metrics } " )
1107+ self .logger .log (data = val_metrics , step = self .global_steps )
1108+ if self .config .trainer .get ("val_only" , False ):
1109+ return
9641110
9651111 current_epoch = self .global_steps // len (self .train_dataloader )
9661112 progress_bar = tqdm (total = self .total_training_steps , initial = self .global_steps , desc = "Training Progress" )
@@ -981,7 +1127,7 @@ def fit(self):
9811127 is_last_step = self .global_steps >= self .total_training_steps
9821128 metrics , timing_raw = {}, {}
9831129
984- # 1. perform rollout, update critic, and update actor
1130+ # 1. perform rollout and actor/critic training
9851131 self ._start_profiling ()
9861132 with marked_timer ("step" , timing_raw ):
9871133 batch = self .step (batch_dict , metrics , timing_raw )
@@ -1011,14 +1157,13 @@ def fit(self):
10111157 # 5. record metrics
10121158 self ._compute_metrics (batch , metrics , timing_raw , global_steps = self .global_steps , epoch = epoch )
10131159
1014- # remove items from transfer queue and replay buffer
1160+ # 6. cleanup transfer queue and replay buffer
10151161 tq .kv_clear (keys = batch .keys , partition_id = batch .partition_id )
10161162 self .replay_buffer .remove (batch .partition_id , batch .keys )
10171163
10181164 self .logger .log (data = metrics , step = self .global_steps )
10191165 progress_bar .update (1 )
10201166 self .global_steps += 1
1021-
10221167 if is_last_step :
10231168 pprint (f"Final validation metrics: { last_val_metrics } " )
10241169 progress_bar .close ()
0 commit comments