Skip to content

Commit 62973bd

Browse files
committed
_validate
1 parent f5e1316 commit 62973bd

File tree

3 files changed

+159
-11
lines changed

3 files changed

+159
-11
lines changed

verl/trainer/constants_ppo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
# https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/maintenref/envvar/envref_07_0143.html
3737
"HCCL_HOST_SOCKET_PORT_RANGE": "auto",
3838
"HCCL_NPU_SOCKET_PORT_RANGE": "auto",
39+
"TQ_ZERO_COPY_SERIALIZATION": "1",
3940
},
4041
}
4142

verl/trainer/main_ppo_sync.py

Lines changed: 155 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
compute_throughout_metrics,
6161
compute_timing_metrics,
6262
compute_variance_proxy_metrics,
63+
process_validation_metrics,
6364
)
6465
from verl.trainer.ppo.ray_trainer import apply_kl_penalty, compute_advantage
6566
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch
@@ -78,7 +79,7 @@
7879
from verl.utils.py_functional import rename_dict
7980
from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance
8081
from verl.utils.tensordict_utils import list_of_dict_to_tensordict
81-
from verl.utils.tracking import Tracking
82+
from verl.utils.tracking import Tracking, ValidationGenerationsLogger
8283
from verl.workers.config import CriticConfig
8384
from verl.workers.engine_workers import ActorRolloutRefWorker, TrainingWorker, TrainingWorkerConfig
8485
from verl.workers.utils.losses import value_loss
@@ -159,7 +160,7 @@ def sample(self, partition_id: str, global_steps: int = None, batch_size: int =
159160
Returns:
160161
KVBatchMeta: A batch of data.
161162
"""
162-
assert (global_steps or batch_size) and (not (global_steps and batch_size)), (
163+
assert (global_steps is not None or batch_size) and (not (global_steps is not None and batch_size)), (
163164
"Either global_steps or batch_size must be specified, but not both."
164165
)
165166

@@ -344,15 +345,15 @@ def generate_sequences(self, prompts: TensorDict) -> None:
344345
"""
345346
# mark prompts as pending in replay buffer
346347
global_steps = prompts["global_steps"]
347-
partition_id = "train" if not prompts.get("validate", False) else "val"
348+
partition_id = "train" if "validate" not in prompts else "val"
348349
items = {uid: {"global_steps": global_steps, "status": "running"} for uid in prompts["uid"]}
349350
self.replay_buffer.add(partition_id, items)
350351

351352
chunkes = prompts.chunk(len(self.agent_loop_workers))
352353
ray.get(
353354
[
354355
worker.generate_sequences.remote(chunk)
355-
for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True)
356+
for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=False)
356357
]
357358
)
358359

@@ -634,8 +635,131 @@ def _load_checkpoint(self):
634635
def _save_checkpoint(self):
635636
raise NotImplementedError
636637

