21
21
from torch_tensorrt .dynamo .conversion .impl .cat import cat
22
22
from torch_tensorrt .dynamo .conversion .impl .elementwise .ops import ge
23
23
from torch_tensorrt .dynamo .conversion .impl .shape import shape as get_shape
24
- from torch_tensorrt .dynamo .types import TRTTensor
25
24
from torch_tensorrt .dynamo .utils import DYNAMIC_DIM
26
25
27
26
_LOGGER : logging .Logger = logging .getLogger (__name__ )
@@ -32,17 +31,17 @@ def batch_norm(
32
31
target : Target ,
33
32
source_ir : Optional [SourceIR ],
34
33
name : str ,
35
- input : TRTTensor ,
36
- weight : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
37
- bias : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
38
- running_mean : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
39
- running_var : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
34
+ input : trt . ITensor ,
35
+ weight : Optional [Union [trt . ITensor , torch .Tensor , np .ndarray ]],
36
+ bias : Optional [Union [trt . ITensor , torch .Tensor , np .ndarray ]],
37
+ running_mean : Optional [Union [trt . ITensor , torch .Tensor , np .ndarray ]],
38
+ running_var : Optional [Union [trt . ITensor , torch .Tensor , np .ndarray ]],
40
39
training : bool ,
41
40
momentum : float ,
42
41
eps : float ,
43
42
cudnn_enabled : bool ,
44
43
return_mean_rstd : bool ,
45
- ) -> Union [TRTTensor , Tuple [TRTTensor , torch .Tensor , torch .Tensor ]]:
44
+ ) -> Union [trt . ITensor , Tuple [trt . ITensor , torch .Tensor , torch .Tensor ]]:
46
45
if has_dynamic_shape (input .shape ):
47
46
assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for batch norm."
48
47
@@ -51,77 +50,14 @@ def batch_norm(
51
50
# We perform constant folding for batch norm when the weight, bias, running_mean, and running_var are all tensors.
52
51
# Batch norm operation can be fused into a single layer, which is more efficient than the original implementation.
53
52
# In this way, the batch norm layer will be fused with the Convolution layer and get a performance boost.
54
- if all (
53
+ if any (
55
54
[
56
- isinstance (weight , torch . Tensor ),
57
- isinstance (bias , torch . Tensor ),
58
- isinstance (running_mean , torch . Tensor ),
59
- isinstance (running_var , torch . Tensor ),
55
+ isinstance (weight , trt . ITensor ),
56
+ isinstance (bias , trt . ITensor ),
57
+ isinstance (running_mean , trt . ITensor ),
58
+ isinstance (running_var , trt . ITensor ),
60
59
]
61
60
):
62
- if weight is None :
63
- weight = 1.0
64
-
65
- if bias is None :
66
- bias = 0.0
67
-
68
- if running_mean is None :
69
- running_mean = 0.0
70
-
71
- if running_var is None :
72
- running_var = 1.0
73
- adjusted_scale = weight / torch .sqrt (running_var + eps )
74
- adjusted_bias = bias - running_mean * adjusted_scale
75
- power = torch .ones_like (adjusted_scale )
76
- adjusted_scale = to_trt_weights (
77
- ctx ,
78
- adjusted_scale ,
79
- name ,
80
- layer_type_name = "SCALE" ,
81
- weight_type_name = "SCALE" ,
82
- target = target ,
83
- source_ir = source_ir ,
84
- )
85
- adjusted_bias = to_trt_weights (
86
- ctx ,
87
- adjusted_bias ,
88
- name ,
89
- layer_type_name = "SCALE" ,
90
- weight_type_name = "SHIFT" ,
91
- target = target ,
92
- source_ir = source_ir ,
93
- )
94
-
95
- power = to_trt_weights (
96
- ctx ,
97
- power ,
98
- name ,
99
- layer_type_name = "SCALE" ,
100
- weight_type_name = "POWER" ,
101
- target = target ,
102
- source_ir = source_ir ,
103
- )
104
-
105
- output_shape = input .shape
106
- if len (input .shape ) < 4 :
107
-
108
- new_shape = (
109
- (input .shape [0 ], input .shape [1 ], 1 , 1 )
110
- if len (input .shape ) == 2
111
- else (input .shape [0 ], input .shape [1 ], input .shape [2 ], 1 )
112
- )
113
- input = impl .shuffle .reshape (
114
- ctx , target , source_ir , f"{ name } _reshape_2d" , input , new_shape
115
- )
116
-
117
- layer = ctx .net .add_scale_nd (
118
- input , trt .ScaleMode .CHANNEL , adjusted_bias , adjusted_scale , power , 1
119
- )
120
- set_layer_name (layer , target , name , source_ir )
121
- output = layer .get_output (0 )
122
-
123
- else :
124
-
125
61
# We name the weight here according to the state_dict name
126
62
weight = (
127
63
get_trt_tensor (ctx , 1.0 , f"{ name } _weight" )
@@ -206,6 +142,70 @@ def batch_norm(
206
142
bias_adjusted_reshape ,
207
143
)
208
144
145
+ else :
146
+ if weight is None :
147
+ weight = 1.0
148
+
149
+ if bias is None :
150
+ bias = 0.0
151
+
152
+ if running_mean is None :
153
+ running_mean = 0.0
154
+
155
+ if running_var is None :
156
+ running_var = 1.0
157
+ adjusted_scale , adjusted_bias = batch_norm_constant_folding (
158
+ weight , bias , running_mean , running_var , eps
159
+ )
160
+ power = torch .ones_like (adjusted_scale )
161
+
162
+ adjusted_scale = to_trt_weights (
163
+ ctx ,
164
+ adjusted_scale ,
165
+ name ,
166
+ layer_type_name = "SCALE" ,
167
+ weight_type_name = "SCALE" ,
168
+ target = target ,
169
+ source_ir = source_ir ,
170
+ )
171
+ adjusted_bias = to_trt_weights (
172
+ ctx ,
173
+ adjusted_bias ,
174
+ name ,
175
+ layer_type_name = "SCALE" ,
176
+ weight_type_name = "SHIFT" ,
177
+ target = target ,
178
+ source_ir = source_ir ,
179
+ )
180
+
181
+ power = to_trt_weights (
182
+ ctx ,
183
+ power ,
184
+ name ,
185
+ layer_type_name = "SCALE" ,
186
+ weight_type_name = "POWER" ,
187
+ target = target ,
188
+ source_ir = source_ir ,
189
+ )
190
+
191
+ output_shape = input .shape
192
+ if len (input .shape ) < 4 :
193
+
194
+ new_shape = (
195
+ (input .shape [0 ], input .shape [1 ], 1 , 1 )
196
+ if len (input .shape ) == 2
197
+ else (input .shape [0 ], input .shape [1 ], input .shape [2 ], 1 )
198
+ )
199
+ input = impl .shuffle .reshape (
200
+ ctx , target , source_ir , f"{ name } _reshape_2d" , input , new_shape
201
+ )
202
+
203
+ layer = ctx .net .add_scale_nd (
204
+ input , trt .ScaleMode .CHANNEL , adjusted_bias , adjusted_scale , power , 1
205
+ )
206
+ set_layer_name (layer , target , name , source_ir )
207
+ output = layer .get_output (0 )
208
+
209
209
# For BatchNorm1d, reshape output back to original shape if necessary
210
210
if len (output_shape ) < 4 :
211
211
output = impl .shuffle .reshape (
@@ -224,17 +224,29 @@ def batch_norm(
224
224
return output
225
225
226
226
227
+ def batch_norm_constant_folding (
228
+ weight : torch .Tensor ,
229
+ bias : torch .Tensor ,
230
+ running_mean : torch .Tensor ,
231
+ running_var : torch .Tensor ,
232
+ eps : float ,
233
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
234
+ adjusted_scale = weight / torch .sqrt (running_var + eps )
235
+ adjusted_bias = bias - running_mean * adjusted_scale
236
+ return adjusted_scale , adjusted_bias
237
+
238
+
227
239
def native_layer_norm (
228
240
ctx : ConversionContext ,
229
241
target : Target ,
230
242
source_ir : Optional [SourceIR ],
231
243
name : str ,
232
- input : TRTTensor ,
244
+ input : trt . ITensor ,
233
245
normalized_shape : List [int ],
234
- weight : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
235
- bias : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
246
+ weight : Optional [Union [trt . ITensor , torch .Tensor , np .ndarray ]],
247
+ bias : Optional [Union [trt . ITensor , torch .Tensor , np .ndarray ]],
236
248
eps : float ,
237
- ) -> Tuple [TRTTensor , torch .Tensor , torch .Tensor ]:
249
+ ) -> Tuple [trt . ITensor , torch .Tensor , torch .Tensor ]:
238
250
dims = list (range (len (input .shape ) - len (normalized_shape ), len (input .shape )))
239
251
axes = get_axes_for_reduce_op (dims )
240
252
@@ -274,15 +286,15 @@ def native_group_norm(
274
286
target : Target ,
275
287
source_ir : Optional [SourceIR ],
276
288
name : str ,
277
- input : TRTTensor ,
278
- weight : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
279
- bias : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
289
+ input : trt . ITensor ,
290
+ weight : Optional [Union [trt . ITensor , torch .Tensor , np .ndarray ]],
291
+ bias : Optional [Union [trt . ITensor , torch .Tensor , np .ndarray ]],
280
292
N : int ,
281
293
C : int ,
282
294
HxW : int ,
283
295
group : int ,
284
296
eps : float ,
285
- ) -> Tuple [TRTTensor , torch .Tensor , torch .Tensor ]:
297
+ ) -> Tuple [trt . ITensor , torch .Tensor , torch .Tensor ]:
286
298
rank = len (input .shape )
287
299
288
300
assert rank >= 3 , f"Expected at least 3 dimensions for input tensor but got { rank } "
@@ -303,7 +315,7 @@ def native_group_norm(
303
315
ctx , target , source_ir , f"{ name } _expand_bias_zero" , bias_zero , shape
304
316
)
305
317
306
- axes = get_axes_for_reduce_op ([ i for i in range (1 if group == 1 else 2 , rank )] )
318
+ axes = get_axes_for_reduce_op (list ( range (1 if group == 1 else 2 , rank )) )
307
319
308
320
# INormalizationLayer scales the normalized output per-group, but PyTorch scales the normalized output per-channel,
309
321
# hence causing diverse result. Let TensorRT does no-op for scaling here, and do scaling ourselves later
@@ -348,10 +360,10 @@ def softmax(
348
360
target : Target ,
349
361
source_ir : Optional [SourceIR ],
350
362
name : str ,
351
- input : TRTTensor ,
363
+ input : trt . ITensor ,
352
364
dim : int ,
353
365
half_to_float : bool ,
354
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
366
+ ) -> Union [trt . ITensor , Sequence [trt . ITensor ]]:
355
367
dim = get_positive_dim (dim , len (input .shape ))
356
368
357
369
if half_to_float :
@@ -368,9 +380,9 @@ def pdist(
368
380
target : Target ,
369
381
source_ir : Optional [SourceIR ],
370
382
name : str ,
371
- input : TRTTensor ,
383
+ input : trt . ITensor ,
372
384
p : float = 2 ,
373
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
385
+ ) -> Union [trt . ITensor , Sequence [trt . ITensor ]]:
374
386
shape = input .shape
375
387
# Extend input from shape [N, D] to [N, 1, D]
376
388
extend_input = impl .unsqueeze .unsqueeze (
@@ -464,8 +476,8 @@ def tri_upper_indices(
464
476
target : Target ,
465
477
source_ir : Optional [SourceIR ],
466
478
name : str ,
467
- size_tensor : TRTTensor ,
468
- ) -> TRTTensor :
479
+ size_tensor : trt . ITensor ,
480
+ ) -> trt . ITensor :
469
481
"""
470
482
Return the indices for the upper-triangle part of a square size of matrix in a N-by-2 Tensor,
471
483
where the diagonal offset = 1. One loop is used to calculate the indices like below.
@@ -484,7 +496,7 @@ def tri_upper_indices(
484
496
target (Target): Target of calling node.
485
497
source_ir (Optional[SourceIR]): SourceIR of calling converter.
486
498
name (str): Name of the calling layer.
487
- size_tensor (TRTTensor ): number of rows in the 2-D square matrix. scalar tensor.
499
+ size_tensor (trt.ITensor ): number of rows in the 2-D square matrix. scalar tensor.
488
500
489
501
Example:
490
502
if size_tensor is 4, it will return [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
@@ -634,11 +646,11 @@ def cdist_forward(
634
646
target : Target ,
635
647
source_ir : Optional [SourceIR ],
636
648
name : str ,
637
- x1 : TRTTensor ,
638
- x2 : TRTTensor ,
649
+ x1 : trt . ITensor ,
650
+ x2 : trt . ITensor ,
639
651
p : float ,
640
652
compute_mode : Optional [int ],
641
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
653
+ ) -> Union [trt . ITensor , Sequence [trt . ITensor ]]:
642
654
"""
643
655
Computes pairwise distances between sets of vectors in tensors x1 and x2 using the p-norm. The function treats the last dimension
644
656
of x1 and x2 as feature dimensions, which must be identical for both inputs. The second-to-last dimensions can differ, reflecting
0 commit comments