Skip to content

Commit 2520a68

Browse files
committed
Added weight recording mechanism
1 parent e34fb80 commit 2520a68

File tree

3 files changed

+72
-16
lines changed

3 files changed

+72
-16
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,15 @@ def cast_int_or_float_to_bool(
321321

322322

323323
def to_trt_weights(
324-
value: Any, target_quantized_type: Optional[trt.DataType] = None
324+
value: Any,
325+
record_weight: bool = False,
326+
name: Optional[str] = None,
327+
ctx: Optional[ConversionContext] = None,
328+
target: Optional[Union[Target, str]] = None,
329+
layer_type_name: Optional[str] = None,
330+
weight_type_name: Optional[str] = None,
331+
source_ir: Optional[SourceIR] = None,
332+
target_quantized_type: Optional[trt.DataType] = None,
325333
) -> trt.Weights:
326334
"""
327335
Convert a PyTorch tensor or NumPy array to TensorRT weights.
@@ -336,6 +344,35 @@ def to_trt_weights(
336344
- Input tensors are made contiguous before conversion
337345
- Data type is preserved from the original tensor/array
338346
"""
347+
if record_weight:
348+
assert name is not None, "name must be provided if record_weight is True"
349+
assert ctx is not None, "ctx must be provided if record_weight is True"
350+
assert target is not None, "target must be provided if record_weight is True"
351+
assert (
352+
layer_type_name is not None
353+
), "layer_type_name must be provided if record_weight is True"
354+
assert (
355+
weight_type_name is not None
356+
), "weight_type_name must be provided if record_weight is True"
357+
358+
supported_layer_types = ["CONVOLUTION", "DECONVOLUTION"]
359+
supported_weight_types = ["KERNEL"]
360+
assert (
361+
layer_type_name in supported_layer_types
362+
), f"Unsupported layer type: {layer_type_name}. Please add the layer type to this function to enable refitting."
363+
assert (
364+
weight_type_name in supported_weight_types
365+
), f"Unsupported weight type: {weight_type_name}. Please add the weight type to this function to enable refitting."
366+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
367+
target_name = (
368+
f"{source_ir}_ops.{target}"
369+
if isinstance(target, str)
370+
else f"{source_ir}_ops.{target.__name__}"
371+
)
372+
373+
name = f"[{layer_type_name}]-[{target_name}]-[{name}] {weight_type_name}"
374+
record_weight_in_ctx(ctx, name, value)
375+
339376
if isinstance(value, torch.Tensor):
340377
# Tensor must be contiguous before conversion
341378
value = value.contiguous()
@@ -351,6 +388,15 @@ def to_trt_weights(
351388
)
352389

353390

391+
def record_weight_in_ctx(
392+
ctx: ConversionContext,
393+
name: str,
394+
value: torch.Tensor,
395+
) -> None:
396+
ctx.weight_refit_map[name] = value
397+
ctx.cpu_weights_reference_holder[name] = value
398+
399+
354400
def create_constant(
355401
ctx: ConversionContext,
356402
value: Union[int, float, bool, np.ndarray, torch.Tensor],
@@ -415,17 +461,14 @@ def create_constant(
415461
weights,
416462
)
417463
constant.name = name
418-
ctx.cpu_weights_reference_holder[name + " FP4_CONSTANT"] = torch_value
464+
record_weight_in_ctx(ctx, name + " FP4_CONSTANT", torch_value)
419465
return constant.get_output(0)
420466

421-
# Used for refit
422-
ctx.weight_refit_map[name + " CONSTANT"] = torch_value
423-
424-
# This is a buffer to hold the torch.Tensor so that they are alive during the course of TRT compilation.
425-
ctx.cpu_weights_reference_holder[name] = torch_value
467+
# Record the weight in ctx for refit and cpu memory reference
468+
record_weight_in_ctx(ctx, name + " CONSTANT", torch_value)
426469

427470
# Convert the torch.Tensor to a trt.Weights object
428-
trt_weights = to_trt_weights(torch_value)
471+
trt_weights = to_trt_weights(torch_value, record_weight=False)
429472
constant = ctx.net.add_constant(
430473
shape,
431474
trt_weights,

py/torch_tensorrt/dynamo/conversion/impl/conv.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,22 @@ def convNd(
7979
kernel_shape = weight.shape[2:]
8080
elif isinstance(weight, (torch.Tensor, np.ndarray)):
8181
weight = to_torch(weight, dtype=input.dtype)
82-
torch_weight = weight
8382
# Append new dimension (unsqueeze) if the convolution is 1d
8483
if is_conv1d:
8584
weight = torch.unsqueeze(weight, -1)
8685

8786
num_output_maps = weight.shape[0]
8887
kernel_shape = weight.shape[2:]
89-
weight = to_trt_weights(weight)
88+
weight = to_trt_weights(
89+
weight,
90+
record_weight=True,
91+
name=name,
92+
ctx=ctx,
93+
target=target,
94+
layer_type_name="CONVOLUTION",
95+
weight_type_name="KERNEL",
96+
source_ir=source_ir,
97+
)
9098

9199
else:
92100
raise RuntimeError(
@@ -113,8 +121,6 @@ def convNd(
113121
if isinstance(weight, TRTTensor):
114122
weight = cast_trt_tensor(ctx, weight, input.dtype, name)
115123
conv_layer.set_input(1, weight)
116-
elif weight is not None:
117-
ctx.weight_refit_map[f"{conv_layer.name} KERNEL"] = torch_weight
118124

119125
# If the bias is a TRTTensor, set it as an input of the layer
120126
if isinstance(bias, TRTTensor):

py/torch_tensorrt/dynamo/conversion/impl/deconv.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,21 @@ def deconvNd(
8080

8181
elif isinstance(weight, (torch.Tensor, np.ndarray)):
8282
weight = to_torch(weight, dtype=input.dtype)
83-
torch_weight = weight
8483
# Append new dimension (unsqueeze) if the deconvolution is 1d
8584
if is_deconv1d:
8685
weight = torch.unsqueeze(weight, -1)
8786
num_output_maps = weight.shape[1]
8887
kernel_shape = weight.shape[2:]
89-
weight = to_trt_weights(weight)
88+
weight = to_trt_weights(
89+
weight,
90+
record_weight=True,
91+
name=name,
92+
ctx=ctx,
93+
target=target,
94+
layer_type_name="DECONVOLUTION",
95+
weight_type_name="KERNEL",
96+
source_ir=source_ir,
97+
)
9098

9199
else:
92100
raise RuntimeError(
@@ -111,8 +119,7 @@ def deconvNd(
111119
# If the weight is a TRTTensor, set it as an input of the layer
112120
if isinstance(weight, TRTTensor):
113121
deconv_layer.set_input(1, weight)
114-
elif weight is not None:
115-
ctx.weight_refit_map[f"{deconv_layer.name} KERNEL"] = torch_weight
122+
116123
# If the bias is a TRTTensor, set it as an input of the layer
117124
if isinstance(bias, TRTTensor):
118125
deconv_layer.set_input(2, bias)

0 commit comments

Comments
 (0)