Skip to content

Commit 75e9ca7

Browse files
committed
Allow flash attention 2 to be used for NemotronH model on FSDP
1 parent 28550a7 commit 75e9ca7

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
# Copyright Amazon.com
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
Patch for NemotronH models to enable flash_attention_2 support.
17+
18+
The HuggingFace NemotronH model doesn't declare _supports_flash_attn_2 = True,
19+
but the model architecture does support it. This patch enables flash attention 2
20+
support by patching the model class before instantiation.
21+
22+
Reference: https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/discussions
23+
"""
24+
25+
import sys
26+
27+
28+
def patch_nemotron_h_flash_attention_support(model_config):
29+
"""
30+
Patch NemotronH model to support flash_attention_2.
31+
32+
This function patches the NemotronHPreTrainedModel class to declare
33+
flash attention 2 support. Must be called AFTER loading the config
34+
(which imports the model module) but BEFORE calling from_pretrained().
35+
36+
Args:
37+
model_config: The model config object from AutoConfig.from_pretrained()
38+
"""
39+
try:
40+
# Force-load the modeling module using transformers' dynamic module utilities
41+
# At this point, only the config module is loaded, we need to load the modeling module
42+
if hasattr(model_config, "auto_map") and "AutoModelForCausalLM" in model_config.auto_map:
43+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
44+
45+
module_path = model_config.auto_map["AutoModelForCausalLM"]
46+
47+
# Force import the modeling module by getting the class
48+
# This will load modeling_nemotron_h into sys.modules
49+
try:
50+
# We don't actually need the class, just need to trigger the import
51+
_ = get_class_from_dynamic_module(
52+
class_reference=module_path,
53+
pretrained_model_name_or_path=model_config.name_or_path,
54+
)
55+
except Exception as e:
56+
print(f"Error loading modeling module: {e}")
57+
58+
# Now search for the modeling module which should be loaded
59+
nemotron_module = None
60+
for module_name, module in sys.modules.items():
61+
if (
62+
"transformers_modules" in module_name
63+
and "nemotron" in module_name.lower()
64+
and "modeling" in module_name
65+
):
66+
if hasattr(module, "NemotronHPreTrainedModel"):
67+
nemotron_module = module
68+
break
69+
70+
if nemotron_module is not None:
71+
# Patch the base class to support flash attention 2
72+
if hasattr(nemotron_module, "NemotronHPreTrainedModel"):
73+
nemotron_module.NemotronHPreTrainedModel._supports_flash_attn_2 = True
74+
else:
75+
print("[NemotronH Patch] Warning: Could not find NemotronHPreTrainedModel class to patch")
76+
else:
77+
print("[NemotronH Patch] Warning: Could not find NemotronH modeling module to patch")
78+
79+
except Exception as e:
80+
print(f"[NemotronH Patch] Warning: Failed to patch NemotronH for flash attention support: {e}")
81+
# Don't raise - let the model loading continue and fail naturally if flash attention is truly unsupported

verl/workers/engine/fsdp/transformer_impl.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,13 @@ def _build_module(self):
227227

228228
auto_class = get_hf_auto_model_class(hf_config=self.model_config.hf_config)
229229

230+
# patch for nemotron-h: enable flash_attention_2 support
231+
model_type = getattr(self.model_config.hf_config, "model_type", None)
232+
if model_type == "nemotron_h":
233+
from verl.models.transformers.nemotron_h import patch_nemotron_h_flash_attention_support
234+
235+
patch_nemotron_h_flash_attention_support(self.model_config.hf_config)
236+
230237
module = auto_class.from_pretrained(
231238
pretrained_model_name_or_path=self.model_config.local_path,
232239
torch_dtype=torch_dtype,

verl/workers/fsdp_workers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,12 @@ def _build_model_optimizer(
404404
if getattr(actor_model_config, "model_type", None) == "kimi_vl":
405405
actor_model_config.text_config.topk_method = "greedy"
406406

407+
# patch for nemotron-h: enable flash_attention_2 support
408+
if getattr(actor_model_config, "model_type", None) == "nemotron_h":
409+
from verl.models.transformers.nemotron_h import patch_nemotron_h_flash_attention_support
410+
411+
patch_nemotron_h_flash_attention_support(actor_model_config)
412+
407413
self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)
408414

409415
override_config_kwargs = {

0 commit comments

Comments
 (0)