Skip to content

Commit 81050ce

Browse files
committed
Fix bugs
1 parent e193d0d commit 81050ce

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

verl/workers/rollout/trtllm_rollout/trtllm_async_server.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ async def launch_server(self):
184184

185185
async def generate(
186186
self,
187-
prompt_ids: str,
187+
prompt_ids: Union[str, list[int]],
188188
sampling_params: dict[str, Any],
189189
request_id: str,
190190
image_data: Optional[list[Any]] = None,
@@ -200,9 +200,8 @@ async def generate(
200200
sampling_params.update(self.sampling_args)
201201

202202
trt_llm_sampling_params = SamplingParams(**sampling_params)
203-
if self.is_vlm_model:
204-
org_prompt = self.llm.tokenizer.decode(prompt_ids)
205-
if image_data or video_data:
203+
if self.is_vlm_model and (image_data or video_data):
204+
org_prompt = self.llm.tokenizer.decode(prompt_ids)
206205
input_dict = {
207206
"prompt": org_prompt,
208207
"multi_modal_data": {},
@@ -217,11 +216,6 @@ async def generate(
217216
inputs=input_dict,
218217
sampling_params=trt_llm_sampling_params,
219218
)
220-
else:
221-
outputs = await self.llm.generate_async(
222-
inputs=prompt_ids,
223-
sampling_params=trt_llm_sampling_params,
224-
)
225219
else:
226220
outputs = await self.llm.generate_async(
227221
inputs=prompt_ids,
@@ -230,7 +224,8 @@ async def generate(
230224
token_ids = outputs.outputs[0].token_ids
231225
log_probs = None
232226
if outputs.outputs[0].logprobs is not None:
233-
log_probs = [logprobs[token_ids[i]].logprob for i, logprobs in enumerate(outputs.outputs[0].logprobs)]
227+
# When logprobs=1, TRT-LLM returns only the sampled token's logprob at each position
228+
log_probs = [list(d.values())[0].logprob for d in outputs.outputs[0].logprobs]
234229
return TokenOutput(token_ids=token_ids, log_probs=log_probs)
235230

236231
async def wake_up(self):

0 commit comments

Comments
 (0)