Skip to content

Commit b740472

Browse files
committed
fix one step off ci
1 parent 4d00f75 commit b740472

File tree

6 files changed

+106
-56
lines changed

6 files changed

+106
-56
lines changed
File renamed without changes.

.github/workflows/stash/e2e_one_step_off_policy.yml renamed to .github/workflows/e2e_one_step_off_policy.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,21 @@ on:
4141
- main
4242
- v0.*
4343
paths:
44-
- "**/*.py"
44+
- "../../setup.py"
4545
- "!**/*.md"
4646
- "!**/*.sh"
4747
# Other entrypoints
4848
- "!examples/*trainer*"
4949
- "!tests/**"
5050
- "!verl/trainer/main_*.py"
5151
- "!verl/trainer/fsdp_sft_trainer.py"
52-
- "verl/experimental/one_step_off_policy"
52+
- "../../verl/experimental/one_step_off_policy"
5353
pull_request:
5454
branches:
5555
- main
5656
- v0.*
5757
paths:
58-
- "**/*.py"
58+
- "../../setup.py"
5959
- "!**/*.md"
6060
- "!**/*.sh"
6161
# Other entrypoints
@@ -64,11 +64,11 @@ on:
6464
- "!verl/trainer/main_*.py"
6565
- "!verl/trainer/fsdp_sft_trainer.py"
6666
# Home
67-
- "verl/experimental/one_step_off_policy"
67+
- "../../verl/experimental/one_step_off_policy"
6868
# Entrypoints
6969
- ".github/workflows/e2e_one_step_off_policy.yml"
70-
- "examples/data_preprocess/gsm8k.py"
71-
- "tests/special_e2e/run_one_step_off_policy.sh"
70+
- "../../examples/data_preprocess/gsm8k.py"
71+
- "../../tests/special_e2e/run_one_step_off_policy.sh"
7272

7373
# Cancel jobs on the same ref if a new one is triggered
7474
concurrency:

verl/experimental/one_step_off_policy/main_ppo.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,24 @@ def run(self, config):
170170
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
171171

172172
# Load the reward manager for training and validation.
173-
reward_fn = load_reward_manager(
174-
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
175-
)
176-
val_reward_fn = load_reward_manager(
177-
config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
178-
)
173+
use_reward_loop = config.reward_model.use_reward_loop
174+
if not use_reward_loop:
175+
print(
176+
"WARNING: Init reward manager in single controller will be deprecated. "
177+
"Please set config.reward_model.use_reward_loop to use distributed reward manager."
178+
)
179+
# Load the reward manager for training and validation.
180+
reward_fn = load_reward_manager(
181+
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
182+
)
183+
val_reward_fn = load_reward_manager(
184+
config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
185+
)
186+
else:
187+
# reward_loop will use init a reward loop manager in ray_trainer
188+
# and use it to compute reward score
189+
reward_fn = None
190+
val_reward_fn = None
179191

180192
resource_pool_manager = create_resource_pool_manager(config, role_worker_mapping.keys())
181193

verl/experimental/one_step_off_policy/ray_trainer.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,27 @@ def _init_models(self):
182182
self._create_weight_sync_group()
183183

184184
def _init_async_rollout_manager(self):
185+
186+
# infrastructure overview: https://verl.readthedocs.io/en/latest/advance/reward_loop.html#architecture-design
187+
# agent_reward_loop: streaming reward computation with actor rollout
188+
# two conditions satisfied: (1) no reward model, or (2) reward model with extra resource pool
189+
enable_agent_reward_loop = self.use_reward_loop and (
190+
not self.use_rm or self.config.reward_model.enable_resource_pool
191+
)
192+
# if enable_agent_reward_loop, we directly pass reward_loop_workers to agent loop manager
193+
# to stream reward computation with actor rollout
194+
self.reward_loop_worker_handles = self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None
195+
reward_loop_worker_handles = self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None
196+
197+
185198
# create async rollout manager and request scheduler
186199
assert self.config.actor_rollout_ref.rollout.mode == "async"
187200
from verl.experimental.one_step_off_policy.agent_loop import OneStepOffAgentLoopManager
188201

189202
self.async_rollout_mode = True
190-
self.async_rollout_manager = OneStepOffAgentLoopManager(config=self.config, worker_group=self.rollout_wg)
203+
self.async_rollout_manager = OneStepOffAgentLoopManager(config=self.config,
204+
worker_group=self.rollout_wg,
205+
reward_loop_worker_handles=reward_loop_worker_handles)
191206

192207
def _create_weight_sync_group(self):
193208
from verl.utils.device import get_nccl_backend
@@ -356,7 +371,7 @@ async def fit(self):
356371

357372
# perform validation before training
358373
# currently, we only support validation using the reward_function.
359-
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
374+
if self.config.trainer.get("val_before_train", True):
360375
val_metrics = self._validate()
361376
assert val_metrics, f"{val_metrics=}"
362377
pprint(f"Initial validation metrics: {val_metrics}")
@@ -390,6 +405,8 @@ async def fit(self):
390405
batch_data_future = asyncio.create_task(self._async_gen_next_batch(continuous_iterator))
391406
while batch_data_future is not None:
392407
batch_data_future = await self.fit_step(batch_data_future, continuous_iterator)
408+
if self.is_last_step:
409+
return
393410

