Skip to content

Commit 1ce1eeb

Browse files
authored
Merge pull request #574 from jaybdub/roll_converter
added converter for torch.roll
2 parents adbd5db + e70833e commit 1ce1eeb

File tree

3 files changed

+76
-0
lines changed

3 files changed

+76
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## [Master]
44

5+
- Added converter for ``torch.roll``
56
- Added converter for ``torch.nn.functional.layer_norm``
67
- Added converter for ``torch.nn.functional.gelu``
78
- Added converter for ``torch.nn.functional.linear``

torch2trt/converters/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from .prod import *
5252
from .relu import *
5353
from .relu6 import *
54+
from .roll import *
5455
from .sigmoid import *
5556
from .silu import *
5657
from .softmax import *

torch2trt/converters/roll.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from torch2trt.torch2trt import *
2+
from torch2trt.module_test import add_module_test
3+
4+
5+
@tensorrt_converter('torch.roll')
6+
@tensorrt_converter('torch.Tensor.roll')
7+
def convert_roll(ctx):
8+
input = get_arg(ctx, 'input', 0, None)
9+
shifts = get_arg(ctx, 'shifts', 1, None)
10+
dims = get_arg(ctx, 'dims', 2, None)
11+
output = ctx.method_return
12+
13+
assert dims is not None, "roll converter only supports roll when dims is specified"
14+
15+
ndim = input.ndim
16+
17+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
18+
19+
try:
20+
iter(shifts)
21+
except:
22+
shifts = (shifts,)
23+
dims = (dims,)
24+
25+
start = [0] * ndim
26+
shape = tuple([int(d) for d in input.shape])
27+
stride = [1] * ndim
28+
29+
for s, d in zip(shifts, dims):
30+
start[d] = (-s) % shape[d]
31+
32+
start = tuple(start[1:])
33+
shape = tuple(shape[1:])
34+
stride = tuple(stride[1:])
35+
36+
37+
layer = ctx.network.add_slice(
38+
input_trt,
39+
start, # [1:] to exclude batch
40+
shape,
41+
stride
42+
)
43+
layer.mode = trt.SliceMode.WRAP
44+
45+
output._trt = layer.get_output(0)
46+
47+
48+
class Roll(torch.nn.Module):
49+
50+
def __init__(self, *args, **kwargs):
51+
super().__init__()
52+
self.args = args
53+
self.kwargs = kwargs
54+
55+
def forward(self, x):
56+
return torch.roll(x, *self.args, **self.kwargs)
57+
58+
59+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 4)])
60+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 4, 5)])
61+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
62+
def test_roll_int():
63+
return Roll(1, 1)
64+
65+
66+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 4, 5)])
67+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
68+
def test_roll_int_dim():
69+
return Roll(1, -2)
70+
71+
72+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
73+
def test_roll_tuple():
74+
return Roll((2, 3), (1, 3))

0 commit comments

Comments
 (0)