|
| 1 | +# Copyright 2024 Bytedance Ltd. and/or its affiliates |
| 2 | +# Copyright Amazon.com |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +""" |
| 16 | +Patch for NemotronH models to enable flash_attention_2 support. |
| 17 | +
|
| 18 | +The HuggingFace NemotronH model doesn't declare _supports_flash_attn_2 = True, |
| 19 | +but the model architecture does support it. This patch enables flash attention 2 |
| 20 | +support by patching the model class before instantiation. |
| 21 | +
|
| 22 | +Reference: https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16/discussions |
| 23 | +""" |
| 24 | + |
| 25 | +import sys |
| 26 | + |
| 27 | + |
| 28 | +def patch_nemotron_h_flash_attention_support(model_config): |
| 29 | + """ |
| 30 | + Patch NemotronH model to support flash_attention_2. |
| 31 | +
|
| 32 | + This function patches the NemotronHPreTrainedModel class to declare |
| 33 | + flash attention 2 support. Must be called AFTER loading the config |
| 34 | + (which imports the model module) but BEFORE calling from_pretrained(). |
| 35 | +
|
| 36 | + Args: |
| 37 | + model_config: The model config object from AutoConfig.from_pretrained() |
| 38 | + """ |
| 39 | + try: |
| 40 | + # Force-load the modeling module using transformers' dynamic module utilities |
| 41 | + # At this point, only the config module is loaded, we need to load the modeling module |
| 42 | + if hasattr(model_config, "auto_map") and "AutoModelForCausalLM" in model_config.auto_map: |
| 43 | + from transformers.dynamic_module_utils import get_class_from_dynamic_module |
| 44 | + |
| 45 | + module_path = model_config.auto_map["AutoModelForCausalLM"] |
| 46 | + |
| 47 | + # Force import the modeling module by getting the class |
| 48 | + # This will load modeling_nemotron_h into sys.modules |
| 49 | + try: |
| 50 | + # We don't actually need the class, just need to trigger the import |
| 51 | + _ = get_class_from_dynamic_module( |
| 52 | + class_reference=module_path, |
| 53 | + pretrained_model_name_or_path=model_config.name_or_path, |
| 54 | + ) |
| 55 | + except Exception as e: |
| 56 | + print(f"Error loading modeling module: {e}") |
| 57 | + |
| 58 | + # Now search for the modeling module which should be loaded |
| 59 | + nemotron_module = None |
| 60 | + for module_name, module in sys.modules.items(): |
| 61 | + if ( |
| 62 | + "transformers_modules" in module_name |
| 63 | + and "nemotron" in module_name.lower() |
| 64 | + and "modeling" in module_name |
| 65 | + ): |
| 66 | + if hasattr(module, "NemotronHPreTrainedModel"): |
| 67 | + nemotron_module = module |
| 68 | + break |
| 69 | + |
| 70 | + if nemotron_module is not None: |
| 71 | + # Patch the base class to support flash attention 2 |
| 72 | + if hasattr(nemotron_module, "NemotronHPreTrainedModel"): |
| 73 | + nemotron_module.NemotronHPreTrainedModel._supports_flash_attn_2 = True |
| 74 | + else: |
| 75 | + print("[NemotronH Patch] Warning: Could not find NemotronHPreTrainedModel class to patch") |
| 76 | + else: |
| 77 | + print("[NemotronH Patch] Warning: Could not find NemotronH modeling module to patch") |
| 78 | + |
| 79 | + except Exception as e: |
| 80 | + print(f"[NemotronH Patch] Warning: Failed to patch NemotronH for flash attention support: {e}") |
| 81 | + # Don't raise - let the model loading continue and fail naturally if flash attention is truly unsupported |
0 commit comments