@@ -400,7 +400,7 @@ def _construct_trt_network_def(self) -> None:
400
400
@staticmethod
401
401
def find_weight (
402
402
weight_name : str ,
403
- np_map : dict [str , Any ],
403
+ weight_refit_map : dict [str , Any ],
404
404
state_dict : dict [str , Any ],
405
405
device : torch .device ,
406
406
) -> str :
@@ -413,7 +413,7 @@ def find_weight(
413
413
state_dict: state of the graph module
414
414
"""
415
415
with unset_fake_temporarily ():
416
- network_weight = torch . from_numpy ( np_map [weight_name ]) .to (device )
416
+ network_weight = weight_refit_map [weight_name ].to (device )
417
417
for sd_w_name , sd_weight in state_dict .items ():
418
418
if TRTInterpreter .check_weight_equal (sd_weight , network_weight , device ):
419
419
del state_dict [sd_w_name ]
@@ -427,8 +427,8 @@ def check_weight_equal(
427
427
device : torch .device ,
428
428
) -> Any :
429
429
with unset_fake_temporarily ():
430
- if not isinstance ( network_weight , torch . Tensor ) :
431
- network_weight = torch . from_numpy ( network_weight ) .to (device )
430
+ if network_weight . device != device :
431
+ network_weight = network_weight .to (device )
432
432
try :
433
433
return sd_weight .shape == network_weight .shape and torch .all (
434
434
torch .abs (sd_weight - network_weight ) < 0.01
@@ -497,8 +497,8 @@ def _save_weight_mapping(self) -> None:
497
497
self .module .to (torch_device )
498
498
sd = self .module .state_dict ()
499
499
weight_name_map : dict [str , Any ] = {}
500
- np_map = self .ctx .weight_refit_map
501
- constant_mapping = {k : v for k , v in np_map .items () if v .size == 1 }
500
+ weight_refit_map = self .ctx .weight_refit_map
501
+ constant_mapping = {k : v for k , v in weight_refit_map .items () if v .size == 1 }
502
502
net = self .ctx .net
503
503
for i in range (net .num_layers ):
504
504
layer = net [i ]
@@ -540,7 +540,7 @@ def _save_weight_mapping(self) -> None:
540
540
else :
541
541
sd_weight_name = f"{ sd_weight_name } .{ torch_attr } "
542
542
543
- if engine_weight_name in np_map :
543
+ if engine_weight_name in weight_refit_map :
544
544
weight_name_map [engine_weight_name ] = sd_weight_name
545
545
546
546
# Stage 2: Value mapping
@@ -549,10 +549,10 @@ def _save_weight_mapping(self) -> None:
549
549
# There is no direct connection in batch_norm layer. So skip it
550
550
pass
551
551
elif sd_weight_name not in sd or not TRTInterpreter .check_weight_equal (
552
- sd [sd_weight_name ], np_map [engine_weight_name ], torch_device
552
+ sd [sd_weight_name ], weight_refit_map [engine_weight_name ], torch_device
553
553
):
554
554
weight_name_map [engine_weight_name ] = TRTInterpreter .find_weight (
555
- engine_weight_name , np_map , sd , torch_device
555
+ engine_weight_name , weight_refit_map , sd , torch_device
556
556
)
557
557
if (
558
558
weight_name_map [engine_weight_name ] != ""
@@ -563,12 +563,13 @@ def _save_weight_mapping(self) -> None:
563
563
564
564
weight_name_map [engine_weight_name ] = [
565
565
weight_name_map [engine_weight_name ],
566
- np_map [engine_weight_name ].dtype ,
566
+ weight_refit_map [engine_weight_name ].dtype ,
567
567
]
568
568
569
569
weight_name_map ["constant_mapping" ] = constant_mapping
570
570
self .weight_name_map = weight_name_map
571
- del np_map , sd
571
+
572
+ del weight_refit_map , sd
572
573
gc .collect ()
573
574
torch .cuda .empty_cache ()
574
575
0 commit comments