1
+ from torch2trt .torch2trt import *
2
+ from torch2trt .module_test import add_module_test
3
+
4
+
5
+ @tensorrt_converter ('torch.roll' )
6
+ @tensorrt_converter ('torch.Tensor.roll' )
7
+ def convert_roll (ctx ):
8
+ input = get_arg (ctx , 'input' , 0 , None )
9
+ shifts = get_arg (ctx , 'shifts' , 1 , None )
10
+ dims = get_arg (ctx , 'dims' , 2 , None )
11
+ output = ctx .method_return
12
+
13
+ assert dims is not None , "roll converter only supports roll when dims is specified"
14
+
15
+ ndim = input .ndim
16
+
17
+ input_trt = add_missing_trt_tensors (ctx .network , [input ])[0 ]
18
+
19
+ try :
20
+ iter (shifts )
21
+ except :
22
+ shifts = (shifts ,)
23
+ dims = (dims ,)
24
+
25
+ start = [0 ] * ndim
26
+ shape = tuple ([int (d ) for d in input .shape ])
27
+ stride = [1 ] * ndim
28
+
29
+ for s , d in zip (shifts , dims ):
30
+ start [d ] = (- s ) % shape [d ]
31
+
32
+ start = tuple (start [1 :])
33
+ shape = tuple (shape [1 :])
34
+ stride = tuple (stride [1 :])
35
+
36
+
37
+ layer = ctx .network .add_slice (
38
+ input_trt ,
39
+ start , # [1:] to exclude batch
40
+ shape ,
41
+ stride
42
+ )
43
+ layer .mode = trt .SliceMode .WRAP
44
+
45
+ output ._trt = layer .get_output (0 )
46
+
47
+
48
+ class Roll (torch .nn .Module ):
49
+
50
+ def __init__ (self , * args , ** kwargs ):
51
+ super ().__init__ ()
52
+ self .args = args
53
+ self .kwargs = kwargs
54
+
55
+ def forward (self , x ):
56
+ return torch .roll (x , * self .args , ** self .kwargs )
57
+
58
+
59
+ @add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 4 )])
60
+ @add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 4 , 5 )])
61
+ @add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 3 , 4 , 5 )])
62
+ def test_roll_int ():
63
+ return Roll (1 , 1 )
64
+
65
+
66
+ @add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 4 , 5 )])
67
+ @add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 3 , 4 , 5 )])
68
+ def test_roll_int_dim ():
69
+ return Roll (1 , - 2 )
70
+
71
+
72
+ @add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 3 , 4 , 5 )])
73
+ def test_roll_tuple ():
74
+ return Roll ((2 , 3 ), (1 , 3 ))
0 commit comments