Skip to content

Commit cf064c5

Browse files
committed
Initial attempt
1 parent 1c00f0f commit cf064c5

File tree

2 files changed

+13
-19
lines changed

2 files changed

+13
-19
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def _construct_trt_network_def(self) -> None:
400400
@staticmethod
401401
def find_weight(
402402
weight_name: str,
403-
np_map: dict[str, Any],
403+
weight_refit_map: dict[str, Any],
404404
state_dict: dict[str, Any],
405405
device: torch.device,
406406
) -> str:
@@ -413,7 +413,7 @@ def find_weight(
413413
state_dict: state of the graph module
414414
"""
415415
with unset_fake_temporarily():
416-
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
416+
network_weight = weight_refit_map[weight_name].to(device)
417417
for sd_w_name, sd_weight in state_dict.items():
418418
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
419419
del state_dict[sd_w_name]
@@ -427,8 +427,8 @@ def check_weight_equal(
427427
device: torch.device,
428428
) -> Any:
429429
with unset_fake_temporarily():
430-
if not isinstance(network_weight, torch.Tensor):
431-
network_weight = torch.from_numpy(network_weight).to(device)
430+
if network_weight.device != device:
431+
network_weight = network_weight.to(device)
432432
try:
433433
return sd_weight.shape == network_weight.shape and torch.all(
434434
torch.abs(sd_weight - network_weight) < 0.01
@@ -497,8 +497,8 @@ def _save_weight_mapping(self) -> None:
497497
self.module.to(torch_device)
498498
sd = self.module.state_dict()
499499
weight_name_map: dict[str, Any] = {}
500-
np_map = self.ctx.weight_refit_map
501-
constant_mapping = {k: v for k, v in np_map.items() if v.size == 1}
500+
weight_refit_map = self.ctx.weight_refit_map
501+
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
502502
net = self.ctx.net
503503
for i in range(net.num_layers):
504504
layer = net[i]
@@ -540,7 +540,7 @@ def _save_weight_mapping(self) -> None:
540540
else:
541541
sd_weight_name = f"{sd_weight_name}.{torch_attr}"
542542

543-
if engine_weight_name in np_map:
543+
if engine_weight_name in weight_refit_map:
544544
weight_name_map[engine_weight_name] = sd_weight_name
545545

546546
# Stage 2: Value mapping
@@ -549,10 +549,10 @@ def _save_weight_mapping(self) -> None:
549549
# There is no direct connection in batch_norm layer. So skip it
550550
pass
551551
elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal(
552-
sd[sd_weight_name], np_map[engine_weight_name], torch_device
552+
sd[sd_weight_name], weight_refit_map[engine_weight_name], torch_device
553553
):
554554
weight_name_map[engine_weight_name] = TRTInterpreter.find_weight(
555-
engine_weight_name, np_map, sd, torch_device
555+
engine_weight_name, weight_refit_map, sd, torch_device
556556
)
557557
if (
558558
weight_name_map[engine_weight_name] != ""
@@ -563,12 +563,13 @@ def _save_weight_mapping(self) -> None:
563563

564564
weight_name_map[engine_weight_name] = [
565565
weight_name_map[engine_weight_name],
566-
np_map[engine_weight_name].dtype,
566+
weight_refit_map[engine_weight_name].dtype,
567567
]
568568

569569
weight_name_map["constant_mapping"] = constant_mapping
570570
self.weight_name_map = weight_name_map
571-
del np_map, sd
571+
572+
del weight_refit_map, sd
572573
gc.collect()
573574
torch.cuda.empty_cache()
574575

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -418,15 +418,8 @@ def create_constant(
418418
ctx.cpu_weights_reference_holder[name + " FP4_CONSTANT"] = torch_value
419419
return constant.get_output(0)
420420

421-
# TODO: Refit map uses numpy arrays. Remove this once refit is updated to use torch.Tensor
422-
if torch_value.dtype == torch.bfloat16:
423-
torch_value_fp32 = torch_value.to(torch.float32)
424-
numpy_value = torch_value_fp32.numpy()
425-
else:
426-
numpy_value = torch_value.numpy()
427-
428421
# Used for refit
429-
ctx.weight_refit_map[name + " CONSTANT"] = numpy_value.reshape(-1)
422+
ctx.weight_refit_map[name + " CONSTANT"] = torch_value
430423

431424
# This is a buffer to hold the torch.Tensor so that they are alive during the course of TRT compilation.
432425
ctx.cpu_weights_reference_holder[name] = torch_value

0 commit comments

Comments
 (0)