Skip to content

Commit e34fb80

Browse files
committed
Fixed the bug of refitting, but need a more systematic approach
1 parent cf064c5 commit e34fb80

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

py/torch_tensorrt/dynamo/conversion/impl/conv.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def convNd(
7979
kernel_shape = weight.shape[2:]
8080
elif isinstance(weight, (torch.Tensor, np.ndarray)):
8181
weight = to_torch(weight, dtype=input.dtype)
82+
torch_weight = weight
8283
# Append new dimension (unsqueeze) if the convolution is 1d
8384
if is_conv1d:
8485
weight = torch.unsqueeze(weight, -1)
@@ -105,10 +106,15 @@ def convNd(
105106
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
106107
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
107108
)
109+
110+
set_layer_name(conv_layer, target, name, source_ir)
111+
108112
# If the weight is a TRTTensor, set it as an input of the layer
109113
if isinstance(weight, TRTTensor):
110114
weight = cast_trt_tensor(ctx, weight, input.dtype, name)
111115
conv_layer.set_input(1, weight)
116+
elif weight is not None:
117+
ctx.weight_refit_map[f"{conv_layer.name} KERNEL"] = torch_weight
112118

113119
# If the bias is a TRTTensor, set it as an input of the layer
114120
if isinstance(bias, TRTTensor):
@@ -145,8 +151,6 @@ def convNd(
145151
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
146152
)
147153

148-
set_layer_name(conv_layer, target, name, source_ir)
149-
150154
# Set relevant attributes of convolution layer
151155
if padding is not None:
152156
conv_layer.padding_nd = padding

py/torch_tensorrt/dynamo/conversion/impl/deconv.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def deconvNd(
8080

8181
elif isinstance(weight, (torch.Tensor, np.ndarray)):
8282
weight = to_torch(weight, dtype=input.dtype)
83+
torch_weight = weight
8384
# Append new dimension (unsqueeze) if the deconvolution is 1d
8485
if is_deconv1d:
8586
weight = torch.unsqueeze(weight, -1)
@@ -105,11 +106,13 @@ def deconvNd(
105106
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
106107
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
107108
)
109+
set_layer_name(deconv_layer, target, name, source_ir)
108110

109111
# If the weight is a TRTTensor, set it as an input of the layer
110112
if isinstance(weight, TRTTensor):
111113
deconv_layer.set_input(1, weight)
112-
114+
elif weight is not None:
115+
ctx.weight_refit_map[f"{deconv_layer.name} KERNEL"] = torch_weight
113116
# If the bias is a TRTTensor, set it as an input of the layer
114117
if isinstance(bias, TRTTensor):
115118
deconv_layer.set_input(2, bias)
@@ -135,8 +138,6 @@ def deconvNd(
135138
else output_padding
136139
)
137140

138-
set_layer_name(deconv_layer, target, name, source_ir)
139-
140141
# Set relevant attributes of deconvolution layer
141142
if padding is not None:
142143
deconv_layer.padding_nd = padding

0 commit comments

Comments
 (0)