|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | +import torch.nn as nn |
| 4 | + |
| 5 | +try: |
| 6 | + import torch_tensorrt |
| 7 | +except Exception: |
| 8 | + torch_tensorrt = None |
| 9 | + |
| 10 | +REQUIRES_TRT = torch.cuda.is_available() and (torch_tensorrt is not None) |
| 11 | + |
| 12 | +pytestmark = pytest.mark.skipif(not REQUIRES_TRT, reason="requires CUDA + Torch-TensorRT runtime") |
| 13 | + |
| 14 | +class CosmosLearnablePositionalEmbed(nn.Module): |
| 15 | + def __init__(self, hidden_size, max_size, patch_size): |
| 16 | + super().__init__() |
| 17 | + self.patch_size = patch_size |
| 18 | + self.pos_emb_t = nn.Parameter(torch.zeros(max_size[0] // patch_size[0], hidden_size)) |
| 19 | + self.pos_emb_h = nn.Parameter(torch.zeros(max_size[1] // patch_size[1], hidden_size)) |
| 20 | + self.pos_emb_w = nn.Parameter(torch.zeros(max_size[2] // patch_size[2], hidden_size)) |
| 21 | + |
| 22 | + def forward(self, hidden_states): |
| 23 | + batch_size, _, num_frames, height, width = hidden_states.shape |
| 24 | + pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]] |
| 25 | + emb_t = self.pos_emb_t[:pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1) |
| 26 | + emb_h = self.pos_emb_h[:pe_size[1]][None, None, :, None, :].repeat(batch_size, pe_size[0], 1, pe_size[2], 1) |
| 27 | + emb_w = self.pos_emb_w[:pe_size[2]][None, None, None, :, :].repeat(batch_size, pe_size[0], pe_size[1], 1, 1) |
| 28 | + emb = emb_t + emb_h + emb_w |
| 29 | + emb = emb.flatten(1, 3) |
| 30 | + return emb |
| 31 | + |
| 32 | +def test_repeat_expand_lowering_repro(): |
| 33 | + device = torch.device("cuda") |
| 34 | + hidden_size = 4096 |
| 35 | + model = CosmosLearnablePositionalEmbed(hidden_size=hidden_size, max_size=(128,240,240), patch_size=(1,2,2)).to(device).eval() |
| 36 | + hidden_states = torch.randn(1, 17, 16, 88, 160, dtype=torch.bfloat16, device=device) |
| 37 | + |
| 38 | + with torch.no_grad(): |
| 39 | + pyt_out = model(hidden_states) |
| 40 | + |
| 41 | + ep = torch.export.export(model, args=(hidden_states,), strict=False) |
| 42 | + trt_mod = torch_tensorrt.dynamo.compile(ep, inputs=[hidden_states], enabled_precisions={torch.bfloat16}, use_python_runtime=True) |
| 43 | + trt_out = trt_mod(hidden_states) |
| 44 | + |
| 45 | + assert pyt_out.shape == trt_out.shape |
| 46 | + maxdiff = (pyt_out.float() - trt_out.float()).abs().max().item() |
| 47 | + assert maxdiff < 1e-2 |
0 commit comments