@@ -184,7 +184,7 @@ async def launch_server(self):
184184
185185 async def generate (
186186 self ,
187- prompt_ids : str ,
187+ prompt_ids : Union [ str , list [ int ]] ,
188188 sampling_params : dict [str , Any ],
189189 request_id : str ,
190190 image_data : Optional [list [Any ]] = None ,
@@ -200,9 +200,8 @@ async def generate(
200200 sampling_params .update (self .sampling_args )
201201
202202 trt_llm_sampling_params = SamplingParams (** sampling_params )
203- if self .is_vlm_model :
204- org_prompt = self .llm .tokenizer .decode (prompt_ids )
205- if image_data or video_data :
203+ if self .is_vlm_model and (image_data or video_data ):
204+ org_prompt = self .llm .tokenizer .decode (prompt_ids )
206205 input_dict = {
207206 "prompt" : org_prompt ,
208207 "multi_modal_data" : {},
@@ -217,11 +216,6 @@ async def generate(
217216 inputs = input_dict ,
218217 sampling_params = trt_llm_sampling_params ,
219218 )
220- else :
221- outputs = await self .llm .generate_async (
222- inputs = prompt_ids ,
223- sampling_params = trt_llm_sampling_params ,
224- )
225219 else :
226220 outputs = await self .llm .generate_async (
227221 inputs = prompt_ids ,
@@ -230,7 +224,8 @@ async def generate(
230224 token_ids = outputs .outputs [0 ].token_ids
231225 log_probs = None
232226 if outputs .outputs [0 ].logprobs is not None :
233- log_probs = [logprobs [token_ids [i ]].logprob for i , logprobs in enumerate (outputs .outputs [0 ].logprobs )]
227+ # When logprobs=1, TRT-LLM returns only the sampled token's logprob at each position
228+ log_probs = [list (d .values ())[0 ].logprob for d in outputs .outputs [0 ].logprobs ]
234229 return TokenOutput (token_ids = token_ids , log_probs = log_probs )
235230
236231 async def wake_up (self ):
0 commit comments