Skip to content

Commit 1f95774

Browse files
committed
[megatron] fix: patch support newer mcore version
Tested on NVIDIA/Megatron-LM@bbbedbb Signed-off-by: Hollow Man <hollowman@opensuse.org>
1 parent b5979db commit 1f95774

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

verl/models/mcore/patch.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,17 @@ def patch_forward(
258258
# Get the query, key and value tensors based on the type of attention -
259259
# self or cross attn.
260260
# query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128]
261-
query, key, value = self.get_query_key_value_tensors(
261+
qkv = self.get_query_key_value_tensors(
262262
hidden_states,
263263
key_value_states,
264264
position_ids,
265265
packed_seq_params,
266266
inference_context=inference_context,
267267
)
268+
query, key, value = qkv[:3]
269+
q_compressed = None
270+
if len(qkv) > 3:
271+
q_compressed = qkv[3]
268272

269273
# ===================================================
270274
# Adjust key, value for inference
@@ -298,13 +302,20 @@ def patch_forward(
298302
query, key, value, attention_mask, packed_seq_params=packed_seq_params
299303
)
300304
else:
305+
extra_kwargs = {}
306+
if getattr(self.config, "experimental_attention_variant", None) == "dsa":
307+
# For dsa we need to pass in the original hidden states and the compressed
308+
# query representation.
309+
extra_kwargs["x"] = hidden_states
310+
extra_kwargs["qr"] = q_compressed
301311
core_attn_out = self.core_attention(
302312
query,
303313
key,
304314
value,
305315
attention_mask,
306316
packed_seq_params=packed_seq_params,
307317
attn_mask_type=attn_mask_type,
318+
**extra_kwargs,
308319
)
309320
if thd_qkv_format:
310321
if core_attn_out.ndim == 2:
@@ -329,7 +340,11 @@ def patch_forward(
329340

330341
return output, bias
331342

332-
MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors
343+
# This patch targets mcore 0.12 MLA behavior only.
344+
# For newer mcore, upstream MLA already has packed-seq + CP handling and
345+
# overriding it with the legacy implementation can break RoPE shapes.
346+
if not mcore_ge_013:
347+
MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors
333348

334349
MultiLatentAttention.forward = patch_forward
335350

0 commit comments

Comments
 (0)