@@ -44,25 +44,14 @@ def apply_swiglu_sharded_factory_patch():
4444 mlp_module ._swiglu_patched = True
4545 mlp_module ._original_apply_swiglu_sharded_factory = mlp_module .apply_swiglu_sharded_factory
4646
47- def patched_apply_swiglu_sharded_factory (
48- original_sh_ten , sharded_offsets , singleton_local_shards : bool = False
49- ):
47+ def patched_apply_swiglu_sharded_factory (original_sh_ten , sharded_offsets , singleton_local_shards : bool = False ):
5048 swiglu_shard_axis = 0
5149 prepend_axis_num = len (sharded_offsets )
5250 original_shape = original_sh_ten .local_shape
5351 local_axis_size = original_shape [swiglu_shard_axis ]
54- assert (
55- original_sh_ten .global_offset [swiglu_shard_axis + prepend_axis_num ]
56- % local_axis_size
57- == 0
58- )
59- rank_offset = (
60- original_sh_ten .global_offset [swiglu_shard_axis + prepend_axis_num ]
61- // local_axis_size
62- )
63- axis_frag = original_sh_ten .axis_fragmentations [
64- swiglu_shard_axis + prepend_axis_num
65- ]
52+ assert original_sh_ten .global_offset [swiglu_shard_axis + prepend_axis_num ] % local_axis_size == 0
53+ rank_offset = original_sh_ten .global_offset [swiglu_shard_axis + prepend_axis_num ] // local_axis_size
54+ axis_frag = original_sh_ten .axis_fragmentations [swiglu_shard_axis + prepend_axis_num ]
6655
6756 @torch .no_grad ()
6857 def sh_ten_build_fn (
@@ -89,12 +78,20 @@ def sh_ten_build_fn(
8978 tensor_w , tensor_v = torch .chunk (t , 2 , dim = swiglu_shard_axis )
9079 return [
9180 ShardedTensor .from_rank_offsets (
92- w_key , tensor_w , * sharded_offsets , offset_w ,
93- replica_id = replica_id , prepend_axis_num = prepend_axis_num ,
81+ w_key ,
82+ tensor_w ,
83+ * sharded_offsets ,
84+ offset_w ,
85+ replica_id = replica_id ,
86+ prepend_axis_num = prepend_axis_num ,
9487 ),
9588 ShardedTensor .from_rank_offsets (
96- v_key , tensor_v , * sharded_offsets , offset_v ,
97- replica_id = replica_id , prepend_axis_num = prepend_axis_num ,
89+ v_key ,
90+ tensor_v ,
91+ * sharded_offsets ,
92+ offset_v ,
93+ replica_id = replica_id ,
94+ prepend_axis_num = prepend_axis_num ,
9895 ),
9996 ]
10097
@@ -104,7 +101,8 @@ def sh_ten_merge_fn(sub_state_dict):
104101 return torch .cat (sub_state_dict )
105102 except (RuntimeError , torch .cuda .OutOfMemoryError ) as e :
106103 logger .warning (
107- "CUDA OOM during tensor merge – falling back to CPU. (Error: %s)" , e ,
104+ "CUDA OOM during tensor merge – falling back to CPU. (Error: %s)" ,
105+ e ,
108106 )
109107 merged = torch .cat ([t .cpu () for t in sub_state_dict ])
110108 gc .collect ()
@@ -156,9 +154,7 @@ def _patched_gather_from_ep_ranks(
156154 model_config = self ._get_config (megatron_module )
157155 num_experts = model_config .num_moe_experts
158156 num_experts_per_rank = num_experts // self .ep_size
159- num_experts_per_rank = self .broadcast_obj_from_pp_rank (
160- num_experts_per_rank , "num_experts_per_rank"
161- )
157+ num_experts_per_rank = self .broadcast_obj_from_pp_rank (num_experts_per_rank , "num_experts_per_rank" )
162158
163159 local_expert_number = None
164160
@@ -212,10 +208,7 @@ def _patched_gather_from_ep_ranks(
212208 return weights_dict
213209
214210 MegatronParamMapping .gather_from_ep_ranks = _patched_gather_from_ep_ranks
215- logger .info (
216- "Applied QAT patch: MegatronParamMapping.gather_from_ep_ranks "
217- "now supports SequentialMLP pattern."
218- )
211+ logger .info ("Applied QAT patch: MegatronParamMapping.gather_from_ep_ranks now supports SequentialMLP pattern." )
219212
220213
221214def revert_ep_gather_patch ():
@@ -231,8 +224,8 @@ def revert_ep_gather_patch():
231224
232225def apply_extract_sort_key_patch ():
233226 """Patch ``extract_sort_key`` to support SequentialMLP naming pattern."""
234- import megatron .bridge .models .conversion .utils as utils_module
235227 import megatron .bridge .models .conversion .model_bridge as bridge_module
228+ import megatron .bridge .models .conversion .utils as utils_module
236229
237230 if getattr (utils_module , "_sort_key_patched" , False ):
238231 return
@@ -270,15 +263,13 @@ def _patched_extract_sort_key(param_name: str):
270263
271264 utils_module .extract_sort_key = _patched_extract_sort_key
272265 bridge_module .extract_sort_key = _patched_extract_sort_key
273- logger .info (
274- "Applied QAT patch: extract_sort_key now supports SequentialMLP pattern."
275- )
266+ logger .info ("Applied QAT patch: extract_sort_key now supports SequentialMLP pattern." )
276267
277268
278269def revert_extract_sort_key_patch ():
279270 """Revert :func:`apply_extract_sort_key_patch`."""
280- import megatron .bridge .models .conversion .utils as utils_module
281271 import megatron .bridge .models .conversion .model_bridge as bridge_module
272+ import megatron .bridge .models .conversion .utils as utils_module
282273
283274 if not getattr (utils_module , "_sort_key_patched" , False ):
284275 return
@@ -307,11 +298,7 @@ def _patched_megatron_local_name_to_global(models, config, param_name, vp_stage=
307298 param_name = _orig_fn (models , config , param_name , vp_stage )
308299
309300 ep_group = parallel_state .get_expert_model_parallel_group ()
310- if (
311- ".mlp.experts.local_experts." in param_name
312- and get_pg_size (ep_group ) > 1
313- and ".adapter." not in param_name
314- ):
301+ if ".mlp.experts.local_experts." in param_name and get_pg_size (ep_group ) > 1 and ".adapter." not in param_name :
315302 num_experts = config .num_moe_experts
316303 num_experts_per_rank = num_experts // ep_group .size ()
317304 local_experts_match = re .search (r"\.local_experts\.(\d+)\." , param_name )
@@ -326,10 +313,7 @@ def _patched_megatron_local_name_to_global(models, config, param_name, vp_stage=
326313 return param_name
327314
328315 bridge_module ._megatron_local_name_to_global = _patched_megatron_local_name_to_global
329- logger .info (
330- "Applied QAT patch: _megatron_local_name_to_global "
331- "now supports SequentialMLP pattern."
332- )
316+ logger .info ("Applied QAT patch: _megatron_local_name_to_global now supports SequentialMLP pattern." )
333317
334318
335319def revert_local_name_to_global_patch ():
@@ -363,9 +347,7 @@ def apply_build_conversion_tasks_patch():
363347 if getattr (MegatronModelBridge , "_build_tasks_patched" , False ):
364348 return
365349 MegatronModelBridge ._build_tasks_patched = True
366- MegatronModelBridge ._original_build_conversion_tasks = (
367- MegatronModelBridge .build_conversion_tasks
368- )
350+ MegatronModelBridge ._original_build_conversion_tasks = MegatronModelBridge .build_conversion_tasks
369351
370352 def _patched_build_conversion_tasks (self , hf_pretrained , megatron_model ):
371353 if not (hasattr (hf_pretrained , "state" ) and hasattr (hf_pretrained .state , "source" )):
@@ -378,24 +360,18 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model):
378360 model_config = unwrapped_model .config
379361 embeddings_are_tied = self ._share_embeddings_and_output_weights (model_config , unwrapped_model )
380362 pp_rank = parallel_state .get_pipeline_model_parallel_rank ()
381- sorted_global_param_names_all_pp_ranks = self ._megatron_global_param_names_all_pp_ranks (
382- megatron_model
383- )
363+ sorted_global_param_names_all_pp_ranks = self ._megatron_global_param_names_all_pp_ranks (megatron_model )
384364
385365 if embeddings_are_tied :
386366 sorted_global_param_names_all_pp_ranks = [
387367 name for name in sorted_global_param_names_all_pp_ranks if "output_layer" not in name
388368 ]
389369
390- global_names_index_dict = {
391- name : idx for idx , name in enumerate (sorted_global_param_names_all_pp_ranks )
392- }
370+ global_names_index_dict = {name : idx for idx , name in enumerate (sorted_global_param_names_all_pp_ranks )}
393371
394372 tasks = [None ] * len (sorted_global_param_names_all_pp_ranks )
395373 for vp_stage , model in enumerate (megatron_model ):
396- for local_name , _ in itertools .chain (
397- model .named_parameters (), persistent_buffers (model )
398- ):
374+ for local_name , _ in itertools .chain (model .named_parameters (), persistent_buffers (model )):
399375 if "_extra_state" in local_name or self ._is_adapter_param_name (local_name ):
400376 continue
401377
@@ -407,9 +383,7 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model):
407383 print_rank_0 (f"WARNING: { global_name } not in global_names_index_dict" )
408384 continue
409385 global_name_idx = global_names_index_dict [global_name ]
410- mapping = mapping_registry .megatron_to_hf_lookup (
411- self ._get_lora_unwrapped_name (global_name )
412- )
386+ mapping = mapping_registry .megatron_to_hf_lookup (self ._get_lora_unwrapped_name (global_name ))
413387
414388 if not mapping :
415389 logger .warning (f"WARNING: No mapping found for megatron_param: { global_name } " )
@@ -421,23 +395,16 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model):
421395 logger .warning (f"WARNING: Can't find { mapping .hf_param } in hf_keys" )
422396 continue
423397 else :
424- missing_params = [
425- hf_param
426- for hf_param in mapping .hf_param .values ()
427- if hf_param not in hf_keys
428- ]
398+ missing_params = [hf_param for hf_param in mapping .hf_param .values () if hf_param not in hf_keys ]
429399 if missing_params :
430400 logger .warning (
431- f"WARNING: Can't find the following HF parameters in hf_keys: "
432- f"{ missing_params } "
401+ f"WARNING: Can't find the following HF parameters in hf_keys: { missing_params } "
433402 )
434403 continue
435404
436- local_module , local_weights = get_module_and_param_from_name (
437- megatron_model , local_name , vp_stage
438- )
405+ local_module , local_weights = get_module_and_param_from_name (megatron_model , local_name , vp_stage )
439406 if local_module is not None and not hasattr (local_module , "config" ):
440- setattr ( local_module , " config" , model_config )
407+ local_module . config = model_config
441408
442409 tasks [global_name_idx ] = WeightConversionTask (
443410 pp_rank = pp_rank ,
@@ -451,9 +418,7 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model):
451418
452419 for idx , global_name in enumerate (sorted_global_param_names_all_pp_ranks ):
453420 if tasks [idx ] is None :
454- mapping = mapping_registry .megatron_to_hf_lookup (
455- self ._get_lora_unwrapped_name (global_name )
456- )
421+ mapping = mapping_registry .megatron_to_hf_lookup (self ._get_lora_unwrapped_name (global_name ))
457422 if mapping is None :
458423 continue
459424 tasks [idx ] = WeightConversionTask (
@@ -470,10 +435,7 @@ def _patched_build_conversion_tasks(self, hf_pretrained, megatron_model):
470435 return tasks
471436
472437 MegatronModelBridge .build_conversion_tasks = _patched_build_conversion_tasks
473- logger .info (
474- "Applied QAT patch: MegatronModelBridge.build_conversion_tasks "
475- "now filters out None entries."
476- )
438+ logger .info ("Applied QAT patch: MegatronModelBridge.build_conversion_tasks now filters out None entries." )
477439
478440
479441def revert_build_conversion_tasks_patch ():
@@ -482,9 +444,7 @@ def revert_build_conversion_tasks_patch():
482444
483445 if not getattr (MegatronModelBridge , "_build_tasks_patched" , False ):
484446 return
485- MegatronModelBridge .build_conversion_tasks = (
486- MegatronModelBridge ._original_build_conversion_tasks
487- )
447+ MegatronModelBridge .build_conversion_tasks = MegatronModelBridge ._original_build_conversion_tasks
488448 MegatronModelBridge ._build_tasks_patched = False
489449 logger .info ("Reverted QAT patch: MegatronModelBridge.build_conversion_tasks." )
490450
@@ -503,8 +463,7 @@ def _patched_detect_parallelism_type(self, module):
503463 module_type = type (module ).__name__
504464 if "LayerNormColumnParallelLinear" in module_type :
505465 if self .megatron_param and (
506- self .megatron_param .endswith ("layer_norm_weight" )
507- or self .megatron_param .endswith ("layer_norm_bias" )
466+ self .megatron_param .endswith ("layer_norm_weight" ) or self .megatron_param .endswith ("layer_norm_bias" )
508467 ):
509468 return "replicated"
510469 return "column"
0 commit comments