Skip to content

Commit 5383cb2

Browse files
committed
Ruff
1 parent 99e294f commit 5383cb2

File tree

15 files changed

+71
-55
lines changed

15 files changed

+71
-55
lines changed

verl/experimental/agent_loop/agent_loop.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,9 @@ async def generate_sequences(self, batch: DataProto) -> DataProto:
478478
)
479479
outputs = await asyncio.gather(*tasks)
480480

481-
output = self._postprocess(outputs, input_non_tensor_batch=batch.non_tensor_batch, validate=batch.meta_info.get("validate", False))
481+
output = self._postprocess(
482+
outputs, input_non_tensor_batch=batch.non_tensor_batch, validate=batch.meta_info.get("validate", False)
483+
)
482484
return output
483485

484486
async def _run_agent_loop(
@@ -736,10 +738,18 @@ async def _compute_score(self, output, prompts, responses, attention_mask, input
736738
async def _compute_teacher_logprobs(self, output, prompt_ids, response_ids, validate):
737739
"""Compute teacher logprobs for single sample."""
738740
if self.distillation_enabled and not validate:
739-
data = DataProto(batch=TensorDict({"prompt_ids": torch.tensor([prompt_ids]), "response_ids": torch.tensor([response_ids])}, batch_size=1))
741+
data = DataProto(
742+
batch=TensorDict(
743+
{"prompt_ids": torch.tensor([prompt_ids]), "response_ids": torch.tensor([response_ids])},
744+
batch_size=1,
745+
)
746+
)
740747
selected_teacher_loop_worker_handle = random.choice(self.teacher_loop_worker_handles)
741748
result = await selected_teacher_loop_worker_handle.compute_logprobs.remote(data)
742-
response_ids, response_logprobs = result["response_ids"], result["response_logprobs"] # (1, S, K), S=sequence length, K=topk/1
749+
response_ids, response_logprobs = (
750+
result["response_ids"],
751+
result["response_logprobs"],
752+
) # (1, S, K), S=sequence length, K=topk/1
743753

744754
pad_size = self.config.actor_rollout_ref.rollout.response_length - response_ids.shape[1]
745755
padding = (0, 0, 0, pad_size) # pad the sequence dimension
@@ -976,7 +986,9 @@ def _init_agent_loop_workers(self):
976986
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
977987
node_id=node_id, soft=True
978988
),
979-
).remote(self.config, self.server_handles, self.reward_loop_worker_handles, self.teacher_loop_worker_handles)
989+
).remote(
990+
self.config, self.server_handles, self.reward_loop_worker_handles, self.teacher_loop_worker_handles
991+
)
980992
)
981993

982994
def generate_sequences(self, prompts: DataProto) -> DataProto:

verl/experimental/teacher_loop/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .teacher_loop import TeacherLoopManager, TeacherLoopWorker
16-
from .teacher_loop import TeacherModelManager
15+
from .teacher_loop import TeacherLoopManager, TeacherLoopWorker, TeacherModelManager
1716

1817
__all__ = ["TeacherModelManager", "TeacherLoopWorker", "TeacherLoopManager"]

verl/experimental/teacher_loop/teacher_loop.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,16 @@
2020
import numpy as np
2121
import ray
2222
import torch
23-
import torch.nn.functional as F
24-
from omegaconf import DictConfig, open_dict
23+
from omegaconf import DictConfig
2524
from tensordict import TensorDict
2625

2726
from verl.protocol import DataProto
2827
from verl.single_controller.ray.base import RayResourcePool
29-
from verl.trainer.ppo.reward import load_reward_manager
30-
from verl.utils import hf_tokenizer
31-
from verl.utils.fs import copy_to_local
28+
from verl.trainer.distillation.losses import DistillationLossSettings, get_distillation_loss_settings
3229
from verl.utils.config import omega_conf_to_dataclass
33-
34-
from .teacher_model import TeacherModelManager
3530
from verl.workers.config import DistillationConfig, DistillationLossConfig
3631

37-
from verl.trainer.distillation.losses import get_distillation_loss_settings, DistillationLossSettings
32+
from .teacher_model import TeacherModelManager
3833