394411
async def fit_step(self, batch_data_future, continuous_iterator):
395412
"""
@@ -469,19 +486,20 @@ async def _fit_generate(self, batch_data_future, continuous_iterator):
469486

470487
# sync weights from actor to rollout
471488
with marked_timer("sync_rollout_weights", timing_raw, color="purple"):
472-
self.sync_rollout_weights()
489+
self._fit_update_weights()
473490
await self.async_rollout_manager.clear_kv_cache()
474491

475492
# async next generation
476493
if not self.is_last_step:
477494
batch_data_future = asyncio.create_task(self._async_gen_next_batch(continuous_iterator))
478495
await asyncio.sleep(0)
496+
else:
497+
batch_data_future = None
479498

480499
return batch, batch_data_future
481500

482501

483502
def _fit_update_weights(self):
484503
# TODO: use checkpoint engine to update weight
485-
self.sync_rollout_weights()
486-
487-
504+
# self.sync_rollout_weights()
505+
pass

verl/trainer/ppo/ray_trainer_for_separation.py

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from tqdm import tqdm
3333

3434
from verl import DataProto
35+
from verl.checkpoint_engine import CheckpointEngineManager
3536
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
3637
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup, ResourcePoolManager
3738
from verl.single_controller.ray.base import create_colocated_worker_cls
@@ -118,8 +119,16 @@ def init_workers(self):
118119
self._create_worker_classes()
119120
self._init_worker_groups()
120121
self._init_models()
122+
self._init_reward_loop()
121123
self._init_async_rollout_manager()
122124

125+
self.checkpoint_manager = CheckpointEngineManager(
126+
backend=self.config.actor_rollout_ref.rollout.checkpoint_engine.backend,
127+
trainer=self.actor_rollout_wg,
128+
replicas=self.async_rollout_manager.rollout_replicas,
129+
)
130+
131+
123132
def _init_resource_pools(self):
124133
self.resource_pool_manager.create_resource_pool()
125134
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
@@ -212,6 +221,21 @@ def _init_models(self):
212221
self.actor_rollout_wg = self.all_wg[str(Role.ActorRollout)]
213222
self.actor_rollout_wg.init_model()
214223

224+
def _init_reward_loop(self):
225+
if self.use_reward_loop:
226+
# create reward loop manager
227+
if self.use_reward_loop:
228+
from verl.experimental.reward_loop import RewardLoopManager
229+
230+
# initalize reward loop manager
231+
# reward model (colocate or standalone): get resource_pool
232+
# no reward model: resource_pool = None
233+
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) if self.use_rm else None
234+
self.reward_loop_manager = RewardLoopManager(
235+
config=self.config,
236+
rm_resource_pool=resource_pool,
237+
)
238+
215239
def _init_async_rollout_manager(self):
216240
pass
217241

@@ -247,7 +271,7 @@ def fit(self):
247271

248272
# perform validation before training
249273
# currently, we only support validation using the reward_function.
250-
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
274+
if self.config.trainer.get("val_before_train", True):
251275
val_metrics = self._validate()
252276
assert val_metrics, f"{val_metrics=}"
253277
pprint(f"Initial validation metrics: {val_metrics}")
@@ -279,6 +303,8 @@ def fit(self):
279303
for batch_dict in self.train_dataloader:
280304
self.epoch = epoch
281305
self.fit_step(batch_dict)
306+
if self.is_last_step:
307+
return
282308

283309
def fit_step(self, batch_dict: Any = None):
284310
"""
@@ -367,9 +393,6 @@ def _fit_generate(self, batch: DataProto = None) -> DataProto:
367393
gen_batch_output.meta_info.pop("timing", None)
368394

369395
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
370-
if self.reward_fn is None:
371-
raise ValueError("A reward_fn is required for REMAX advantage estimation.")
372-
373396
with marked_timer("gen_max", timing_raw, color="purple"):
374397
gen_baseline_batch = deepcopy(gen_batch)
375398
gen_baseline_batch.meta_info["do_sample"] = False
@@ -386,17 +409,16 @@ def _fit_generate(self, batch: DataProto = None) -> DataProto:
386409
# compute reward model score on batch
387410
rm_scores = None
388411
if self.use_rm and "rm_scores" not in batch.batch.keys():
389-
if not self.use_reward_loop:
390-
rm_scores = self.rm_wg.compute_rm_score(batch)
391-
else:
392-
assert self.reward_loop_manager is not None, "RewardLoopManager is None"
393-
rm_scores = self.reward_loop_manager.compute_rm_score(batch)
394-
batch = batch.union(rm_scores)
412+
batch_reward = self._compute_reward_colocate(batch)
413+
batch = batch.union(batch_reward)
395414

