@@ -321,7 +321,15 @@ def cast_int_or_float_to_bool(
321
321
322
322
323
323
def to_trt_weights (
324
- value : Any , target_quantized_type : Optional [trt .DataType ] = None
324
+ value : Any ,
325
+ record_weight : bool = False ,
326
+ name : Optional [str ] = None ,
327
+ ctx : Optional [ConversionContext ] = None ,
328
+ target : Optional [Union [Target , str ]] = None ,
329
+ layer_type_name : Optional [str ] = None ,
330
+ weight_type_name : Optional [str ] = None ,
331
+ source_ir : Optional [SourceIR ] = None ,
332
+ target_quantized_type : Optional [trt .DataType ] = None ,
325
333
) -> trt .Weights :
326
334
"""
327
335
Convert a PyTorch tensor or NumPy array to TensorRT weights.
@@ -336,6 +344,35 @@ def to_trt_weights(
336
344
- Input tensors are made contiguous before conversion
337
345
- Data type is preserved from the original tensor/array
338
346
"""
347
+ if record_weight :
348
+ assert name is not None , "name must be provided if record_weight is True"
349
+ assert ctx is not None , "ctx must be provided if record_weight is True"
350
+ assert target is not None , "target must be provided if record_weight is True"
351
+ assert (
352
+ layer_type_name is not None
353
+ ), "layer_type_name must be provided if record_weight is True"
354
+ assert (
355
+ weight_type_name is not None
356
+ ), "weight_type_name must be provided if record_weight is True"
357
+
358
+ supported_layer_types = ["CONVOLUTION" , "DECONVOLUTION" ]
359
+ supported_weight_types = ["KERNEL" ]
360
+ assert (
361
+ layer_type_name in supported_layer_types
362
+ ), f"Unsupported layer type: { layer_type_name } . Please add the layer type to this function to enable refitting."
363
+ assert (
364
+ weight_type_name in supported_weight_types
365
+ ), f"Unsupported weight type: { weight_type_name } . Please add the weight type to this function to enable refitting."
366
+ source_ir = source_ir if source_ir is not None else SourceIR .UNKNOWN
367
+ target_name = (
368
+ f"{ source_ir } _ops.{ target } "
369
+ if isinstance (target , str )
370
+ else f"{ source_ir } _ops.{ target .__name__ } "
371
+ )
372
+
373
+ name = f"[{ layer_type_name } ]-[{ target_name } ]-[{ name } ] { weight_type_name } "
374
+ record_weight_in_ctx (ctx , name , value )
375
+
339
376
if isinstance (value , torch .Tensor ):
340
377
# Tensor must be contiguous before conversion
341
378
value = value .contiguous ()
@@ -351,6 +388,15 @@ def to_trt_weights(
351
388
)
352
389
353
390
391
+ def record_weight_in_ctx (
392
+ ctx : ConversionContext ,
393
+ name : str ,
394
+ value : torch .Tensor ,
395
+ ) -> None :
396
+ ctx .weight_refit_map [name ] = value
397
+ ctx .cpu_weights_reference_holder [name ] = value
398
+
399
+
354
400
def create_constant (
355
401
ctx : ConversionContext ,
356
402
value : Union [int , float , bool , np .ndarray , torch .Tensor ],
@@ -415,17 +461,14 @@ def create_constant(
415
461
weights ,
416
462
)
417
463
constant .name = name
418
- ctx . cpu_weights_reference_holder [ name + " FP4_CONSTANT" ] = torch_value
464
+ record_weight_in_ctx ( ctx , name + " FP4_CONSTANT" , torch_value )
419
465
return constant .get_output (0 )
420
466
421
- # Used for refit
422
- ctx .weight_refit_map [name + " CONSTANT" ] = torch_value
423
-
424
- # This is a buffer to hold the torch.Tensor so that they are alive during the course of TRT compilation.
425
- ctx .cpu_weights_reference_holder [name ] = torch_value
467
+ # Record the weight in ctx for refit and cpu memory reference
468
+ record_weight_in_ctx (ctx , name + " CONSTANT" , torch_value )
426
469
427
470
# Convert the torch.Tensor to a trt.Weights object
428
- trt_weights = to_trt_weights (torch_value )
471
+ trt_weights = to_trt_weights (torch_value , record_weight = False )
429
472
constant = ctx .net .add_constant (
430
473
shape ,
431
474
trt_weights ,
0 commit comments