Skip to content

Commit 5453033

Browse files
authored
revert TND modify when dcp pcp (#3948)
### What this PR does / why we need it? 1、revert TND modify when dcp pcp, which is introduced by f57bdb0 2、deal aclgraph pad border issue - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@83f478b Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
1 parent cc2cd42 commit 5453033

File tree

3 files changed

+27
-17
lines changed

3 files changed

+27
-17
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
869869
else:
870870
num_heads = self.num_heads
871871

872+
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
872873
k_nope = self.key_cache.view(self.key_cache.shape[0],
873874
self.key_cache.shape[1], -1)
874875
value = self.value_cache.view(self.key_cache.shape[0],
@@ -879,7 +880,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
879880
'num_key_value_heads':
880881
self.num_kv_heads,
881882
'input_layout':
882-
"TND",
883+
"BSND",
883884
'atten_mask':
884885
None,
885886
'scale':
@@ -895,14 +896,12 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
895896
'block_size':
896897
self.key_cache.shape[1],
897898
'actual_seq_lengths_kv':
898-
attn_metadata.seq_lens_list[:attn_metadata.num_decode_tokens],
899-
'actual_seq_lengths':
900-
attn_metadata.actual_seq_lengths_q[:attn_metadata.
901-
num_decode_tokens]
899+
attn_metadata.decode_meta.
900+
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
902901
}
903902
graph_params = get_graph_params()
904903
forward_context: ForwardContext = get_forward_context()
905-
num_tokens = query.shape[0]
904+
num_tokens = q_nope.shape[0]
906905
if forward_context.capturing:
907906
stream = torch_npu.npu.current_stream()
908907

@@ -914,16 +913,16 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
914913
workspace = graph_params.workspaces.get(num_tokens)
915914
if workspace is None:
916915
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
917-
query, k_nope, value, **common_kwargs)
916+
q_nope, k_nope, value, **common_kwargs)
918917
update_graph_params_workspaces(num_tokens,
919918
weak_ref_tensors(workspace))
920-
attn_out = torch.empty_like(query)
919+
attn_out = torch.empty_like(q_nope)
921920
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
922921
dtype=torch.float,
923-
device=query.device)
922+
device=q_nope.device)
924923

925924
graph_params.attn_params[num_tokens].append(
926-
(weak_ref_tensors(query), weak_ref_tensors(k_nope),
925+
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
927926
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
928927
self.scale, attn_metadata.block_tables,
929928
self.key_cache.shape[1], attn_metadata.decode_meta.
@@ -933,7 +932,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
933932
self.pcp_rank, self.dcp_rank, self.dcp_size))
934933
torch.npu.graph_task_group_begin(stream)
935934
torch_npu.npu_fused_infer_attention_score.out(
936-
query,
935+
q_nope,
937936
k_nope,
938937
value,
939938
**common_kwargs,
@@ -943,7 +942,11 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
943942
graph_params.handles[num_tokens].append(handle)
944943
else:
945944
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
946-
query, k_nope, value, **common_kwargs)
945+
q_nope, k_nope, value, **common_kwargs)
946+
947+
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
948+
attn_out.shape[3])
949+
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
947950

948951
attn_out_lse_list = []
949952
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
@@ -1017,7 +1020,8 @@ def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
10171020
prefill_query, key, value, attn_metadata,
10181021
output[num_decode_tokens:], prefill_query.shape[0])
10191022
attn_metadata.seq_lens = seq_lens_back
1020-
output[num_decode_tokens:] = output_prefill
1023+
output[num_decode_tokens:output_prefill.shape[0] +
1024+
num_decode_tokens] = output_prefill
10211025
return output
10221026

10231027
def forward(
@@ -1089,7 +1093,9 @@ def forward(
10891093
if has_prefill:
10901094
if self.pcp_size > 1:
10911095
kv = torch.cat([key, value], dim=-1)
1092-
all_kv = get_pcp_group().all_gather(kv, dim=0)
1096+
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
1097+
all_kv = get_pcp_group().all_gather(
1098+
kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0)
10931099
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None
10941100
all_kv = torch.index_select(all_kv, 0,
10951101
pcp_allgather_restore_idx)

vllm_ascend/compilation/acl_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,9 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
301301
):
302302
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
303303
block_table, block_size, actual_seq_lengths_kv, attn_output,
304-
softmax_lse, cp_rank, dcp_rank, dcp_size) = param
304+
softmax_lse, pcp_rank, dcp_rank, dcp_size) = param
305305
actual_seq_lengths_kv = forward_context.attn_metadata[
306-
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
306+
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank,
307307
dcp_rank]
308308
pad_length = runtime_shape - len(actual_seq_lengths_kv)
309309
pad_tensor = np.zeros(pad_length,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
476476
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
477477
dtype=torch.int32,
478478
device=self.device)
479+
self.num_actual_tokens_pcp_padded = 0
479480
if self.speculative_config and self.pcp_size > 1:
480481
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
481482
dtype=torch.int32,
@@ -1915,7 +1916,9 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
19151916
hidden_states = hidden_states[:-pad_size, :]
19161917

19171918
if self.pcp_size > 1:
1918-
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
1919+
hidden_states = get_pcp_group().all_gather(
1920+
hidden_states[:self.num_actual_tokens_pcp_padded //
1921+
self.pcp_size], 0)
19191922
hidden_states = torch.index_select(
19201923
hidden_states, 0,
19211924
self.pcp_allgather_restore_idx[:hidden_states.shape[0]])
@@ -4304,6 +4307,7 @@ def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens):
43044307
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
43054308
>= self.input_batch.num_prompt_tokens[:num_reqs])
43064309
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
4310+
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
43074311
long_seq_metadata = None
43084312
if self.pcp_size * self.dcp_size > 1:
43094313
long_seq_metadata = AscendPrefillContextParallelMetadata(

0 commit comments

Comments
 (0)