Skip to content

[trainer] feat: add support for the GDPO algorithm#5409

Open
yue-zeng-yue wants to merge 3 commits intoverl-project:mainfrom
yue-zeng-yue:feat-gdpo
Open

[trainer] feat: add support for the GDPO algorithm#5409
yue-zeng-yue wants to merge 3 commits intoverl-project:mainfrom
yue-zeng-yue:feat-gdpo

Conversation

@yue-zeng-yue
Copy link

@yue-zeng-yue yue-zeng-yue commented Feb 26, 2026

Description

This PR introduces support for the GDPO algorithm into the training framework. It enables the system to handle multiple reward functions effectively, which is a key requirement for multi-objective reinforcement learning tasks.

Reference Paper & Experimental Results

GDPO Paper: https://arxiv.org/abs/2601.05242

Performance Benchmarks:

Tool Calling: GDPO increases the correct format ratio from 76.33% to 80.66% and improves overall accuracy from 30.18% to 32.81% (BFCL-v3).

Math Reasoning: On AIME (DeepSeek-R1-7B), GDPO improves accuracy from 50.2% to 53.1% and reduces length-exceeding violations from 2.1% to 0.2%.

Coding Reasoning: In CodeContests (3-objective optimization), GDPO reduces the bug ratio significantly from 13.2% to 3.9% compared to GRPO.

Changes Made

Core Algorithm: Updated verl/trainer/ppo/core_algos.py to implement GDPO advantage computation logic.

Ray Trainer: Modified verl/trainer/ppo/ray_trainer.py to support multi-reward data preparation and metrics logging.

Configuration: Added necessary parameters in verl/trainer/config/algorithm.py to enable GDPO-specific settings.

Motivation

The current framework's limitations made it difficult to natively implement algorithms like GDPO that require complex or multiple reward function handling. This update modifies the core PPO trainer and Ray trainer to fully support GDPO.

@CLAassistant
Copy link

CLAassistant commented Feb 26, 2026

CLA assistant check
All committers have signed the CLA.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the GDPO algorithm. The changes span configuration, the core algorithm implementation, and integration into the Ray trainer, including data preparation and metrics logging. While the implementation is largely correct, I've identified a significant performance issue in the core GDPO advantage computation logic that should be addressed. My review includes a suggestion to vectorize a performance-critical loop.

Comment on lines +413 to +427
normalized = torch.zeros_like(reward_components)

id2indices = defaultdict(list)
for i in range(bs):
id2indices[index[i]].append(i)

for group_id, indices in id2indices.items():
idx_tensor = torch.tensor(indices, device=reward_components.device)
if len(indices) == 1:
normalized[indices[0]] = 0.0
else:
group_rewards = reward_components[idx_tensor] # (group_size, n_rewards)
group_mean = group_rewards.mean(dim=0)
group_std = group_rewards.std(dim=0)
normalized[idx_tensor] = (group_rewards - group_mean) / (group_std + epsilon)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for group-wise normalization iterates over each group using a Python loop. This can be a performance bottleneck, especially with a large number of groups or when running on GPU. This part of the computation should be vectorized to improve performance, similar to how compute_grpo_vectorized_outcome_advantage is implemented. A vectorized approach using torch.unique and scatter_add_ would be much more efficient.

Suggested change
normalized = torch.zeros_like(reward_components)
id2indices = defaultdict(list)
for i in range(bs):
id2indices[index[i]].append(i)
for group_id, indices in id2indices.items():
idx_tensor = torch.tensor(indices, device=reward_components.device)
if len(indices) == 1:
normalized[indices[0]] = 0.0
else:
group_rewards = reward_components[idx_tensor] # (group_size, n_rewards)
group_mean = group_rewards.mean(dim=0)
group_std = group_rewards.std(dim=0)
normalized[idx_tensor] = (group_rewards - group_mean) / (group_std + epsilon)
g = as_torch_index(index, device=reward_components.device)
unique_groups, group_indices, group_counts = torch.unique(g, return_inverse=True, return_counts=True)
group_sum = torch.zeros((unique_groups.shape[0], n_rewards), device=g.device, dtype=reward_components.dtype).scatter_add_(
0, group_indices.unsqueeze(1).expand(-1, n_rewards), reward_components
)
group_means_per_group = group_sum / group_counts.unsqueeze(1).clamp(min=1)
group_means = group_means_per_group[group_indices]
group_sum_sq = torch.zeros((unique_groups.shape[0], n_rewards), device=g.device, dtype=reward_components.dtype).scatter_add_(
0, group_indices.unsqueeze(1).expand(-1, n_rewards), reward_components.pow(2)
)
group_means_sq_per_group = group_sum_sq / group_counts.unsqueeze(1).clamp(min=1)
group_vars_per_group = (group_means_sq_per_group - group_means_per_group.pow(2)).clamp(min=0)
group_stds_per_group = torch.sqrt(group_vars_per_group)
group_stds = group_stds_per_group[group_indices]
normalized = (reward_components - group_means) / (group_stds + epsilon)

