Skip to content

Commit fb66d21

Browse files
authored
Closed the perf gap of resnet and enabled refit (#3629)
1 parent 99ffe1a commit fb66d21

File tree

4 files changed

+344
-110
lines changed

4 files changed

+344
-110
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
DYNAMO_CONVERTERS as CONVERTERS,
2323
)
2424
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
25+
from torch_tensorrt.dynamo.conversion.impl.normalization.ops import (
26+
batch_norm_constant_folding,
27+
)
2528
from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs
2629
from torch_tensorrt.dynamo.lowering import (
2730
get_decompositions,
@@ -78,8 +81,9 @@ def construct_refit_mapping(
7881
compilation_settings=settings,
7982
)
8083
interpreter._construct_trt_network_def()
84+
weight_refit_map: dict[str, torch.Tensor] = interpreter.ctx.weight_refit_map
8185

82-
return interpreter.ctx.weight_refit_map
86+
return weight_refit_map
8387

8488

8589
@needs_refit
@@ -90,7 +94,20 @@ def construct_refit_mapping_from_weight_name_map(
9094
) -> dict[Any, Any]:
9195
engine_weight_map = {}
9296
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
93-
if sd_weight_name not in state_dict:
97+
# Add more constant folding converters here
98+
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
99+
# Batch Norm Layer
100+
params = {}
101+
for w in sd_weight_name:
102+
params[w.split(".")[-1]] = state_dict[w].cuda()
103+
# Batch norm constant folding
104+
105+
scale, shift = batch_norm_constant_folding(**params, eps=1e-7)
106+
# Set scale to scale or shift to shift
107+
engine_weight_map[engine_weight_name] = eval(
108+
engine_weight_name.split(" ")[-1].lower()
109+
)
110+
elif sd_weight_name not in state_dict:
94111
# If weights is not in sd, we can leave it unchanged
95112
continue
96113
else:
@@ -178,10 +195,12 @@ def _refit_single_trt_engine_with_gm(
178195
for layer_name in weight_list:
179196
if layer_name not in mapping:
180197
raise AssertionError(f"{layer_name} is not found in weight mapping")
181-
# Use Numpy to create weights
198+
# Use Tensor to create weights
182199
weight = mapping[layer_name]
183200
trt_dtype = dtype._from(weight.dtype).to(trt.DataType)
184-
trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size)
201+
trt_wt_tensor = trt.Weights(
202+
trt_dtype, weight.data_ptr(), torch.numel(weight)
203+
)
185204
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
186205
refitted.add(layer_name)
187206

@@ -300,7 +319,7 @@ def refit_module_weights(
300319

301320
# Check the number of supported operations in the graph
302321
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
303-
new_gm, settings.debug, settings.torch_executed_ops
322+
new_gm, settings.torch_executed_ops
304323
)
305324

306325
if num_supported_ops == 0 or (
@@ -363,7 +382,6 @@ def refit_module_weights(
363382

364383
# Iterate over all components that can be accelerated
365384
# Generate the corresponding TRT Module for those
366-
new_weight_module.module().to(CPU_DEVICE)
367385
for name, new_submodule in new_partitioned_module.named_children():
368386
# Refit each submodule
369387
# Extract engine from the submodule

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ def to_trt_weights(
335335
ctx: ConversionContext,
336336
value: torch.Tensor,
337337
name: str,
338-
layer_type_name: Literal["CONVOLUTION", "DECONVOLUTION", "CONSTANT"],
339-
weight_type_name: Literal["KERNEL", "BIAS", "CONSTANT"],
338+
layer_type_name: Literal["CONVOLUTION", "DECONVOLUTION", "CONSTANT", "SCALE"],
339+
weight_type_name: Literal["KERNEL", "BIAS", "CONSTANT", "SCALE", "SHIFT", "POWER"],
340340
target: Optional[Union[Target, str]] = None,
341341
source_ir: Optional[SourceIR] = None,
342342
target_quantized_type: Optional[trt.DataType] = None,
@@ -362,8 +362,8 @@ def to_trt_weights(
362362
)
363363

364364
# Weight Recording
365-
supported_layer_types = ["CONVOLUTION", "DECONVOLUTION", "CONSTANT"]
366-
supported_weight_types = ["KERNEL", "BIAS", "CONSTANT"]
365+
supported_layer_types = ["CONVOLUTION", "DECONVOLUTION", "CONSTANT", "SCALE"]
366+
supported_weight_types = ["KERNEL", "BIAS", "CONSTANT", "SCALE", "SHIFT", "POWER"]
367367
assert (
368368
layer_type_name in supported_layer_types
369369
), f"Encountered unsupported layer type: {layer_type_name}. Supported types are: {supported_layer_types}. Manually calling to_trt_weights with a custom layer type is not intended for general use."

0 commit comments

Comments
 (0)