@@ -258,13 +258,17 @@ def patch_forward(
258258 # Get the query, key and value tensors based on the type of attention -
259259 # self or cross attn.
260260 # query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128]
261- query , key , value = self .get_query_key_value_tensors (
261+ qkv = self .get_query_key_value_tensors (
262262 hidden_states ,
263263 key_value_states ,
264264 position_ids ,
265265 packed_seq_params ,
266266 inference_context = inference_context ,
267267 )
268+ query , key , value = qkv [:3 ]
269+ q_compressed = None
270+ if len (qkv ) > 3 :
271+ q_compressed = qkv [4 ]
268272
269273 # ===================================================
270274 # Adjust key, value for inference
@@ -298,13 +302,20 @@ def patch_forward(
298302 query , key , value , attention_mask , packed_seq_params = packed_seq_params
299303 )
300304 else :
305+ extra_kwargs = {}
306+ if getattr (self .config , "experimental_attention_variant" , None ) == "dsa" :
307+ # For dsa we need to pass in the original hidden states and the compressed
308+ # query representation.
309+ extra_kwargs ["x" ] = hidden_states
310+ extra_kwargs ["qr" ] = q_compressed
301311 core_attn_out = self .core_attention (
302312 query ,
303313 key ,
304314 value ,
305315 attention_mask ,
306316 packed_seq_params = packed_seq_params ,
307317 attn_mask_type = attn_mask_type ,
318+ ** extra_kwargs ,
308319 )
309320 if thd_qkv_format :
310321 if core_attn_out .ndim == 2 :
@@ -329,7 +340,11 @@ def patch_forward(
329340
330341 return output , bias
331342
332- MLASelfAttention .get_query_key_value_tensors = patch_get_query_key_value_tensors
343+ # This patch targets mcore 0.12 MLA behavior only.
344+ # For newer mcore, upstream MLA already has packed-seq + CP handling and
345+ # overriding it with the legacy implementation can break RoPE shapes.
346+ if not mcore_ge_013 :
347+ MLASelfAttention .get_query_key_value_tensors = patch_get_query_key_value_tensors
333348
334349 MultiLatentAttention .forward = patch_forward
335350
0 commit comments