Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ repos:
hooks:
- id: compileall
name: Compile all python files
entry: sh -c 'PYTHONWARNINGS=error python3 -m compileall -q .'
entry: sh -c 'PYTHONWARNINGS=error python3 -m compileall -q . -x "\.venv|venv|\.git"'
language: python
pass_filenames: false
1 change: 1 addition & 0 deletions verl/experimental/vla/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def init_model(self):
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=self.config.actor.checkpoint,
trust_remote_code=self.config.model.trust_remote_code,
)

torch.distributed.barrier()
1 change: 1 addition & 0 deletions verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ def _init_checkpoint_manager(self):
lr_scheduler=self.lr_scheduler,
processing_class=self.tokenizer,
checkpoint_config=checkpoint_config_dict,
trust_remote_code=self.config.model.trust_remote_code,
)

def load_checkpoint(self):
Expand Down
8 changes: 7 additions & 1 deletion verl/utils/checkpoint/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class FSDPCheckpointManager(BaseCheckpointManager):
checkpoint_contents DictConfig: Configuration for checkpoint contents.
- 'load': Components to load; must contain 'model'. Defaults to ['model', 'optimizer', 'extra'].
- 'save': Components to save; must contain 'model'. Defaults to ['model', 'optimizer', 'extra'].
trust_remote_code: Whether to trust_remote_code when loading the model configuration
"""

def __init__(
Expand All @@ -79,6 +80,7 @@ def __init__(
lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
processing_class: PreTrainedTokenizer | ProcessorMixin = None,
checkpoint_config: DictConfig = None,
trust_remote_code: bool = False,
**kwargs,
):
if processing_class is None and "tokenizer" in kwargs:
Expand All @@ -94,6 +96,7 @@ def __init__(
processing_class=processing_class,
checkpoint_config=checkpoint_config,
)
self.trust_remote_code = trust_remote_code

def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):
"""
Expand Down Expand Up @@ -333,7 +336,10 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
raise NotImplementedError(f"Unknown architecture {model_config['architectures']}")

with init_empty_weights():
save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16)
save_model = auto_model_cls.from_config(
model_config, torch_dtype=torch.bfloat16, trust_remote_code=self.trust_remote_code
)

save_model.to_empty(device="cpu")

if save_model.can_generate():
Expand Down
1 change: 1 addition & 0 deletions verl/workers/engine/fsdp/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def initialize(self):
lr_scheduler=self.lr_scheduler,
processing_class=self.model_config.get_processor(),
checkpoint_config=self.checkpoint_config,
trust_remote_code=self.model_config.trust_remote_code,
)

self.to(
Expand Down
1 change: 1 addition & 0 deletions verl/workers/engine/veomni/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def initialize(self):
lr_scheduler=self.lr_scheduler,
processing_class=self.model_config.get_processor(),
checkpoint_config=self.checkpoint_config,
trust_remote_code=self.model_config.trust_remote_code,
)

self.to(
Expand Down
4 changes: 3 additions & 1 deletion verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,7 @@ def init_model(self):
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=self.config.actor.checkpoint,
trust_remote_code=self.config.model.get("trust_remote_code", False),
)

if not self._is_actor and self._is_rollout:
Expand Down Expand Up @@ -1255,7 +1256,7 @@ def __init__(self, config: FSDPCriticConfig):
)
self.use_orig_params = self.config.model.fsdp_config.get("use_orig_params", False)

def _build_critic_model_optimizer(self, config):
def _build_critic_model_optimizer(self, config: FSDPCriticConfig):
# the following line is necessary
from torch.distributed.fsdp import MixedPrecision

Expand Down Expand Up @@ -1533,6 +1534,7 @@ def init_model(self):
lr_scheduler=self.critic_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=self.config.checkpoint,
trust_remote_code=self.config.model.get("trust_remote_code", False),
)

@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic"))
Expand Down
Loading