@tardis-key
Copy link
Collaborator

there are multi scores in gdpo rewards , and will be used in advEstimator laster. The current implementation does not seem to align with this algorithm.

checkpoint_engine_config = omega_conf_to_dataclass(self.config.actor_rollout_ref.rollout.checkpoint_engine)
self.checkpoint_manager = CheckpointEngineManager(
config=checkpoint_engine_config,
backend=self.config.actor_rollout_ref.rollout.checkpoint_engine.backend,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rebase main first, this should not be changed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,got it

Updated core_algos.py and ray_trainer.py to implement and integrate the GDPO algorithm into the training framework.
@yue-zeng-yue
Copy link
Author

there are multi scores in gdpo rewards , and will be used in advEstimator laster. The current implementation does not seem to align with this algorithm.

Thank you for the review! I’ve double-checked the code, and it seems to align with the underlying GDPO calculation logic. Could you please point out specifically which part doesn't match the algorithm? I’d appreciate more details so I can fix it properly.

@tongyx361 tongyx361 self-assigned this Feb 27, 2026
@tardis-key tardis-key self-requested a review February 28, 2026 01:26
@tardis-key
Copy link
Collaborator

In verl/trainer/ppo/core_algos.py: line 395:
reward_components: (bs, N_rewards) – per-sample scores for each reward dimension.

These sample-level rewards do not take into account prompt_mask or attention_mask, resulting in significant differences compared to the official GDPO implementation, in GDPO/verl-GDPO/verl/trainer/main_ppo.py: line 85
score, fomrat_score, correctness_score, length_score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth, step=step)

you can refer to #5422, which adapted reward_score

@yue-zeng-yue
Copy link
Author

In verl/trainer/ppo/core_algos.py: line 395: reward_components: (bs, N_rewards) – per-sample scores for each reward dimension.

These sample-level rewards do not take into account prompt_mask or attention_mask, resulting in significant differences compared to the official GDPO implementation, in GDPO/verl-GDPO/verl/trainer/main_ppo.py: line 85 score, fomrat_score, correctness_score, length_score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth, step=step)

you can refer to #5422, which adapted reward_score

Thanks for the review and for referencing #5422.

To clarify, my compute_gdpo_outcome_advantage in core_algos.py is the advantage computation step — it receives pre-computed per-dimension rewards as input and implements the three GDPO steps (group-wise decoupled normalization → weighted aggregation → batch-level normalization), which is mathematically consistent with the paper.

I agree that my PR currently lacks the multi-reward scoring pipeline (reward manager + compute_score function) to produce those per-dimension rewards. My PR assumes they are already available in non_tensor_batch, and the GDPO-related fields in algorithm.py (i.e., gdpo_reward_keys and gdpo_reward_weights) only define configuration parameters without the actual reward computation logic.

@yue-zeng-yue
Copy link
Author

In verl/trainer/ppo/core_algos.py: line 395: reward_components: (bs, N_rewards) – per-sample scores for each reward dimension.

These sample-level rewards do not take into account prompt_mask or attention_mask, resulting in significant differences compared to the official GDPO implementation, in GDPO/verl-GDPO/verl/trainer/main_ppo.py: line 85 score, fomrat_score, correctness_score, length_score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth, step=step)

you can refer to #5422, which adapted reward_score

My design intention was to keep the advantage computation decoupled from any specific reward function, so that users can plug in their own compute_score that returns a dict with custom reward keys.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants