Skip to content

Commit ce31bbd

Browse files
committed
Closed the perf gap of resnet and enabled refit
1 parent b0d5787 commit ce31bbd

File tree

3 files changed

+165
-82
lines changed

3 files changed

+165
-82
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ def construct_refit_mapping(
7878
compilation_settings=settings,
7979
)
8080
interpreter._construct_trt_network_def()
81+
weight_refit_map: dict[str, torch.Tensor] = interpreter.ctx.weight_refit_map
8182

82-
return interpreter.ctx.weight_refit_map
83+
return weight_refit_map
8384

8485

8586
@needs_refit
@@ -90,7 +91,18 @@ def construct_refit_mapping_from_weight_name_map(
9091
) -> dict[Any, Any]:
9192
engine_weight_map = {}
9293
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
93-
if sd_weight_name not in state_dict:
94+
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
95+
# Batch Norm Layer
96+
params = {}
97+
for w in sd_weight_name:
98+
params[w.split(".")[-1]] = state_dict[w].cuda()
99+
scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7)
100+
shift = params["bias"] - params["running_mean"] * scale
101+
# Set scale to scale or shift to shift
102+
engine_weight_map[engine_weight_name] = eval(
103+
engine_weight_name.split(" ")[-1].lower()
104+
)
105+
elif sd_weight_name not in state_dict:
94106
# If weights is not in sd, we can leave it unchanged
95107
continue
96108
else:
@@ -300,7 +312,7 @@ def refit_module_weights(
300312

301313
# Check the number of supported operations in the graph
302314
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
303-
new_gm, settings.debug, settings.torch_executed_ops
315+
new_gm, settings.torch_executed_ops
304316
)
305317

