Skip to content

[fsdp, model]{fix} Allow flash attention 2 to be used for NemotronH model on FSDP#5419

Open
thvasilo wants to merge 1 commit intoverl-project:mainfrom
thvasilo:enable-nemotron-h-fa2
Open

[fsdp, model]{fix} Allow flash attention 2 to be used for NemotronH model on FSDP#5419
thvasilo wants to merge 1 commit intoverl-project:mainfrom
thvasilo:enable-nemotron-h-fa2

Conversation

@thvasilo
Copy link
Contributor

What does this PR do?

The nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 models support Flash Attention 2, however the model artifact does not include that annotation.

Verl uses FA2 by default for models in FSDP, which means that trying to train this model without overriding the attn_implementation would lead to an error.

This PR patches the NemotronHPreTrainedModel class to include the _supports_flash_attn_2 = True annotation that enables flash attention to be used with the model.

TaskRunner pid=7922)   File "/usr/lib/python3.12/concurrent/futures/_base.py", line 449, in result
(TaskRunner pid=7922)     return self.__get_result()
(TaskRunner pid=7922)            ^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)   File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
(TaskRunner pid=7922)     raise self._exception
(TaskRunner pid=7922)            ^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)   File "/usr/local/lib/python3.12/dist-packages/verl/single_controller/ray/base.py", line 841, in func
(TaskRunner pid=7922)     return getattr(self.worker_dict[key], name)(*args, **kwargs)
(TaskRunner pid=7922)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)   File "/usr/local/lib/python3.12/dist-packages/verl/single_controller/base/decorator.py", line 456, in inner
(TaskRunner pid=7922)     return func(*args, **kwargs)
(TaskRunner pid=7922)            ^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)   File "/usr/local/lib/python3.12/dist-packages/verl/utils/transferqueue_utils.py", line 314, in dummy_inner
(TaskRunner pid=7922)     output = func(*args, **kwargs)
(TaskRunner pid=7922)              ^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)   File "/usr/local/lib/python3.12/dist-packages/verl/workers/fsdp_workers.py", line 809, in init_model
(TaskRunner pid=7922)     ) = self._build_model_optimizer(
(TaskRunner pid=7922)         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)   File "/usr/local/lib/python3.12/dist-packages/verl/workers/fsdp_workers.py", line 399, in _build_model_optimizer
(TaskRunner pid=7922)     actor_module = actor_module_class.from_pretrained(
(TaskRunner pid=7922)                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)   File "/usr/local/lib/python3.12/dist-packages/transformers/models/auto/auto_factory.py", line 597, in from_pretrained
(TaskRunner pid=7922)     return model_class.from_pretrained(
(TaskRunner pid=7922)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)   File "/usr/local/lib/python3.12/dist-packages/transformers/modeling_utils.py", line 277, in _wrapper
(TaskRunner pid=7922)     return func(*args, **kwargs)
(TaskRunner pid=7922)            ^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)   File "/usr/local/lib/python3.12/dist-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
(TaskRunner pid=7922)     model = cls(config, *model_args, **model_kwargs)
(TaskRunner pid=7922)             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)   File "/root/.cache/huggingface/modules/transformers_modules/NVIDIA_hyphen_Nemotron_hyphen_3_hyphen_Nano_hyphen_30B_hyphen_A3B_hyphen_BF16/modeling_nemotron_h.py", line 1580, in __init__
(TaskRunner pid=7922)     super().__init__(config)
(TaskRunner pid=7922)   File "/usr/local/lib/python3.12/dist-packages/transformers/modeling_utils.py", line 2076, in __init__
(TaskRunner pid=7922)     self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
(TaskRunner pid=7922)                                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)   File "/usr/local/lib/python3.12/dist-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
(TaskRunner pid=7922)     applicable_attn_implementation = self.get_correct_attn_implementation(
(TaskRunner pid=7922)                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=7922)   File "/usr/local/lib/python3.12/dist-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
(TaskRunner pid=7922)     self._flash_attn_2_can_dispatch(is_init_check)
(TaskRunner pid=7922)   File "/usr/local/lib/python3.12/dist-packages/transformers/modeling_utils.py", line 2406, in _flash_attn_2_can_dispatch
(TaskRunner pid=7922)     raise ValueError(
(TaskRunner pid=7922) ValueError: NemotronHForCausalLM does not support Flash Attention 2.0 yet. Please request to add support where the model is hosted, on its model hub page: https://huggingface.co//dev/shm/verl-cache/af0d1029355e0e35afdcdcc2eedfe46d/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/discussions/new or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, fully_async, one_step_off
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

Test through FFT training with GRPO

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# ==============================================================================
# Training Configuration
# ==============================================================================

train_batch_size=64
ppo_mini_batch_size=32
ppo_micro_batch_size_per_gpu=2
learning_rate=1e-6 # Use 1e-5 for Lora, 1e-6 for FFT

# ==============================================================================
# LoRA Configuration
# ==============================================================================

lora_rank=${LORA_RANK:-0}
lora_alpha=32

# ==============================================================================
# Algorithm Configuration (GRPO)
# ==============================================================================

adv_estimator=grpo
kl_coef=0.001

# Allows to train the model using
python3 verl.trainer.main_ppo \
    data.train_files="${TRAIN_FILE}" \
    data.val_files="${VAL_FILE}" \
    data.max_prompt_length=${max_prompt_length} \
    data.max_response_length=${max_response_length} \
    data.train_batch_size=${train_batch_size} \
    data.prompt_key=${prompt_key} \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    actor_rollout_ref.model.path="${MODEL_PATH}" \
    actor_rollout_ref.actor.checkpoint.save_contents='[hf_model]' \
    actor_rollout_ref.actor.optim.lr=${learning_rate} \
    actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${ppo_micro_batch_size_per_gpu} \
    ++actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
    ++actor_rollout_ref.actor.fsdp_config.strategy=fsdp2 \
    actor_rollout_ref.model.lora_rank=${lora_rank} \
    actor_rollout_ref.model.lora_alpha=${lora_alpha} \
	+actor_rollout_ref.model.lora.merge=True \
    actor_rollout_ref.model.target_modules='["k_proj","o_proj","down_proj","q_proj","up_proj","out_proj","in_proj","v_proj"]' \
	actor_rollout_ref.model.trust_remote_code=True \
	actor_rollout_ref.model.use_shm=True \
	++actor_rollout_ref.model.override_config.attn_implementation=flash_attention_2 \
    actor_rollout_ref.rollout.enforce_eager=False \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.tensor_model_parallel_size=8\
    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
    reward_model.use_reward_loop=${use_reward_loop} \
    reward_model.reward_manager=${reward_manager} \
    reward_model.num_workers=${num_workers} \
    reward_model.enable=false \
    trainer.default_local_dir="${OUTPUT_DIR}" \
    trainer.logger='[console]' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=$NUM_NODES \
    trainer.total_epochs=1 \
    trainer.save_freq=20 \
    trainer.test_freq=5 \
	++trainer.use_legacy_worker_impl=disable \
    trainer.val_before_train=true \
    trainer.project_name=verl-grpo-tests \
    trainer.experiment_name=nemotron-fsdp-${TIMESTAMP}

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

@thvasilo
Copy link
Contributor Author

Ping @ISEEKYAN could you help assign a reviewer? Thanks!

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a patch to enable Flash Attention 2 for NemotronH models, which is a necessary workaround for an issue in the upstream model artifacts. The approach of monkey-patching is reasonable for this situation. My review includes a critical suggestion to improve the robustness of the patching mechanism, making it less dependent on the internal implementation details of the transformers library.

Comment on lines +25 to +81
import sys


def patch_nemotron_h_flash_attention_support(model_config):
"""
Patch NemotronH model to support flash_attention_2.

This function patches the NemotronHPreTrainedModel class to declare
flash attention 2 support. Must be called AFTER loading the config
(which imports the model module) but BEFORE calling from_pretrained().

Args:
model_config: The model config object from AutoConfig.from_pretrained()
"""
try:
# Force-load the modeling module using transformers' dynamic module utilities
# At this point, only the config module is loaded, we need to load the modeling module
if hasattr(model_config, "auto_map") and "AutoModelForCausalLM" in model_config.auto_map:
from transformers.dynamic_module_utils import get_class_from_dynamic_module

module_path = model_config.auto_map["AutoModelForCausalLM"]

# Force import the modeling module by getting the class
# This will load modeling_nemotron_h into sys.modules
try:
# We don't actually need the class, just need to trigger the import
_ = get_class_from_dynamic_module(
class_reference=module_path,
pretrained_model_name_or_path=model_config.name_or_path,
)
except Exception as e:
print(f"Error loading modeling module: {e}")

# Now search for the modeling module which should be loaded
nemotron_module = None
for module_name, module in sys.modules.items():
if (
"transformers_modules" in module_name
and "nemotron" in module_name.lower()
and "modeling" in module_name
):
if hasattr(module, "NemotronHPreTrainedModel"):
nemotron_module = module
break

if nemotron_module is not None:
# Patch the base class to support flash attention 2
if hasattr(nemotron_module, "NemotronHPreTrainedModel"):
nemotron_module.NemotronHPreTrainedModel._supports_flash_attn_2 = True
else:
print("[NemotronH Patch] Warning: Could not find NemotronHPreTrainedModel class to patch")
else:
print("[NemotronH Patch] Warning: Could not find NemotronH modeling module to patch")

except Exception as e:
print(f"[NemotronH Patch] Warning: Failed to patch NemotronH for flash attention support: {e}")
# Don't raise - let the model loading continue and fail naturally if flash attention is truly unsupported
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current method of finding the nemotron_module by iterating through sys.modules is fragile. It relies on the internal implementation details of the transformers library's dynamic module loading, which can change without notice. A more robust approach is to use the model class's Method Resolution Order (MRO) to find the NemotronHPreTrainedModel base class and patch it directly. This avoids relying on module naming conventions. Additionally, using the logging module is preferred over print for warnings and informational messages.

import logging
from transformers.dynamic_module_utils import get_class_from_dynamic_module

logger = logging.getLogger(__name__)


def patch_nemotron_h_flash_attention_support(model_config):
    """
    Patch NemotronH model to support flash_attention_2.

    This function patches the NemotronHPreTrainedModel class to declare
    flash attention 2 support. Must be called AFTER loading the config
    (which imports the model module) but BEFORE calling from_pretrained().

    Args:
        model_config: The model config object from AutoConfig.from_pretrained()
    """
    try:
        if not (hasattr(model_config, "auto_map") and "AutoModelForCausalLM" in model_config.auto_map):
            return

        module_path = model_config.auto_map["AutoModelForCausalLM"]
        model_class = get_class_from_dynamic_module(
            class_reference=module_path,
            pretrained_model_name_or_path=model_config.name_or_path,
        )

        base_model_to_patch = None
        for base_class in model_class.__mro__:
            if base_class.__name__ == "NemotronHPreTrainedModel":
                base_model_to_patch = base_class
                break

        if base_model_to_patch:
            base_model_to_patch._supports_flash_attn_2 = True
            logger.info("[NemotronH Patch] Successfully patched NemotronHPreTrainedModel for flash attention 2 support.")
        else:
            logger.warning("[NemotronH Patch] Warning: Could not find NemotronHPreTrainedModel in MRO to patch.")

    except Exception as e:
        logger.warning(f"[NemotronH Patch] Warning: Failed to patch NemotronH for flash attention support: {e}")
        # Don't raise - let the model loading continue and fail naturally if flash attention is truly unsupported

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant