@@ -841,6 +841,29 @@ def is_tegra_platform() -> bool:
841
841
return False
842
842
843
843
844
+ def is_platform_supported_for_trtllm (platform : str ) -> bool :
845
+ """
846
+ Checks if the current platform supports TensorRT-LLM plugins for NCCL backend
847
+ Returns:
848
+ bool: True if the platform supports TensorRT-LLM plugins for NCCL backend, False otherwise.
849
+ Note:
850
+ TensorRT-LLM plugins for NCCL backend are not supported on:
851
+ - Windows platforms
852
+ - Jetson devices (aarch64 architecture)
853
+ """
854
+ if "windows" in platform :
855
+ logger .info (
856
+ "TensorRT-LLM plugins for NCCL backend are not supported on Windows"
857
+ )
858
+ return False
859
+ if "aarch64" in platform :
860
+ logger .info (
861
+ "TensorRT-LLM plugins for NCCL backend are not supported on Jetson devices (aarch64)"
862
+ )
863
+ return False
864
+ return True
865
+
866
+
844
867
@contextmanager
845
868
def download_plugin_lib_path (platform : str ) -> Iterator [str ]:
846
869
"""
@@ -891,6 +914,7 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]:
891
914
if "linux" in platform :
892
915
lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
893
916
else :
917
+ # This condition is never met though
894
918
lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
895
919
896
920
with tempfile .TemporaryDirectory () as tmpdir :
@@ -923,7 +947,7 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]:
923
947
yield plugin_lib_path
924
948
925
949
926
- def load_and_initialize_trtllm_plugin (plugin_lib_path : str , platform : str ) -> bool :
950
+ def load_and_initialize_trtllm_plugin (plugin_lib_path : str ) -> bool :
927
951
"""
928
952
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
929
953
@@ -933,9 +957,6 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str, platform: str) -> bo
933
957
Returns:
934
958
bool: True if successful, False otherwise.
935
959
"""
936
- if "windows" in platform :
937
- logger .info ("NCCL backend is not supported on Windows" )
938
- return False
939
960
try :
940
961
handle = ctypes .CDLL (plugin_lib_path )
941
962
logger .info (f"Successfully loaded plugin library: { plugin_lib_path } " )
@@ -981,7 +1002,7 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str, platform: str) -> bo
981
1002
return False
982
1003
983
1004
984
- def load_tensorrt_llm () -> bool :
1005
+ def load_tensorrt_llm_for_nccl () -> bool :
985
1006
"""
986
1007
Attempts to load the TensorRT-LLM plugin and initialize it.
987
1008
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
@@ -990,11 +1011,15 @@ def load_tensorrt_llm() -> bool:
990
1011
Returns:
991
1012
bool: True if the plugin was successfully loaded and initialized, False otherwise.
992
1013
"""
993
- plugin_lib_path = os . environ . get ( "TRTLLM_PLUGINS_PATH" )
1014
+ # Check platform compatibility first
994
1015
platform = Platform .current_platform ()
995
1016
platform = str (platform ).lower ()
1017
+ if not is_platform_supported_for_trtllm (platform ):
1018
+ return False
1019
+ plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
1020
+
996
1021
if plugin_lib_path :
997
- return load_and_initialize_trtllm_plugin (plugin_lib_path , platform )
1022
+ return load_and_initialize_trtllm_plugin (plugin_lib_path )
998
1023
else :
999
1024
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
1000
1025
use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
@@ -1010,5 +1035,5 @@ def load_tensorrt_llm() -> bool:
1010
1035
return False
1011
1036
1012
1037
with download_plugin_lib_path (platform ) as plugin_lib_path :
1013
- return load_and_initialize_trtllm_plugin (plugin_lib_path , platform )
1038
+ return load_and_initialize_trtllm_plugin (plugin_lib_path )
1014
1039
return False
0 commit comments