Skip to content

πŸ›[BUG]: Backward pass fails with DDP/LightningΒ #12

@maitim

Description

@maitim

Describe the bug
When I try to wrap a nn.Module that uses iSHTCUDA internally with a DDP, the backward pass in iSHTFunction fails with RuntimeError: CUFFT error: 6. I'm not sure if I'm missing something here, but the exact same thing without the DDP wrapper works just fine.
My original issue was training with Lightning, where I just got a segmentation fault without any detailed traceback during the first backward pass. After removing everything lightning-related I figured that the DDP wrapper (I'm guessing lightning uses this in ddp strategy) gives an error, which I assume is the underyling root cause of the Lightning segfault.

To Reproduce
Setting distributed=True in this script shows the issue. I'm running this with python -m torch.distributed.run --nproc_per_node=2 minimal_example.py on a node with two cuda devices

import os

import torch
import torch.distributed as dist
import torch.nn as nn
from cuhpx import SHTCUDA, iSHTCUDA
from torch.nn.parallel import DistributedDataParallel as DDP


class Net(nn.Module):
    def __init__(self, nside, lmax, mmax=None, quad_weights="ring"):
        super().__init__()

        npix = 12 * nside**2
        self.mlp = nn.Linear(npix, npix)
        self.sht = SHTCUDA(nside, lmax, mmax, quad_weights=quad_weights)
        self.isht = iSHTCUDA(nside, lmax, mmax)

    def forward(self, x):
        x = self.mlp(x)
        x = self.sht(x)
        x = self.isht(x)
        return x


if __name__ == "__main__":
    distributed = True

    batch_size = 8
    nside = 32
    lmax = 2 * nside + 1
    npix = 12 * nside**2
    model = Net(nside, lmax, mmax=lmax)
    sample = torch.randn([batch_size, npix], dtype=torch.float32)

    if distributed:
        rank = int(os.environ["LOCAL_RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        dist.init_process_group("nccl", rank=rank, world_size=world_size)
        torch.cuda.set_device(rank)
        device = torch.device(f"cuda:{rank}")
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    sample = sample.to(device)
    model.to(device)

    if distributed:
        model = DDP(model, device_ids=[rank], output_device=rank)

    criterion = nn.MSELoss()
    try:
        pred = model(sample)
        loss = criterion(pred, sample)
        loss.backward()
        print(loss.item())

    finally:
        if distributed:
            dist.destroy_process_group()

Expected behavior
Successful backward pass same as in single-GPU case

Traceback

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/b/b309298/flows/scripts/bug_minimal_example.py", line 64, in <module>
[rank1]:     loss.backward()
[rank1]:   File "/sw/spack-levante/miniforge3-24.11.3-2-Linux-x86_64-hbhytx/lib/python3.12/site-packages/torch/_tensor.py", line 581, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/sw/spack-levante/miniforge3-24.11.3-2-Linux-x86_64-hbhytx/lib/python3.12/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/sw/spack-levante/miniforge3-24.11.3-2-Linux-x86_64-hbhytx/lib/python3.12/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/sw/spack-levante/miniforge3-24.11.3-2-Linux-x86_64-hbhytx/lib/python3.12/site-packages/torch/autograd/function.py", line 307, in apply
[rank1]:     return user_fn(self, *args)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/work/bd1179/b309298/flows/sys_torch251_cuda126/lib/python3.12/site-packages/cuhpx/hpx_sht.py", line 596, in backward
[rank1]:     x = cuhpx_fft.healpix_rfft_batch(x, mmax, nside)
[rank1]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: CUFFT error: 6

Environment (please complete the following information):

I'm working on a HPC system where I'm using a pytorch/cuda module provided by the admins. It has CUDA toolkit installed

  • Python 3.12.9
  • PyTorch 2.5.1
  • CUDA 12.6

Any help on this would be greatly appreciated! Thanks for this great work

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions