Skip to content

Commit 54cd758

Browse files
committed
upd
1 parent b132d94 commit 54cd758

File tree

1 file changed

+78
-24
lines changed

1 file changed

+78
-24
lines changed

verl/models/transformers/qwen3_next.py

Lines changed: 78 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717

1818
import torch
19+
import torch.distributed as dist
1920
import torch.nn as nn
2021
import torch.nn.functional as F
2122
from transformers.activations import ACT2FN
@@ -30,6 +31,7 @@
3031

3132
try:
3233
from fla.modules import FusedRMSNormGated, ShortConvolution
34+
from fla.ops.cp import build_cp_context
3335
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
3436
except ImportError as err:
3537
raise ImportError("Please install flash-linear-attention for Qwen3-Next") from err
@@ -38,6 +40,24 @@
3840
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
3941

4042

43+
class _AllGather(torch.autograd.Function):
44+
@staticmethod
45+
def forward(ctx, local_tensor: torch.Tensor, group):
46+
ctx.group = group
47+
ctx.part_size = local_tensor.size(0)
48+
return all_gather_tensor(local_tensor, group=group)
49+
50+
@staticmethod
51+
def backward(ctx, grad_output: torch.Tensor):
52+
grad_local = torch.empty(
53+
(ctx.part_size, *grad_output.shape[1:]),
54+
dtype=grad_output.dtype,
55+
device=grad_output.device,
56+
)
57+
dist.reduce_scatter_tensor(grad_local, grad_output, op=dist.ReduceOp.SUM, group=ctx.group)
58+
return grad_local, None
59+
60+
4161
# Adapted from https://github.com/huggingface/transformers/blob/c9ea365a7b56326418769a4ba4682864d407ed63/src/transformers/models/qwen3_next/modular_qwen3_next.py#L428
4262
class PatchQwen3NextGatedDeltaNet(nn.Module):
4363
def __init__(self, config, layer_idx: int):
@@ -120,12 +140,19 @@ def forward(
120140
self,
121141
hidden_states,
122142
cu_seqlens=None,
143+
cp_context=None,
123144
):
124-
# NOTE: when using ulysses sequence parallelism, batch size is always 1
125-
# pre-process: [bsz, seq, h] -> [seq, bsz, h] -> [seq * sp, bsz, h] -> [bsz, seq * sp, h]
126-
hidden_states = hidden_states.transpose(0, 1).contiguous()
127-
hidden_states = all_gather_tensor(hidden_states)
128-
hidden_states = hidden_states.transpose(0, 1).contiguous()
145+
if cp_context is not None:
146+
use_cp_mode = True
147+
cu_seqlens = cp_context.cu_seqlens
148+
elif cu_seqlens is not None:
149+
# pre-process: [bsz, seq, h] -> [seq, bsz, h] -> [seq * sp, bsz, h] -> [bsz, seq * sp, h]
150+
use_cp_mode = False
151+
hidden_states = hidden_states.transpose(0, 1).contiguous()
152+
hidden_states = _AllGather.apply(hidden_states, get_ulysses_sequence_parallel_group())
153+
hidden_states = hidden_states.transpose(0, 1).contiguous()
154+
else:
155+
raise ValueError("cu_seqlens or cp_context is required")
129156

130157
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
131158
projected_states_ba = self.in_proj_ba(hidden_states)
@@ -137,6 +164,7 @@ def forward(
137164
mixed_qkv, _ = self.conv1d(
138165
x=mixed_qkv,
139166
cu_seqlens=cu_seqlens,
167+
cp_context=cp_context if use_cp_mode else None,
140168
)
141169

142170
query, key, value = torch.split(
@@ -159,17 +187,30 @@ def forward(
159187
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
160188
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
161189

162-
core_attn_out, _ = chunk_gated_delta_rule(
163-
query,
164-
key,
165-
value,
166-
g=g,
167-
beta=beta,
168-
initial_state=None,
169-
output_final_state=False,
170-
use_qk_l2norm_in_kernel=True,
171-
cu_seqlens=cu_seqlens,
172-
)
190+
if use_cp_mode:
191+
core_attn_out, _ = chunk_gated_delta_rule(
192+
query,
193+
key,
194+
value,
195+
g=g,
196+
beta=beta,
197+
initial_state=None,
198+
output_final_state=False,
199+
use_qk_l2norm_in_kernel=True,
200+
cp_context=cp_context,
201+
)
202+
else:
203+
core_attn_out, _ = chunk_gated_delta_rule(
204+
query,
205+
key,
206+
value,
207+
g=g,
208+
beta=beta,
209+
initial_state=None,
210+
output_final_state=False,
211+
use_qk_l2norm_in_kernel=True,
212+
cu_seqlens=cu_seqlens,
213+
)
173214

174215
z_shape_og = z.shape
175216
# reshape input data into 2D tensor
@@ -180,9 +221,9 @@ def forward(
180221
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)
181222

182223
output = self.out_proj(core_attn_out)
183-
184-
# post-process: [bsz, seq * sp, h] -> [bsz, seq, h]
185-
output = slice_input_tensor(output, dim=1, padding=False)
224+
if not use_cp_mode:
225+
# post-process: [bsz, seq * sp, h] -> [bsz, seq, h]
226+
output = slice_input_tensor(output, dim=1, padding=False)
186227
return output
187228

188229

@@ -195,6 +236,7 @@ def patch_qwen3_next_decoder_layer_forward(
195236
position_ids=None,
196237
past_key_values=None,
197238
cache_position=None,
239+
gdn_use_cp: bool = True,
198240
**kwargs,
199241
):
200242
residual = hidden_states
@@ -205,15 +247,27 @@ def patch_qwen3_next_decoder_layer_forward(
205247
if self.layer_type == "linear_attention":
206248
# 1. Get the global position ids
207249
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
250+
ulysses_sp_group = get_ulysses_sequence_parallel_group()
208251
position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)]
209-
torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())
252+
torch.distributed.all_gather(position_ids_list, position_ids, group=ulysses_sp_group)
210253
position_ids = torch.concat(position_ids_list, dim=-1)
211254
# 2. Get the cu_seqlens by position_ids
212255
(cu_seqlens_q, _), _ = prepare_fa_kwargs_from_position_ids(position_ids)
213-
hidden_states = self.linear_attn(
214-
hidden_states=hidden_states,
215-
cu_seqlens=cu_seqlens_q,
216-
)
256+
if gdn_use_cp:
257+
cp_context = build_cp_context(
258+
cu_seqlens_q,
259+
group=ulysses_sp_group,
260+
conv1d_kernel_size=self.linear_attn.conv_kernel_size,
261+
)
262+
hidden_states = self.linear_attn(
263+
hidden_states=hidden_states,
264+
cp_context=cp_context,
265+
)
266+
else:
267+
hidden_states = self.linear_attn(
268+
hidden_states=hidden_states,
269+
cu_seqlens=cu_seqlens_q,
270+
)
217271
elif self.layer_type == "full_attention":
218272
# Self Attention
219273
hidden_states, _ = self.self_attn(

0 commit comments

Comments
 (0)