Skip to content

Commit d149dd9

Browse files
committed
feat: add core_algo for GDPO
1 parent 5ca64e2 commit d149dd9

File tree

5 files changed

+184
-12
lines changed

5 files changed

+184
-12
lines changed

examples/gdpo_trainer/run_gdpo.sh

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
export DATA_DIR="verl/dataset/rlla_4k"
2+
export BASE_MODEL="Qwen/Qwen2.5-1.5B-Instruct"
3+
export EXPERIMENT_NAME="qwen2.5-1.5B-GDPO"
4+
export CKPT_DIR="verl/results/gdpo"
5+
6+
PROJECT_DIR="$(pwd)"
7+
8+
trainer_n_gpus_per_node=8
9+
trainer_nnodes=1
10+
11+
python3 -u -m verl.trainer.main_ppo \
12+
algorithm.adv_estimator=gdpo \
13+
data.train_files=$DATA_DIR/train.parquet \
14+
data.val_files=$DATA_DIR/test.parquet \
15+
data.train_batch_size=32 \
16+
data.val_batch_size=16 \
17+
data.max_prompt_length=2048 \
18+
data.max_response_length=1024 \
19+
actor_rollout_ref.model.path=$BASE_MODEL \
20+
actor_rollout_ref.actor.optim.lr=1e-6 \
21+
actor_rollout_ref.model.use_remove_padding=True \
22+
actor_rollout_ref.actor.ppo_mini_batch_size=4 \
23+
actor_rollout_ref.actor.use_dynamic_bsz=True \
24+
actor_rollout_ref.actor.use_kl_loss=False \
25+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
26+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
27+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
28+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
29+
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
30+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
31+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
32+
actor_rollout_ref.rollout.name=vllm \
33+
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
34+
actor_rollout_ref.rollout.n=4 \
35+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
36+
algorithm.kl_ctrl.kl_coef=0.001 \
37+
reward.custom_reward_function.path="$PROJECT_DIR/verl/experimental/reward_loop/reward_manager/gdpo.py" \
38+
reward.custom_reward_function.name=compute_score \
39+
trainer.critic_warmup=0 \
40+
trainer.logger=['console'] \
41+
trainer.project_name=Var_inspect \
42+
trainer.n_gpus_per_node=$trainer_n_gpus_per_node \
43+
trainer.experiment_name=$EXPERIMENT_NAME \
44+
trainer.n_gpus_per_node=$trainer_nnodes \
45+
trainer.nnodes=1 \
46+
trainer.save_freq=5 \
47+
trainer.test_freq=10 \
48+
trainer.default_local_dir=$CKPT_DIR \
49+
trainer.total_epochs=15 \
50+
trainer.val_before_train=False 2>&1 | tee ${LOG_PATH}

verl/experimental/reward_loop/reward_manager/gdpo.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ async def run_single(self, data: DataProto) -> dict:
7878
),
7979
)
8080

81+
# result = {
82+
# "score": score,
83+
# "score_list": [fomrat_score, correctness_score],
84+
# }
85+
86+
# return = {"reward_score": reward, "reward_extra_info": reward_extra_info}
87+
# reward_extra_info = {"score": score, "score_list": [fomrat_score, correctness_score]}
88+
8189
reward_extra_info = {}
8290

8391
score: float

verl/trainer/ppo/core_algos.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class AdvantageEstimator(str, Enum):
9696

