Skip to content

Commit 4071c57

Browse files
committed
[Misc] Remove custom op rotary_embedding
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent bfcc372 commit 4071c57

File tree

8 files changed

+62
-1397
lines changed

8 files changed

+62
-1397
lines changed

csrc/kernels/pos_encoding_kernels.cpp

Lines changed: 0 additions & 372 deletions
This file was deleted.

csrc/ops.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,6 @@
2424
#include "torch_npu/csrc/aten/common/from_blob.h"
2525

2626
namespace vllm_ascend {
27-
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
28-
void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim,
29-
const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride,
30-
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
31-
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
32-
uint32_t aivNum);
33-
3427
extern void get_masked_input_and_mask_impl(
3528
void* stream,
3629
void* input,

csrc/torch_binding.cpp

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -105,75 +105,6 @@ AscendType get_dtype_from_torch(at::ScalarType scalarType)
105105
}
106106
}
107107

108-
std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
109-
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
110-
{
111-
int32_t deviceId = 0;
112-
int64_t num_tokens = positions.numel();
113-
int positions_ndim = positions.dim();
114-
TORCH_CHECK(
115-
positions_ndim == 1 || positions_ndim == 2,
116-
"positions must have shape [num_tokens] or [batch_size, seq_len]");
117-
if (positions_ndim == 1) {
118-
TORCH_CHECK(
119-
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
120-
"query, key and positions must have the same number of tokens");
121-
}
122-
if (positions_ndim == 2) {
123-
TORCH_CHECK(
124-
query.size(0) == positions.size(0) &&
125-
key.size(0) == positions.size(0) &&
126-
query.size(1) == positions.size(1) &&
127-
key.size(1) == positions.size(1),
128-
"query, key and positions must have the same batch_size and seq_len");
129-
}
130-
TORCH_CHECK(head_size % 32 == 0, "rotary_embedding: headSize should be divisible by 32");
131-
int query_hidden_size = query.numel() / num_tokens;
132-
int key_hidden_size = key.numel() / num_tokens;
133-
TORCH_CHECK(query_hidden_size % head_size == 0);
134-
TORCH_CHECK(key_hidden_size % head_size == 0);
135-
TORCH_CHECK(is_neox == true, "rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend");
136-
137-
// Make sure query and key have consistent number of heads
138-
int num_heads = query_hidden_size / head_size;
139-
int num_kv_heads = key_hidden_size / head_size;
140-
TORCH_CHECK(num_heads % num_kv_heads == 0);
141-
at::Tensor query_dst = at::empty({num_tokens, num_heads, head_size}, query.options());
142-
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.options());
143-
144-
int rot_dim = cos_sin_cache.size(1);
145-
int seq_dim_idx = positions_ndim - 1;
146-
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
147-
void *query_dst_ptr = query_dst.data_ptr();
148-
void *key_dst_ptr = key_dst.data_ptr();
149-
void *query_ptr = query.data_ptr();
150-
void *key_ptr = key.data_ptr();
151-
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
152-
int64_t query_stride = query.stride(seq_dim_idx);
153-
int64_t key_stride = key.stride(seq_dim_idx);
154-
int64_t dst_query_stride = query_dst.stride(0);
155-
int64_t dst_key_stride = key_dst.stride(0);
156-
at::ScalarType scalar_type = query.scalar_type();
157-
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
158-
at_npu::native::OpCommand cmd;
159-
cmd.Name("rotary_embedding");
160-
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr,
161-
query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
162-
dst_query_stride, dst_key_stride, num_heads, num_kv_heads, head_size]() -> int {
163-
auto dtype_num = get_dtype_from_torch(scalar_type);
164-
int device_id = 0;
165-
int64_t aiv_num = 0;
166-
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
167-
uint32_t loop_cnt = (num_tokens + aiv_num - 1) / aiv_num;
168-
rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, query_ptr,
169-
key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, dst_query_stride,
170-
dst_key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aiv_num);
171-
return 0;
172-
});
173-
cmd.Run();
174-
return {query_dst, key_dst};
175-
}
176-
177108
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
178109
const at::Tensor &hiddenState, const at::Tensor &wdqkv,
179110
const c10::optional<at::Tensor> &descale0, const at::Tensor &gamma1, const c10::optional<at::Tensor> &beta1, const at::Tensor &wuq,
@@ -1368,14 +1299,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
13681299
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
13691300
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
13701301

1371-
// Rotary embedding
1372-
// Apply GPT-NeoX style rotary embedding to query and key.
1373-
ops.def(
1374-
"rotary_embedding(Tensor positions, Tensor! query,"
1375-
" Tensor! key, int head_size,"
1376-
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
1377-
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
1378-
13791302
ops.def(
13801303
"get_masked_input_and_mask(Tensor input, "
13811304
" int org_vocab_start_index, "

csrc/torch_binding_meta.cpp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,6 @@
3636
namespace vllm_ascend {
3737
namespace meta {
3838
const int64_t INT4_NUMS_IN_INT32 = 8;
39-
std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
40-
at::Tensor &positions,
41-
at::Tensor &query,
42-
at::Tensor &key,
43-
int64_t head_size,
44-
at::Tensor &cos_sin_cache,
45-
bool is_neox) {
46-
auto num_tokens = positions.sym_numel();
47-
auto query_hidden_size = query.sym_numel() / num_tokens;
48-
auto key_hidden_size = key.sym_numel() / num_tokens;
49-
50-
auto num_heads = query_hidden_size / head_size;
51-
auto num_kv_heads = key_hidden_size / head_size;
52-
at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options());
53-
at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options());
54-
55-
return {query_dst, key_dst};
56-
}
5739

5840
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask_meta(
5941
at::Tensor &input,
@@ -457,8 +439,6 @@ namespace {
457439
// the custom kernel been captured into aclgraph
458440
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
459441

460-
// Rotary embedding meta implementation
461-
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
462442
// Masked input and mask meta implementation
463443
ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta);
464444
// Bgmv expand

0 commit comments

Comments
 (0)