From ef80f44170fb5bebd44b811d2e572d9645a2c5e9 Mon Sep 17 00:00:00 2001 From: Corey Adams Date: Wed, 30 Apr 2025 07:51:48 -0700 Subject: [PATCH] This commit enables __init__ level gating of imports. - check_package_installed will return a simple bool if a package is installed or not. - use check_min_version to require a minimum installed version. - the __init__.py file in physicsnemo/datapipes/cae now wraps both datapipes in a function. - python 3.7 enabled __getattr__ at module level. (https://peps.python.org/pep-0562/) - this function now performs requirement checking on these imports before returning. - If an import fails, it will print an informative message about it in the ImportError. --- physicsnemo/datapipes/cae/__init__.py | 43 +++++++++++++++- physicsnemo/utils/version_check.py | 71 ++++++++++++++++++++++++++- pyproject.toml | 2 +- 3 files changed, 111 insertions(+), 5 deletions(-) diff --git a/physicsnemo/datapipes/cae/__init__.py b/physicsnemo/datapipes/cae/__init__.py index c0d17ff723..659ec7a30e 100644 --- a/physicsnemo/datapipes/cae/__init__.py +++ b/physicsnemo/datapipes/cae/__init__.py @@ -14,5 +14,44 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .domino_datapipe import DoMINODataPipe -from .mesh_datapipe import MeshDatapipe +from physicsnemo.utils.version_check import check_package_installed + +# This could and should be reoragnized into meta-data level requirements on pipelines. +domino_datapipe_requirements = ["warp", "scipy"] +mesh_datapipe_requirements = ["vtk"] + + +def __getattr__(name): + """ + This file is meant to provide information + """ + if name == "DoMINODataPipe": + missing = [ + p for p in domino_datapipe_requirements if not check_package_installed(p) + ] + + if missing: + raise ImportError( + f"Cannot import DoMINODataPipe: Missing required packages: {', '.join(missing)}" + ) + else: + from .domino_datapipe import DoMINODataPipe + + return DoMINODataPipe + + if name == "MeshDatapipe": + missing = [ + p for p in mesh_datapipe_requirements if not check_package_installed(p) + ] + if missing: + raise ImportError( + f"Cannot import MeshDatapipe: Missing required packages: {', '.join(missing)}" + ) + else: + from .mesh_datapipe import MeshDatapipe + + return MeshDatapipe + + raise AttributeError( + f"module 'physicsnemo.datapipes.cae' has no attribute '{name}'" + ) diff --git a/physicsnemo/utils/version_check.py b/physicsnemo/utils/version_check.py index 9964aa6e7a..bf3f8df93a 100644 --- a/physicsnemo/utils/version_check.py +++ b/physicsnemo/utils/version_check.py @@ -36,6 +36,32 @@ } +def check_package_installed(package_name: str, error_msg: Optional[str] = None) -> bool: + """ + Check if a package is installed. + + Args: + package_name: Name of the package to check + error_msg: Optional custom error message + + Returns: + True if package is installed + + Raises: + ImportError: If package is not installed + """ + try: + importlib.import_module(package_name) + return True + except ImportError: + msg = ( + error_msg + or f"Package {package_name} is required but not installed, or broken." + ) + print(msg) + return False + + def check_min_version( package_name: str, min_version: str, error_msg: Optional[str] = None ) -> bool: @@ -86,7 +112,44 @@ def check_module_requirements(module_path: str) -> None: check_min_version(package, min_version) -def require_version(package_name: str, min_version: str): +def require_installed_package(package_name: str, error_msg: Optional[str] = None): + """ + Decorator that prevents a function from being called unless the + specified package is installed. + + Args: + package_name: Name of the package to check + error_msg: Optional custom error message + + Returns: + Decorator function that checks package installation before execution + + Example: + @require_installed_package("torch") + def my_function(): + # This function will only execute if torch is installed + pass + """ + + def decorator(func): + import functools + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Verify the package is installed before executing + check_package_installed(package_name, error_msg) + + # If we get here, package is installed + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def require_version( + package_name: str, min_version: str, error_msg: Optional[str] = None +): """ Decorator that prevents a function from being called unless the specified package meets the minimum version requirement. @@ -94,6 +157,7 @@ def require_version(package_name: str, min_version: str): Args: package_name: Name of the package to check min_version: Minimum required version string (e.g. '2.3') + error_msg: Optional custom error message Returns: Decorator function that checks version requirement before execution @@ -110,8 +174,11 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): + # Verify the package is installed + check_package_installed(package_name) + # Verify the package meets minimum version before executing - check_min_version(package_name, min_version) + check_min_version(package_name, min_version, error_msg) # If we get here, version check passed return func(*args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index d86393f4e9..6765752387 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ description = "A deep learning framework for AI-driven multi-physics systems" readme = "README.md" requires-python = ">=3.10" -license = "Apache-2.0" +license = {file="LICENSE.txt"} dependencies = [ "certifi>=2023.7.22", "fsspec>=2023.1.0",