Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 86 additions & 107 deletions paddleformers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@
except:
_obtain_optimizer_parameters_list = None

from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizerV2,
)
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
fused_allreduce_gradients,
)
Expand Down Expand Up @@ -1009,130 +1012,100 @@ def get_metadata_file_name(path):
opt_states_path = os.path.join(resume_from_checkpoint, OPTIMIZER_STATE_DIC)
model_states_path = os.path.join(resume_from_checkpoint, MODEL_STATE_DIC)

if self.args.load_from_hf:
hf_aoa_config = self.model._gen_aoa_config(self.model.config)
hcg = dist.fleet.get_hybrid_communicate_group()
assert (
self.args.ignore_load_lr_and_optim
), "Loading from HuggingFace format is only allowed when learning rate and optimizer state are ignored."
hcg = dist.fleet.get_hybrid_communicate_group()
if self.args.flex_ckpt_comm_method == "parallel_broadcast":
try:
pp_group = hcg.get_pipe_parallel_group()
if pp_group is None or pp_group.nranks < 1:
raise NotImplementedError("Only support when pp_group is not None.")
except Exception:
raise RuntimeError("Only support when pp_group is not None.")

try:
moe_group = hcg.get_expert_parallel_group()
if moe_group is None or moe_group.nranks < 1:
raise NotImplementedError("Only support when moe_group is not None.")
except Exception:
raise RuntimeError("Only support when moe_group is not None.")

try:
moe_sharding_group = hcg.get_moe_sharding_parallel_group()
except Exception:
moe_sharding_group = None

if moe_sharding_group is None or moe_sharding_group.nranks <= 1:
# when moe_sharding_group is None, we use the default process_group
logger.info(f"Loading model weights from '{resume_from_checkpoint}' in safetensors format.")
flex_checkpoint_load_func(
model_sharded_state_dict,
resume_from_checkpoint,
aoa_config=hf_aoa_config,
offload=self.args.load_via_cpu,
safetensors=True,
process_group=None,
comm_method=self.args.flex_ckpt_comm_method,
)
else:
try:
pp_group = hcg.get_pipe_parallel_group()
if pp_group is None or pp_group.nranks < 1:
raise NotImplementedError("Only support when pp_group is not None.")
except Exception:
raise RuntimeError("Only support when pp_group is not None.")
worker_groups = [moe_group, pp_group, moe_sharding_group]
else:
worker_groups = None

try:
moe_group = hcg.get_expert_parallel_group()
if moe_group is None or moe_group.nranks < 1:
raise NotImplementedError("Only support when moe_group is not None.")
except Exception:
raise RuntimeError("Only support when moe_group is not None.")
moe_sharding_rank = moe_sharding_group.rank
cur_rank = dist.get_rank()
if moe_sharding_rank == 0:
moe_group_ranks = []
dist.all_gather_object(moe_group_ranks, cur_rank, group=moe_group)
pp_group_ranks = []
dist.all_gather_object(pp_group_ranks, moe_group_ranks, group=pp_group)
process_group_ranks = [rank for ranks in pp_group_ranks for rank in ranks]
else:
process_group_ranks = [0] * (pp_group.nranks * moe_group.nranks)
src_rank = hcg.get_moe_sharding_parallel_group_src_rank()
dist.broadcast_object_list(process_group_ranks, src=src_rank, group=moe_sharding_group)
assert any(process_group_ranks), "process_group_ranks should not be all 0"
logger.info(f"Creating a temporary process group with ranks: {process_group_ranks}")
process_group = dist.new_group(process_group_ranks)

if moe_sharding_rank == 0:
logger.info(f"Loading model weights from '{resume_from_checkpoint}' in safetensors format.")
# Only the first moe_sharding process is allowed to load the model weights.
flex_checkpoint_load_func(
model_sharded_state_dict,
resume_from_checkpoint,
aoa_config=hf_aoa_config,
offload=self.args.load_via_cpu,
safetensors=True,
process_group=process_group,
comm_method=self.args.flex_ckpt_comm_method,
)
if self.args.load_from_hf:
hf_aoa_config = self.model._gen_aoa_config(self.model.config)
assert (
self.args.ignore_load_lr_and_optim
), "Loading from HuggingFace format is only allowed when learning rate and optimizer state are ignored."

