[fsdp, model]{fix} Allow flash attention 2 to be used for NemotronH model on FSDP#5419
[fsdp, model]{fix} Allow flash attention 2 to be used for NemotronH model on FSDP#5419thvasilo wants to merge 1 commit intoverl-project:mainfrom
Conversation
|
Ping @ISEEKYAN could you help assign a reviewer? Thanks! |
There was a problem hiding this comment.
Code Review
This pull request introduces a patch to enable Flash Attention 2 for NemotronH models, which is a necessary workaround for an issue in the upstream model artifacts. The approach of monkey-patching is reasonable for this situation. My review includes a critical suggestion to improve the robustness of the patching mechanism, making it less dependent on the internal implementation details of the transformers library.
| import sys | ||
|
|
||
|
|
||
| def patch_nemotron_h_flash_attention_support(model_config): | ||
| """ | ||
| Patch NemotronH model to support flash_attention_2. | ||
|
|
||
| This function patches the NemotronHPreTrainedModel class to declare | ||
| flash attention 2 support. Must be called AFTER loading the config | ||
| (which imports the model module) but BEFORE calling from_pretrained(). | ||
|
|
||
| Args: | ||
| model_config: The model config object from AutoConfig.from_pretrained() | ||
| """ | ||
| try: | ||
| # Force-load the modeling module using transformers' dynamic module utilities | ||
| # At this point, only the config module is loaded, we need to load the modeling module | ||
| if hasattr(model_config, "auto_map") and "AutoModelForCausalLM" in model_config.auto_map: | ||
| from transformers.dynamic_module_utils import get_class_from_dynamic_module | ||
|
|
||
| module_path = model_config.auto_map["AutoModelForCausalLM"] | ||
|
|
||
| # Force import the modeling module by getting the class | ||
| # This will load modeling_nemotron_h into sys.modules | ||
| try: | ||
| # We don't actually need the class, just need to trigger the import | ||
| _ = get_class_from_dynamic_module( | ||
| class_reference=module_path, | ||
| pretrained_model_name_or_path=model_config.name_or_path, | ||
| ) | ||
| except Exception as e: | ||
| print(f"Error loading modeling module: {e}") | ||
|
|
||
| # Now search for the modeling module which should be loaded | ||
| nemotron_module = None | ||
| for module_name, module in sys.modules.items(): | ||
| if ( | ||
| "transformers_modules" in module_name | ||
| and "nemotron" in module_name.lower() | ||
| and "modeling" in module_name | ||
| ): | ||
| if hasattr(module, "NemotronHPreTrainedModel"): | ||
| nemotron_module = module | ||
| break | ||
|
|
||
| if nemotron_module is not None: | ||
| # Patch the base class to support flash attention 2 | ||
| if hasattr(nemotron_module, "NemotronHPreTrainedModel"): | ||
| nemotron_module.NemotronHPreTrainedModel._supports_flash_attn_2 = True | ||
| else: | ||
| print("[NemotronH Patch] Warning: Could not find NemotronHPreTrainedModel class to patch") | ||
| else: | ||
| print("[NemotronH Patch] Warning: Could not find NemotronH modeling module to patch") | ||
|
|
||
| except Exception as e: | ||
| print(f"[NemotronH Patch] Warning: Failed to patch NemotronH for flash attention support: {e}") | ||
| # Don't raise - let the model loading continue and fail naturally if flash attention is truly unsupported |
There was a problem hiding this comment.
The current method of finding the nemotron_module by iterating through sys.modules is fragile. It relies on the internal implementation details of the transformers library's dynamic module loading, which can change without notice. A more robust approach is to use the model class's Method Resolution Order (MRO) to find the NemotronHPreTrainedModel base class and patch it directly. This avoids relying on module naming conventions. Additionally, using the logging module is preferred over print for warnings and informational messages.
import logging
from transformers.dynamic_module_utils import get_class_from_dynamic_module
logger = logging.getLogger(__name__)
def patch_nemotron_h_flash_attention_support(model_config):
"""
Patch NemotronH model to support flash_attention_2.
This function patches the NemotronHPreTrainedModel class to declare
flash attention 2 support. Must be called AFTER loading the config
(which imports the model module) but BEFORE calling from_pretrained().
Args:
model_config: The model config object from AutoConfig.from_pretrained()
"""
try:
if not (hasattr(model_config, "auto_map") and "AutoModelForCausalLM" in model_config.auto_map):
return
module_path = model_config.auto_map["AutoModelForCausalLM"]
model_class = get_class_from_dynamic_module(
class_reference=module_path,
pretrained_model_name_or_path=model_config.name_or_path,
)
base_model_to_patch = None
for base_class in model_class.__mro__:
if base_class.__name__ == "NemotronHPreTrainedModel":
base_model_to_patch = base_class
break
if base_model_to_patch:
base_model_to_patch._supports_flash_attn_2 = True
logger.info("[NemotronH Patch] Successfully patched NemotronHPreTrainedModel for flash attention 2 support.")
else:
logger.warning("[NemotronH Patch] Warning: Could not find NemotronHPreTrainedModel in MRO to patch.")
except Exception as e:
logger.warning(f"[NemotronH Patch] Warning: Failed to patch NemotronH for flash attention support: {e}")
# Don't raise - let the model loading continue and fail naturally if flash attention is truly unsupported
What does this PR do?
The nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 models support Flash Attention 2, however the model artifact does not include that annotation.
Verl uses FA2 by default for models in FSDP, which means that trying to train this model without overriding the attn_implementation would lead to an error.
This PR patches the
NemotronHPreTrainedModelclass to include the_supports_flash_attn_2 = Trueannotation that enables flash attention to be used with the model.Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
Test through FFT training with GRPO
API and Usage Example
Design & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.