-
Notifications
You must be signed in to change notification settings - Fork 3
Description
Describe the bug
When I try to wrap a nn.Module that uses iSHTCUDA internally with a DDP, the backward pass in iSHTFunction fails with RuntimeError: CUFFT error: 6. I'm not sure if I'm missing something here, but the exact same thing without the DDP wrapper works just fine.
My original issue was training with Lightning, where I just got a segmentation fault without any detailed traceback during the first backward pass. After removing everything lightning-related I figured that the DDP wrapper (I'm guessing lightning uses this in ddp strategy) gives an error, which I assume is the underyling root cause of the Lightning segfault.
To Reproduce
Setting distributed=True in this script shows the issue. I'm running this with python -m torch.distributed.run --nproc_per_node=2 minimal_example.py on a node with two cuda devices
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from cuhpx import SHTCUDA, iSHTCUDA
from torch.nn.parallel import DistributedDataParallel as DDP
class Net(nn.Module):
def __init__(self, nside, lmax, mmax=None, quad_weights="ring"):
super().__init__()
npix = 12 * nside**2
self.mlp = nn.Linear(npix, npix)
self.sht = SHTCUDA(nside, lmax, mmax, quad_weights=quad_weights)
self.isht = iSHTCUDA(nside, lmax, mmax)
def forward(self, x):
x = self.mlp(x)
x = self.sht(x)
x = self.isht(x)
return x
if __name__ == "__main__":
distributed = True
batch_size = 8
nside = 32
lmax = 2 * nside + 1
npix = 12 * nside**2
model = Net(nside, lmax, mmax=lmax)
sample = torch.randn([batch_size, npix], dtype=torch.float32)
if distributed:
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sample = sample.to(device)
model.to(device)
if distributed:
model = DDP(model, device_ids=[rank], output_device=rank)
criterion = nn.MSELoss()
try:
pred = model(sample)
loss = criterion(pred, sample)
loss.backward()
print(loss.item())
finally:
if distributed:
dist.destroy_process_group()Expected behavior
Successful backward pass same as in single-GPU case
Traceback
[rank1]: Traceback (most recent call last):
[rank1]: File "/home/b/b309298/flows/scripts/bug_minimal_example.py", line 64, in <module>
[rank1]: loss.backward()
[rank1]: File "/sw/spack-levante/miniforge3-24.11.3-2-Linux-x86_64-hbhytx/lib/python3.12/site-packages/torch/_tensor.py", line 581, in backward
[rank1]: torch.autograd.backward(
[rank1]: File "/sw/spack-levante/miniforge3-24.11.3-2-Linux-x86_64-hbhytx/lib/python3.12/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank1]: _engine_run_backward(
[rank1]: File "/sw/spack-levante/miniforge3-24.11.3-2-Linux-x86_64-hbhytx/lib/python3.12/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
[rank1]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/sw/spack-levante/miniforge3-24.11.3-2-Linux-x86_64-hbhytx/lib/python3.12/site-packages/torch/autograd/function.py", line 307, in apply
[rank1]: return user_fn(self, *args)
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/work/bd1179/b309298/flows/sys_torch251_cuda126/lib/python3.12/site-packages/cuhpx/hpx_sht.py", line 596, in backward
[rank1]: x = cuhpx_fft.healpix_rfft_batch(x, mmax, nside)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: CUFFT error: 6
Environment (please complete the following information):
I'm working on a HPC system where I'm using a pytorch/cuda module provided by the admins. It has CUDA toolkit installed
- Python 3.12.9
- PyTorch 2.5.1
- CUDA 12.6
Any help on this would be greatly appreciated! Thanks for this great work