Skip to content

Numerical instability in truncated normal initialization with bfloat16 #2269

@francesco-bertolotti

Description

@francesco-bertolotti

Bug description

I am experiencing numerical instability issues in training with qwen3. I tracked down the issue to the initialization of the model. In particular, to the usage of torch.nn.init.trunc_normal_. The issue is illustrated in the following code snippet. It is apparent that this function does not produce the intended behavior and it is skewed heavily toward the argument of a, in this case -2. I see that this function has an history of being problematic (see pytorch/pytorch#155588 and pytorch/pytorch#145498)

In [20]: import torch
    ...: x = torch.zeros(10,500, dtype=torch.bfloat16)
    ...: torch.nn.init.trunc_normal_(x, mean=0, std=0.02, a=-2, b=2)
    ...: print("min", x.min(-1)[0])
    ...: print("max", x.max(-1)[0])
    ...: print("mean", x.mean(-1)[0])
    ...: print("sum", x.sum(-1)[0])
min tensor([-2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -0.0532, -2.0000, -2.0000, -2.0000, -2.0000], dtype=torch.bfloat16)
max tensor([0.0532, 0.0532, 0.0532, 0.0532, 0.0532, 0.0532, 0.0532, 0.0532, 0.0532, 0.0483], dtype=torch.bfloat16)
mean tensor(-0.0084, dtype=torch.bfloat16)
sum tensor(-4.1875, dtype=torch.bfloat16)

My current workaround is to cast to float32, use truncated_init_, and cast back to bfloat16. This is the related code-snippet

@torch.no_grad()
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
    orig_dtype = tensor.dtype

    if isinstance(tensor, torch.distributed.tensor.DTensor):
        # Convert DTensor -> float32 DTensor
        tensor_fp32 = tensor.to(dtype=torch.float32)

        # Local init (safe because trunc_normal_ is elementwise)
        torch.nn.init.trunc_normal_(tensor_fp32._local_tensor, mean=mean, std=std, a=a, b=b)

        # Cast back to original dtype and copy into original tensor
        tensor.copy_(tensor_fp32.to(dtype=orig_dtype))
    else:
        # Regular Tensor
        tensor_fp32 = tensor.to(dtype=torch.float32)
        torch.nn.init.trunc_normal_(tensor_fp32, mean=mean, std=std, a=a, b=b)
        tensor.copy_(tensor_fp32.to(dtype=orig_dtype))

    del tensor_fp32

    return tensor

Please, let me know if you have any suggestion regarding this issue. I can also submit a PR with the proposed workaround.

Versions

torch: 2.11.0.dev20260107+cu126

Metadata

Metadata

Assignees

No one assigned

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions