Skip to content

Commit acf4a50

Browse files
committed
Address the comments
1 parent 7f05316 commit acf4a50

File tree

2 files changed

+113
-96
lines changed

2 files changed

+113
-96
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
DYNAMO_CONVERTERS as CONVERTERS,
2323
)
2424
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
25+
from torch_tensorrt.dynamo.conversion.impl.normalization.ops import (
26+
batch_norm_constant_folding,
27+
)
2528
from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs
2629
from torch_tensorrt.dynamo.lowering import (
2730
get_decompositions,
@@ -91,13 +94,15 @@ def construct_refit_mapping_from_weight_name_map(
9194
) -> dict[Any, Any]:
9295
engine_weight_map = {}
9396
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
97+
# Add more constant folding converters here
9498
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
9599
# Batch Norm Layer
96100
params = {}
97101
for w in sd_weight_name:
98102
params[w.split(".")[-1]] = state_dict[w].cuda()
99-
scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7)
100-
shift = params["bias"] - params["running_mean"] * scale
103+
# Batch norm constant folding
104+
105+
scale, shift = batch_norm_constant_folding(**params, eps=1e-7)
101106
# Set scale to scale or shift to shift
102107
engine_weight_map[engine_weight_name] = eval(
103108
engine_weight_name.split(" ")[-1].lower()

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 106 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from torch_tensorrt.dynamo.conversion.impl.cat import cat
2222
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
2323
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
24-
from torch_tensorrt.dynamo.types import TRTTensor
2524
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
2625

2726
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -32,17 +31,17 @@ def batch_norm(
3231
target: Target,
3332
source_ir: Optional[SourceIR],
3433
name: str,
35-
input: TRTTensor,
36-
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
37-
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
38-
running_mean: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
39-
running_var: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
34+
input: trt.ITensor,
35+
weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
36+
bias: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
37+
running_mean: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
38+
running_var: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
4039
training: bool,
4140
momentum: float,
4241
eps: float,
4342
cudnn_enabled: bool,
4443
return_mean_rstd: bool,
45-
) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:
44+
) -> Union[trt.ITensor, Tuple[trt.ITensor, torch.Tensor, torch.Tensor]]:
4645
if has_dynamic_shape(input.shape):
4746
assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm."
4847

