Skip to content

Commit 0c15d6b

Browse files
committed
[megatron] fix: patch support newer mcore version
Tested on NVIDIA/Megatron-LM@bbbedbb Signed-off-by: Hollow Man <hollowman@opensuse.org>
1 parent 37ff251 commit 0c15d6b

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

verl/models/mcore/patch.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,14 @@ def patch_forward(
258258
# Get the query, key and value tensors based on the type of attention -
259259
# self or cross attn.
260260
# query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128]
261-
query, key, value = self.get_query_key_value_tensors(
261+
qkv = self.get_query_key_value_tensors(
262262
hidden_states,
263263
key_value_states,
264264
position_ids,
265265
packed_seq_params,
266266
inference_context=inference_context,
267267
)
268+
query, key, value = qkv[:3]
268269

269270
# ===================================================
270271
# Adjust key, value for inference
@@ -329,7 +330,11 @@ def patch_forward(
329330

330331
return output, bias
331332

332-
MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors
333+
# This patch targets mcore 0.12 MLA behavior only.
334+
# For newer mcore, upstream MLA already has packed-seq + CP handling and
335+
# overriding it with the legacy implementation can break RoPE shapes.
336+
if not mcore_ge_013:
337+
MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors
333338

334339
MultiLatentAttention.forward = patch_forward
335340

0 commit comments

Comments
 (0)