Skip to content

Commit dffac6d

Browse files
authored
[Refactor] Add expert processed token count output for DispatchFFNCombine/DispatchFFNCombineBF16 (#6402)
### What this PR does / why we need it? Add New Output for Expert Token Count An additional output tensor expert_token_nums is added to both operators to meet the requirement of tracking token distribution among experts: Tensor Name: expert_token_nums Dimension: 1D tensor Shape: (local_expert_num,) Data Type: int32 Semantics: Represents the number of tokens actually received by each expert on the current card. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.1 - vLLM main: vllm-project/vllm@dc917cc --------- Signed-off-by: guanguan0308 <1546542263@qq.com> Signed-off-by: guanguan0308 <162653673+guanguan0308@users.noreply.github.com>
1 parent 26b83f8 commit dffac6d

18 files changed

+97
-84
lines changed

csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor*
4747
const aclTensor* probs,
4848
const char* group, int64_t maxOutputSize,
4949
bool transB, bool weightNz,
50-
const aclTensor* out,
50+
const aclTensor* out, const aclTensor* expertTokenNums,
5151
uint64_t* workspaceSize, aclOpExecutor** executor);
5252
extern aclnnStatus aclnnInnerDispatchFFNCombine(void *workspace, uint64_t workspaceSize,
5353
aclOpExecutor *executor, aclrtStream stream);
@@ -59,15 +59,15 @@ aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const ac
5959
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
6060
const aclTensor* probs,
6161
const char* group, int64_t maxOutputSize,
62-
const aclTensor* out,
62+
const aclTensor* out, const aclTensor* expertTokenNums,
6363
uint64_t* workspaceSize, aclOpExecutor** executor)
6464
{
6565
bool transB = false;
6666
bool weightNz = true;
6767

6868
aclnnStatus ret = aclnnInnerDispatchFFNCombineGetWorkspaceSize(x, weight1, weight2, expertId, scale1, scale2, probs, group,
6969
maxOutputSize, transB, weightNz,
70-
out, workspaceSize, executor);
70+
out, expertTokenNums, workspaceSize, executor);
7171
return ret;
7272
}
7373

csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWor
4343
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
4444
const aclTensor* probs,
4545
const char* group, int64_t maxOutputSize,
46-
const aclTensor* out,
46+
const aclTensor* out, const aclTensor* expertTokenNums,
4747
uint64_t* workspaceSize, aclOpExecutor** executor);
4848