@@ -51,77 +50,14 @@ def batch_norm(
5150
# We perform constant folding for batch norm when the weight, bias, running_mean, and running_var are all tensors.
5251
# Batch norm operation can be fused into a single layer, which is more efficient than the original implementation.
5352
# In this way, the batch norm layer will be fused with the Convolution layer and get a performance boost.
54-
if all(
53+
if any(
5554
[
56-
isinstance(weight, torch.Tensor),
57-
isinstance(bias, torch.Tensor),
58-
isinstance(running_mean, torch.Tensor),
59-
isinstance(running_var, torch.Tensor),
55+
isinstance(weight, trt.ITensor),
56+
isinstance(bias, trt.ITensor),
57+
isinstance(running_mean, trt.ITensor),
58+
isinstance(running_var, trt.ITensor),
6059
]
6160
):
62-
if weight is None:
63-
weight = 1.0
64-
65-
if bias is None:
66-
bias = 0.0
67-
68-
if running_mean is None:
69-
running_mean = 0.0
70-
71-
if running_var is None:
72-
running_var = 1.0
73-
adjusted_scale = weight / torch.sqrt(running_var + eps)
74-
adjusted_bias = bias - running_mean * adjusted_scale
75-
power = torch.ones_like(adjusted_scale)
76-
adjusted_scale = to_trt_weights(
77-
ctx,
78-
adjusted_scale,
79-
name,
80-
layer_type_name="SCALE",
81-
weight_type_name="SCALE",
82-
target=target,
83-
source_ir=source_ir,
84-
)
85-
adjusted_bias = to_trt_weights(
86-
ctx,
87-
adjusted_bias,
88-
name,
89-
layer_type_name="SCALE",
90-
weight_type_name="SHIFT",
91-
target=target,
92-
source_ir=source_ir,
93-
)
94-
95-
power = to_trt_weights(
96-
ctx,
97-
power,
98-
name,
99-
layer_type_name="SCALE",
100-
weight_type_name="POWER",
101-
target=target,
102-
source_ir=source_ir,
103-
)
104-
105-
output_shape = input.shape
106-
if len(input.shape) < 4:
107-
108-
new_shape = (
109-
(input.shape[0], input.shape[1], 1, 1)
110-
if len(input.shape) == 2
111-
else (input.shape[0], input.shape[1], input.shape[2], 1)
112-
)
113-
input = impl.shuffle.reshape(
114-
ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape
115-
)
116-
117-
layer = ctx.net.add_scale_nd(
118-
input, trt.ScaleMode.CHANNEL, adjusted_bias, adjusted_scale, power, 1
119-
)
120-
set_layer_name(layer, target, name, source_ir)
121-
output = layer.get_output(0)
122-
123-
else:
124-
12561
# We name the weight here according to the state_dict name
12662
weight = (
12763
get_trt_tensor(ctx, 1.0, f"{name}_weight")
@@ -206,6 +142,70 @@ def batch_norm(
206142
bias_adjusted_reshape,
207143
)
208144

145+
else:
146+
if weight is None:
147+
weight = 1.0
148+
149+
if bias is None:
150+
bias = 0.0
151+
152+
if running_mean is None:
153+
running_mean = 0.0
154+
155+
if running_var is None:
156+
running_var = 1.0
157+
adjusted_scale, adjusted_bias = batch_norm_constant_folding(
158+
weight, bias, running_mean, running_var, eps
159+
)
160+
power = torch.ones_like(adjusted_scale)
161+
162+
adjusted_scale = to_trt_weights(
163+
ctx,
164+
adjusted_scale,
165+
name,
166+
layer_type_name="SCALE",
167+
weight_type_name="SCALE",
168+
target=target,
169+
source_ir=source_ir,
170+
)
171+
adjusted_bias = to_trt_weights(
172+
ctx,
173+
adjusted_bias,
174+
name,
175+
layer_type_name="SCALE",
176+
weight_type_name="SHIFT",
177+
target=target,
178+
source_ir=source_ir,
179+
)
180+
181+
power = to_trt_weights(
182+
ctx,
183+
power,
184+
name,
185+
layer_type_name="SCALE",
186+
weight_type_name="POWER",
187+
target=target,
188+
source_ir=source_ir,
189+
)
190+
191+
output_shape = input.shape
192+
if len(input.shape) < 4:
193+
194+
new_shape = (
195+
(input.shape[0], input.shape[1], 1, 1)
196+
if len(input.shape) == 2
197+
else (input.shape[0], input.shape[1], input.shape[2], 1)
198+
)
199+
input = impl.shuffle.reshape(
200+
ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape
201+
)
202+
203+
layer = ctx.net.add_scale_nd(
204+
input, trt.ScaleMode.CHANNEL, adjusted_bias, adjusted_scale, power, 1
205+
)
206+
set_layer_name(layer, target, name, source_ir)
207+
output = layer.get_output(0)
208+
209209
# For BatchNorm1d, reshape output back to original shape if necessary
210210
if len(output_shape) < 4:
211211
output = impl.shuffle.reshape(
@@ -224,17 +224,29 @@ def batch_norm(
224224
return output
225225

226226

227+
def batch_norm_constant_folding(
228+
weight: torch.Tensor,
229+
bias: torch.Tensor,
230+
running_mean: torch.Tensor,
231+
running_var: torch.Tensor,
232+
eps: float,
233+
) -> Tuple[torch.Tensor, torch.Tensor]:
234+
adjusted_scale = weight / torch.sqrt(running_var + eps)
235+
adjusted_bias = bias - running_mean * adjusted_scale
236+
return adjusted_scale, adjusted_bias
237+
238+
227239
def native_layer_norm(
228240
ctx: ConversionContext,
229241
target: Target,
230242
source_ir: Optional[SourceIR],
231243
name: str,
232-
input: TRTTensor,
244+
input: trt.ITensor,
233245
normalized_shape: List[int],
234-
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
235-
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
246+
weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
247+
bias: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
236248
eps: float,
237-
) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]:
249+
) -> Tuple[trt.ITensor, torch.Tensor, torch.Tensor]:
238250
dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape)))
239251
axes = get_axes_for_reduce_op(dims)
240252