396415
# Compute or extract reward for REMAX baseline
397-
reward_baseline_tensor = self._compute_or_extract_reward(
398-
batch, reward_fn=self.reward_fn, sum_reward=True
399-
)
416+
if not self.use_reward_loop:
417+
reward_baseline_tensor = self._compute_reward_legacy(
418+
batch, reward_fn=self.reward_fn, sum_reward=True
419+
)
420+
else:
421+
reward_baseline_tensor = batch.batch["rm_scores"].sum(dim=-1)
400422

401423
keys_to_pop = set(gen_baseline_output.batch.keys())
402424
if rm_scores is not None:
@@ -435,22 +457,23 @@ def _fit_compute_reward(self, batch: DataProto) -> DataProto:
435457
with marked_timer("reward", timing_raw, color="yellow"):
436458
# compute reward model score
437459
if self.use_rm and "rm_scores" not in batch.batch.keys():
438-
if not self.use_reward_loop:
439-
self.reward_tensor = self.rm_wg.compute_rm_score(batch)
460+
batch_reward = self._compute_reward_colocate(batch)
461+
batch = batch.union(batch_reward)
462+
463+
# Compute or extract reward_tensor and reward_extra_infos_dict for training
464+
if not self.use_reward_loop:
465+
if self.config.reward_model.launch_reward_fn_async:
466+
self.future_reward = compute_reward_async.remote(
467+
data=batch, config=self.config, tokenizer=self.tokenizer
468+
)
440469
else:
441-
assert self.reward_loop_manager is not None
442-
self.reward_tensor = self.reward_loop_manager.compute_rm_score(batch)
443-
batch = batch.union(self.reward_tensor)
444-
445-
# Compute or extract reward for training
446-
if self.config.reward_model.launch_reward_fn_async:
447-
self.future_reward = compute_reward_async.remote(
448-
data=batch, config=self.config, tokenizer=self.tokenizer
449-
)
470+
self.reward_tensor, self.reward_extra_infos_dict = self._compute_reward_legacy(
471+
batch, reward_fn=self.reward_fn, reward_for_val=False
472+
)
450473
else:
451-
self.reward_tensor, self.reward_extra_infos_dict = self._compute_or_extract_reward(
452-
batch, reward_fn=self.reward_fn, reward_for_val=False
453-
)
474+
self.reward_tensor = batch.batch["rm_scores"]
475+
reward_extra_keys = batch.meta_info.get("reward_extra_keys", [])
476+
self.reward_extra_infos_dict = {key: batch.non_tensor_batch[key] for key in reward_extra_keys}
454477
return batch
455478

456479
def _fit_compute_log_prob(self, batch: DataProto) -> DataProto:
@@ -620,11 +643,8 @@ def _fit_dump_data(self, batch: DataProto):
620643
def _fit_validate(self):
621644
metrics = self.metrics
622645
timing_raw = self.timing_raw
623-
if (
624-
self.val_reward_fn is not None
625-
and self.config.trainer.test_freq > 0
626-
and (self.is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
627-
):
646+
if self.config.trainer.test_freq > 0 and (
647+
self.is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
628648
with marked_timer("testing", timing_raw, color="green"):
629649
val_metrics: dict = self._validate()
630650
if self.is_last_step:
@@ -652,10 +672,11 @@ def _fit_save_checkpoint(self):
652672
print("Force saving checkpoint: ESI instance expiration approaching.")
653673
with marked_timer("save_checkpoint", timing_raw, color="green"):
654674
# sleep replicas to avoid OOM during checkpoint saving
655-
self.checkpoint_manager.sleep_replicas()
675+
# self.checkpoint_manager.sleep_replicas()
656676
self._save_checkpoint()
657677
# wake replicas to avoid OOM during checkpoint saving
658-
self.checkpoint_manager.update_weights()
678+
# TODO: Check separation is needed.
679+
# self.checkpoint_manager.update_weights()
659680

660681
def _fit_stop_profile(self):
661682
timing_raw = self.timing_raw
@@ -727,4 +748,3 @@ def _fit_postprocess_step(self):
727748
self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True)
728749
pprint(f"Final validation metrics: {self.last_val_metrics}")
729750
self.progress_bar.close()
730-
return

verl/workers/rollout/sglang_rollout/async_sglang_server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
266266
"dp_size": self.config.data_parallel_size,
267267
"ep_size": self.config.expert_parallel_size,
268268
"node_rank": self.node_rank,
269-
"load_format": self.config.load_format,
269+
"load_format": "auto",
270270
"dist_init_addr": dist_init_addr,
271271
"nnodes": self.nnodes,
272272
"trust_remote_code": self.model_config.trust_remote_code,
@@ -391,8 +391,8 @@ async def sleep(self):
391391
logger.info("skip sleep in standalone mode")
392392

393393
async def clear_kv_cache(self):
394-
obj = ReleaseMemoryOccupationReqInput(tags=["kv_cache"])
395-
await self.tokenizer_manager.release_memory_occupation(obj, None)
394+
if self.node_rank == 0:
395+
await self.tokenizer_manager.flush_cache()
396396

397397
async def generate(
398398
self,

0 commit comments

Comments
 (0)