Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __init__(
tokenizer: AutoTokenizer,
processor: AutoProcessor,
dataset_cls: type[RLHFDataset],
dataset_config: DictConfig,
dataset_config: DictConfigWrap,
**kwargs,
):
"""Initialize agent loop, each sample will have its own loop instance.
Expand All @@ -211,15 +211,15 @@ def __init__(
tokenizer (AutoTokenizer): Tokenizer for tokenize messages.
processor (AutoProcessor): Processor for process messages.
dataset_cls (type[Dataset]): Dataset class for creating dataset, Defaults to RLHFDataset.
dataset_config (DictConfig): Dataset config.
dataset_config (DictConfigWrap): Dataset config.
"""
self.config = trainer_config.config
self.server_manager = server_manager
self.tokenizer = tokenizer
self.processor = processor
self.dataset_cls = dataset_cls
self.dataset_config = dataset_config
self.apply_chat_template_kwargs = dataset_config.get("apply_chat_template_kwargs", {})
self.dataset_config = dataset_config.config
self.apply_chat_template_kwargs = self.dataset_config.get("apply_chat_template_kwargs", {})
self.system_prompt = initialize_system_prompt(self.tokenizer, **self.apply_chat_template_kwargs)
self.loop = get_event_loop()

Expand Down Expand Up @@ -513,7 +513,7 @@ async def _run_agent_loop(
tokenizer=self.tokenizer,
processor=self.processor,
dataset_cls=self.dataset_cls,
dataset_config=self.config.data,
dataset_config=DictConfigWrap(self.config.data),
)
output: AgentLoopOutput = await agent_loop.run(sampling_params, **kwargs)
return await self._agent_loop_postprocess(output, **kwargs)
Expand Down
Loading