Skip to content

Commit 1ee1fef

Browse files
committed
better trtllm unittest ray.init and shutdown; add rollout config doc
1 parent afbcfea commit 1ee1fef

File tree

4 files changed

+39
-6
lines changed

4 files changed

+39
-6
lines changed

tests/workers/rollout/rollout_trtllm/test_adapter.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,17 @@ def test_init_without_device_mesh(self):
142142

143143
try:
144144
os.environ.setdefault("TLLM_RAY_FORCE_LOCAL_CLUSTER", "1")
145-
ray.init(address="local", ignore_reinit_error=True, include_dashboard=False)
145+
ray.init(
146+
runtime_env={
147+
"env_vars": {
148+
"TOKENIZERS_PARALLELISM": "true",
149+
"NCCL_DEBUG": "WARN",
150+
"VLLM_LOGGING_LEVEL": "INFO",
151+
"VLLM_USE_V1": "1",
152+
}
153+
},
154+
ignore_reinit_error=True,
155+
)
146156

147157
config_dir = os.path.abspath("verl/verl/trainer/config")
148158
if not os.path.exists(config_dir):
@@ -187,5 +197,5 @@ def test_init_without_device_mesh(self):
187197
os.environ.pop("RANK", None)
188198
else:
189199
os.environ["RANK"] = prev_rank
200+
print("\nShutting down Ray...")
190201
ray.shutdown()
191-
subprocess.run(["ray", "stop"], capture_output=True)

tests/workers/rollout/rollout_trtllm/test_async_server.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,17 @@ def test_async_generate(self):
170170
"""Test TRT-LLM generate method with real model."""
171171
try:
172172
os.environ.setdefault("TLLM_RAY_FORCE_LOCAL_CLUSTER", "1")
173-
ray.init(address="local", ignore_reinit_error=True, include_dashboard=False)
173+
ray.init(
174+
runtime_env={
175+
"env_vars": {
176+
"TOKENIZERS_PARALLELISM": "true",
177+
"NCCL_DEBUG": "WARN",
178+
"VLLM_LOGGING_LEVEL": "INFO",
179+
"VLLM_USE_V1": "1",
180+
}
181+
},
182+
ignore_reinit_error=True,
183+
)
174184

175185
rollout_config, model_config = self._build_rollout_config(response_length=50)
176186

@@ -209,14 +219,24 @@ def test_async_generate(self):
209219
print(f"Log probs: {result.log_probs[:10]}...") # Print first 10 log probs
210220

211221
finally:
222+
print("\nShutting down Ray...")
212223
ray.shutdown()
213-
subprocess.run(["ray", "stop"], capture_output=True)
214224

215225
def test_async_memory_management(self):
216226
"""Test TRT-LLM async memory management (sleep) reduces memory usage."""
217227
try:
218228
os.environ.setdefault("TLLM_RAY_FORCE_LOCAL_CLUSTER", "1")
219-
ray.init(address="local", ignore_reinit_error=True, include_dashboard=False)
229+
ray.init(
230+
runtime_env={
231+
"env_vars": {
232+
"TOKENIZERS_PARALLELISM": "true",
233+
"NCCL_DEBUG": "WARN",
234+
"VLLM_LOGGING_LEVEL": "INFO",
235+
"VLLM_USE_V1": "1",
236+
}
237+
},
238+
ignore_reinit_error=True,
239+
)
220240

221241
rollout_config, model_config = self._build_rollout_config(free_cache_engine=True)
222242

@@ -271,5 +291,5 @@ def get_gpu_memory_mb_for_device(device_uuid: str) -> float:
271291
)
272292

273293
finally:
294+
print("\nShutting down Ray...")
274295
ray.shutdown()
275-
subprocess.run(["ray", "stop"], capture_output=True)

verl/trainer/config/rollout/rollout.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ profiler:
313313
# choices: npu, torch
314314
tool: ${oc.select:global_profiler.tool,null}
315315

316+
# global tool config
316317
global_tool_config: ${oc.select:global_profiler.global_tool_config,null}
317318

318319
# whether enable profile on rollout

verl/workers/rollout/trtllm_rollout/trtllm_async_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ async def launch_server(self):
178178
self.llm = await AsyncLLM(**llm_kwargs)
179179

180180
trtllm_server = OpenAIServer(
181+
# TODO: update to generator in future
182+
# generator=self.llm,
181183
llm=self.llm,
182184
model=self.model_config.local_path,
183185
tool_parser=None,

0 commit comments

Comments
 (0)