Skip to content

TDT decoder CUDA graph capture corrupts torch.load() — process-wide, irreversible after first model.transcribe() #15423

@rsclafani

Description

@rsclafani

Describe the bug

After calling model.transcribe() on a Parakeet-TDT model with CUDA graphs enabled (the default), torch.load() becomes permanently broken for the remainder of the process. Any attempt to deserialize a payload containing 4 or more tensors fails with:

  • TypeError: 'str' object is not callable (CUDA tensors)
  • TypeError: 'tuple' object is not callable (CPU tensors)

The corruption is process-wide and irreversible within the same Python process. Once a single model.transcribe() call runs with CUDA graphs enabled, torch.load() is broken until the process is restarted. This affects any workflow that calls torch.load() after inference — including model reloading, checkpoint loading, weight swapping, or any library that internally uses torch.load().

Key diagnostic findings:

  1. Corruption threshold is exactly 4 tensors. 3 or fewer always works; 4 or more always fails. Independent of tensor size — 4 tensors of shape (1,) fail; 3 tensors of shape (4096, 4096) succeed.
  2. Raw pickle is unaffected. Standard pickle.loads() / pickle.dumps() work normally. The corruption is in PyTorch's custom unpickler or persistent_load machinery.
  3. Irreversible. No combination of gc.collect(), torch.cuda.empty_cache(), importlib.reload(), or model deletion fixes it.
  4. Disabling CUDA graphs prevents it entirely. model.decoding.decoding.decoding_computer.disable_cuda_graphs() before inference prevents the issue.
  5. Disabling torch._dynamo has no effect.
  6. No performance penalty. Disabling CUDA graphs actually resulted in ~15% faster wall-clock time for single-file transcription.

Steps/Code to reproduce bug

import io, wave
from collections import OrderedDict
import numpy as np
import torch
import nemo.collections.asr as nemo_asr

# 1. Create minimal test audio (5 seconds of silence)
samples = np.zeros(16000 * 5, dtype=np.int16)
with wave.open('/tmp/test.wav', 'wb') as wf:
    wf.setnchannels(1)
    wf.setsampwidth(2)
    wf.setframerate(16000)
    wf.writeframes(samples.tobytes())

# 2. Load model (CUDA graphs enabled by default in TDT decoder)
model = nemo_asr.models.ASRModel.from_pretrained('nvidia/parakeet-tdt-0.6b-v3')
model = model.cuda().eval()

# 3. Verify torch.load works BEFORE inference
sd = OrderedDict((f'layer.{i}', torch.randn(32, 32, device='cuda')) for i in range(10))
buf = io.BytesIO()
torch.save(sd, buf)
buf.seek(0)
loaded = torch.load(buf, weights_only=False)
print(f"Before inference: torch.load OK, loaded {len(loaded)} tensors")

# 4. Run inference (triggers CUDA graph capture in the TDT decoder)
model.transcribe(['/tmp/test.wav'])

# 5. torch.load is now BROKEN for 4+ tensors
sd = OrderedDict((f'layer.{i}', torch.randn(32, 32, device='cuda')) for i in range(10))
buf = io.BytesIO()
torch.save(sd, buf)
buf.seek(0)
try:
    torch.load(buf, weights_only=False)
    print("After inference: torch.load OK")
except TypeError as e:
    print(f"After inference: torch.load FAILED — {e}")
    # TypeError: 'str' object is not callable

Workaround — disable CUDA graphs before inference:

model = nemo_asr.models.ASRModel.from_pretrained('nvidia/parakeet-tdt-0.6b-v3')
model = model.cuda().eval()
model.decoding.decoding.decoding_computer.disable_cuda_graphs()
model.transcribe(['/tmp/test.wav'])  # torch.load works fine after this

Expected behavior

torch.load() should work normally after model.transcribe(). CUDA graph capture/replay in the TDT decoder should not have side effects on unrelated PyTorch subsystems like serialization.

Environment overview (please complete the following information)

  • Environment location: Docker (NVIDIA DGX Spark, Grace Blackwell aarch64)
  • Method of NeMo install: Docker
  • Docker image: nvcr.io/nvidia/nemo:25.11.nemotron_3_nano
  • docker pull nvcr.io/nvidia/nemo:25.11.nemotron_3_nano

Environment details

NVIDIA docker image is used. Additional details for reference:

  • NeMo: 2.6.0rc0
  • PyTorch: 2.9.0a0+50eac811a6.nv25.09
  • CUDA Toolkit: 13.0
  • NVIDIA Driver: 580.126.09
  • Python: 3.12.3
  • Platform: Linux aarch64

Additional context

  • GPU: NVIDIA GB10 (DGX Spark, Grace Blackwell, unified memory, compute capability 12.1)
  • The corruption is triggered specifically by the CUDA graph capture/replay in GreedyBatchedTDTLabelLoopingComputer. The code path is: model.transcribe()GreedyBatchedTDTInfer.forward()GreedyBatchedTDTLabelLoopingComputer.__call__()cuda_graphs_impl().
  • The CUDA graph code itself never touches pickle or torch.load — the corruption is an indirect side effect of graph capture on PyTorch's internal tensor reconstruction dispatch.
  • A corresponding issue will be filed with PyTorch, as the root cause may be in PyTorch's serialization internals being corrupted by CUDA graph state changes.
  • This may be related to PyTorch PR #165474 (merged in 2.10) which fixes AOTAutogradCache serialization — untested on 2.10+ yet.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions