Skip to content

Commit dbfd7ee

Browse files
committed
Keeping the extracted and deleting download file, restructuring test
1 parent 9cb3cab commit dbfd7ee

File tree

4 files changed

+136
-159
lines changed

4 files changed

+136
-159
lines changed

py/torch_tensorrt/dynamo/utils.py

Lines changed: 79 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,13 @@
88
import tempfile
99
import urllib.request
1010
import warnings
11-
from contextlib import contextmanager
1211
from dataclasses import fields, replace
1312
from enum import Enum
1413
from pathlib import Path
1514
from typing import (
1615
Any,
1716
Callable,
1817
Dict,
19-
Iterator,
2018
List,
2119
Optional,
2220
Sequence,
@@ -864,40 +862,52 @@ def is_platform_supported_for_trtllm(platform: str) -> bool:
864862
return True
865863

866864

867-
@contextmanager
868-
def download_plugin_lib_path(platform: str) -> Iterator[str]:
869-
"""
870-
Downloads (if needed) and extracts the TensorRT-LLM plugin wheel for the specified platform,
871-
then yields the path to the extracted shared library (.so or .dll).
865+
def _cache_root() -> Path:
866+
username = getpass.getuser()
867+
return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}"
872868

873-
The wheel file is cached in a user-specific temporary directory to avoid repeated downloads.
874-
Extraction happens in a temporary directory that is cleaned up after use.
875869

876-
Args:
877-
platform (str): The platform identifier string (e.g., 'linux_x86_64') to select the correct wheel.
870+
def _extracted_dir_trtllm(platform: str) -> Path:
871+
return _cache_root() / "trtllm" / f"{__tensorrt_llm_version__}_{platform}"
878872

879-
Yields:
880-
str: The full path to the extracted TensorRT-LLM shared library file.
881873

882-
Raises:
883-
ImportError: If the 'zipfile' module is not available.
884-
RuntimeError: If the wheel file is missing, corrupted, or extraction fails.
874+
def download_and_get_plugin_lib_path(platform: str) -> Optional[str]:
885875
"""
886-
plugin_lib_path = None
887-
username = getpass.getuser()
888-
torchtrt_cache_dir = Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}"
889-
torchtrt_cache_dir.mkdir(parents=True, exist_ok=True)
890-
file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-{_WHL_CPYTHON_VERSION}-{platform}.whl"
891-
torchtrt_cache_trtllm_whl = torchtrt_cache_dir / file_name
892-
downloaded_file_path = torchtrt_cache_trtllm_whl
893-
894-
if not torchtrt_cache_trtllm_whl.exists():
895-
# Downloading TRT-LLM lib
876+
Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.
877+
878+
Args:
879+
platform (str): Platform identifier (e.g., 'linux_x86_64')
880+
881+
Returns:
882+
Optional[str]: Path to shared library or None if operation fails.
883+
"""
884+
wheel_filename = (
885+
f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-"
886+
f"{_WHL_CPYTHON_VERSION}-{platform}.whl"
887+
)
888+
wheel_path = _cache_root() / wheel_filename
889+
extract_dir = _extracted_dir_trtllm(platform)
890+
# else will never be met though
891+
lib_filename = (
892+
"libnvinfer_plugin_tensorrt_llm.so"
893+
if "linux" in platform
894+
else "libnvinfer_plugin_tensorrt_llm.dll"
895+
)
896+
# eg: /tmp/torch_tensorrt_<username>/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so
897+
plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename
898+
899+
if plugin_lib_path.exists():
900+
return str(plugin_lib_path)
901+
902+
wheel_path.parent.mkdir(parents=True, exist_ok=True)
903+
extract_dir.mkdir(parents=True, exist_ok=True)
904+
905+
if not wheel_path.exists():
896906
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
897-
download_url = base_url + file_name
907+
download_url = base_url + wheel_filename
898908
try:
899-
logger.debug(f"Downloading {download_url} ...")
900-
urllib.request.urlretrieve(download_url, downloaded_file_path)
909+
logger.debug("Downloading %s ...", download_url)
910+
urllib.request.urlretrieve(download_url, wheel_path)
901911
logger.debug("Download succeeded and TRT-LLM wheel is now present")
902912
except urllib.error.HTTPError as e:
903913
logger.error(
@@ -910,41 +920,45 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]:
910920
except OSError as e:
911921
logger.error(f"Local file write error: {e}")
912922

913-
# Proceeding with the unzip of the wheel file in tmpdir
914-
if "linux" in platform:
915-
lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
916-
else:
917-
# This condition is never met though
918-
lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
923+
try:
924+
import zipfile
925+
except ImportError as e:
926+
raise ImportError(
927+
"zipfile module is required but not found. Please install zipfile"
928+
)
929+
try:
930+
with zipfile.ZipFile(wheel_path) as zip_ref:
931+
zip_ref.extractall(extract_dir)
932+
logger.debug(f"Extracted wheel to {extract_dir}")
933+
except FileNotFoundError as e:
934+
# This should capture the errors in the download failure above
935+
logger.error(f"Wheel file not found at {wheel_path}: {e}")
936+
raise RuntimeError(
937+
f"Failed to find downloaded wheel file at {wheel_path}"
938+
) from e
939+
except zipfile.BadZipFile as e:
940+
logger.error(f"Invalid or corrupted wheel file: {e}")
941+
raise RuntimeError(
942+
"Downloaded wheel file is corrupted or not a valid zip archive"
943+
) from e
944+
except Exception as e:
945+
logger.error(f"Unexpected error while extracting wheel: {e}")
946+
raise RuntimeError(
947+
"Unexpected error during extraction of TensorRT-LLM wheel"
948+
) from e
919949

920-
with tempfile.TemporaryDirectory() as tmpdir:
921-
try:
922-
import zipfile
923-
except ImportError:
924-
raise ImportError(
925-
"zipfile module is required but not found. Please install zipfile"
926-
)
927-
try:
928-
with zipfile.ZipFile(downloaded_file_path, "r") as zip_ref:
929-
zip_ref.extractall(tmpdir) # Extract to a folder named 'tensorrt_llm'
930-
except FileNotFoundError as e:
931-
# This should capture the errors in the download failure above
932-
logger.error(f"Wheel file not found at {downloaded_file_path}: {e}")
933-
raise RuntimeError(
934-
f"Failed to find downloaded wheel file at {downloaded_file_path}"
935-
) from e
936-
except zipfile.BadZipFile as e:
937-
logger.error(f"Invalid or corrupted wheel file: {e}")
938-
raise RuntimeError(
939-
"Downloaded wheel file is corrupted or not a valid zip archive"
940-
) from e
941-
except Exception as e:
942-
logger.error(f"Unexpected error while extracting wheel: {e}")
943-
raise RuntimeError(
944-
"Unexpected error during extraction of TensorRT-LLM wheel"
945-
) from e
946-
plugin_lib_path = os.path.join(tmpdir, "tensorrt_llm/libs", lib_filename)
947-
yield plugin_lib_path
950+
try:
951+
wheel_path.unlink(missing_ok=True)
952+
logger.debug(f"Deleted wheel file: {wheel_path}")
953+
except Exception as e:
954+
logger.warning(f"Could not delete wheel file {wheel_path}: {e}")
955+
if not plugin_lib_path.exists():
956+
logger.error(
957+
f"Plugin library not found at expected location: {plugin_lib_path}"
958+
)
959+
return None
960+
961+
return str(plugin_lib_path)
948962

949963

950964
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
@@ -1034,6 +1048,6 @@ def load_tensorrt_llm_for_nccl() -> bool:
10341048
)
10351049
return False
10361050

1037-
with download_plugin_lib_path(platform) as plugin_lib_path:
1038-
return load_and_initialize_trtllm_plugin(plugin_lib_path)
1051+
plugin_lib_path = download_and_get_plugin_lib_path(platform)
1052+
return load_and_initialize_trtllm_plugin(plugin_lib_path) # type: ignore[arg-type]
10391053
return False

tests/py/dynamo/distributed/distributed_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ def set_environment_variables_pytest():
1313
os.environ["RANK"] = str(0)
1414
os.environ["MASTER_ADDR"] = "127.0.0.1"
1515
os.environ["MASTER_PORT"] = str(29500)
16-
os.environ["USE_TRTLLM_PLUGINS"] = "1"
1716

1817

1918
def initialize_logger(rank, logger_file_name):

tests/py/dynamo/distributed/test_nccl_ops.py

Lines changed: 56 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,42 @@
44
import torch
55
import torch.distributed as dist
66
import torch.nn as nn
7+
from conversion.harness import DispatchTestCase
78
from distributed_utils import set_environment_variables_pytest
89
from parameterized import parameterized
910
from torch.testing._internal.common_utils import run_tests
1011
from torch_tensorrt._enums import Platform
1112

12-
set_environment_variables_pytest()
13-
dist.init_process_group(backend="nccl", init_method="env://")
14-
group = dist.new_group(ranks=[0])
15-
group_name = group.group_name
16-
world_size = 1
1713

18-
from conversion.harness import DispatchTestCase
14+
class DistributedGatherModel(nn.Module):
15+
def __init__(self, input_dim, world_size, group_name):
16+
super().__init__()
17+
self.fc = nn.Linear(input_dim, input_dim)
18+
self.world_size = world_size
19+
self.group_name = group_name
20+
21+
def forward(self, x):
22+
x = self.fc(x)
23+
gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor(
24+
x, self.world_size, self.group_name
25+
)
26+
return torch.ops._c10d_functional.wait_tensor(gathered_tensor)
27+
28+
29+
class DistributedReduceScatterModel(nn.Module):
30+
def __init__(self, input_dim, world_size, group_name):
31+
super().__init__()
32+
self.fc = nn.Linear(input_dim, input_dim)
33+
self.world_size = world_size
34+
self.group_name = group_name
35+
36+
def forward(self, x):
37+
x = self.fc(x)
38+
out = torch.ops._c10d_functional.reduce_scatter_tensor(
39+
x, "sum", self.world_size, self.group_name
40+
)
41+
return torch.ops._c10d_functional.wait_tensor(out)
42+
1943

2044
platform_str = str(Platform.current_platform()).lower()
2145

@@ -25,64 +49,49 @@ class TestGatherNcclOpsConverter(DispatchTestCase):
2549
"win" or "aarch64" in platform_str,
2650
"Skipped on Windows and Jetson: NCCL backend is not supported.",
2751
)
52+
@classmethod
53+
def setUpClass(cls):
54+
set_environment_variables_pytest()
55+
print("USE_TRTLLM_PLUGINS =", os.environ.get("USE_TRTLLM_PLUGINS"))
56+
cls.world_size = 1
57+
if not dist.is_initialized():
58+
dist.init_process_group(
59+
backend="nccl",
60+
init_method="env://",
61+
world_size=cls.world_size,
62+
rank=0, # or read from env
63+
)
64+
cls.group = dist.new_group(ranks=[0])
65+
cls.group_name = cls.group.group_name
66+
67+
@classmethod
68+
def tearDownClass(cls):
69+
if dist.is_initialized():
70+
dist.destroy_process_group()
71+
2872
@parameterized.expand([8])
2973
def test_nccl_ops_gather(self, linear_layer_dim):
30-
class DistributedGatherModel(nn.Module):
31-
def __init__(self, input_dim):
32-
super().__init__()
33-
self.fc = torch.nn.Linear(input_dim, input_dim)
34-
35-
def forward(self, x):
36-
x = self.fc(x)
37-
gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor(
38-
x, world_size, group_name
39-
)
40-
gathered_tensor = torch.ops._c10d_functional.wait_tensor(
41-
gathered_tensor
42-
)
43-
return gathered_tensor
44-
4574
inputs = [torch.randn(1, linear_layer_dim).to("cuda")]
4675
self.run_test(
47-
DistributedGatherModel(linear_layer_dim).cuda(),
76+
DistributedGatherModel(
77+
linear_layer_dim, self.world_size, self.group_name
78+
).cuda(),
4879
inputs,
4980
use_dynamo_tracer=True,
5081
enable_passes=True,
5182
)
5283

53-
@unittest.skipIf(
54-
"win" or "aarch64" in platform_str,
55-
"Skipped on Windows and Jetson: NCCL backend is not supported.",
56-
)
5784
@parameterized.expand([8])
5885
def test_nccl_ops_scatter(self, linear_layer_dim):
59-
60-
class DistributedReduceScatterModel(nn.Module):
61-
def __init__(self, input_dim):
62-
super().__init__()
63-
self.fc = torch.nn.Linear(input_dim, input_dim)
64-
65-
def forward(self, x):
66-
x = self.fc(x)
67-
scatter_reduce_tensor = (
68-
torch.ops._c10d_functional.reduce_scatter_tensor(
69-
x, "sum", world_size, group_name
70-
)
71-
)
72-
scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor(
73-
scatter_reduce_tensor
74-
)
75-
return scatter_reduce_tensor
76-
7786
inputs = [torch.zeros(1, linear_layer_dim).to("cuda")]
78-
7987
self.run_test(
80-
DistributedReduceScatterModel(linear_layer_dim).cuda(),
88+
DistributedReduceScatterModel(
89+
linear_layer_dim, self.world_size, self.group_name
90+
).cuda(),
8191
inputs,
8292
use_dynamo_tracer=True,
8393
enable_passes=True,
8494
)
85-
dist.destroy_process_group()
8695

8796

8897
if __name__ == "__main__":

tests/py/dynamo/distributed/test_nccl_ops.sh

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -70,51 +70,6 @@ ensure_pytest_installed(){
7070

7171
echo "Setting up the environment"
7272

73-
OS="$(uname -s)"
74-
ARCH="$(uname -m)"
75-
76-
77-
#getting the file name for TensorRT-LLM download
78-
if [[ "$OS" == "Linux" && "$ARCH" == "x86_64"]]; then
79-
FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_x86_64.whl"
80-
elif [[ "$OS" == "Linux" && "$ARCH" == "aarch64"]]; then
81-
FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_aarch64.whl"
82-
else:
83-
echo "Unsupported platform: OS=$OS ARCH=$ARCH
84-
exit 1
85-
fi
86-
87-
# Download the selected file
88-
URL="https://pypi.nvidia.com/tensorrt-llm/$FILE"
89-
echo "Downloading $FILE from $URL..."
90-
91-
#Installing wget
92-
ensure_installed wget
93-
94-
#Downloading the file
95-
filename=$(basename "$URL")
96-
if [ -f "$filename" ]; then
97-
echo "File already exists: $filename"
98-
else
99-
wget "$URL"
100-
fi
101-
echo "Download complete: $FILE"
102-
103-
UNZIP_DIR="tensorrt_llm_unzip"
104-
if [[ ! -d "$UNZIP_DIR" ]]; then
105-
echo "Creating directory: $UNZIP_DIR"
106-
mkdir -p "$UNZIP_DIR"
107-
echo "extracting $FILE to $UNZIP_DIR ..."
108-
#Installing unzip
109-
ensure_installed unzip
110-
#unzip the TensorRT-LLM package
111-
unzip -q "$FILE" -d "$UNZIP_DIR"
112-
echo "Unzip complete"
113-
fi
114-
115-
116-
export TRTLLM_PLUGINS_PATH="$(pwd)/${UNZIP_DIR}/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"
117-
echo ${TRTLLM_PLUGINS_PATH}
11873

11974
ensure_mpi_installed libmpich-dev
12075
ensure_mpi_installed libopenmpi-dev
@@ -123,7 +78,7 @@ run_tests() {
12378
cd ..
12479
export PYTHONPATH=$(pwd)
12580
echo "Running pytest on distributed/test_nccl_ops.py..."
126-
pytest distributed/test_nccl_ops.py
81+
USE_TRTLLM_PLUGINS=1 pytest distributed/test_nccl_ops.py
12782
}
12883

12984
run_mpi_tests(){

0 commit comments

Comments
 (0)