Skip to content

Commit 2332d2d

Browse files
Hsieh, KevinGitHub Enterprise
authored andcommitted
Add support for grid_sample in model_preparer
Signed-off-by: Kevin Hsieh <quic_klhsieh@quicinc.com>
1 parent 96d61db commit 2332d2d

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

TrainingExtensions/torch/src/python/aimet_torch/model_preparer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@
217217
'masked_fill' : aimet_modules.MaskedFill,
218218
'square' : aimet_modules.Square,
219219
'rsqrt' : aimet_modules.RSqrt,
220+
'grid_sample' : aimet_modules.GridSample,
220221
}
221222

222223

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
import torch
3+
from aimet_torch.model_preparer import prepare_model
4+
5+
@pytest.mark.parametrize('mode', ['bilinear', 'nearest', 'bicubic'])
6+
@pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection'])
7+
@pytest.mark.parametrize('align_corners', [True, False])
8+
def test_grid_sample(mode, padding_mode, align_corners):
9+
torch.manual_seed(0)
10+
class Model(torch.nn.Module):
11+
def __init__(self):
12+
super(Model, self).__init__()
13+
14+
def forward(self, input, grid):
15+
return torch.nn.functional.grid_sample(input,
16+
grid,
17+
mode=mode,
18+
padding_mode=padding_mode,
19+
align_corners=align_corners)
20+
21+
model = Model()
22+
dummy_input = (torch.randn(1, 3, 8, 8), torch.randn(1, 5, 5, 2))
23+
24+
original_out = model(*dummy_input)
25+
print(original_out)
26+
prepared_model = prepare_model(model)
27+
prepared_out = prepared_model(*dummy_input)
28+
29+
assert torch.equal(original_out, prepared_out)
30+
assert len([module for module in prepared_model.modules()]) == 2

0 commit comments

Comments
 (0)