Skip to content

Commit 190047d

Browse files
committed
update test
1 parent e41f903 commit 190047d

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

tests/utils/test_bucketed_weight_transfer.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,28 +24,34 @@
2424
import pytest
2525
import torch
2626

27-
HAS_CUDA = torch.cuda.is_available()
27+
from verl.utils.device import get_device_name, get_torch_device
28+
2829
PROCESS_TIMEOUT = 60
2930

31+
HAS_ACCELERATOR = get_device_name() != "cpu"
32+
HAS_CUDA = torch.cuda.is_available()
33+
3034

3135
def _unique_zmq_handle():
3236
return f"ipc:///tmp/test-bwt-{uuid.uuid4().hex}.sock"
3337

3438

3539
def _generate_weights(weight_specs, seed):
36-
"""Deterministically generate weights on CUDA from specs.
40+
"""Deterministically generate weights on the best available device from specs.
3741
3842
Args:
3943
weight_specs: list of (name, shape, dtype) tuples
4044
seed: random seed for reproducibility
4145
Returns:
42-
list of (name, tensor_on_cuda) tuples
46+
list of (name, tensor_on_device) tuples
4347
"""
44-
torch.cuda.manual_seed(seed)
48+
device_name = get_device_name()
49+
device = torch.device(f"{device_name}:0")
50+
get_torch_device().manual_seed(seed)
4551
weights = []
4652
for name, shape, dtype in weight_specs:
47-
# Generate in float32 on CUDA then cast, since torch.randn doesn't support all dtypes
48-
t = torch.randn(shape, dtype=torch.float32, device="cuda:0").to(dtype)
53+
# Generate in float32 then cast, since torch.randn doesn't support all dtypes
54+
t = torch.randn(shape, dtype=torch.float32, device=device).to(dtype)
4955
weights.append((name, t))
5056
return weights
5157

@@ -54,7 +60,7 @@ def _generate_weights(weight_specs, seed):
5460
# Process entry points (must be module-level for pickling with spawn)
5561
# ---------------------------------------------------------------------------
5662
def _sender_fn(zmq_handle, weight_specs, seed, bucket_size_mb, use_shm):
57-
"""Sender process: generate weights, move to GPU, send."""
63+
"""Sender process: generate weights, move to device, send."""
5864
from verl.workers.rollout.bucketed_weight_transfer import BucketedWeightSender
5965

6066
weights = _generate_weights(weight_specs, seed)
@@ -68,9 +74,10 @@ def _sender_fn(zmq_handle, weight_specs, seed, bucket_size_mb, use_shm):
6874

6975
def _receiver_fn(zmq_handle, use_shm, result_queue):
7076
"""Receiver process: receive weights, send back (name, dtype, shape, checksum)."""
77+
from verl.utils.device import get_device_name
7178
from verl.workers.rollout.bucketed_weight_transfer import BucketedWeightReceiver
7279

73-
device = torch.device("cuda:0")
80+
device = torch.device(f"{get_device_name()}:0")
7481
receiver = BucketedWeightReceiver(
7582
zmq_handle=zmq_handle,
7683
device=device,
@@ -114,7 +121,7 @@ def _transfer_and_validate(weight_specs, bucket_size_mb, use_shm):
114121

115122
summaries = result_queue.get(timeout=5)
116123

117-
# Regenerate expected weights on CUDA with the same seed
124+
# Regenerate expected weights on device with the same seed
118125
expected = _generate_weights(weight_specs, seed)
119126

120127
assert len(summaries) == len(expected), f"Expected {len(expected)} weights, got {len(summaries)}"
@@ -136,7 +143,7 @@ def _transfer_and_validate(weight_specs, bucket_size_mb, use_shm):
136143
# ---------------------------------------------------------------------------
137144
# Shared memory tests
138145
# ---------------------------------------------------------------------------
139-
@pytest.mark.skipif(not HAS_CUDA, reason="Requires CUDA")
146+
@pytest.mark.skipif(not HAS_ACCELERATOR, reason="Requires CUDA or NPU")
140147
class TestBucketedWeightTransferSHM:
141148
"""Test BucketedWeightSender/Receiver via shared memory path."""
142149

@@ -170,9 +177,9 @@ def test_empty_weights(self):
170177

171178

172179
# ---------------------------------------------------------------------------
173-
# CUDA IPC tests
180+
# CUDA IPC tests (CUDA only — IPC is not supported on NPU)
174181
# ---------------------------------------------------------------------------
175-
@pytest.mark.skipif(not HAS_CUDA, reason="Requires CUDA")
182+
@pytest.mark.skipif(not HAS_CUDA, reason="Requires CUDA (IPC not supported on NPU)")
176183
class TestBucketedWeightTransferIPC:
177184
"""Test BucketedWeightSender/Receiver via CUDA IPC path."""
178185

0 commit comments

Comments
 (0)