-
Notifications
You must be signed in to change notification settings - Fork 681
Description
Bug description
I am experiencing numerical instability issues in training with qwen3. I tracked down the issue to the initialization of the model. In particular, to the usage of torch.nn.init.trunc_normal_. The issue is illustrated in the following code snippet. It is apparent that this function does not produce the intended behavior and it is skewed heavily toward the argument of a, in this case -2. I see that this function has an history of being problematic (see pytorch/pytorch#155588 and pytorch/pytorch#145498)
In [20]: import torch
...: x = torch.zeros(10,500, dtype=torch.bfloat16)
...: torch.nn.init.trunc_normal_(x, mean=0, std=0.02, a=-2, b=2)
...: print("min", x.min(-1)[0])
...: print("max", x.max(-1)[0])
...: print("mean", x.mean(-1)[0])
...: print("sum", x.sum(-1)[0])min tensor([-2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -0.0532, -2.0000, -2.0000, -2.0000, -2.0000], dtype=torch.bfloat16)
max tensor([0.0532, 0.0532, 0.0532, 0.0532, 0.0532, 0.0532, 0.0532, 0.0532, 0.0532, 0.0483], dtype=torch.bfloat16)
mean tensor(-0.0084, dtype=torch.bfloat16)
sum tensor(-4.1875, dtype=torch.bfloat16)
My current workaround is to cast to float32, use truncated_init_, and cast back to bfloat16. This is the related code-snippet
@torch.no_grad()
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
orig_dtype = tensor.dtype
if isinstance(tensor, torch.distributed.tensor.DTensor):
# Convert DTensor -> float32 DTensor
tensor_fp32 = tensor.to(dtype=torch.float32)
# Local init (safe because trunc_normal_ is elementwise)
torch.nn.init.trunc_normal_(tensor_fp32._local_tensor, mean=mean, std=std, a=a, b=b)
# Cast back to original dtype and copy into original tensor
tensor.copy_(tensor_fp32.to(dtype=orig_dtype))
else:
# Regular Tensor
tensor_fp32 = tensor.to(dtype=torch.float32)
torch.nn.init.trunc_normal_(tensor_fp32, mean=mean, std=std, a=a, b=b)
tensor.copy_(tensor_fp32.to(dtype=orig_dtype))
del tensor_fp32
return tensorPlease, let me know if you have any suggestion regarding this issue. I can also submit a PR with the proposed workaround.
Versions
torch: 2.11.0.dev20260107+cu126