@@ -211,7 +211,9 @@ async def set_max_required_samples(self):
211211 / (self .required_samples * self .config .async_training .trigger_parameter_sync_step )
212212 )
213213
214- self .max_concurrent_samples = len (self .async_rollout_manager .server_handles ) * self .config .rollout .get ("max_concurrent_samples_per_replica" , 16 )
214+ self .max_concurrent_samples = len (self .async_rollout_manager .server_handles ) * self .config .rollout .get (
215+ "max_concurrent_samples_per_replica" , 16
216+ )
215217 self .max_concurrent_samples = min (self .max_concurrent_samples , self .max_required_samples )
216218 self .max_queue_size = self .max_required_samples
217219
@@ -548,9 +550,7 @@ async def _processor_worker(self):
548550 tasks_to_wait = set (self .active_tasks ) if self .active_tasks else set ()
549551
550552 if tasks_to_wait :
551- done_tasks , _ = await asyncio .wait (
552- tasks_to_wait , return_when = asyncio .FIRST_COMPLETED
553- )
553+ done_tasks , _ = await asyncio .wait (tasks_to_wait , return_when = asyncio .FIRST_COMPLETED )
554554 for task in done_tasks :
555555 await task
556556
@@ -564,9 +564,7 @@ async def _processor_worker(self):
564564 tasks_to_wait = set (self .active_tasks ) if self .active_tasks else set ()
565565
566566 if tasks_to_wait :
567- done_tasks , _ = await asyncio .wait (
568- tasks_to_wait , return_when = asyncio .FIRST_COMPLETED
569- )
567+ done_tasks , _ = await asyncio .wait (tasks_to_wait , return_when = asyncio .FIRST_COMPLETED )
570568 for task in done_tasks :
571569 await task
572570
0 commit comments