2020import numpy as np
2121import ray
2222import torch
23- import torch .nn .functional as F
24- from omegaconf import DictConfig , open_dict
23+ from omegaconf import DictConfig
2524from tensordict import TensorDict
2625
2726from verl .protocol import DataProto
2827from verl .single_controller .ray .base import RayResourcePool
29- from verl .trainer .ppo .reward import load_reward_manager
30- from verl .utils import hf_tokenizer
31- from verl .utils .fs import copy_to_local
28+ from verl .trainer .distillation .losses import DistillationLossSettings , get_distillation_loss_settings
3229from verl .utils .config import omega_conf_to_dataclass
33-
34- from .teacher_model import TeacherModelManager
3530from verl .workers .config import DistillationConfig , DistillationLossConfig
3631
37- from verl . trainer . distillation . losses import get_distillation_loss_settings , DistillationLossSettings
32+ from . teacher_model import TeacherModelManager
3833
3934logger = logging .getLogger (__file__ )
4035logger .setLevel (os .getenv ("VERL_LOGGING_LEVEL" , "WARN" ))
@@ -52,7 +47,9 @@ def __init__(self, config: DictConfig, teacher_router_address: str = None):
5247 self .config = config
5348 self .distillation_config : DistillationConfig = self .config .distillation
5449 self .distillation_loss_config : DistillationLossConfig = self .distillation_config .distillation_loss
55- self .distillation_loss_settings : DistillationLossSettings = get_distillation_loss_settings (self .distillation_loss_config .loss_mode )
50+ self .distillation_loss_settings : DistillationLossSettings = get_distillation_loss_settings (
51+ self .distillation_loss_config .loss_mode
52+ )
5653 self .teacher_router_address = teacher_router_address
5754 # # Serialize teacher requests per actor to reduce pressure on the teacher vLLM router/backend.
5855 # self._request_semaphore = asyncio.Semaphore(1)
@@ -112,16 +109,16 @@ async def _post_request(self, payload: dict, endpoint: str, max_retries: int = 1
112109 raise last_exception
113110
114111 async def _compute_logprobs (self , data : DataProto ) -> dict :
115- prompt_ids = data .batch [' prompt_ids' ]
116- response_ids = data .batch [' response_ids' ]
112+ prompt_ids = data .batch [" prompt_ids" ]
113+ response_ids = data .batch [" response_ids" ]
117114 input_ids = torch .cat ([prompt_ids , response_ids ], dim = 1 ).squeeze (0 ).tolist ()
118115 engine_name = self .config .distillation .teacher_model .inference .name
119116 model_name = self .config .distillation .teacher_model .model_path
120117 if engine_name == "vllm" :
121118 if self .distillation_loss_settings .use_topk :
122119 num_logprobs = topk = self .distillation_loss_config .topk
123120 else :
124- num_logprobs = 0 # only the sampled logprob
121+ num_logprobs = 0 # only the sampled logprob
125122 payloads = {
126123 "model" : model_name ,
127124 "prompt" : input_ids ,
@@ -140,7 +137,7 @@ async def _compute_logprobs(self, data: DataProto) -> dict:
140137 for logprobs_dict in response_logprob_dicts :
141138 if num_logprobs == 0 :
142139 token_id_str = list (logprobs_dict .keys ())[0 ]
143- logprob = logprobs_dict [token_id_str ][' logprob' ]
140+ logprob = logprobs_dict [token_id_str ][" logprob" ]
144141 response_logprobs_ls .append ([logprob ])
145142 response_ids_ls .append ([int (token_id_str )])
146143 else :
@@ -149,18 +146,22 @@ async def _compute_logprobs(self, data: DataProto) -> dict:
149146 # We get either top-k logprobs or top-k plus the sampled logprob (if sampled token is not in top-k)
150147 assert len (logprobs_dict ) in [topk , topk + 1 ], len (logprobs_dict )
151148 for token_id_str , token_dict in logprobs_dict .items ():
152- if token_dict [' rank' ] > topk :
153- continue # the sampled token is not in the top-k
154- rank = token_dict [' rank' ]
155- logprob = token_dict [' logprob' ]
149+ if token_dict [" rank" ] > topk :
150+ continue # the sampled token is not in the top-k
151+ rank = token_dict [" rank" ]
152+ logprob = token_dict [" logprob" ]
156153 response_ids [rank - 1 ] = int (token_id_str )
157154 response_logprobs [rank - 1 ] = logprob
158155 response_logprobs_ls .append (response_logprobs )
159156 response_ids_ls .append (response_ids )
160- logprobs_dtype = torch .bfloat16 if self .distillation_config .teacher_model .inference .dtype == "bfloat16" else torch .float32
157+ logprobs_dtype = (
158+ torch .bfloat16
159+ if self .distillation_config .teacher_model .inference .dtype == "bfloat16"
160+ else torch .float32
161+ )
161162 response_logprobs = torch .tensor (response_logprobs_ls , dtype = logprobs_dtype ).unsqueeze (0 )
162163 response_ids = torch .tensor (response_ids_ls , dtype = torch .long ).unsqueeze (0 )
163-
164+
164165 elif engine_name == "sglang" :
165166 raise ValueError ("SGLang backend does not support distillation currently." )
166167 payloads = {
@@ -200,7 +201,9 @@ class TeacherLoopManager:
200201
201202 def __init__ (self , config : DictConfig , teacher_resource_pool : RayResourcePool = None ):
202203 self .config = config
203- self .distillation_config : DistillationConfig = omega_conf_to_dataclass (self .config .distillation ) # to dataclass for the post init to handle top-k and engine kwargs
204+ self .distillation_config : DistillationConfig = omega_conf_to_dataclass (
205+ self .config .distillation
206+ ) # to dataclass for the post init to handle top-k and engine kwargs
204207 self .teacher_model_manager = TeacherModelManager (self .distillation_config .teacher_model , teacher_resource_pool )
205208 self .teacher_router_address = self .teacher_model_manager .get_router_address ()
206209
@@ -266,6 +269,7 @@ def compute_teacher_logprobs(self, data: DataProto) -> DataProto:
266269
267270 def _run_all (self , tasks : list [asyncio .Task ]):
268271 raise NotImplementedError ("TODO:RM" )
272+
269273 async def run_all ():
270274 return await asyncio .gather (* tasks )
271275
0 commit comments