3232from tqdm import tqdm
3333
3434from verl import DataProto
35+ from verl .checkpoint_engine import CheckpointEngineManager
3536from verl .experimental .dataset .sampler import AbstractCurriculumSampler
3637from verl .single_controller .ray import RayClassWithInitArgs , RayWorkerGroup , ResourcePoolManager
3738from verl .single_controller .ray .base import create_colocated_worker_cls
@@ -118,8 +119,16 @@ def init_workers(self):
118119 self ._create_worker_classes ()
119120 self ._init_worker_groups ()
120121 self ._init_models ()
122+ self ._init_reward_loop ()
121123 self ._init_async_rollout_manager ()
122124
125+ self .checkpoint_manager = CheckpointEngineManager (
126+ backend = self .config .actor_rollout_ref .rollout .checkpoint_engine .backend ,
127+ trainer = self .actor_rollout_wg ,
128+ replicas = self .async_rollout_manager .rollout_replicas ,
129+ )
130+
131+
123132 def _init_resource_pools (self ):
124133 self .resource_pool_manager .create_resource_pool ()
125134 self .resource_pool_to_cls = {pool : {} for pool in self .resource_pool_manager .resource_pool_dict .values ()}
@@ -212,6 +221,21 @@ def _init_models(self):
212221 self .actor_rollout_wg = self .all_wg [str (Role .ActorRollout )]
213222 self .actor_rollout_wg .init_model ()
214223
224+ def _init_reward_loop (self ):
225+ if self .use_reward_loop :
226+ # create reward loop manager
227+ if self .use_reward_loop :
228+ from verl .experimental .reward_loop import RewardLoopManager
229+
230+ # initalize reward loop manager
231+ # reward model (colocate or standalone): get resource_pool
232+ # no reward model: resource_pool = None
233+ resource_pool = self .resource_pool_manager .get_resource_pool (Role .RewardModel ) if self .use_rm else None
234+ self .reward_loop_manager = RewardLoopManager (
235+ config = self .config ,
236+ rm_resource_pool = resource_pool ,
237+ )
238+
215239 def _init_async_rollout_manager (self ):
216240 pass
217241
@@ -247,7 +271,7 @@ def fit(self):
247271
248272 # perform validation before training
249273 # currently, we only support validation using the reward_function.
250- if self .val_reward_fn is not None and self . config .trainer .get ("val_before_train" , True ):
274+ if self .config .trainer .get ("val_before_train" , True ):
251275 val_metrics = self ._validate ()
252276 assert val_metrics , f"{ val_metrics = } "
253277 pprint (f"Initial validation metrics: { val_metrics } " )
@@ -279,6 +303,8 @@ def fit(self):
279303 for batch_dict in self .train_dataloader :
280304 self .epoch = epoch
281305 self .fit_step (batch_dict )
306+ if self .is_last_step :
307+ return
282308
283309 def fit_step (self , batch_dict : Any = None ):
284310 """
@@ -367,9 +393,6 @@ def _fit_generate(self, batch: DataProto = None) -> DataProto:
367393 gen_batch_output .meta_info .pop ("timing" , None )
368394
369395 if self .config .algorithm .adv_estimator == AdvantageEstimator .REMAX :
370- if self .reward_fn is None :
371- raise ValueError ("A reward_fn is required for REMAX advantage estimation." )
372-
373396 with marked_timer ("gen_max" , timing_raw , color = "purple" ):
374397 gen_baseline_batch = deepcopy (gen_batch )
375398 gen_baseline_batch .meta_info ["do_sample" ] = False
@@ -386,17 +409,16 @@ def _fit_generate(self, batch: DataProto = None) -> DataProto:
386409 # compute reward model score on batch
387410 rm_scores = None
388411 if self .use_rm and "rm_scores" not in batch .batch .keys ():
389- if not self .use_reward_loop :
390- rm_scores = self .rm_wg .compute_rm_score (batch )
391- else :
392- assert self .reward_loop_manager is not None , "RewardLoopManager is None"
393- rm_scores = self .reward_loop_manager .compute_rm_score (batch )
394- batch = batch .union (rm_scores )
412+ batch_reward = self ._compute_reward_colocate (batch )
413+ batch = batch .union (batch_reward )
395414
396415 # Compute or extract reward for REMAX baseline
397- reward_baseline_tensor = self ._compute_or_extract_reward (
398- batch , reward_fn = self .reward_fn , sum_reward = True
399- )
416+ if not self .use_reward_loop :
417+ reward_baseline_tensor = self ._compute_reward_legacy (
418+ batch , reward_fn = self .reward_fn , sum_reward = True
419+ )
420+ else :
421+ reward_baseline_tensor = batch .batch ["rm_scores" ].sum (dim = - 1 )
400422
401423 keys_to_pop = set (gen_baseline_output .batch .keys ())
402424 if rm_scores is not None :
@@ -435,22 +457,23 @@ def _fit_compute_reward(self, batch: DataProto) -> DataProto:
435457 with marked_timer ("reward" , timing_raw , color = "yellow" ):
436458 # compute reward model score
437459 if self .use_rm and "rm_scores" not in batch .batch .keys ():
438- if not self .use_reward_loop :
439- self .reward_tensor = self .rm_wg .compute_rm_score (batch )
460+ batch_reward = self ._compute_reward_colocate (batch )
461+ batch = batch .union (batch_reward )
462+
463+ # Compute or extract reward_tensor and reward_extra_infos_dict for training
464+ if not self .use_reward_loop :
465+ if self .config .reward_model .launch_reward_fn_async :
466+ self .future_reward = compute_reward_async .remote (
467+ data = batch , config = self .config , tokenizer = self .tokenizer
468+ )
440469 else :
441- assert self .reward_loop_manager is not None
442- self .reward_tensor = self .reward_loop_manager .compute_rm_score (batch )
443- batch = batch .union (self .reward_tensor )
444-
445- # Compute or extract reward for training
446- if self .config .reward_model .launch_reward_fn_async :
447- self .future_reward = compute_reward_async .remote (
448- data = batch , config = self .config , tokenizer = self .tokenizer
449- )
470+ self .reward_tensor , self .reward_extra_infos_dict = self ._compute_reward_legacy (
471+ batch , reward_fn = self .reward_fn , reward_for_val = False
472+ )
450473 else :
451- self .reward_tensor , self . reward_extra_infos_dict = self . _compute_or_extract_reward (
452- batch , reward_fn = self . reward_fn , reward_for_val = False
453- )
474+ self .reward_tensor = batch . batch [ "rm_scores" ]
475+ reward_extra_keys = batch . meta_info . get ( "reward_extra_keys" , [])
476+ self . reward_extra_infos_dict = { key : batch . non_tensor_batch [ key ] for key in reward_extra_keys }
454477 return batch
455478
456479 def _fit_compute_log_prob (self , batch : DataProto ) -> DataProto :
@@ -620,11 +643,8 @@ def _fit_dump_data(self, batch: DataProto):
620643 def _fit_validate (self ):
621644 metrics = self .metrics
622645 timing_raw = self .timing_raw
623- if (
624- self .val_reward_fn is not None
625- and self .config .trainer .test_freq > 0
626- and (self .is_last_step or self .global_steps % self .config .trainer .test_freq == 0 )
627- ):
646+ if self .config .trainer .test_freq > 0 and (
647+ self .is_last_step or self .global_steps % self .config .trainer .test_freq == 0 ):
628648 with marked_timer ("testing" , timing_raw , color = "green" ):
629649 val_metrics : dict = self ._validate ()
630650 if self .is_last_step :
@@ -652,10 +672,11 @@ def _fit_save_checkpoint(self):
652672 print ("Force saving checkpoint: ESI instance expiration approaching." )
653673 with marked_timer ("save_checkpoint" , timing_raw , color = "green" ):
654674 # sleep replicas to avoid OOM during checkpoint saving
655- self .checkpoint_manager .sleep_replicas ()
675+ # self.checkpoint_manager.sleep_replicas()
656676 self ._save_checkpoint ()
657677 # wake replicas to avoid OOM during checkpoint saving
658- self .checkpoint_manager .update_weights ()
678+ # TODO: Check separation is needed.
679+ # self.checkpoint_manager.update_weights()
659680
660681 def _fit_stop_profile (self ):
661682 timing_raw = self .timing_raw
@@ -727,4 +748,3 @@ def _fit_postprocess_step(self):
727748 self .actor_rollout_wg .async_calls_finalize_fn_exec (blocking = True )
728749 pprint (f"Final validation metrics: { self .last_val_metrics } " )
729750 self .progress_bar .close ()
730- return
0 commit comments