@@ -274,15 +286,15 @@ def native_group_norm(
274286
target: Target,
275287
source_ir: Optional[SourceIR],
276288
name: str,
277-
input: TRTTensor,
278-
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
279-
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
289+
input: trt.ITensor,
290+
weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
291+
bias: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
280292
N: int,
281293
C: int,
282294
HxW: int,
283295
group: int,
284296
eps: float,
285-
) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]:
297+
) -> Tuple[trt.ITensor, torch.Tensor, torch.Tensor]:
286298
rank = len(input.shape)
287299

288300
assert rank >= 3, f"Expected at least 3 dimensions for input tensor but got {rank}"
@@ -303,7 +315,7 @@ def native_group_norm(
303315
ctx, target, source_ir, f"{name}_expand_bias_zero", bias_zero, shape
304316
)
305317

306-
axes = get_axes_for_reduce_op([i for i in range(1 if group == 1 else 2, rank)])
318+
axes = get_axes_for_reduce_op(list(range(1 if group == 1 else 2, rank)))
307319

308320
# INormalizationLayer scales the normalized output per-group, but PyTorch scales the normalized output per-channel,
309321
# hence causing diverse result. Let TensorRT does no-op for scaling here, and do scaling ourselves later
@@ -348,10 +360,10 @@ def softmax(
348360
target: Target,
349361
source_ir: Optional[SourceIR],
350362
name: str,
351-
input: TRTTensor,
363+
input: trt.ITensor,
352364
dim: int,
353365
half_to_float: bool,
354-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
366+
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
355367
dim = get_positive_dim(dim, len(input.shape))
356368

357369
if half_to_float:
@@ -368,9 +380,9 @@ def pdist(
368380
target: Target,
369381
source_ir: Optional[SourceIR],
370382
name: str,
371-
input: TRTTensor,
383+
input: trt.ITensor,
372384
p: float = 2,
373-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
385+
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
374386
shape = input.shape
375387
# Extend input from shape [N, D] to [N, 1, D]
376388
extend_input = impl.unsqueeze.unsqueeze(
@@ -464,8 +476,8 @@ def tri_upper_indices(
464476
target: Target,
465477
source_ir: Optional[SourceIR],
466478
name: str,
467-
size_tensor: TRTTensor,
468-
) -> TRTTensor:
479+
size_tensor: trt.ITensor,
480+
) -> trt.ITensor:
469481
"""
470482
Return the indices for the upper-triangle part of a square size of matrix in a N-by-2 Tensor,
471483
where the diagonal offset = 1. One loop is used to calculate the indices like below.
@@ -484,7 +496,7 @@ def tri_upper_indices(
484496
target (Target): Target of calling node.
485497
source_ir (Optional[SourceIR]): SourceIR of calling converter.
486498
name (str): Name of the calling layer.
487-
size_tensor (TRTTensor): number of rows in the 2-D square matrix. scalar tensor.
499+
size_tensor (trt.ITensor): number of rows in the 2-D square matrix. scalar tensor.
488500
489501
Example:
490502
if size_tensor is 4, it will return [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
@@ -634,11 +646,11 @@ def cdist_forward(
634646
target: Target,
635647
source_ir: Optional[SourceIR],
636648
name: str,
637-
x1: TRTTensor,
638-
x2: TRTTensor,
649+
x1: trt.ITensor,
650+
x2: trt.ITensor,
639651
p: float,
640652
compute_mode: Optional[int],
641-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
653+
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
642654
"""
643655
Computes pairwise distances between sets of vectors in tensors x1 and x2 using the p-norm. The function treats the last dimension
644656
of x1 and x2 as feature dimensions, which must be identical for both inputs. The second-to-last dimensions can differ, reflecting

0 commit comments

Comments
 (0)