Skip to content

Commit 4adb81b

Browse files
committed
clean up
1 parent f3c6ce5 commit 4adb81b

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

tests/utils/test_bucketed_weight_transfer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828

2929
PROCESS_TIMEOUT = 60
3030

31+
# Use string checks to avoid initializing CUDA in the main pytest process,
32+
# which would make subsequent fork-based multiprocessing in other tests unsafe.
3133
HAS_ACCELERATOR = get_device_name() != "cpu"
32-
HAS_CUDA = torch.cuda.is_available()
34+
HAS_CUDA = "cuda" in get_device_name()
3335

3436

3537
def _unique_zmq_handle():

verl/workers/rollout/bucketed_weight_transfer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def create_shared_memory(size: int, name: str):
5858
shm = shared_memory.SharedMemory(name=name, create=True, size=size)
5959
except FileExistsError:
6060
shm = shared_memory.SharedMemory(name=name)
61+
assert shm.size >= size, f"Stale shm segment '{name}': expected {size} bytes, got {shm.size}"
6162
return shm
6263

6364

@@ -286,13 +287,15 @@ def _cleanup(self):
286287
if self.socket is not None:
287288
self.socket.close()
288289
self.socket = None
290+
# Synchronize before releasing the buffer to ensure all async ops
291+
# referencing it (e.g. clone, .to()) have completed.
292+
get_torch_device().synchronize()
289293
del self.buffer
290294
self.buffer = None
291295
if self.shm is not None:
292296
self.shm.close()
293297
del self.shm
294298
self.shm = None
295-
get_torch_device().synchronize()
296299
gc.collect()
297300
get_torch_device().ipc_collect()
298301
get_torch_device().empty_cache()

0 commit comments

Comments
 (0)