@@ -105,75 +105,6 @@ AscendType get_dtype_from_torch(at::ScalarType scalarType)
105105 }
106106}
107107
108- std::tuple<at::Tensor, at::Tensor> rotary_embedding (at::Tensor &positions, at::Tensor &query, at::Tensor &key,
109- int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
110- {
111- int32_t deviceId = 0 ;
112- int64_t num_tokens = positions.numel ();
113- int positions_ndim = positions.dim ();
114- TORCH_CHECK (
115- positions_ndim == 1 || positions_ndim == 2 ,
116- " positions must have shape [num_tokens] or [batch_size, seq_len]" );
117- if (positions_ndim == 1 ) {
118- TORCH_CHECK (
119- query.size (0 ) == positions.size (0 ) && key.size (0 ) == positions.size (0 ),
120- " query, key and positions must have the same number of tokens" );
121- }
122- if (positions_ndim == 2 ) {
123- TORCH_CHECK (
124- query.size (0 ) == positions.size (0 ) &&
125- key.size (0 ) == positions.size (0 ) &&
126- query.size (1 ) == positions.size (1 ) &&
127- key.size (1 ) == positions.size (1 ),
128- " query, key and positions must have the same batch_size and seq_len" );
129- }
130- TORCH_CHECK (head_size % 32 == 0 , " rotary_embedding: headSize should be divisible by 32" );
131- int query_hidden_size = query.numel () / num_tokens;
132- int key_hidden_size = key.numel () / num_tokens;
133- TORCH_CHECK (query_hidden_size % head_size == 0 );
134- TORCH_CHECK (key_hidden_size % head_size == 0 );
135- TORCH_CHECK (is_neox == true , " rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend" );
136-
137- // Make sure query and key have consistent number of heads
138- int num_heads = query_hidden_size / head_size;
139- int num_kv_heads = key_hidden_size / head_size;
140- TORCH_CHECK (num_heads % num_kv_heads == 0 );
141- at::Tensor query_dst = at::empty ({num_tokens, num_heads, head_size}, query.options ());
142- at::Tensor key_dst = at::empty ({num_tokens, num_kv_heads, head_size}, key.options ());
143-
144- int rot_dim = cos_sin_cache.size (1 );
145- int seq_dim_idx = positions_ndim - 1 ;
146- int64_t *position_ids_ptr = positions.data_ptr <int64_t >();
147- void *query_dst_ptr = query_dst.data_ptr ();
148- void *key_dst_ptr = key_dst.data_ptr ();
149- void *query_ptr = query.data_ptr ();
150- void *key_ptr = key.data_ptr ();
151- void *cos_sin_cache_ptr = cos_sin_cache.data_ptr ();
152- int64_t query_stride = query.stride (seq_dim_idx);
153- int64_t key_stride = key.stride (seq_dim_idx);
154- int64_t dst_query_stride = query_dst.stride (0 );
155- int64_t dst_key_stride = key_dst.stride (0 );
156- at::ScalarType scalar_type = query.scalar_type ();
157- aclrtStream stream = c10_npu::getCurrentNPUStream ().stream ();
158- at_npu::native::OpCommand cmd;
159- cmd.Name (" rotary_embedding" );
160- cmd.SetCustomHandler ([scalar_type, is_neox, num_tokens, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr,
161- query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
162- dst_query_stride, dst_key_stride, num_heads, num_kv_heads, head_size]() -> int {
163- auto dtype_num = get_dtype_from_torch (scalar_type);
164- int device_id = 0 ;
165- int64_t aiv_num = 0 ;
166- TORCH_CHECK (aclGetDeviceCapability (device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
167- uint32_t loop_cnt = (num_tokens + aiv_num - 1 ) / aiv_num;
168- rotary_embedding_impl (dtype_num, is_neox, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, query_ptr,
169- key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, dst_query_stride,
170- dst_key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aiv_num);
171- return 0 ;
172- });
173- cmd.Run ();
174- return {query_dst, key_dst};
175- }
176-
177108std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess (
178109 const at::Tensor &hiddenState, const at::Tensor &wdqkv,
179110 const c10::optional<at::Tensor> &descale0, const at::Tensor &gamma1, const c10::optional<at::Tensor> &beta1, const at::Tensor &wuq,
@@ -1368,14 +1299,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
13681299 ops.def (" weak_ref_tensor(Tensor input) -> Tensor" );
13691300 ops.impl (" weak_ref_tensor" , torch::kPrivateUse1 , &vllm_ascend::weak_ref_tensor);
13701301
1371- // Rotary embedding
1372- // Apply GPT-NeoX style rotary embedding to query and key.
1373- ops.def (
1374- " rotary_embedding(Tensor positions, Tensor! query,"
1375- " Tensor! key, int head_size,"
1376- " Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)" );
1377- ops.impl (" rotary_embedding" , torch::kPrivateUse1 , &vllm_ascend::rotary_embedding);
1378-
13791302 ops.def (
13801303 " get_masked_input_and_mask(Tensor input, "
13811304 " int org_vocab_start_index, "
0 commit comments