-
Notifications
You must be signed in to change notification settings - Fork 681
Description
Bug description
It appears that weight tying is not functioning correctly for the Qwen3 model. The relevant code is here:
torchtitan/torchtitan/models/qwen3/infra/parallelize.py
Lines 176 to 180 in 09c6d74
| # Enable weight tying after applying parallelisms | |
| # pyrefly: ignore [missing-attribute] | |
| if model.model_args.enable_weight_tying: | |
| # pyrefly: ignore [missing-attribute] | |
| model.output.weight = model.tok_embeddings.weight |
The weight tying logic would work as expected if the model were already materialized on a device. However, at this point the model is still on the meta device. When materialization later occurs, both model.output and model.tok_embeddings are independently materialized, which breaks the intended weight tying.
I verified this by printing self.output.weight and self.tok_embeddings.weight during the forward pass and confirmed that the two tensors are different:
torchtitan/torchtitan/models/qwen3/model/model.py
Lines 583 to 592 in 09c6d74
| h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens | |
| for layer in self.layers.values(): | |
| h = layer(h, self.rope_cache, attention_masks, positions) | |
| # pyrefly: ignore[not-callable, invalid-argument] | |
| h = self.norm(h) if self.norm else h | |
| # pyrefly: ignore[not-callable, invalid-argument] | |
| output = self.output(h) if self.output else h | |
| return output |
As an experiment, I moved the application of parallelization to occur after model materialization. With this change, the expected behavior is restored, and self.output.weight is equal to self.tok_embeddings.weight:
torchtitan/torchtitan/train.py
Lines 256 to 261 in 09c6d74
| model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) | |
| model.to_empty(device=init_device) | |
| with torch.no_grad(): | |
| # pyrefly: ignore [not-callable] | |
| model.init_weights(buffer_device=buffer_device) |
Versions
This issue can be reproduced on the latest pytorch.
toml
[job]
dump_folder = "./outputs"
description = "Qwen 3 0.6B training"
[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100
[metrics]
log_freq = 1
enable_tensorboard = true
save_tb_folder = "tb"
[model]
name = "qwen3"
flavor = "0.6B"
# hf_assets_path = "./assets/hf/Qwen3-0.6B"
# converters = ["float8"]
[optimizer]
name = "AdamW"
lr = 3e-4
eps = 1e-8
[lr_scheduler]
warmup_steps = 2 # lr scheduler warm up, 20% total steps
[training]
local_batch_size = 4
seq_len = 4096
max_norm = 1.0 # grad norm clipping
steps = 10
dataset = "c4"
[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always
tensor_parallel_degree = 1
context_parallel_degree = 1
[checkpoint]
enable = false
folder = "checkpoint"
interval = 500
last_save_model_only = false
export_dtype = "float16"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy
[compile]
enable=false
components = ["model", "loss"]
[quantize.linear.float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]