Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions verl/models/mcore/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,17 @@ def patch_forward(
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
# query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128]
query, key, value = self.get_query_key_value_tensors(
qkv = self.get_query_key_value_tensors(
hidden_states,
key_value_states,
position_ids,
packed_seq_params,
inference_context=inference_context,
)
query, key, value = qkv[:3]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In NVIDIA/Megatron-LM@382eeea#diff-d7e453d550e7f78221ea3445018264015c1e83c0db892a41253bde71d069a202L247, the return value for get_query_key_value_tensors has changed from query, key, value to query, key, value, q_compressed, kv_compressed. kv_compressed is not used in the forward pass so it just gets ignored, here I write in this way to take the first 3 returned values for backward compatibility.

I just found that q_compressed is actually used for DSA like DeepSeek V3.2, so I made and pushed a modification to fix it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this whole patch for newer version megatron?

Copy link
Collaborator Author

@HollowMan6 HollowMan6 Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as unfortunately NVIDIA/TransformerEngine#2629 hasn't been merged into TE.

Copy link
Collaborator Author

@HollowMan6 HollowMan6 Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and neither NVIDIA/Megatron-LM#3003 for mcore, if one of them gets merged, then we don't need this patch to support flash attention for MLA.

q_compressed = None
if len(qkv) > 3:
q_compressed = qkv[3]

# ===================================================
# Adjust key, value for inference
Expand Down Expand Up @@ -298,13 +302,20 @@ def patch_forward(
query, key, value, attention_mask, packed_seq_params=packed_seq_params
)
else:
extra_kwargs = {}
if getattr(self.config, "experimental_attention_variant", None) == "dsa":
# For dsa we need to pass in the original hidden states and the compressed
# query representation.
extra_kwargs["x"] = hidden_states
extra_kwargs["qr"] = q_compressed
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
packed_seq_params=packed_seq_params,
attn_mask_type=attn_mask_type,
**extra_kwargs,
)
if thd_qkv_format:
if core_attn_out.ndim == 2:
Expand All @@ -329,7 +340,11 @@ def patch_forward(

return output, bias

MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors
# This patch targets mcore 0.12 MLA behavior only.
# For newer mcore, upstream MLA already has packed-seq + CP handling and
# overriding it with the legacy implementation can break RoPE shapes.
if not mcore_ge_013:
MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors

MultiLatentAttention.forward = patch_forward

Expand Down
Loading