4949
/**

csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_def.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ class DispatchFFNCombine : public OpDef {
6262
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16})
6363
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
6464
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND,ge::FORMAT_ND});
65+
this->Output("expert_token_nums")
66+
.ParamType(REQUIRED)
67+
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
68+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
69+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
6570

6671
this->Attr("group").AttrType(REQUIRED).String();
6772
this->Attr("M").AttrType(OPTIONAL).Int();

csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,32 @@
2020
using namespace AscendC;
2121
using namespace DispatchFFNCombineImpl;
2222
extern "C" __global__ __aicore__ void dispatch_ffn_combine(GM_ADDR x, GM_ADDR w1, GM_ADDR w2, GM_ADDR expertId, GM_ADDR scale1, GM_ADDR scale2, GM_ADDR probs,
23-
GM_ADDR c, GM_ADDR workspaceGM, GM_ADDR tilingGM)
23+
GM_ADDR c, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
2424
{
2525
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
2626
if (TILING_KEY_IS(1000000)) {
2727
KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2);
2828
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
2929
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
30-
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
30+
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
3131
op.Process();
3232
} else if (TILING_KEY_IS(1000001)) {
3333
KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2);
3434
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
3535
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, false> op;
36-
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
36+
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
3737
op.Process();
3838
} else if (TILING_KEY_IS(1000010)) {
3939
KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2);
4040
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
4141
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
42-
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
42+
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
4343
op.Process();
4444
} else if (TILING_KEY_IS(1000011)) {
4545
KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2);
4646
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
4747
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, true> op;
48-
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM);
48+
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
4949
op.Process();
5050
}
5151
}

csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class DispatchFFNCombine {
5555
public:
5656
__aicore__ inline DispatchFFNCombine() {};
5757
__aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
58-
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM);
58+
GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM);
5959
__aicore__ inline void Process();
6060

6161

@@ -68,6 +68,7 @@ class DispatchFFNCombine {
6868
GM_ADDR scale2GM_;
6969
GM_ADDR probs_;
7070
GM_ADDR outGM_;
71+
GM_ADDR gmExpertTokenNums_;
7172
GM_ADDR workspaceGM_;
7273

7374
GM_ADDR moeInitRoutingQuantV2Scale = nullptr;
@@ -112,7 +113,7 @@ class DispatchFFNCombine {
112113

113114
template <TemplateMMA2AClass>
114115
__aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM,
115-
GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM)
116+
GM_ADDR probs, GM_ADDR outGM, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
116117
{
117118
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
118119
auto tiling = (__gm__ DispatchFFNCombineTilingData*)tilingGM;
@@ -127,6 +128,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Init(GM_ADDR xGM,
127128
probs_ = probs;
128129

129130
outGM_ = outGM;
131+
gmExpertTokenNums_ = expertTokenNums;
130132

131133
workspaceGM_ = workspaceGM;
132134

@@ -268,7 +270,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Process()
268270
outGM_, layoutD1, layoutD2,
269271
expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset,
270272
expertTokensBeforeCapacity, probs_,
271-
workspaceGM_, ubMoveNum, moeInitRoutingQuantV2TilingData};
273+
workspaceGM_, gmExpertTokenNums_, ubMoveNum, moeInitRoutingQuantV2TilingData};
272274
//Call kernel
273275
MatmulKernel kernel(params);
274276
kernel(params);

csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class DispatchFFNCombineKernel {
9696
LayoutD1 layoutD1;
9797
LayoutD2 layoutD2;
9898
GM_ADDR ptrWorkspace;
99+
GM_ADDR ptrExpertTokenNums;
99100
int32_t EP;
100101
int32_t listLen;
101102
int32_t expertPerRank;
@@ -139,7 +140,7 @@ class DispatchFFNCombineKernel {
139140
GM_ADDR expertIdx_, GM_ADDR moeInitRoutingQuantV2Scale_,
140141
GM_ADDR moeInitRoutingQuantV2Offset_,
141142
GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_,
142-
GM_ADDR ptrWorkspace_, int32_t ubMoveNum_,
143+
GM_ADDR ptrWorkspace_, GM_ADDR gmExpertTokenNums_, int32_t ubMoveNum_,
143144
optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData_
144145
) : problemShape(problemShape_),
145146
EP(EP_), listLen(listLen_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_),
@@ -155,7 +156,7 @@ class DispatchFFNCombineKernel {
155156
expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_),
156157
moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_),
157158
expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_),
158-
ptrWorkspace(ptrWorkspace_), ubMoveNum(ubMoveNum_),
159+
ptrWorkspace(ptrWorkspace_), ptrExpertTokenNums(gmExpertTokenNums_), ubMoveNum(ubMoveNum_),
159160
moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_)
160161
{
161162
}
@@ -228,7 +229,7 @@ class DispatchFFNCombineKernel {
228229

229230
tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert));
230231

231-
tokenPerExpertLayout = Layout3D( AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
232+
tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank);
232233
}
233234

234235
template<typename T>
@@ -335,15 +336,6 @@ class DispatchFFNCombineKernel {
335336
AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul
336337
syncgmmIdx++;
337338

338-
constexpr uint32_t MAX_EXPERTS_PER_RANK = 32;
339-
__gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK];
340-
__gm__ ElementScale * scale1Array[MAX_EXPERTS_PER_RANK];
341-
342-
int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank;
343-
for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) {
344-
weight1Array[loopIdx] = reinterpret_cast<__gm__ ElementB*>(GetTensorAddr<int8_t>(loopIdx, params.ptrB1));
345-
scale1Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(loopIdx, params.ptrScale1));
346-
}
347339
AscendC::PipeBarrier<PIPE_ALL>();
348340

349341
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
@@ -356,8 +348,8 @@ class DispatchFFNCombineKernel {
356348
AscendC::GlobalTensor<ElementB> gmB1;
357349
AscendC::GlobalTensor<ElementScale> gmS;
358350
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
359-
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight1Array[arrayGroupIdx]));
360-
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale1Array[arrayGroupIdx]));
351+
gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB1)));
352+
gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale1)));
361353

362354
AscendC::PipeBarrier<PIPE_ALL>();
363355

@@ -455,14 +447,6 @@ class DispatchFFNCombineKernel {
455447
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
456448
}
457449

458-
constexpr uint32_t MAX_EXPERTS_PER_RANK = 8;
459-
__gm__ ElementB* weight2Array[MAX_EXPERTS_PER_RANK];
460-
__gm__ ElementScale * scale2Array[MAX_EXPERTS_PER_RANK];
461-
int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank;
462-
for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) {
463-
weight2Array[loopIdx] = reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(loopIdx, params.ptrB2));
464-
scale2Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(loopIdx, params.ptrScale2));
465-
}
466450
AscendC::PipeBarrier<PIPE_ALL>();
467451

468452
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
@@ -476,8 +460,8 @@ class DispatchFFNCombineKernel {
476460
AscendC::GlobalTensor<ElementScale> gmS2;
477461
AscendC::PipeBarrier<PIPE_ALL>();
478462
int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx;
479-
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight2Array[arrayGroupIdx]));
480-
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale2Array[arrayGroupIdx]));
463+
gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr<int8_t>(arrayGroupIdx, params.ptrB2)));
464+
gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr<int64_t>(arrayGroupIdx, params.ptrScale2)));
481465

482466
if (currentM <= L1TileShape::M) {
483467
gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE);
@@ -596,7 +580,6 @@ class DispatchFFNCombineKernel {
596580
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
597581
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(EVENT_ID0);
598582
}
599-
600583
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
601584
if (dstEpIdx == params.rank) {
602585
continue;
@@ -639,6 +622,13 @@ class DispatchFFNCombineKernel {
639622
GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP);
640623
}
641624
AscendC::SyncAll<true>();
625+
626+
AscendC::GlobalTensor<int32_t> ExpertTokenNums;
627+
ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums));
628+
AscendC::GlobalTensor<int32_t> LcalCumsumMM;
629+
LcalCumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM + (params.EP - 1) * params.expertPerRank * sizeof(int32_t)));
630+
CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum);
631+
AscendC::SyncAll<true>();
642632
uint16_t syncgmm1Idx = 0;
643633
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);
644634
syncgmm1Idx++;

csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ extern aclnnStatus aclnnInnerDispatchFFNCombineBF16GetWorkspaceSize(const aclTen
4747
const aclTensor* probs,
4848
const char* group, int64_t maxOutputSize,
4949
bool transB, bool weightNz,
50-
const aclTensor* out,
50+
const aclTensor* out, const aclTensor* expertTokenNums,
5151
uint64_t* workspaceSize, aclOpExecutor** executor);
5252
extern aclnnStatus aclnnInnerDispatchFFNCombineBF16(void *workspace, uint64_t workspaceSize,
5353
aclOpExecutor *executor, aclrtStream stream);
@@ -59,15 +59,15 @@ aclnnStatus aclnnDispatchFFNCombineBF16GetWorkspaceSize(const aclTensor* x, cons
5959
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
6060
const aclTensor* probs,
6161
const char* group, int64_t maxOutputSize,
62-
const aclTensor* out,
62+
const aclTensor* out, const aclTensor* expertTokenNums,
6363
uint64_t* workspaceSize, aclOpExecutor** executor)
6464
{
6565
bool transB = false;
6666
bool weightNz = true;
6767

6868
aclnnStatus ret = aclnnInnerDispatchFFNCombineBF16GetWorkspaceSize(x, weight1, weight2, expertId, scale1, scale2, probs, group,
6969
maxOutputSize, transB, weightNz,
70-
out, workspaceSize, executor);
70+
out, expertTokenNums, workspaceSize, executor);
7171
return ret;
7272
}
7373

csrc/dispatch_ffn_combine_bf16/op_host/aclnn_dispatch_ffn_combine_bf16.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineBF16Ge
2525
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
2626
const aclTensor* probs,
2727
const char* group, int64_t maxOutputSize,
28-
const aclTensor* out,
28+
const aclTensor* out, const aclTensor* expertTokenNums,
2929
uint64_t* workspaceSize, aclOpExecutor** executor);
3030

3131

csrc/dispatch_ffn_combine_bf16/op_host/dispatch_ffn_combine_bf16_def.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ class DispatchFFNCombineBF16 : public OpDef {
6262
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16})
6363
.Format({ ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
6464
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
65+
this->Output("expert_token_nums")
66+
.ParamType(REQUIRED)
67+
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
68+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
69+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
6570

6671
this->Attr("group").AttrType(REQUIRED).String();
6772
this->Attr("M").AttrType(OPTIONAL).Int();

0 commit comments

Comments
 (0)