22
22
DYNAMO_CONVERTERS as CONVERTERS ,
23
23
)
24
24
from torch_tensorrt .dynamo .conversion ._TRTInterpreter import TRTInterpreter
25
+ from torch_tensorrt .dynamo .conversion .impl .normalization .ops import (
26
+ batch_norm_constant_folding ,
27
+ )
25
28
from torch_tensorrt .dynamo .conversion .truncate_double import repair_double_inputs
26
29
from torch_tensorrt .dynamo .lowering import (
27
30
get_decompositions ,
@@ -78,8 +81,9 @@ def construct_refit_mapping(
78
81
compilation_settings = settings ,
79
82
)
80
83
interpreter ._construct_trt_network_def ()
84
+ weight_refit_map : dict [str , torch .Tensor ] = interpreter .ctx .weight_refit_map
81
85
82
- return interpreter . ctx . weight_refit_map
86
+ return weight_refit_map
83
87
84
88
85
89
@needs_refit
@@ -90,7 +94,20 @@ def construct_refit_mapping_from_weight_name_map(
90
94
) -> dict [Any , Any ]:
91
95
engine_weight_map = {}
92
96
for engine_weight_name , (sd_weight_name , np_weight_type ) in weight_name_map .items ():
93
- if sd_weight_name not in state_dict :
97
+ # Add more constant folding converters here
98
+ if engine_weight_name .split (" " )[- 1 ] in ["SCALE" , "SHIFT" ]:
99
+ # Batch Norm Layer
100
+ params = {}
101
+ for w in sd_weight_name :
102
+ params [w .split ("." )[- 1 ]] = state_dict [w ].cuda ()
103
+ # Batch norm constant folding
104
+
105
+ scale , shift = batch_norm_constant_folding (** params , eps = 1e-7 )
106
+ # Set scale to scale or shift to shift
107
+ engine_weight_map [engine_weight_name ] = eval (
108
+ engine_weight_name .split (" " )[- 1 ].lower ()
109
+ )
110
+ elif sd_weight_name not in state_dict :
94
111
# If weights is not in sd, we can leave it unchanged
95
112
continue
96
113
else :
@@ -178,10 +195,12 @@ def _refit_single_trt_engine_with_gm(
178
195
for layer_name in weight_list :
179
196
if layer_name not in mapping :
180
197
raise AssertionError (f"{ layer_name } is not found in weight mapping" )
181
- # Use Numpy to create weights
198
+ # Use Tensor to create weights
182
199
weight = mapping [layer_name ]
183
200
trt_dtype = dtype ._from (weight .dtype ).to (trt .DataType )
184
- trt_wt_tensor = trt .Weights (trt_dtype , weight .ctypes .data , weight .size )
201
+ trt_wt_tensor = trt .Weights (
202
+ trt_dtype , weight .data_ptr (), torch .numel (weight )
203
+ )
185
204
refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
186
205
refitted .add (layer_name )
187
206
@@ -300,7 +319,7 @@ def refit_module_weights(
300
319
301
320
# Check the number of supported operations in the graph
302
321
num_supported_ops , total_ops = partitioning .get_graph_converter_support (
303
- new_gm , settings .debug , settings . torch_executed_ops
322
+ new_gm , settings .torch_executed_ops
304
323
)
305
324
306
325
if num_supported_ops == 0 or (
@@ -363,7 +382,6 @@ def refit_module_weights(
363
382
364
383
# Iterate over all components that can be accelerated
365
384
# Generate the corresponding TRT Module for those
366
- new_weight_module .module ().to (CPU_DEVICE )
367
385
for name , new_submodule in new_partitioned_module .named_children ():
368
386
# Refit each submodule
369
387
# Extract engine from the submodule
0 commit comments