Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions verl/models/transformers/nemotron_h.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright Amazon.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Patch for NemotronH models to enable flash_attention_2 support.

The HuggingFace NemotronH model doesn't declare _supports_flash_attn_2 = True,
but the model architecture does support it. This patch enables flash attention 2
support by patching the model class before instantiation.

Reference: https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/discussions
"""

import sys


def patch_nemotron_h_flash_attention_support(model_config):
"""
Patch NemotronH model to support flash_attention_2.

This function patches the NemotronHPreTrainedModel class to declare
flash attention 2 support. Must be called AFTER loading the config
(which imports the model module) but BEFORE calling from_pretrained().

Args:
model_config: The model config object from AutoConfig.from_pretrained()
"""
try:
# Force-load the modeling module using transformers' dynamic module utilities
# At this point, only the config module is loaded, we need to load the modeling module
if hasattr(model_config, "auto_map") and "AutoModelForCausalLM" in model_config.auto_map:
from transformers.dynamic_module_utils import get_class_from_dynamic_module

module_path = model_config.auto_map["AutoModelForCausalLM"]

# Force import the modeling module by getting the class
# This will load modeling_nemotron_h into sys.modules
try:
# We don't actually need the class, just need to trigger the import
_ = get_class_from_dynamic_module(
class_reference=module_path,
pretrained_model_name_or_path=model_config.name_or_path,
)
except Exception as e:
print(f"Error loading modeling module: {e}")

# Now search for the modeling module which should be loaded
nemotron_module = None
for module_name, module in sys.modules.items():
if (
"transformers_modules" in module_name
and "nemotron" in module_name.lower()
and "modeling" in module_name
):
if hasattr(module, "NemotronHPreTrainedModel"):
nemotron_module = module
break

if nemotron_module is not None:
# Patch the base class to support flash attention 2
if hasattr(nemotron_module, "NemotronHPreTrainedModel"):
nemotron_module.NemotronHPreTrainedModel._supports_flash_attn_2 = True
else:
print("[NemotronH Patch] Warning: Could not find NemotronHPreTrainedModel class to patch")
else:
print("[NemotronH Patch] Warning: Could not find NemotronH modeling module to patch")

except Exception as e:
print(f"[NemotronH Patch] Warning: Failed to patch NemotronH for flash attention support: {e}")
# Don't raise - let the model loading continue and fail naturally if flash attention is truly unsupported
Comment on lines +25 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current method of finding the nemotron_module by iterating through sys.modules is fragile. It relies on the internal implementation details of the transformers library's dynamic module loading, which can change without notice. A more robust approach is to use the model class's Method Resolution Order (MRO) to find the NemotronHPreTrainedModel base class and patch it directly. This avoids relying on module naming conventions. Additionally, using the logging module is preferred over print for warnings and informational messages.

import logging
from transformers.dynamic_module_utils import get_class_from_dynamic_module

logger = logging.getLogger(__name__)


def patch_nemotron_h_flash_attention_support(model_config):
    """
    Patch NemotronH model to support flash_attention_2.

    This function patches the NemotronHPreTrainedModel class to declare
    flash attention 2 support. Must be called AFTER loading the config
    (which imports the model module) but BEFORE calling from_pretrained().

    Args:
        model_config: The model config object from AutoConfig.from_pretrained()
    """
    try:
        if not (hasattr(model_config, "auto_map") and "AutoModelForCausalLM" in model_config.auto_map):
            return

        module_path = model_config.auto_map["AutoModelForCausalLM"]
        model_class = get_class_from_dynamic_module(
            class_reference=module_path,
            pretrained_model_name_or_path=model_config.name_or_path,
        )

        base_model_to_patch = None
        for base_class in model_class.__mro__:
            if base_class.__name__ == "NemotronHPreTrainedModel":
                base_model_to_patch = base_class
                break

        if base_model_to_patch:
            base_model_to_patch._supports_flash_attn_2 = True
            logger.info("[NemotronH Patch] Successfully patched NemotronHPreTrainedModel for flash attention 2 support.")
        else:
            logger.warning("[NemotronH Patch] Warning: Could not find NemotronHPreTrainedModel in MRO to patch.")

    except Exception as e:
        logger.warning(f"[NemotronH Patch] Warning: Failed to patch NemotronH for flash attention support: {e}")
        # Don't raise - let the model loading continue and fail naturally if flash attention is truly unsupported

7 changes: 7 additions & 0 deletions verl/workers/engine/fsdp/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ def _build_module(self):

auto_class = get_hf_auto_model_class(hf_config=self.model_config.hf_config)

# patch for nemotron-h: enable flash_attention_2 support
model_type = getattr(self.model_config.hf_config, "model_type", None)
if model_type == "nemotron_h":
from verl.models.transformers.nemotron_h import patch_nemotron_h_flash_attention_support

patch_nemotron_h_flash_attention_support(self.model_config.hf_config)

module = auto_class.from_pretrained(
pretrained_model_name_or_path=self.model_config.local_path,
torch_dtype=torch_dtype,
Expand Down
6 changes: 6 additions & 0 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ def _build_model_optimizer(
if getattr(actor_model_config, "model_type", None) == "kimi_vl":
actor_model_config.text_config.topk_method = "greedy"

# patch for nemotron-h: enable flash_attention_2 support
if getattr(actor_model_config, "model_type", None) == "nemotron_h":
from verl.models.transformers.nemotron_h import patch_nemotron_h_flash_attention_support

patch_nemotron_h_flash_attention_support(actor_model_config)

self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)

override_config_kwargs = {
Expand Down