-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Description
Describe the bug
After calling model.transcribe() on a Parakeet-TDT model with CUDA graphs enabled (the default), torch.load() becomes permanently broken for the remainder of the process. Any attempt to deserialize a payload containing 4 or more tensors fails with:
TypeError: 'str' object is not callable(CUDA tensors)TypeError: 'tuple' object is not callable(CPU tensors)
The corruption is process-wide and irreversible within the same Python process. Once a single model.transcribe() call runs with CUDA graphs enabled, torch.load() is broken until the process is restarted. This affects any workflow that calls torch.load() after inference — including model reloading, checkpoint loading, weight swapping, or any library that internally uses torch.load().
Key diagnostic findings:
- Corruption threshold is exactly 4 tensors. 3 or fewer always works; 4 or more always fails. Independent of tensor size — 4 tensors of shape
(1,)fail; 3 tensors of shape(4096, 4096)succeed. - Raw pickle is unaffected. Standard
pickle.loads()/pickle.dumps()work normally. The corruption is in PyTorch's custom unpickler orpersistent_loadmachinery. - Irreversible. No combination of
gc.collect(),torch.cuda.empty_cache(),importlib.reload(), or model deletion fixes it. - Disabling CUDA graphs prevents it entirely.
model.decoding.decoding.decoding_computer.disable_cuda_graphs()before inference prevents the issue. - Disabling
torch._dynamohas no effect. - No performance penalty. Disabling CUDA graphs actually resulted in ~15% faster wall-clock time for single-file transcription.
Steps/Code to reproduce bug
import io, wave
from collections import OrderedDict
import numpy as np
import torch
import nemo.collections.asr as nemo_asr
# 1. Create minimal test audio (5 seconds of silence)
samples = np.zeros(16000 * 5, dtype=np.int16)
with wave.open('/tmp/test.wav', 'wb') as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(16000)
wf.writeframes(samples.tobytes())
# 2. Load model (CUDA graphs enabled by default in TDT decoder)
model = nemo_asr.models.ASRModel.from_pretrained('nvidia/parakeet-tdt-0.6b-v3')
model = model.cuda().eval()
# 3. Verify torch.load works BEFORE inference
sd = OrderedDict((f'layer.{i}', torch.randn(32, 32, device='cuda')) for i in range(10))
buf = io.BytesIO()
torch.save(sd, buf)
buf.seek(0)
loaded = torch.load(buf, weights_only=False)
print(f"Before inference: torch.load OK, loaded {len(loaded)} tensors")
# 4. Run inference (triggers CUDA graph capture in the TDT decoder)
model.transcribe(['/tmp/test.wav'])
# 5. torch.load is now BROKEN for 4+ tensors
sd = OrderedDict((f'layer.{i}', torch.randn(32, 32, device='cuda')) for i in range(10))
buf = io.BytesIO()
torch.save(sd, buf)
buf.seek(0)
try:
torch.load(buf, weights_only=False)
print("After inference: torch.load OK")
except TypeError as e:
print(f"After inference: torch.load FAILED — {e}")
# TypeError: 'str' object is not callableWorkaround — disable CUDA graphs before inference:
model = nemo_asr.models.ASRModel.from_pretrained('nvidia/parakeet-tdt-0.6b-v3')
model = model.cuda().eval()
model.decoding.decoding.decoding_computer.disable_cuda_graphs()
model.transcribe(['/tmp/test.wav']) # torch.load works fine after thisExpected behavior
torch.load() should work normally after model.transcribe(). CUDA graph capture/replay in the TDT decoder should not have side effects on unrelated PyTorch subsystems like serialization.
Environment overview (please complete the following information)
- Environment location: Docker (NVIDIA DGX Spark, Grace Blackwell aarch64)
- Method of NeMo install: Docker
- Docker image:
nvcr.io/nvidia/nemo:25.11.nemotron_3_nano docker pull nvcr.io/nvidia/nemo:25.11.nemotron_3_nano
Environment details
NVIDIA docker image is used. Additional details for reference:
- NeMo: 2.6.0rc0
- PyTorch: 2.9.0a0+50eac811a6.nv25.09
- CUDA Toolkit: 13.0
- NVIDIA Driver: 580.126.09
- Python: 3.12.3
- Platform: Linux aarch64
Additional context
- GPU: NVIDIA GB10 (DGX Spark, Grace Blackwell, unified memory, compute capability 12.1)
- The corruption is triggered specifically by the CUDA graph capture/replay in
GreedyBatchedTDTLabelLoopingComputer. The code path is:model.transcribe()→GreedyBatchedTDTInfer.forward()→GreedyBatchedTDTLabelLoopingComputer.__call__()→cuda_graphs_impl(). - The CUDA graph code itself never touches pickle or
torch.load— the corruption is an indirect side effect of graph capture on PyTorch's internal tensor reconstruction dispatch. - A corresponding issue will be filed with PyTorch, as the root cause may be in PyTorch's serialization internals being corrupted by CUDA graph state changes.
- This may be related to PyTorch PR #165474 (merged in 2.10) which fixes AOTAutogradCache serialization — untested on 2.10+ yet.