Skip to content

Commit 747e38e

Browse files
committed
checks for windows where NCCL backend is not supported
1 parent 61c93b2 commit 747e38e

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

py/torch_tensorrt/dynamo/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]:
937937
yield plugin_lib_path
938938

939939

940-
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
940+
def load_and_initialize_trtllm_plugin(plugin_lib_path: str, platform: str) -> bool:
941941
"""
942942
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
943943
@@ -947,6 +947,9 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
947947
Returns:
948948
bool: True if successful, False otherwise.
949949
"""
950+
if "windows" in platform:
951+
logger.info("NCCL backend is not supported on Windows")
952+
return False
950953
try:
951954
handle = ctypes.CDLL(plugin_lib_path)
952955
logger.info(f"Successfully loaded plugin library: {plugin_lib_path}")
@@ -1002,8 +1005,10 @@ def load_tensorrt_llm() -> bool:
10021005
bool: True if the plugin was successfully loaded and initialized, False otherwise.
10031006
"""
10041007
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
1008+
platform = Platform.current_platform()
1009+
platform = str(platform).lower()
10051010
if plugin_lib_path:
1006-
return load_and_initialize_trtllm_plugin(plugin_lib_path)
1011+
return load_and_initialize_trtllm_plugin(plugin_lib_path, platform)
10071012
else:
10081013
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
10091014
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
@@ -1017,10 +1022,7 @@ def load_tensorrt_llm() -> bool:
10171022
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
10181023
)
10191024
return False
1020-
else:
1021-
platform = Platform.current_platform()
1022-
platform = str(platform).lower()
10231025

10241026
with download_plugin_lib_path(platform) as plugin_lib_path:
1025-
return load_and_initialize_trtllm_plugin(plugin_lib_path)
1027+
return load_and_initialize_trtllm_plugin(plugin_lib_path, platform)
10261028
return False

tests/py/dynamo/distributed/test_nccl_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from distributed_utils import set_environment_variables_pytest
77
from parameterized import parameterized
88
from torch.testing._internal.common_utils import run_tests
9+
from torch_tensorrt._enums import Platform
910

1011
set_environment_variables_pytest()
1112
dist.init_process_group(backend="nccl", init_method="env://")
@@ -15,7 +16,12 @@
1516

1617
from conversion.harness import DispatchTestCase
1718

19+
platform_str = str(Platform.current_platform()).lower()
1820

21+
22+
@unittest.skipIf(
23+
"win" in platform_str, "Skipped on Windows: NCCL backend is not supported."
24+
)
1925
class TestGatherNcclOpsConverter(DispatchTestCase):
2026
@parameterized.expand([8])
2127
def test_nccl_ops(self, linear_layer_dim):

0 commit comments

Comments
 (0)