@@ -211,6 +211,58 @@ def compute_advantage(
211211 rollout_is_weights = data .batch .get ("rollout_is_weights" , None )
212212 adv_kwargs ["rollout_is_weights" ] = rollout_is_weights
213213
214+ if adv_estimator == AdvantageEstimator .GDPO :
215+ assert "score_list" in data .batch , (
216+ "GDPO need multi-scores. "
217+ "Please change the config: reward.custom_reward_function.path to point gdpo.py or "
218+ "change the reward function to compute multi-scores."
219+ )
220+
221+ # prompt_length = prompt_ids.size(1)
222+ # response_length = attention_mask[:, prompt_length:].sum(dim=1) - 1
223+ # rm_scores = torch.zeros_like(response_mask, dtype=torch.float32)
224+ # rm_scores[torch.arange(response_mask.size(0)), response_length] =
225+ # torch.tensor(scores, dtype=torch.float32)
226+ # batch["rm_scores"] = rm_scores
227+
228+ # batch = TensorDict(
229+ # {
230+ # "prompts": prompt_ids, # [bsz, prompt_length]
231+ # "responses": response_ids, # [bsz, response_length]
232+ # "response_mask": response_mask, # [bsz, response_length]
233+ # "input_ids": input_ids, # [bsz, prompt_length + response_length]
234+ # "attention_mask": attention_mask, # [bsz, prompt_length + response_length]
235+ # # position_ids: [bsz, 3, prompt_length + response_length]
236+ # or [bsz, prompt_length + response_length]
237+ # "position_ids": position_ids,
238+ # **optional_outputs,
239+ # },
240+ # batch_size=len(inputs),
241+ # )
242+ score_list = []
243+ multi_score_tensor = torch .tensor (
244+ data .non_tensor_batch ["score_list" ], dtype = torch .float32
245+ ) # # [bsz, score_num, 1]
246+ print (f"----------multi_score_tensor:{ multi_score_tensor .shape } " )
247+
248+ for i in range (multi_score_tensor .shape [1 ]):
249+ rm_score = multi_score_tensor [:, i ]
250+ prompt_length = data .batch ["prompts" ].size (1 )
251+ response_length = data .batch ["attention_mask" ][:, prompt_length :].sum (dim = 1 ) - 1
252+ rm_scores = torch .zeros_like (data .batch ["response_mask" ], dtype = torch .float32 )
253+ rm_scores [torch .arange (data .batch ["response_mask" ].size (0 )), response_length ] = torch .tensor (
254+ rm_score , dtype = torch .float32
255+ )
256+ score_list .append (rm_scores )
257+
258+ # sum_score_tensor = data.batch["token_level_rewards"]
259+
260+ # rm_scores[torch.arange(rm_scores.size(0)), valid_response_length - 1] = torch.tensor(
261+ # scores, dtype=torch.float32
262+ # )
263+ adv_kwargs ["score_list" ] = score_list
264+
265+ # np.array([[format_score,correct_score] for info in reward_extra_infos])
214266 # calculate advantage estimator
215267 advantages , returns = adv_estimator_fn (** adv_kwargs )
216268 data .batch ["advantages" ] = advantages
0 commit comments