Skip to content

Commit 42d4862

Browse files
committed
adding checks for windows and jetson devices
1 parent 1e2148d commit 42d4862

File tree

3 files changed

+44
-13
lines changed

3 files changed

+44
-13
lines changed

py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
tensorrt_fused_nccl_all_gather_op,
1616
tensorrt_fused_nccl_reduce_scatter_op,
1717
)
18-
from torch_tensorrt.dynamo.utils import load_tensorrt_llm
18+
from torch_tensorrt.dynamo.utils import load_tensorrt_llm_for_nccl
1919

2020
_LOGGER: logging.Logger = logging.getLogger(__name__)
2121

22-
if load_tensorrt_llm():
22+
if load_tensorrt_llm_for_nccl():
2323

2424
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
2525
def fused_nccl_gather(

py/torch_tensorrt/dynamo/utils.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,29 @@ def is_tegra_platform() -> bool:
841841
return False
842842

843843

844+
def is_platform_supported_for_trtllm(platform: str) -> bool:
845+
"""
846+
Checks if the current platform supports TensorRT-LLM plugins for NCCL backend
847+
Returns:
848+
bool: True if the platform supports TensorRT-LLM plugins for NCCL backend, False otherwise.
849+
Note:
850+
TensorRT-LLM plugins for NCCL backend are not supported on:
851+
- Windows platforms
852+
- Jetson devices (aarch64 architecture)
853+
"""
854+
if "windows" in platform:
855+
logger.info(
856+
"TensorRT-LLM plugins for NCCL backend are not supported on Windows"
857+
)
858+
return False
859+
if "aarch64" in platform:
860+
logger.info(
861+
"TensorRT-LLM plugins for NCCL backend are not supported on Jetson devices (aarch64)"
862+
)
863+
return False
864+
return True
865+
866+
844867
@contextmanager
845868
def download_plugin_lib_path(platform: str) -> Iterator[str]:
846869
"""
@@ -891,6 +914,7 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]:
891914
if "linux" in platform:
892915
lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
893916
else:
917+
# This condition is never met though
894918
lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
895919

896920
with tempfile.TemporaryDirectory() as tmpdir:
@@ -923,7 +947,7 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]:
923947
yield plugin_lib_path
924948

925949

926-
def load_and_initialize_trtllm_plugin(plugin_lib_path: str, platform: str) -> bool:
950+
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
927951
"""
928952
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
929953
@@ -933,9 +957,6 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str, platform: str) -> bo
933957
Returns:
934958
bool: True if successful, False otherwise.
935959
"""
936-
if "windows" in platform:
937-
logger.info("NCCL backend is not supported on Windows")
938-
return False
939960
try:
940961
handle = ctypes.CDLL(plugin_lib_path)
941962
logger.info(f"Successfully loaded plugin library: {plugin_lib_path}")
@@ -981,7 +1002,7 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str, platform: str) -> bo
9811002
return False
9821003

9831004

984-
def load_tensorrt_llm() -> bool:
1005+
def load_tensorrt_llm_for_nccl() -> bool:
9851006
"""
9861007
Attempts to load the TensorRT-LLM plugin and initialize it.
9871008
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
@@ -990,11 +1011,15 @@ def load_tensorrt_llm() -> bool:
9901011
Returns:
9911012
bool: True if the plugin was successfully loaded and initialized, False otherwise.
9921013
"""
993-
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
1014+
# Check platform compatibility first
9941015
platform = Platform.current_platform()
9951016
platform = str(platform).lower()
1017+
if not is_platform_supported_for_trtllm(platform):
1018+
return False
1019+
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
1020+
9961021
if plugin_lib_path:
997-
return load_and_initialize_trtllm_plugin(plugin_lib_path, platform)
1022+
return load_and_initialize_trtllm_plugin(plugin_lib_path)
9981023
else:
9991024
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
10001025
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
@@ -1010,5 +1035,5 @@ def load_tensorrt_llm() -> bool:
10101035
return False
10111036

10121037
with download_plugin_lib_path(platform) as plugin_lib_path:
1013-
return load_and_initialize_trtllm_plugin(plugin_lib_path, platform)
1038+
return load_and_initialize_trtllm_plugin(plugin_lib_path)
10141039
return False

tests/py/dynamo/distributed/test_nccl_ops.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import unittest
23

34
import torch
45
import torch.distributed as dist
@@ -19,10 +20,11 @@
1920
platform_str = str(Platform.current_platform()).lower()
2021

2122

22-
@unittest.skipIf(
23-
"win" in platform_str, "Skipped on Windows: NCCL backend is not supported."
24-
)
2523
class TestGatherNcclOpsConverter(DispatchTestCase):
24+
@unittest.skipIf(
25+
"win" or "aarch64" in platform_str,
26+
"Skipped on Windows and Jetson: NCCL backend is not supported.",
27+
)
2628
@parameterized.expand([8])
2729
def test_nccl_ops(self, linear_layer_dim):
2830
class DistributedGatherModel(nn.Module):
@@ -48,6 +50,10 @@ def forward(self, x):
4850
enable_passes=True,
4951
)
5052

53+
@unittest.skipIf(
54+
"win" or "aarch64" in platform_str,
55+
"Skipped on Windows and Jetson: NCCL backend is not supported.",
56+
)
5157
@parameterized.expand([8])
5258
def test_nccl_ops_scatter(self, linear_layer_dim):
5359

0 commit comments

Comments
 (0)