306318
if num_supported_ops == 0 or (
@@ -363,7 +375,6 @@ def refit_module_weights(
363375

364376
# Iterate over all components that can be accelerated
365377
# Generate the corresponding TRT Module for those
366-
new_weight_module.module().to(CPU_DEVICE)
367378
for name, new_submodule in new_partitioned_module.named_children():
368379
# Refit each submodule
369380
# Extract engine from the submodule

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ def to_trt_weights(
335335
ctx: ConversionContext,
336336
value: torch.Tensor,
337337
name: str,
338-
layer_type_name: Literal["CONVOLUTION", "DECONVOLUTION", "CONSTANT"],
339-
weight_type_name: Literal["KERNEL", "BIAS", "CONSTANT"],
338+
layer_type_name: Literal["CONVOLUTION", "DECONVOLUTION", "CONSTANT", "SCALE"],
339+
weight_type_name: Literal["KERNEL", "BIAS", "CONSTANT", "SCALE", "SHIFT", "POWER"],
340340
target: Optional[Union[Target, str]] = None,
341341
source_ir: Optional[SourceIR] = None,
342342
target_quantized_type: Optional[trt.DataType] = None,
@@ -362,8 +362,8 @@ def to_trt_weights(
362362
)
363363

364364
# Weight Recording
365-
supported_layer_types = ["CONVOLUTION", "DECONVOLUTION", "CONSTANT"]
366-
supported_weight_types = ["KERNEL", "BIAS", "CONSTANT"]
365+
supported_layer_types = ["CONVOLUTION", "DECONVOLUTION", "CONSTANT", "SCALE"]
366+
supported_weight_types = ["KERNEL", "BIAS", "CONSTANT", "SCALE", "SHIFT", "POWER"]
367367
assert (
368368
layer_type_name in supported_layer_types
369369
), f"Encountered unsupported layer type: {layer_type_name}. Supported types are: {supported_layer_types}. Manually calling to_trt_weights with a custom layer type is not intended for general use."

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 146 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_trt_tensor,
1717
has_dynamic_shape,
1818
set_layer_name,
19+
to_trt_weights,
1920
)
2021
from torch_tensorrt.dynamo.conversion.impl.cat import cat
2122
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
@@ -48,89 +49,160 @@ def batch_norm(
4849
# Save the original output shape for later use
4950
output_shape = input.shape
5051

51-
# We name the weight here according to the state_dict name
52-
weight = (
53-
get_trt_tensor(ctx, 1.0, f"{name}_weight")
54-
if weight is None
55-
else get_trt_tensor(ctx, weight, f"{name}_weight")
56-
)
57-
bias = (
58-
get_trt_tensor(ctx, 0.0, f"{name}_bias")
59-
if bias is None
60-
else get_trt_tensor(ctx, bias, f"{name}_bias")
61-
)
62-
running_mean = (
63-
get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
64-
if running_mean is None
65-
else get_trt_tensor(ctx, running_mean, f"{name}_running_mean")
66-
)
67-
running_var = (
68-
get_trt_tensor(ctx, 1.0, f"{name}_running_var")
69-
if running_var is None
70-
else get_trt_tensor(ctx, running_var, f"{name}_running_var")
71-
)
52+
if all(
53+
[
54+
isinstance(weight, torch.Tensor),
55+
isinstance(bias, torch.Tensor),
56+
isinstance(running_mean, torch.Tensor),
57+
isinstance(running_var, torch.Tensor),
58+
]
59+
):
60+
if weight is None:
61+
weight = 1.0
62+
63+
if bias is None:
64+
bias = 0.0
65+
66+
if running_mean is None:
67+
running_mean = 0.0
68+
69+
if running_var is None:
70+
running_var = 1.0
71+
adjusted_scale = weight / torch.sqrt(running_var + eps)
72+
adjusted_bias = bias - running_mean * adjusted_scale
73+
power = torch.ones_like(adjusted_scale)
74+
adjusted_scale = to_trt_weights(
75+
ctx,
76+
adjusted_scale,
77+
name,
78+
layer_type_name="SCALE",
79+
weight_type_name="SCALE",
80+
target=target,
81+
source_ir=source_ir,
82+
)
83+
adjusted_bias = to_trt_weights(
84+
ctx,
85+
adjusted_bias,
86+
name,
87+
layer_type_name="SCALE",
88+
weight_type_name="SHIFT",
89+
target=target,
90+
source_ir=source_ir,
91+
)
7292

73-
# eps_tensor for numerical stability
74-
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")
93+
power = to_trt_weights(
94+
ctx,
95+
power,
96+
name,
97+
layer_type_name="SCALE",
98+
weight_type_name="POWER",
99+
target=target,
100+
source_ir=source_ir,
101+
)
75102

76-
# adjusted_var = running_var + eps
77-
adjusted_var = impl.elementwise.add(
78-
ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor
79-
)
103+
output_shape = input.shape
104+
if len(input.shape) < 4:
80105

81-
# sqrt_adjusted_var = sqrt(adjusted_var)
82-
sqrt_adjusted_var = impl.unary.sqrt(
83-
ctx, target, source_ir, f"{name}_sqrt", adjusted_var
84-
)
106+
new_shape = (
107+
(input.shape[0], input.shape[1], 1, 1)
108+
if len(input.shape) == 2
109+
else (input.shape[0], input.shape[1], input.shape[2], 1)
110+
)
111+
input = impl.shuffle.reshape(
112+
ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape
113+
)
85114

86-
# scale = weight / sqrt_adjusted_var
87-
scale = impl.elementwise.div(
88-
ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var
89-
)
115+
layer = ctx.net.add_scale_nd(
116+
input, trt.ScaleMode.CHANNEL, adjusted_bias, adjusted_scale, power, 1
117+
)
118+
set_layer_name(layer, target, name, source_ir)
119+
output = layer.get_output(0)
90120

91-
# scaled_running_mean = running_mean * scale
92-
scaled_running_mean = impl.elementwise.mul(
93-
ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale
94-
)
121+
else:
95122

96-
# bias_adjusted = bias - scaled_running_mean
97-
bias_adjusted = impl.elementwise.sub(
98-
ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean
99-
)
123+
# We name the weight here according to the state_dict name
124+
weight = (
125+
get_trt_tensor(ctx, 1.0, f"{name}_weight")
126+
if weight is None
127+
else get_trt_tensor(ctx, weight, f"{name}_weight")
128+
)
129+
bias = (
130+
get_trt_tensor(ctx, 0.0, f"{name}_bias")
131+
if bias is None
132+
else get_trt_tensor(ctx, bias, f"{name}_bias")
133+
)
134+
running_mean = (
135+
get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
136+
if running_mean is None
137+
else get_trt_tensor(ctx, running_mean, f"{name}_running_mean")
138+
)
139+
running_var = (
140+
get_trt_tensor(ctx, 1.0, f"{name}_running_var")
141+
if running_var is None
142+
else get_trt_tensor(ctx, running_var, f"{name}_running_var")
143+
)
100144

101-
# Reshape scale and bias_adjusted to match input shape for broadcasting
102-
expanded_shape = [1] * len(output_shape)
103-
expanded_shape[1] = output_shape[1] # Set channel dimension
145+
# eps_tensor for numerical stability
146+
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")
104147

105-
scale_reshape = impl.shuffle.reshape(
106-
ctx,
107-
target,
108-
source_ir,
109-
f"{name}_reshape_scale",
110-
scale,
111-
tuple(expanded_shape),
112-
)
113-
bias_adjusted_reshape = impl.shuffle.reshape(
114-
ctx,
115-
target,
116-
source_ir,
117-
f"{name}_reshape_bias",
118-
bias_adjusted,
119-
tuple(expanded_shape),
120-
)
148+
# adjusted_var = running_var + eps
149+
adjusted_var = impl.elementwise.add(
150+
ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor
151+
)
121152

122-
# Apply the scale and bias to the input
123-
scaled_input = impl.elementwise.mul(
124-
ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape
125-
)
126-
output = impl.elementwise.add(
127-
ctx,
128-
target,
129-
source_ir,
130-
f"{name}_output",
131-
scaled_input,
132-
bias_adjusted_reshape,
133-
)
153+
# sqrt_adjusted_var = sqrt(adjusted_var)
154+
sqrt_adjusted_var = impl.unary.sqrt(
155+
ctx, target, source_ir, f"{name}_sqrt", adjusted_var
156+
)
157+
158+
# scale = weight / sqrt_adjusted_var
159+
scale = impl.elementwise.div(
160+
ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var
161+
)
162+
163+
# scaled_running_mean = running_mean * scale
164+
scaled_running_mean = impl.elementwise.mul(
165+
ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale
166+
)
167+
168+
# bias_adjusted = bias - scaled_running_mean
169+
bias_adjusted = impl.elementwise.sub(
170+
ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean
171+
)
172+
173+
# Reshape scale and bias_adjusted to match input shape for broadcasting
174+
expanded_shape = [1] * len(output_shape)
175+
expanded_shape[1] = output_shape[1] # Set channel dimension
176+
177+
scale_reshape = impl.shuffle.reshape(
178+
ctx,
179+
target,
180+
source_ir,
181+
f"{name}_reshape_scale",
182+
scale,
183+
tuple(expanded_shape),
184+
)
185+
bias_adjusted_reshape = impl.shuffle.reshape(
186+
ctx,
187+
target,
188+
source_ir,
189+
f"{name}_reshape_bias",
190+
bias_adjusted,
191+
tuple(expanded_shape),
192+
)
193+
194+
# Apply the scale and bias to the input
195+
scaled_input = impl.elementwise.mul(
196+
ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape
197+
)
198+
output = impl.elementwise.add(
199+
ctx,
200+
target,
201+
source_ir,
202+
f"{name}_output",
203+
scaled_input,
204+
bias_adjusted_reshape,
205+
)
134206

135207
# For BatchNorm1d, reshape output back to original shape if necessary
136208
if len(output_shape) < 4:

0 commit comments

Comments
 (0)