Skip to content

Commit 5b5c527

Browse files
committed
move file
1 parent f46b033 commit 5b5c527

File tree

5 files changed

+6
-6
lines changed

5 files changed

+6
-6
lines changed

tests/utils/test_bucketed_weight_transfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _generate_weights(weight_specs, seed):
5555
# ---------------------------------------------------------------------------
5656
def _sender_fn(zmq_handle, weight_specs, seed, bucket_size_mb, use_shm):
5757
"""Sender process: generate weights, move to GPU, send."""
58-
from verl.workers.rollout.vllm_rollout.bucketed_weight_transfer import BucketedWeightSender
58+
from verl.workers.rollout.bucketed_weight_transfer import BucketedWeightSender
5959

6060
weights = _generate_weights(weight_specs, seed)
6161
sender = BucketedWeightSender(
@@ -68,7 +68,7 @@ def _sender_fn(zmq_handle, weight_specs, seed, bucket_size_mb, use_shm):
6868

6969
def _receiver_fn(zmq_handle, use_shm, result_queue):
7070
"""Receiver process: receive weights, send back (name, dtype, shape, checksum)."""
71-
from verl.workers.rollout.vllm_rollout.bucketed_weight_transfer import BucketedWeightReceiver
71+
from verl.workers.rollout.bucketed_weight_transfer import BucketedWeightReceiver
7272

7373
device = torch.device("cuda:0")
7474
receiver = BucketedWeightReceiver(

tests/utils/test_shared_memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch
2020

21-
from verl.workers.rollout.vllm_rollout.bucketed_weight_transfer import create_shared_memory, rebuild_shared_memory
21+
from verl.workers.rollout.bucketed_weight_transfer import create_shared_memory, rebuild_shared_memory
2222

2323

2424
class TestSharedMemory(unittest.TestCase):

verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py renamed to verl/workers/rollout/bucketed_weight_transfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _init_buffer(self):
169169

170170
# Create unique name for shared memory
171171
shm_name = f"verl_weights_{uuid.uuid4().hex}"
172-
shm = shared_memory.SharedMemory(name=shm_name, create=True, size=self.bucket_size)
172+
shm = create_shared_memory(self.bucket_size, shm_name)
173173
buffer = torch.frombuffer(shm.buf, dtype=torch.uint8)
174174

175175
comm_metadata = {"name": shm_name, "size": self.bucket_size}

verl/workers/rollout/vllm_rollout/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False
145145
"""Update the weights of the rollout model."""
146146
from vllm.platforms import current_platform
147147

148-
from verl.workers.rollout.vllm_rollout.bucketed_weight_transfer import BucketedWeightReceiver
148+
from verl.workers.rollout.bucketed_weight_transfer import BucketedWeightReceiver
149149

150150
if current_platform.device_type == "npu" and self.device is None:
151151
self.device = torch.device(f"npu:{self.local_rank}")

verl/workers/rollout/vllm_rollout/vllm_rollout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from verl.utils.device import get_device_id, is_support_ipc
4242
from verl.workers.config import HFModelConfig, RolloutConfig
4343
from verl.workers.rollout.base import BaseRollout
44-
from verl.workers.rollout.vllm_rollout.bucketed_weight_transfer import BucketedWeightSender
44+
from verl.workers.rollout.bucketed_weight_transfer import BucketedWeightSender
4545
from verl.workers.rollout.vllm_rollout.utils import get_device_uuid
4646

4747
logger = logging.getLogger(__file__)

0 commit comments

Comments
 (0)