2424import pytest
2525import torch
2626
27- HAS_CUDA = torch .cuda .is_available ()
27+ from verl .utils .device import get_device_name , get_torch_device
28+
2829PROCESS_TIMEOUT = 60
2930
31+ HAS_ACCELERATOR = get_device_name () != "cpu"
32+ HAS_CUDA = torch .cuda .is_available ()
33+
3034
3135def _unique_zmq_handle ():
3236 return f"ipc:///tmp/test-bwt-{ uuid .uuid4 ().hex } .sock"
3337
3438
3539def _generate_weights (weight_specs , seed ):
36- """Deterministically generate weights on CUDA from specs.
40+ """Deterministically generate weights on the best available device from specs.
3741
3842 Args:
3943 weight_specs: list of (name, shape, dtype) tuples
4044 seed: random seed for reproducibility
4145 Returns:
42- list of (name, tensor_on_cuda ) tuples
46+ list of (name, tensor_on_device ) tuples
4347 """
44- torch .cuda .manual_seed (seed )
48+ device_name = get_device_name ()
49+ device = torch .device (f"{ device_name } :0" )
50+ get_torch_device ().manual_seed (seed )
4551 weights = []
4652 for name , shape , dtype in weight_specs :
47- # Generate in float32 on CUDA then cast, since torch.randn doesn't support all dtypes
48- t = torch .randn (shape , dtype = torch .float32 , device = "cuda:0" ).to (dtype )
53+ # Generate in float32 then cast, since torch.randn doesn't support all dtypes
54+ t = torch .randn (shape , dtype = torch .float32 , device = device ).to (dtype )
4955 weights .append ((name , t ))
5056 return weights
5157
@@ -54,7 +60,7 @@ def _generate_weights(weight_specs, seed):
5460# Process entry points (must be module-level for pickling with spawn)
5561# ---------------------------------------------------------------------------
5662def _sender_fn (zmq_handle , weight_specs , seed , bucket_size_mb , use_shm ):
57- """Sender process: generate weights, move to GPU , send."""
63+ """Sender process: generate weights, move to device , send."""
5864 from verl .workers .rollout .bucketed_weight_transfer import BucketedWeightSender
5965
6066 weights = _generate_weights (weight_specs , seed )
@@ -68,9 +74,10 @@ def _sender_fn(zmq_handle, weight_specs, seed, bucket_size_mb, use_shm):
6874
6975def _receiver_fn (zmq_handle , use_shm , result_queue ):
7076 """Receiver process: receive weights, send back (name, dtype, shape, checksum)."""
77+ from verl .utils .device import get_device_name
7178 from verl .workers .rollout .bucketed_weight_transfer import BucketedWeightReceiver
7279
73- device = torch .device ("cuda :0" )
80+ device = torch .device (f" { get_device_name () } :0" )
7481 receiver = BucketedWeightReceiver (
7582 zmq_handle = zmq_handle ,
7683 device = device ,
@@ -114,7 +121,7 @@ def _transfer_and_validate(weight_specs, bucket_size_mb, use_shm):
114121
115122 summaries = result_queue .get (timeout = 5 )
116123
117- # Regenerate expected weights on CUDA with the same seed
124+ # Regenerate expected weights on device with the same seed
118125 expected = _generate_weights (weight_specs , seed )
119126
120127 assert len (summaries ) == len (expected ), f"Expected { len (expected )} weights, got { len (summaries )} "
@@ -136,7 +143,7 @@ def _transfer_and_validate(weight_specs, bucket_size_mb, use_shm):
136143# ---------------------------------------------------------------------------
137144# Shared memory tests
138145# ---------------------------------------------------------------------------
139- @pytest .mark .skipif (not HAS_CUDA , reason = "Requires CUDA" )
146+ @pytest .mark .skipif (not HAS_ACCELERATOR , reason = "Requires CUDA or NPU " )
140147class TestBucketedWeightTransferSHM :
141148 """Test BucketedWeightSender/Receiver via shared memory path."""
142149
@@ -170,9 +177,9 @@ def test_empty_weights(self):
170177
171178
172179# ---------------------------------------------------------------------------
173- # CUDA IPC tests
180+ # CUDA IPC tests (CUDA only — IPC is not supported on NPU)
174181# ---------------------------------------------------------------------------
175- @pytest .mark .skipif (not HAS_CUDA , reason = "Requires CUDA" )
182+ @pytest .mark .skipif (not HAS_CUDA , reason = "Requires CUDA (IPC not supported on NPU) " )
176183class TestBucketedWeightTransferIPC :
177184 """Test BucketedWeightSender/Receiver via CUDA IPC path."""
178185
0 commit comments