Skip to content

[fsdp] fix: Support trust_remote_code during FSDP HugginFace checkpoint save#5200

Merged
HollowMan6 merged 7 commits intoverl-project:mainfrom
thvasilo:fix-saving-with-remote-code
Feb 6, 2026
Merged

[fsdp] fix: Support trust_remote_code during FSDP HugginFace checkpoint save#5200
HollowMan6 merged 7 commits intoverl-project:mainfrom
thvasilo:fix-saving-with-remote-code

Conversation

@thvasilo
Copy link
Contributor

@thvasilo thvasilo commented Feb 5, 2026

What does this PR do?

Fixes #5214

Ensures that trust_remote_code is passed correctly to auto_model_cls.from_config when the model requires remote code during saving.

While testing GRPO full fine-tuning of the nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 model I got

During handling of the above exception, another exception occurred:

ray::WorkerDict.actor_rollout_save_checkpoint() (pid=8158, ip=10.0.65.133, actor_id=3b07fc0198f8a1a8292d741302000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7ef34d430ad0>)
  File "/usr/lib/python3.12/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
           ^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/verl/single_controller/ray/base.py", line 841, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/verl/single_controller/base/decorator.py", line 456, in inner
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/verl/utils/transferqueue_utils.py", line 314, in dummy_inner
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/verl/workers/fsdp_workers.py", line 1086, in save_checkpoint
    self.checkpoint_manager.save_checkpoint(
  File "/usr/local/lib/python3.12/dist-packages/verl/utils/checkpoint/fsdp_checkpoint_manager.py", line 345, in save_checkpoint
    save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/models/auto/auto_factory.py", line 435, in from_config
    trust_remote_code = resolve_trust_remote_code(
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/dynamic_module_utils.py", line 769, in resolve_trust_remote_code
    raise ValueError(
ValueError: The repository /opt/dlami/nvme/models/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 contains custom code which must be executed to correctly load the model. You can inspect the repository content at /opt/dlami/nvme/models/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 .
 You can inspect the repository content at https://hf.co//opt/dlami/models/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.

To replicate the issue use an 80GB+ GPU and run

MODEL_PATH="nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16"

python verl.trainer.main_ppo \
    --config-path "${HYDRA_CONFIG_PATH}" \
    --config-name "${HYDRA_CONFIG_NAME}" \
    data.train_files="${TRAIN_FILE}" \
    data.val_files="${VAL_FILE}" \
    data.max_prompt_length=${max_prompt_length} \
    data.max_response_length=${max_response_length} \
    data.train_batch_size=${train_batch_size} \
    data.prompt_key=${prompt_key} \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    actor_rollout_ref.model.path="${MODEL_PATH}" \
    actor_rollout_ref.actor.checkpoint.save_contents='[hf_model]' \
    actor_rollout_ref.actor.optim.lr=${learning_rate} \
    actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${ppo_micro_batch_size_per_gpu} \
    ++actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
    ++actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
    actor_rollout_ref.model.lora_rank=0 \
	actor_rollout_ref.model.trust_remote_code=True \
	actor_rollout_ref.model.use_shm=False \
	++actor_rollout_ref.model.override_config.attn_implementation=eager \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
    reward_model.use_reward_loop=${use_reward_loop} \
    reward_model.reward_manager=${reward_manager} \
    reward_model.num_workers=${num_workers} \
    reward_model.enable=false \
    trainer.default_local_dir="${OUTPUT_DIR}" \
    trainer.logger='[console]' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.total_epochs=1 \
    trainer.save_freq=20 \
    trainer.test_freq=5 \
    trainer.val_before_train=true \
    trainer.project_name=bedrock-verl-reward-loop-naive-gsm8k \
    trainer.experiment_name=reward-loop-naive-grpo-gsm8k-${TIMESTAMP}

After applying the fix I'm able to save the model

Training Progress:  75%|███████▌  | 3/4 [04:49<01:28, 88.40s/it]
(TaskRunner pid=7817) test_gen_batch meta info: {'eos_token_id': 11, 'pad_token_id': 11, 'recompute_log_prob': False, 'do_sample': False, 'validate': True, 'global_steps': 4}
(TaskRunner pid=7817) validation generation end
(TaskRunner pid=7817) local_global_step_folder: /fsx/ubuntu/users/thvasilo/outputs/legacy-prime-gsm8k-20260204_113408/global_step_4
(WorkerDict pid=9034) INFO:2026-02-04 19:47:05,234:[Rank 0] Saved model config and tokenizer class to /fsx/ubuntu/users/thvasilo/outputs/legacy-prime-gsm8k-20260204_113408/global_step_4/actor/huggingface
(WorkerDict pid=9034) /usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:675: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
(WorkerDict pid=9034)   warnings.warn(
(WorkerDict pid=9034) INFO:2026-02-04 19:50:08,740:[Rank 0] Saved hf_model to /fsx/ubuntu/users/thvasilo/outputs/legacy-prime-gsm8k-20260204_113408/global_step_4/actor/huggingface
(WorkerDict pid=9038) /usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:675: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html . [repeated 7x across cluster]
(WorkerDict pid=9038)   warnings.warn( [repeated 7x across cluster]
(TaskRunner pid=7817) step:4 - actor/entropy:0.3219437599182129 - perf/mfu/actor_infer:0 - actor/pg_loss:-0.7364279553294182 - actor/kl_loss:0.0 - actor/pg_clipfrac:0.012568664009450004 - actor/ppo_kl:-0.006911035072789673 - actor/pg_clipfrac_lower:0.0 - actor/grad_norm:0.6123046875 - perf/mfu/actor:0.0 - perf/max_memory_allocated_gb:74.05595064163208 - perf/max_memory_reserved_gb:88.890625 - perf/cpu_memory_used_gb:145.18696212768555 - actor/lr:1e-06 - val-aux/openai/gsm8k/reward/mean@1:0.4453125 - val-core/openai/gsm8k/acc/mean@1:0.4453125 - val-aux/num_turns/min:2 - val-aux/num_turns/max:2 - val-aux/num_turns/mean:2.0 - training/global_step:4 - training/epoch:0 - critic/score/mean:0.40625 - critic/score/max:1.0 - critic/score/min:0.0 - critic/rewards/mean:0.40625 - critic/rewards/max:1.0 - critic/rewards/min:0.0 - critic/advantages/mean:0.36734017729759216 - critic/advantages/max:0.9999990463256836 - critic/advantages/min:0.0 - critic/returns/mean:0.36734017729759216 - critic/returns/max:0.9999990463256836 - critic/returns/min:0.0 - response_length/mean:445.21875 - response_length/max:1024.0 - response_length/min:125.0 - response_length/clip_ratio:0.09375 - response_length_non_aborted/mean:445.21875 - response_length_non_aborted/max:1024.0 - response_length_non_aborted/min:125.0 - response_length_non_aborted/clip_ratio:0.09375 - response/aborted_ratio:0.0 - prompt_length/mean:93.078125 - prompt_length/max:175.0 - prompt_length/min:57.0 - prompt_length/clip_ratio:0.0 - num_turns/min:2 - num_turns/max:2 - num_turns/mean:2.0 - timing_s/start_profile:3.2389070838689804e-05 - timing_s/agent_loop/generate_sequences/min:1.9702747561968863 - timing_s/agent_loop/generate_sequences/max:20.55855459300801 - timing_s/agent_loop/generate_sequences/mean:6.543419518391602 - timing_s/agent_loop/tool_calls/min:0.0 - timing_s/agent_loop/tool_calls/max:0.0 - timing_s/agent_loop/tool_calls/mean:0.0 - timing_s/agent_loop/slowest/generate_sequences:20.55855459300801 - timing_s/agent_loop/slowest/tool_calls:0.0 - timing_s/agent_loop/slowest/prompt_length:82 - timing_s/agent_loop/slowest/response_length:1024 - timing_s/gen:31.25609784666449 - timing_s/reward:8.507398888468742e-05 - timing_s/old_log_prob:1.868525329977274 - timing_s/adv:0.0023601027205586433 - timing_s/update_actor:35.30974240321666 - timing_s/step:68.44496135693043 - timing_s/testing:36.04290238209069 - timing_s/save_checkpoint:184.1523506329395 - timing_s/stop_profile:7.03069381415844e-05 - timing_per_token_ms/adv:6.850607298942391e-05 - timing_per_token_ms/gen:1.0969361215225832 - timing_per_token_ms/update_actor:1.024926486987799 - perf/total_num_tokens:34451 - perf/time_per_step:68.44496135693043 - perf/throughput:62.91734138825627
(TaskRunner pid=7817) ("Final validation metrics: {'val-aux/openai/gsm8k/reward/mean@1': 0.4453125, "
(TaskRunner pid=7817)  "'val-core/openai/gsm8k/acc/mean@1': 0.4453125, 'val-aux/num_turns/min': 2, "
(TaskRunner pid=7817)  "'val-aux/num_turns/max': 2, 'val-aux/num_turns/mean': 2.0}")
(TaskRunner pid=7817)
Training Progress: 100%|██████████| 4/4 [09:38<00:00, 167.45s/it]
Training Progress: 100%|██████████| 4/4 [09:38<00:00, 144.53s/it]

Update .pre-commit config

pre-commit checks were failing because the python compilation check scans the entire directory including .venv which can include other python files which do not conform to the requirements. The PR add some exclusions to avoid this

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

Despite my best efforts I'm not able to produce a unit test that detects this, I can only observe the issue while training with Ray.

I've written https://gist.github.com/thvasilo/76596a638440cd7f342ba4d23a2efb2e that replicates the process, but this does not trigger the error, my current assumption is an execution env discrepancy (I'm using torchrun to run the test that shares processes, my job fail when using Ray)

API and Usage Example

No API changes, we detect custom models from their contents

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a crash during FSDP checkpoint saving for models requiring trust_remote_code=True by introducing logic to detect and pass this flag. However, a critical security concern arises as automatically enabling trust_remote_code bypasses an important security boundary, potentially allowing arbitrary code execution from malicious models during checkpoint saving. Furthermore, the new remote code detection logic has a potential critical issue that could lead to an AttributeError. The .pre-commit-config.yaml update is a good improvement.

@thvasilo
Copy link
Contributor Author

thvasilo commented Feb 5, 2026

About the security concern raised, we could cross-verify with actor_rollout_ref.model.trust_remote_code and only enable trust_remote_code during saving if the user has set it to true for the actor model.

My thinking was that if we get to the saving point for a model that required remote code, we can assume the user has enabled trust_remote_code for the actor model, otherwise training would have failed at model loading time.

Let me know if you'd prefer the explicit approach though

@thvasilo
Copy link
Contributor Author

thvasilo commented Feb 6, 2026

Hi @HollowMan6 who should I ping for a review here? Thanks!

Copy link
Collaborator

@HollowMan6 HollowMan6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can help review here by myself :)

thvasilo and others added 3 commits February 6, 2026 04:17
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@thvasilo thvasilo force-pushed the fix-saving-with-remote-code branch from fcc0a69 to 66a4f4d Compare February 6, 2026 04:17
@thvasilo thvasilo force-pushed the fix-saving-with-remote-code branch from 66a4f4d to 45d80cb Compare February 6, 2026 04:24
@thvasilo
Copy link
Contributor Author

thvasilo commented Feb 6, 2026

New implementation added @HollowMan6 , let me know if you'd like some modification to tests/special_distributed/test_fsdp_ckpt.py added to demonstrate the use of the new parameter in FSDPCheckpointManager

Copy link
Collaborator

@HollowMan6 HollowMan6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me know if you'd like some modification to tests/special_distributed/test_fsdp_ckpt.py added to demonstrate the use of the new parameter in FSDPCheckpointManager

Feel free

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Propagates the Hugging Face trust_remote_code setting into FSDP checkpoint save logic so models that rely on custom code can be re-instantiated from config during hf_model checkpoint export (fixing the failure reported in #5214). Also adjusts pre-commit’s Python compilation hook to avoid scanning virtualenv directories.

Changes:

  • Thread trust_remote_code from model config into FSDPCheckpointManager across FSDP workers/engines/trainers.
  • Pass trust_remote_code to AutoModel*.from_config(...) when building the temporary model used for hf_model save.
  • Exclude .venv/venv/.git from the pre-commit compileall scan.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
verl/workers/fsdp_workers.py Passes trust_remote_code into FSDPCheckpointManager for actor/critic checkpointing.
verl/workers/engine/veomni/transformer_impl.py Forwards trust_remote_code to checkpoint manager in VeOmni engine.
verl/workers/engine/fsdp/transformer_impl.py Forwards trust_remote_code to checkpoint manager in FSDP engine.
verl/utils/checkpoint/fsdp_checkpoint_manager.py Adds trust_remote_code parameter and uses it when instantiating the save-time HF model via from_config.
verl/trainer/fsdp_sft_trainer.py Forwards trust_remote_code into checkpoint manager for SFT.
verl/experimental/vla/fsdp_workers.py Forwards trust_remote_code into checkpoint manager for experimental VLA worker.
.pre-commit-config.yaml Updates compileall hook to exclude env/git directories.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@HollowMan6 HollowMan6 changed the title [fsdp] fix: Detect whether model needs trust_remote_code during FSDP HugginFace checkpoint save [fsdp] fix: Support trust_remote_code during FSDP HugginFace checkpoint save Feb 6, 2026
Copy link
Collaborator

@HollowMan6 HollowMan6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! The changes now LGTM.

@HollowMan6 HollowMan6 merged commit f245e2d into verl-project:main Feb 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Saving checkpoint fails for FSDP when model requires trust_remote_code

2 participants