637-
def _validate(self) -> dict:
638-
raise NotImplementedError
638+
def _validate(self) -> dict[str, float]:
639+
# Lists to collect samples for the table
640+
sample_uids = []
641+
sample_inputs = []
642+
sample_outputs = []
643+
sample_gts = []
644+
sample_scores = []
645+
sample_turns = []
646+
data_sources = []
647+
reward_extra_infos_dict: dict[str, list] = defaultdict(list)
648+
649+
for batch_dict in self.val_dataloader:
650+
# 1. put batch to agent loop manager
651+
batch_dict["uid"] = np.array(
652+
[str(uuid.uuid4()) for _ in range(len(batch_dict["raw_prompt"]))], dtype=object
653+
)
654+
batch = tu.get_tensordict(batch_dict)
655+
tu.assign_non_tensor_data(batch, "global_steps", self.global_steps)
656+
tu.assign_non_tensor_data(batch, "validate", True)
657+
self.agent_loop_manager.generate_sequences(batch)
658+
659+
# 2. sample batch from replay buffer
660+
batch = self.replay_buffer.sample(partition_id="val", global_steps=self.global_steps)
661+
662+
# 3. [OPTIONAL] compute reward score with colocated reward model
663+
if self.reward_loop_manager.reward_loop_worker_handles is None:
664+
self.checkpoint_manager.sleep_replicas()
665+
batch = self._compute_reward_colocate(batch)
666+
self.checkpoint_manager.update_weights()
667+
668+
# 4. collect necessary data for logging
669+
fields = ["uid", "prompts", "responses", "rm_scores", "num_turns", "reward_model", "data_source"]
670+
data = tq.kv_batch_get(keys=batch.keys, partition_id=batch.partition_id, fields=fields)
671+
data["prompts"] = data["prompts"].to_padded_tensor(padding=self.tokenizer.pad_token_id)
672+
data["responses"] = data["responses"].to_padded_tensor(padding=self.tokenizer.pad_token_id)
673+
674+
sample_uids.extend(data.pop("uid").tolist())
675+
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in data["responses"]]
676+
sample_outputs.extend(output_texts)
677+
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in data["prompts"]]
678+
sample_inputs.extend(input_texts)
679+
scores = data["rm_scores"].sum(dim=1).tolist()
680+
sample_scores.extend(scores)
681+
sample_turns.extend(data.pop("num_turns").tolist())
682+
reward_extra_infos_dict["reward"].extend(scores)
683+
684+
reward_model = data.pop("reward_model", None)
685+
if reward_model is not None:
686+
sample_gts.extend([item.get("ground_truth", None) for item in reward_model.tolist()])
687+
else:
688+
sample_gts.extend([None] * len(batch))
689+
690+
data_source = data.pop("data_source", None)
691+
if data_source is not None:
692+
data_sources.extend(data_source.tolist())
693+
else:
694+
data_sources.extend(["unknown"] * len(batch))
695+
696+
# 5. cleanup transfer queue and replay buffer
697+
tq.kv_clear(keys=batch.keys, partition_id=batch.partition_id)
698+
self.replay_buffer.remove(batch.partition_id, batch.keys)
699+
700+
# logger to wandb
701+
self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
702+
703+
# dump to local dir
704+
val_data_dir = self.config.trainer.get("validation_data_dir", None)
705+
if val_data_dir:
706+
self._dump_generations(
707+
inputs=sample_inputs,
708+
outputs=sample_outputs,
709+
gts=sample_gts,
710+
scores=sample_scores,
711+
reward_extra_infos_dict=reward_extra_infos_dict,
712+
dump_path=val_data_dir,
713+
)
714+
715+
return self._val_metrics_update(data_sources, sample_uids, reward_extra_infos_dict, sample_turns)
716+
717+
def _maybe_log_val_generations(self, inputs, outputs, scores):
718+
"""Log a table of validation samples to the configured logger (wandb or swanlab)"""
719+
generations_to_log = self.config.trainer.log_val_generations
720+
if generations_to_log == 0:
721+
return
722+
723+
# Create tuples of (input, output, score) and sort by input text
724+
samples = list(zip(inputs, outputs, scores, strict=True))
725+
samples.sort(key=lambda x: x[0]) # Sort by input text
726+
727+
# Use fixed random seed for deterministic shuffling
728+
rng = np.random.RandomState(42)
729+
rng.shuffle(samples)
730+
731+
# Take first N samples after shuffling
732+
samples = samples[:generations_to_log]
733+
734+
# Log to each configured logger
735+
self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)
736+
737+
def _val_metrics_update(self, data_sources, sample_uids, reward_extra_infos_dict, sample_turns) -> dict[str, float]:
738+
data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict)
739+
metric_dict = {}
740+
for data_source, var2metric2val in data_src2var2metric2val.items():
741+
core_var = "acc" if "acc" in var2metric2val else "reward"
742+
for var_name, metric2val in var2metric2val.items():
743+
n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
744+
for metric_name, metric_val in metric2val.items():
745+
if (
746+
(var_name == core_var)
747+
and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"])
748+
and (f"@{n_max}" in metric_name)
749+
):
750+
metric_sec = "val-core"
751+
else:
752+
metric_sec = "val-aux"
753+
pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
754+
metric_dict[pfx] = metric_val
755+
756+
if len(sample_turns) > 0:
757+
sample_turns = np.array(sample_turns)
758+
metric_dict["val-aux/num_turns/min"] = sample_turns.min()
759+
metric_dict["val-aux/num_turns/max"] = sample_turns.max()
760+
metric_dict["val-aux/num_turns/mean"] = sample_turns.mean()
761+
762+
return metric_dict
639763

