From 1f95774c3107713256933b2c7e4cc38f27df1777 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Mon, 23 Feb 2026 22:14:44 +0800 Subject: [PATCH] [megatron] fix: patch support newer mcore version Tested on https://github.com/NVIDIA/Megatron-LM/commit/bbbedbb9f53343762e4dc70abc771b813a83d817 Signed-off-by: Hollow Man --- verl/models/mcore/patch.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/verl/models/mcore/patch.py b/verl/models/mcore/patch.py index 2968b3daace..0edc8693394 100644 --- a/verl/models/mcore/patch.py +++ b/verl/models/mcore/patch.py @@ -258,13 +258,17 @@ def patch_forward( # Get the query, key and value tensors based on the type of attention - # self or cross attn. # query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128] - query, key, value = self.get_query_key_value_tensors( + qkv = self.get_query_key_value_tensors( hidden_states, key_value_states, position_ids, packed_seq_params, inference_context=inference_context, ) + query, key, value = qkv[:3] + q_compressed = None + if len(qkv) > 3: + q_compressed = qkv[3] # =================================================== # Adjust key, value for inference @@ -298,6 +302,12 @@ def patch_forward( query, key, value, attention_mask, packed_seq_params=packed_seq_params ) else: + extra_kwargs = {} + if getattr(self.config, "experimental_attention_variant", None) == "dsa": + # For dsa we need to pass in the original hidden states and the compressed + # query representation. + extra_kwargs["x"] = hidden_states + extra_kwargs["qr"] = q_compressed core_attn_out = self.core_attention( query, key, @@ -305,6 +315,7 @@ def patch_forward( attention_mask, packed_seq_params=packed_seq_params, attn_mask_type=attn_mask_type, + **extra_kwargs, ) if thd_qkv_format: if core_attn_out.ndim == 2: @@ -329,7 +340,11 @@ def patch_forward( return output, bias - MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors + # This patch targets mcore 0.12 MLA behavior only. + # For newer mcore, upstream MLA already has packed-seq + CP handling and + # overriding it with the legacy implementation can break RoPE shapes. + if not mcore_ge_013: + MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors MultiLatentAttention.forward = patch_forward