Skip to content

Commit 4f47e1b

Browse files
committed
dynamo: robust expand rank check to avoid symbolic len() errors; add regression test. Fixes #3972
1 parent 22b0e5f commit 4f47e1b

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,19 @@ def expand(
241241
"Cannot expand to shape with rank smaller than original tensor."
242242
)
243243

244-
# After the above padding, the shape and tensor rank must be equal
245-
assert len(input_t.shape) == shape_rank
244+
# After the above padding, the shape and tensor rank must be equal.
245+
# Safely check rank (len(...) may fail on symbolic shapes).
246+
try:
247+
current_rank = len(input_t.shape)
248+
except Exception:
249+
current_rank = initial_tensor_rank
250+
251+
if current_rank != shape_rank:
252+
raise RuntimeError(
253+
f"expand lowering: expected input rank {shape_rank} after padding, but got {current_rank}. "
254+
"This may indicate symbolic or dynamic dimensions causing a rank mismatch."
255+
)
256+
246257

247258
# Configure the start, strides and output shape tensors
248259
start = tuple([0] * shape_rank)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
5+
try:
6+
import torch_tensorrt
7+
except Exception:
8+
torch_tensorrt = None
9+
10+
REQUIRES_TRT = torch.cuda.is_available() and (torch_tensorrt is not None)
11+
12+
pytestmark = pytest.mark.skipif(not REQUIRES_TRT, reason="requires CUDA + Torch-TensorRT runtime")
13+
14+
class CosmosLearnablePositionalEmbed(nn.Module):
15+
def __init__(self, hidden_size, max_size, patch_size):
16+
super().__init__()
17+
self.patch_size = patch_size
18+
self.pos_emb_t = nn.Parameter(torch.zeros(max_size[0] // patch_size[0], hidden_size))
19+
self.pos_emb_h = nn.Parameter(torch.zeros(max_size[1] // patch_size[1], hidden_size))
20+
self.pos_emb_w = nn.Parameter(torch.zeros(max_size[2] // patch_size[2], hidden_size))
21+
22+
def forward(self, hidden_states):
23+
batch_size, _, num_frames, height, width = hidden_states.shape
24+
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
25+
emb_t = self.pos_emb_t[:pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1)
26+
emb_h = self.pos_emb_h[:pe_size[1]][None, None, :, None, :].repeat(batch_size, pe_size[0], 1, pe_size[2], 1)
27+
emb_w = self.pos_emb_w[:pe_size[2]][None, None, None, :, :].repeat(batch_size, pe_size[0], pe_size[1], 1, 1)
28+
emb = emb_t + emb_h + emb_w
29+
emb = emb.flatten(1, 3)
30+
return emb
31+
32+
def test_repeat_expand_lowering_repro():
33+
device = torch.device("cuda")
34+
hidden_size = 4096
35+
model = CosmosLearnablePositionalEmbed(hidden_size=hidden_size, max_size=(128,240,240), patch_size=(1,2,2)).to(device).eval()
36+
hidden_states = torch.randn(1, 17, 16, 88, 160, dtype=torch.bfloat16, device=device)
37+
38+
with torch.no_grad():
39+
pyt_out = model(hidden_states)
40+
41+
ep = torch.export.export(model, args=(hidden_states,), strict=False)
42+
trt_mod = torch_tensorrt.dynamo.compile(ep, inputs=[hidden_states], enabled_precisions={torch.bfloat16}, use_python_runtime=True)
43+
trt_out = trt_mod(hidden_states)
44+
45+
assert pyt_out.shape == trt_out.shape
46+
maxdiff = (pyt_out.float() - trt_out.float()).abs().max().item()
47+
assert maxdiff < 1e-2

0 commit comments

Comments
 (0)