1616import os
1717
1818import torch
19+ import torch .distributed as dist
1920import torch .nn as nn
2021import torch .nn .functional as F
2122from transformers .activations import ACT2FN
3031
3132try :
3233 from fla .modules import FusedRMSNormGated , ShortConvolution
34+ from fla .ops .cp import build_cp_context
3335 from fla .ops .gated_delta_rule import chunk_gated_delta_rule
3436except ImportError as err :
3537 raise ImportError ("Please install flash-linear-attention for Qwen3-Next" ) from err
3840logger .setLevel (os .getenv ("VERL_LOGGING_LEVEL" , "WARN" ))
3941
4042
43+ class _AllGather (torch .autograd .Function ):
44+ @staticmethod
45+ def forward (ctx , local_tensor : torch .Tensor , group ):
46+ ctx .group = group
47+ ctx .part_size = local_tensor .size (0 )
48+ return all_gather_tensor (local_tensor , group = group )
49+
50+ @staticmethod
51+ def backward (ctx , grad_output : torch .Tensor ):
52+ grad_local = torch .empty (
53+ (ctx .part_size , * grad_output .shape [1 :]),
54+ dtype = grad_output .dtype ,
55+ device = grad_output .device ,
56+ )
57+ dist .reduce_scatter_tensor (grad_local , grad_output , op = dist .ReduceOp .SUM , group = ctx .group )
58+ return grad_local , None
59+
60+
4161# Adapted from https://github.com/huggingface/transformers/blob/c9ea365a7b56326418769a4ba4682864d407ed63/src/transformers/models/qwen3_next/modular_qwen3_next.py#L428
4262class PatchQwen3NextGatedDeltaNet (nn .Module ):
4363 def __init__ (self , config , layer_idx : int ):
@@ -120,12 +140,19 @@ def forward(
120140 self ,
121141 hidden_states ,
122142 cu_seqlens = None ,
143+ cp_context = None ,
123144 ):
124- # NOTE: when using ulysses sequence parallelism, batch size is always 1
125- # pre-process: [bsz, seq, h] -> [seq, bsz, h] -> [seq * sp, bsz, h] -> [bsz, seq * sp, h]
126- hidden_states = hidden_states .transpose (0 , 1 ).contiguous ()
127- hidden_states = all_gather_tensor (hidden_states )
128- hidden_states = hidden_states .transpose (0 , 1 ).contiguous ()
145+ if cp_context is not None :
146+ use_cp_mode = True
147+ cu_seqlens = cp_context .cu_seqlens
148+ elif cu_seqlens is not None :
149+ # pre-process: [bsz, seq, h] -> [seq, bsz, h] -> [seq * sp, bsz, h] -> [bsz, seq * sp, h]
150+ use_cp_mode = False
151+ hidden_states = hidden_states .transpose (0 , 1 ).contiguous ()
152+ hidden_states = _AllGather .apply (hidden_states , get_ulysses_sequence_parallel_group ())
153+ hidden_states = hidden_states .transpose (0 , 1 ).contiguous ()
154+ else :
155+ raise ValueError ("cu_seqlens or cp_context is required" )
129156
130157 projected_states_qkvz = self .in_proj_qkvz (hidden_states )
131158 projected_states_ba = self .in_proj_ba (hidden_states )
@@ -137,6 +164,7 @@ def forward(
137164 mixed_qkv , _ = self .conv1d (
138165 x = mixed_qkv ,
139166 cu_seqlens = cu_seqlens ,
167+ cp_context = cp_context if use_cp_mode else None ,
140168 )
141169
142170 query , key , value = torch .split (
@@ -159,17 +187,30 @@ def forward(
159187 query = query .repeat_interleave (self .num_v_heads // self .num_k_heads , dim = 2 )
160188 key = key .repeat_interleave (self .num_v_heads // self .num_k_heads , dim = 2 )
161189
162- core_attn_out , _ = chunk_gated_delta_rule (
163- query ,
164- key ,
165- value ,
166- g = g ,
167- beta = beta ,
168- initial_state = None ,
169- output_final_state = False ,
170- use_qk_l2norm_in_kernel = True ,
171- cu_seqlens = cu_seqlens ,
172- )
190+ if use_cp_mode :
191+ core_attn_out , _ = chunk_gated_delta_rule (
192+ query ,
193+ key ,
194+ value ,
195+ g = g ,
196+ beta = beta ,
197+ initial_state = None ,
198+ output_final_state = False ,
199+ use_qk_l2norm_in_kernel = True ,
200+ cp_context = cp_context ,
201+ )
202+ else :
203+ core_attn_out , _ = chunk_gated_delta_rule (
204+ query ,
205+ key ,
206+ value ,
207+ g = g ,
208+ beta = beta ,
209+ initial_state = None ,
210+ output_final_state = False ,
211+ use_qk_l2norm_in_kernel = True ,
212+ cu_seqlens = cu_seqlens ,
213+ )
173214
174215 z_shape_og = z .shape
175216 # reshape input data into 2D tensor
@@ -180,9 +221,9 @@ def forward(
180221 core_attn_out = core_attn_out .reshape (core_attn_out .shape [0 ], core_attn_out .shape [1 ], - 1 )
181222
182223 output = self .out_proj (core_attn_out )
183-
184- # post-process: [bsz, seq * sp, h] -> [bsz, seq, h]
185- output = slice_input_tensor (output , dim = 1 , padding = False )
224+ if not use_cp_mode :
225+ # post-process: [bsz, seq * sp, h] -> [bsz, seq, h]
226+ output = slice_input_tensor (output , dim = 1 , padding = False )
186227 return output
187228
188229
@@ -195,6 +236,7 @@ def patch_qwen3_next_decoder_layer_forward(
195236 position_ids = None ,
196237 past_key_values = None ,
197238 cache_position = None ,
239+ gdn_use_cp : bool = True ,
198240 ** kwargs ,
199241):
200242 residual = hidden_states
@@ -205,15 +247,27 @@ def patch_qwen3_next_decoder_layer_forward(
205247 if self .layer_type == "linear_attention" :
206248 # 1. Get the global position ids
207249 ulysses_sp_size = get_ulysses_sequence_parallel_world_size ()
250+ ulysses_sp_group = get_ulysses_sequence_parallel_group ()
208251 position_ids_list = [torch .empty_like (position_ids ) for _ in range (ulysses_sp_size )]
209- torch .distributed .all_gather (position_ids_list , position_ids , group = get_ulysses_sequence_parallel_group () )
252+ torch .distributed .all_gather (position_ids_list , position_ids , group = ulysses_sp_group )
210253 position_ids = torch .concat (position_ids_list , dim = - 1 )
211254 # 2. Get the cu_seqlens by position_ids
212255 (cu_seqlens_q , _ ), _ = prepare_fa_kwargs_from_position_ids (position_ids )
213- hidden_states = self .linear_attn (
214- hidden_states = hidden_states ,
215- cu_seqlens = cu_seqlens_q ,
216- )
256+ if gdn_use_cp :
257+ cp_context = build_cp_context (
258+ cu_seqlens_q ,
259+ group = ulysses_sp_group ,
260+ conv1d_kernel_size = self .linear_attn .conv_kernel_size ,
261+ )
262+ hidden_states = self .linear_attn (
263+ hidden_states = hidden_states ,
264+ cp_context = cp_context ,
265+ )
266+ else :
267+ hidden_states = self .linear_attn (
268+ hidden_states = hidden_states ,
269+ cu_seqlens = cu_seqlens_q ,
270+ )
217271 elif self .layer_type == "full_attention" :
218272 # Self Attention
219273 hidden_states , _ = self .self_attn (
0 commit comments