dist.barrier()
logger.info("Destroying the temporary process group.")
dist.destroy_process_group(process_group)
# The first moe_sharding group loads the model weights and then broadcasts them to all other moe_sharding groups.
logger.info(
"First shard (moe_sharding_group) has loaded safetensors weights, starting broadcast on moe_sharding_groups."
)
for param_name, param in self.model.state_dict().items():
dist.broadcast(param, src=src_rank, group=moe_sharding_group)
logger.info(f"Loading model weights from '{resume_from_checkpoint}' in safetensors format.")
flex_checkpoint_load_func(
model_sharded_state_dict,
resume_from_checkpoint,
aoa_config=hf_aoa_config,
offload=self.args.load_via_cpu,
safetensors=True,
process_group=None,
comm_method=self.args.flex_ckpt_comm_method,
worker_groups=worker_groups,
)
return

if not self.args.ignore_load_lr_and_optim:
state_dict_metadata = {}
metadata_paths = [
os.path.join(model_states_path, get_metadata_file_name(model_states_path)),
os.path.join(opt_states_path, get_metadata_file_name(opt_states_path)),
os.path.join(master_weights_path, get_metadata_file_name(master_weights_path)),
]

for metadata_file in metadata_paths:
if not os.path.exists(metadata_file):
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
metadata = paddle.load(metadata_file)
state_dict_metadata.update(metadata.state_dict_metadata)

if not self.args.sharded_model_from_ema:
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)

optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)

opt_states = {}
master_weights = {}
for k, v in optimizer_sharded_state_dict.items():
if k.endswith(".w_0"):
master_weights[k] = v
else:
opt_states[k] = v
state_dict_metadata = {}
metadata_paths = [
os.path.join(model_states_path, get_metadata_file_name(model_states_path)),
os.path.join(opt_states_path, get_metadata_file_name(opt_states_path)),
os.path.join(master_weights_path, get_metadata_file_name(master_weights_path)),
]

for metadata_file in metadata_paths:
if not os.path.exists(metadata_file):
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
metadata = paddle.load(metadata_file)
state_dict_metadata.update(metadata.state_dict_metadata)

if not self.args.sharded_model_from_ema:
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)

optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)

opt_states = {}
master_weights = {}
for k, v in optimizer_sharded_state_dict.items():
if k.endswith(".w_0"):
master_weights[k] = v
else:
opt_states[k] = v

flex_checkpoint_load_func(
master_weights,
master_weights_path,
aoa_config=self.args.aoa_config,
offload=self.args.load_via_cpu,
comm_method=self.args.flex_ckpt_comm_method,
worker_groups=worker_groups,
)

if not self.args.ignore_load_lr_and_optim:
flex_checkpoint_load_func(
opt_states,
opt_states_path,
aoa_config=self.args.aoa_config,
offload=self.args.load_via_cpu,
comm_method=self.args.flex_ckpt_comm_method,
worker_groups=worker_groups,
)

flex_checkpoint_load_func(
master_weights,
master_weights_path,
aoa_config=self.args.aoa_config,
offload=self.args.load_via_cpu,
comm_method=self.args.flex_ckpt_comm_method,
)

self._load_scheduler(resume_from_checkpoint)
self._load_scheduler(resume_from_checkpoint)

enable_bf16_opt = (
not isinstance(self.model, LoRAModel) and self.args.enable_zero_cost_checkpoint and self.args.bf16
not isinstance(self.model, LoRAModel)
and self.args.bf16
and isinstance(self.optimizer._inner_opt, DygraphShardingOptimizerV2)
)
logger.debug(f"sharded_model_from_ema: {self.args.sharded_model_from_ema}")
logger.debug(f"enable_bf16_opt: {enable_bf16_opt}")
Expand All @@ -1155,18 +1128,24 @@ def bf16_filtered_sharded_state_dict(sharded_state_dict):
new_state_dict[k] = v
return new_state_dict

# NOTE(xingmingyyj) When saving model states only in float32 format, we assume that users
# will not use AOA to change the mapping relationships among these float32 weights.
if enable_bf16_opt:
model_sharded_state_dict = bf16_filtered_sharded_state_dict(model_sharded_state_dict)
aoa_config = None
else:
aoa_config = self.args.aoa_config

flex_checkpoint_load_func(
model_sharded_state_dict,
model_states_path,
aoa_config=self.args.aoa_config,
aoa_config=aoa_config,
offload=self.args.load_via_cpu,
comm_method=self.args.flex_ckpt_comm_method,
worker_groups=worker_groups,
)

if enable_bf16_opt and (not self.args.ignore_load_lr_and_optim):
if enable_bf16_opt:
opt_state_dict = self.optimizer.state_dict()

def recover_params_from_master_weight(opt_state_dict, group):
Expand Down