9797
GAE = "gae"
9898
GRPO = "grpo"
99+
GDPO = "gdpo"
99100
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
100101
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
101102
REMAX = "remax"
@@ -2370,3 +2371,70 @@ def compute_policy_loss_bypass_mode(
23702371
pg_metrics.update(rollout_metrics)
23712372

23722373
return pg_loss, pg_metrics
2374+
2375+
2376+
@register_adv_est(AdvantageEstimator.GDPO) # or simply: @register_adv_est("gdpo")
2377+
def compute_gdpo_outcome_advantage(
2378+
token_level_rewards: torch.Tensor,
2379+
response_mask: torch.Tensor,
2380+
index: np.ndarray,
2381+
epsilon: float = 1e-6,
2382+
norm_adv_by_std_in_grpo: bool = True,
2383+
config: Optional[AlgoConfig] = None,
2384+
score_list: Optional[list[torch.Tensor]] = None,
2385+
) -> tuple[torch.Tensor, torch.Tensor]:
2386+
"""
2387+
Compute advantage for GDPO, operating only on Outcome reward
2388+
(with only one scalar reward for each response).
2389+
2390+
Args:
2391+
token_level_rewards: `(torch.Tensor)`
2392+
shape is (bs, response_length)
2393+
response_mask: `(torch.Tensor)`
2394+
shape is (bs, response_length)
2395+
index: `(np.ndarray)`
2396+
index array for grouping
2397+
epsilon: `(float)`
2398+
small value to avoid division by zero
2399+
norm_adv_by_std_in_grpo: `(bool)`
2400+
whether to scale the GRPO advantage
2401+
config: `(Optional[AlgoConfig])`
2402+
algorithm configuration object
2403+
score_list: `(Optional[list[torch.Tensor]])`
2404+
multi scores for GDPO
2405+
2406+
Note:
2407+
Ref GDPO (https://arxiv.org/abs/2601.05242).
2408+
2409+
Returns:
2410+
advantages: `(torch.Tensor)`
2411+
shape is (bs, response_length)
2412+
Returns: `(torch.Tensor)`
2413+
shape is (bs, response_length)
2414+
"""
2415+
if score_list is None:
2416+
score_list = [token_level_rewards]
2417+
# for debug
2418+
print("------no multi-score-find---------")
2419+
num_scores = len(score_list)
2420+
new_advantage = None
2421+
for i in range(num_scores):
2422+
token_level_scores = score_list[i]
2423+
2424+
normalized_score, _ = compute_grpo_outcome_advantage(
2425+
token_level_rewards=token_level_scores,
2426+
response_mask=response_mask,
2427+
index=index,
2428+
epsilon=epsilon,
2429+
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
2430+
config=config,
2431+
)
2432+
2433+
if new_advantage is None:
2434+
new_advantage = normalized_score
2435+
else:
2436+
new_advantage += normalized_score
2437+
2438+
advantages = verl_F.masked_whiten(new_advantage, response_mask) * response_mask
2439+
2440+
return advantages, advantages

verl/trainer/ppo/ray_trainer.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,58 @@ def compute_advantage(
211211
rollout_is_weights = data.batch.get("rollout_is_weights", None)
212212
adv_kwargs["rollout_is_weights"] = rollout_is_weights
213213

214+
if adv_estimator == AdvantageEstimator.GDPO:
215+
assert "score_list" in data.batch, (
216+
"GDPO need multi-scores. "
217+
"Please change the config: reward.custom_reward_function.path to point gdpo.py or "
218+
"change the reward function to compute multi-scores."
219+
)
220+
221+
# prompt_length = prompt_ids.size(1)
222+
# response_length = attention_mask[:, prompt_length:].sum(dim=1) - 1
223+
# rm_scores = torch.zeros_like(response_mask, dtype=torch.float32)
224+
# rm_scores[torch.arange(response_mask.size(0)), response_length] =
225+
# torch.tensor(scores, dtype=torch.float32)
226+
# batch["rm_scores"] = rm_scores
227+
228+
# batch = TensorDict(
229+
# {
230+
# "prompts": prompt_ids, # [bsz, prompt_length]
231+
# "responses": response_ids, # [bsz, response_length]
232+
# "response_mask": response_mask, # [bsz, response_length]
233+
# "input_ids": input_ids, # [bsz, prompt_length + response_length]
234+
# "attention_mask": attention_mask, # [bsz, prompt_length + response_length]
235+
# # position_ids: [bsz, 3, prompt_length + response_length]
236+
# or [bsz, prompt_length + response_length]
237+
# "position_ids": position_ids,
238+
# **optional_outputs,
239+
# },
240+
# batch_size=len(inputs),
241+
# )
242+
score_list = []
243+
multi_score_tensor = torch.tensor(
244+
data.non_tensor_batch["score_list"], dtype=torch.float32
245+
) # # [bsz, score_num, 1]
246+
print(f"----------multi_score_tensor:{multi_score_tensor.shape}")
247+
248+
for i in range(multi_score_tensor.shape[1]):
249+
rm_score = multi_score_tensor[:, i]
250+
prompt_length = data.batch["prompts"].size(1)
251+
response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=1) - 1
252+
rm_scores = torch.zeros_like(data.batch["response_mask"], dtype=torch.float32)
253+
rm_scores[torch.arange(data.batch["response_mask"].size(0)), response_length] = torch.tensor(
254+
rm_score, dtype=torch.float32
255+
)
256+
score_list.append(rm_scores)
257+
258+
# sum_score_tensor = data.batch["token_level_rewards"]
259+
260+
# rm_scores[torch.arange(rm_scores.size(0)), valid_response_length - 1] = torch.tensor(
261+
# scores, dtype=torch.float32
262+
# )
263+
adv_kwargs["score_list"] = score_list
264+
265+
# np.array([[format_score,correct_score] for info in reward_extra_infos])
214266
# calculate advantage estimator
215267
advantages, returns = adv_estimator_fn(**adv_kwargs)
216268
data.batch["advantages"] = advantages

verl/utils/reward_score/rlla.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,6 @@ def compute_score(solution_str, ground_truth, step=0):
306306
format_max_possible = 1.0
307307
format_min_possible = 0.0
308308

309-
length_max_possible = 1.0
310-
length_min_possible = 0.0
311-
312309
completions = [[{"role": "assistant", "content": predict_str}]]
313310
answer = [ground_truth]
314311

@@ -317,14 +314,11 @@ def compute_score(solution_str, ground_truth, step=0):
317314
completions, answer, step, tool_max_possible, tool_min_possible
318315
)[0]
319316

320-
if str(os.getenv("WITHLENGTH", 0)) == "1":
321-
print("WITHLENGTH is set to 1, so length score is set!")
322-
length_score = customize_length_reward_func(
323-
completions, answer, step, length_max_possible, length_min_possible
324-
)[0]
325-
else:
326-
length_score = 0
317+
score = fomrat_score + correctness_score
327318

328-
score = fomrat_score + correctness_score + length_score
319+
result = {
320+
"score": score,
321+
"score_list": [fomrat_score, correctness_score],
322+
}
329323

330-
return score, fomrat_score, correctness_score, length_score
324+
return result

0 commit comments

Comments
 (0)