3934
logger = logging.getLogger(__file__)
4035
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@@ -52,7 +47,9 @@ def __init__(self, config: DictConfig, teacher_router_address: str = None):
5247
self.config = config
5348
self.distillation_config: DistillationConfig = self.config.distillation
5449
self.distillation_loss_config: DistillationLossConfig = self.distillation_config.distillation_loss
55-
self.distillation_loss_settings: DistillationLossSettings = get_distillation_loss_settings(self.distillation_loss_config.loss_mode)
50+
self.distillation_loss_settings: DistillationLossSettings = get_distillation_loss_settings(
51+
self.distillation_loss_config.loss_mode
52+
)
5653
self.teacher_router_address = teacher_router_address
5754
# # Serialize teacher requests per actor to reduce pressure on the teacher vLLM router/backend.
5855
# self._request_semaphore = asyncio.Semaphore(1)
@@ -112,16 +109,16 @@ async def _post_request(self, payload: dict, endpoint: str, max_retries: int = 1
112109
raise last_exception
113110

114111
async def _compute_logprobs(self, data: DataProto) -> dict:
115-
prompt_ids = data.batch['prompt_ids']
116-
response_ids = data.batch['response_ids']
112+
prompt_ids = data.batch["prompt_ids"]
113+
response_ids = data.batch["response_ids"]
117114
input_ids = torch.cat([prompt_ids, response_ids], dim=1).squeeze(0).tolist()
118115
engine_name = self.config.distillation.teacher_model.inference.name
119116
model_name = self.config.distillation.teacher_model.model_path
120117
if engine_name == "vllm":
121118
if self.distillation_loss_settings.use_topk:
122119
num_logprobs = topk = self.distillation_loss_config.topk
123120
else:
124-
num_logprobs = 0 # only the sampled logprob
121+
num_logprobs = 0 # only the sampled logprob
125122
payloads = {
126123
"model": model_name,
127124
"prompt": input_ids,
@@ -140,7 +137,7 @@ async def _compute_logprobs(self, data: DataProto) -> dict:
140137
for logprobs_dict in response_logprob_dicts:
141138
if num_logprobs == 0:
142139
token_id_str = list(logprobs_dict.keys())[0]
143-
logprob = logprobs_dict[token_id_str]['logprob']
140+
logprob = logprobs_dict[token_id_str]["logprob"]
144141
response_logprobs_ls.append([logprob])
145142
response_ids_ls.append([int(token_id_str)])
146143
else:
@@ -149,18 +146,22 @@ async def _compute_logprobs(self, data: DataProto) -> dict:
149146
# We get either top-k logprobs or top-k plus the sampled logprob (if sampled token is not in top-k)
150147
assert len(logprobs_dict) in [topk, topk + 1], len(logprobs_dict)
151148
for token_id_str, token_dict in logprobs_dict.items():
152-
if token_dict['rank'] > topk:
153-
continue # the sampled token is not in the top-k
154-
rank = token_dict['rank']
155-
logprob = token_dict['logprob']
149+
if token_dict["rank"] > topk:
150+
continue # the sampled token is not in the top-k
151+
rank = token_dict["rank"]
152+
logprob = token_dict["logprob"]
156153
response_ids[rank - 1] = int(token_id_str)
157154
response_logprobs[rank - 1] = logprob
158155
response_logprobs_ls.append(response_logprobs)
159156
response_ids_ls.append(response_ids)
160-
logprobs_dtype = torch.bfloat16 if self.distillation_config.teacher_model.inference.dtype == "bfloat16" else torch.float32
157+
logprobs_dtype = (
158+
torch.bfloat16
159+
if self.distillation_config.teacher_model.inference.dtype == "bfloat16"
160+
else torch.float32
161+
)
161162
response_logprobs = torch.tensor(response_logprobs_ls, dtype=logprobs_dtype).unsqueeze(0)
162163
response_ids = torch.tensor(response_ids_ls, dtype=torch.long).unsqueeze(0)
163-
164+
164165
elif engine_name == "sglang":
165166
raise ValueError("SGLang backend does not support distillation currently.")
166167
payloads = {
@@ -200,7 +201,9 @@ class TeacherLoopManager:
200201

201202
def __init__(self, config: DictConfig, teacher_resource_pool: RayResourcePool = None):
202203
self.config = config
203-
self.distillation_config: DistillationConfig = omega_conf_to_dataclass(self.config.distillation) # to dataclass for the post init to handle top-k and engine kwargs
204+
self.distillation_config: DistillationConfig = omega_conf_to_dataclass(
205+
self.config.distillation
206+
) # to dataclass for the post init to handle top-k and engine kwargs
204207
self.teacher_model_manager = TeacherModelManager(self.distillation_config.teacher_model, teacher_resource_pool)
205208
self.teacher_router_address = self.teacher_model_manager.get_router_address()
206209

@@ -266,6 +269,7 @@ def compute_teacher_logprobs(self, data: DataProto) -> DataProto:
266269

267270
def _run_all(self, tasks: list[asyncio.Task]):
268271
raise NotImplementedError("TODO:RM")
272+
269273
async def run_all():
270274
return await asyncio.gather(*tasks)
271275

verl/experimental/teacher_loop/teacher_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818

1919
from verl.single_controller.ray.base import RayResourcePool, split_resource_pool
20-
from verl.workers.config import HFModelConfig, DistillationTeacherModelConfig
20+
from verl.workers.config import DistillationTeacherModelConfig, HFModelConfig
2121
from verl.workers.rollout.replica import get_rollout_replica_class
2222

2323
logger = logging.getLogger(__file__)
@@ -91,6 +91,7 @@ def _initialize_router(self):
9191
worker_urls = [f"http://{server_address}" for server_address in self.server_addresses]
9292

9393
from ..reward_loop.router.naive_router import launch_router_process
94+
9495
self.router_address, _ = launch_router_process(worker_urls=worker_urls)
9596

9697
def get_router_address(self):

verl/trainer/distillation/fsdp/losses.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def kl_divergence(log_q: torch.Tensor, log_p: torch.Tensor) -> torch.Tensor:
2727
kld = p * (log_p - log_q)
2828
return kld.sum(dim=-1)
2929

30+
3031
def compute_forward_kl_topk(
3132
student_logits: torch.Tensor,
3233
teacher_topk_log_probs: torch.Tensor,

verl/trainer/distillation/losses.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from verl.trainer.distillation.types import DistillationLossInputs
2323
from verl.trainer.ppo.core_algos import agg_loss, kl_penalty
2424
from verl.utils.metric import AggregationType, Metric
25-
from verl.workers.config import DistillationConfig, DistillationLossConfig, ActorConfig
25+
from verl.workers.config import ActorConfig, DistillationConfig, DistillationLossConfig
2626

2727
DistillationLossFn = Callable[
2828
[
@@ -56,8 +56,7 @@ def __post_init__(self):
5656
self.names = [self.names] if isinstance(self.names, str) else self.names
5757
if sum([self.use_topk, self.use_estimator]) > 1:
5858
raise ValueError(
59-
f"Expected only one of use_estimator, use_topk, but got "
60-
f"{self.use_estimator=}, {self.use_topk=}."
59+
f"Expected only one of use_estimator, use_topk, but got {self.use_estimator=}, {self.use_topk=}."
6160
)
6261

6362

@@ -206,7 +205,11 @@ def compute_forward_kl_topk(
206205
teacher_topk_ids=teacher_topk_ids,
207206
config=distillation_config,
208207
)
209-
distillation_losses, student_mass, teacher_mass = outputs["distillation_losses"], outputs["student_mass"], outputs["teacher_mass"]
208+
distillation_losses, student_mass, teacher_mass = (
209+
outputs["distillation_losses"],
210+
outputs["student_mass"],
211+
outputs["teacher_mass"],
212+
)
210213

211214
# Log amount of mass in the top-k log probabilities for both student and teacher.
212215
student_mass = student_mass[response_mask]

verl/trainer/main_ppo.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def init_resource_pool_mgr(self, config):
242242
if distillation_config.teacher_model.nnodes <= 0:
243243
raise ValueError("config.distillation.teacher_model.nnodes must be greater than 0")
244244

245-
teacher_pool = [distillation_config.teacher_model.n_gpus_per_node] * distillation_config.teacher_model.nnodes
245+
teacher_pool = [
246+
distillation_config.teacher_model.n_gpus_per_node
247+
] * distillation_config.teacher_model.nnodes
246248
resource_pool_spec["teacher_pool"] = teacher_pool
247249

248250
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
@@ -274,7 +276,6 @@ def add_teacher_model_resource_pool(self, config):
274276
else:
275277
self.mapping[Role.TeacherModel] = "global_pool"
276278

277-
278279
def add_ref_policy_worker(self, config, ref_policy_cls):
279280
"""Add reference policy worker if KL loss or KL reward is used."""
280281
from verl.trainer.ppo.ray_trainer import Role
@@ -462,4 +463,4 @@ def create_rl_sampler(data_config, dataset):
462463

463464

464465
if __name__ == "__main__":
465-
main()
466+
main()

verl/trainer/ppo/ray_trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
3939
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
4040
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup, ResourcePoolManager
41-
from verl.single_controller.ray.base import create_colocated_worker_cls, split_resource_pool
41+
from verl.single_controller.ray.base import create_colocated_worker_cls
4242
from verl.trainer.config import AlgoConfig
4343
from verl.trainer.ppo import core_algos
4444
from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
@@ -54,9 +54,9 @@
5454
Role,
5555
WorkerType,
5656
need_critic,
57-
need_teacher_policy,
5857
need_reference_policy,
5958
need_reward_model,
59+
need_teacher_policy,
6060
)
6161
from verl.utils import tensordict_utils as tu
6262
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
@@ -744,7 +744,6 @@ def init_workers(self):
744744
)
745745
self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls
746746

747-
748747
# initialize WorkerGroup
749748
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
750749
# you should not use `create_colocated_worker_cls`.
@@ -1150,7 +1149,6 @@ def _compute_ref_log_prob(self, batch: DataProto) -> DataProto:
11501149

11511150
return ref_log_prob
11521151

1153-
11541152
def _compute_old_log_prob(self, batch: DataProto):
11551153
if self.use_legacy_worker_impl == "disable":
11561154
# TODO: remove step 1, 2, 4 after we make the whole training tensordict and padding free

verl/utils/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,4 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
199199

200200
get_vllm_max_lora_rank(lora_rank)
201201

202-
print("[validate_config] All configuration checks passed successfully!")
202+
print("[validate_config] All configuration checks passed successfully!")

verl/utils/stages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ class Stage(Enum):
2020

2121
OLD_LOG_PROB = "old_log_prob"
2222
REF_LOG_PROB = "ref_log_prob"
23-
ACTOR_UPDATE = "actor_update"
23+
ACTOR_UPDATE = "actor_update"

0 commit comments

Comments
 (0)