640764
def _start_profiling(self) -> None:
641765
"""Start profiling for all worker groups if profiling is enabled."""
@@ -930,6 +1054,7 @@ def _compute_metrics(self, batch: KVBatchMeta, metrics, timing_raw, global_steps
9301054
"returns",
9311055
"rm_scores",
9321056
"token_level_rewards",
1057+
"num_turns",
9331058
]
9341059
data = tq.kv_batch_get(keys=batch.keys, partition_id=batch.partition_id, fields=fields)
9351060
data = data.to_padded_tensor()
@@ -948,19 +1073,40 @@ def _compute_metrics(self, batch: KVBatchMeta, metrics, timing_raw, global_steps
9481073
gradient_norm = metrics.get("actor/grad_norm", None)
9491074
metrics.update(compute_variance_proxy_metrics(batch=batch, gradient_norm=gradient_norm))
9501075

1076+
# 3. other auxiliary metrics
1077+
num_turns = np.array(data.pop("num_turns").tolist())
1078+
metrics.update(
1079+
{
1080+
"training/num_turns/mean": num_turns.mean(),
1081+
"training/num_turns/max": num_turns.max(),
1082+
"training/num_turns/min": num_turns.min(),
1083+
}
1084+
)
1085+
9511086
def fit(self):
9521087
self.logger = Tracking(
9531088
project_name=self.config.trainer.project_name,
9541089
experiment_name=self.config.trainer.experiment_name,
9551090
default_backend=self.config.trainer.logger,
9561091
config=OmegaConf.to_container(self.config, resolve=True),
9571092
)
1093+
self.validation_generations_logger = ValidationGenerationsLogger(
1094+
project_name=self.config.trainer.project_name,
1095+
experiment_name=self.config.trainer.experiment_name,
1096+
)
9581097

9591098
# load checkpoint and update weights before doing anything
9601099
self._load_checkpoint()
9611100
self.checkpoint_manager.update_weights()
9621101

963-
# TODO(wuxibin): validate before train
1102+
# perform validation before training
1103+
if self.config.trainer.get("val_before_train", True):
1104+
val_metrics = self._validate()
1105+
assert val_metrics, f"{val_metrics=}"
1106+
pprint(f"Initial validation metrics: {val_metrics}")
1107+
self.logger.log(data=val_metrics, step=self.global_steps)
1108+
if self.config.trainer.get("val_only", False):
1109+
return
9641110

9651111
current_epoch = self.global_steps // len(self.train_dataloader)
9661112
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
@@ -981,7 +1127,7 @@ def fit(self):
9811127
is_last_step = self.global_steps >= self.total_training_steps
9821128
metrics, timing_raw = {}, {}
9831129

984-
# 1. perform rollout, update critic, and update actor
1130+
# 1. perform rollout and actor/critic training
9851131
self._start_profiling()
9861132
with marked_timer("step", timing_raw):
9871133
batch = self.step(batch_dict, metrics, timing_raw)
@@ -1011,14 +1157,13 @@ def fit(self):
10111157
# 5. record metrics
10121158
self._compute_metrics(batch, metrics, timing_raw, global_steps=self.global_steps, epoch=epoch)
10131159

1014-
# remove items from transfer queue and replay buffer
1160+
# 6. cleanup transfer queue and replay buffer
10151161
tq.kv_clear(keys=batch.keys, partition_id=batch.partition_id)
10161162
self.replay_buffer.remove(batch.partition_id, batch.keys)
10171163

10181164
self.logger.log(data=metrics, step=self.global_steps)
10191165
progress_bar.update(1)
10201166
self.global_steps += 1
1021-
10221167
if is_last_step:
10231168
pprint(f"Final validation metrics: {last_val_metrics}")
10241169
progress_bar.close()

verl/workers/utils/padding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,6 @@ def response_to_nested(tensor: torch.Tensor, response_mask: torch.Tensor) -> tor
160160
response_list = []
161161
for i in range(tensor.shape[0]):
162162
response_list.append(tensor[i, : response_lens[i]])
163-
return torch.nested.as_nested_tensor(response_list, layout=torch.jagged)
163+
# FIXME: switch to jagged layout
164+
# return torch.nested.as_nested_tensor(response_list, layout=torch.jagged)
165+
return torch.nested.as_nested_tensor(response_list, layout=torch.strided)

0 commit comments

Comments
 (0)