@@ -869,6 +869,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
869869 else :
870870 num_heads = self .num_heads
871871
872+ q_nope = query .view (query .shape [0 ], 1 , query .shape [1 ], query .shape [2 ])
872873 k_nope = self .key_cache .view (self .key_cache .shape [0 ],
873874 self .key_cache .shape [1 ], - 1 )
874875 value = self .value_cache .view (self .key_cache .shape [0 ],
@@ -879,7 +880,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
879880 'num_key_value_heads' :
880881 self .num_kv_heads ,
881882 'input_layout' :
882- "TND " ,
883+ "BSND " ,
883884 'atten_mask' :
884885 None ,
885886 'scale' :
@@ -895,14 +896,12 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
895896 'block_size' :
896897 self .key_cache .shape [1 ],
897898 'actual_seq_lengths_kv' :
898- attn_metadata .seq_lens_list [:attn_metadata .num_decode_tokens ],
899- 'actual_seq_lengths' :
900- attn_metadata .actual_seq_lengths_q [:attn_metadata .
901- num_decode_tokens ]
899+ attn_metadata .decode_meta .
900+ num_computed_tokens_of_pcp_dcp [:, self .pcp_rank , self .dcp_rank ],
902901 }
903902 graph_params = get_graph_params ()
904903 forward_context : ForwardContext = get_forward_context ()
905- num_tokens = query .shape [0 ]
904+ num_tokens = q_nope .shape [0 ]
906905 if forward_context .capturing :
907906 stream = torch_npu .npu .current_stream ()
908907
@@ -914,16 +913,16 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
914913 workspace = graph_params .workspaces .get (num_tokens )
915914 if workspace is None :
916915 workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
917- query , k_nope , value , ** common_kwargs )
916+ q_nope , k_nope , value , ** common_kwargs )
918917 update_graph_params_workspaces (num_tokens ,
919918 weak_ref_tensors (workspace ))
920- attn_out = torch .empty_like (query )
919+ attn_out = torch .empty_like (q_nope )
921920 attn_lse = torch .empty ((num_tokens , num_heads , 1 , 1 ),
922921 dtype = torch .float ,
923- device = query .device )
922+ device = q_nope .device )
924923
925924 graph_params .attn_params [num_tokens ].append (
926- (weak_ref_tensors (query ), weak_ref_tensors (k_nope ),
925+ (weak_ref_tensors (q_nope ), weak_ref_tensors (k_nope ),
927926 weak_ref_tensors (value ), self .num_heads , self .num_kv_heads ,
928927 self .scale , attn_metadata .block_tables ,
929928 self .key_cache .shape [1 ], attn_metadata .decode_meta .
@@ -933,7 +932,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
933932 self .pcp_rank , self .dcp_rank , self .dcp_size ))
934933 torch .npu .graph_task_group_begin (stream )
935934 torch_npu .npu_fused_infer_attention_score .out (
936- query ,
935+ q_nope ,
937936 k_nope ,
938937 value ,
939938 ** common_kwargs ,
@@ -943,7 +942,11 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
943942 graph_params .handles [num_tokens ].append (handle )
944943 else :
945944 attn_out , attn_lse = torch_npu .npu_fused_infer_attention_score (
946- query , k_nope , value , ** common_kwargs )
945+ q_nope , k_nope , value , ** common_kwargs )
946+
947+ attn_out = attn_out .view (attn_out .shape [0 ], attn_out .shape [2 ],
948+ attn_out .shape [3 ])
949+ attn_lse = attn_lse .view (attn_lse .shape [0 ], attn_lse .shape [1 ], 1 )
947950
948951 attn_out_lse_list = []
949952 # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
@@ -1017,7 +1020,8 @@ def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
10171020 prefill_query , key , value , attn_metadata ,
10181021 output [num_decode_tokens :], prefill_query .shape [0 ])
10191022 attn_metadata .seq_lens = seq_lens_back
1020- output [num_decode_tokens :] = output_prefill
1023+ output [num_decode_tokens :output_prefill .shape [0 ] +
1024+ num_decode_tokens ] = output_prefill
10211025 return output
10221026
10231027 def forward (
@@ -1089,7 +1093,9 @@ def forward(
10891093 if has_prefill :
10901094 if self .pcp_size > 1 :
10911095 kv = torch .cat ([key , value ], dim = - 1 )
1092- all_kv = get_pcp_group ().all_gather (kv , dim = 0 )
1096+ num_actual_tokens_pcp_padded = attn_metadata .num_actual_tokens_pcp_padded // self .pcp_size
1097+ all_kv = get_pcp_group ().all_gather (
1098+ kv [:num_actual_tokens_pcp_padded ].contiguous (), dim = 0 )
10931099 pcp_allgather_restore_idx = attn_metadata .prefill .pcp_allgather_restore_idx if attn_metadata .prefill else None
10941100 all_kv = torch .index_select (all_kv , 0 ,
10951101 pcp_allgather_restore_idx )
0 commit comments