Skip to content

Gating on Datapipes - prototype #861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions physicsnemo/datapipes/cae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .domino_datapipe import DoMINODataPipe
from .mesh_datapipe import MeshDatapipe
from physicsnemo.utils.version_check import check_package_installed

# This could and should be reoragnized into meta-data level requirements on pipelines.
domino_datapipe_requirements = ["warp", "scipy"]
mesh_datapipe_requirements = ["vtk"]


def __getattr__(name):
"""
This file is meant to provide information
"""
if name == "DoMINODataPipe":
missing = [
p for p in domino_datapipe_requirements if not check_package_installed(p)
]

if missing:
raise ImportError(
f"Cannot import DoMINODataPipe: Missing required packages: {', '.join(missing)}"
)
else:
from .domino_datapipe import DoMINODataPipe

return DoMINODataPipe

if name == "MeshDatapipe":
missing = [
p for p in mesh_datapipe_requirements if not check_package_installed(p)
]
if missing:
raise ImportError(
f"Cannot import MeshDatapipe: Missing required packages: {', '.join(missing)}"
)
else:
from .mesh_datapipe import MeshDatapipe

return MeshDatapipe

raise AttributeError(
f"module 'physicsnemo.datapipes.cae' has no attribute '{name}'"
)
71 changes: 69 additions & 2 deletions physicsnemo/utils/version_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,32 @@
}


def check_package_installed(package_name: str, error_msg: Optional[str] = None) -> bool:
"""
Check if a package is installed.

Args:
package_name: Name of the package to check
error_msg: Optional custom error message

Returns:
True if package is installed

Raises:
ImportError: If package is not installed
"""
try:
importlib.import_module(package_name)
return True
except ImportError:
msg = (
error_msg
or f"Package {package_name} is required but not installed, or broken."
)
print(msg)
return False


def check_min_version(
package_name: str, min_version: str, error_msg: Optional[str] = None
) -> bool:
Expand Down Expand Up @@ -86,14 +112,52 @@ def check_module_requirements(module_path: str) -> None:
check_min_version(package, min_version)


def require_version(package_name: str, min_version: str):
def require_installed_package(package_name: str, error_msg: Optional[str] = None):
"""
Decorator that prevents a function from being called unless the
specified package is installed.

Args:
package_name: Name of the package to check
error_msg: Optional custom error message

Returns:
Decorator function that checks package installation before execution

Example:
@require_installed_package("torch")
def my_function():
# This function will only execute if torch is installed
pass
"""

def decorator(func):
import functools

@functools.wraps(func)
def wrapper(*args, **kwargs):
# Verify the package is installed before executing
check_package_installed(package_name, error_msg)

# If we get here, package is installed
return func(*args, **kwargs)

return wrapper

return decorator


def require_version(
package_name: str, min_version: str, error_msg: Optional[str] = None
):
"""
Decorator that prevents a function from being called unless the
specified package meets the minimum version requirement.

Args:
package_name: Name of the package to check
min_version: Minimum required version string (e.g. '2.3')
error_msg: Optional custom error message

Returns:
Decorator function that checks version requirement before execution
Expand All @@ -110,8 +174,11 @@ def decorator(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
# Verify the package is installed
check_package_installed(package_name)

# Verify the package meets minimum version before executing
check_min_version(package_name, min_version)
check_min_version(package_name, min_version, error_msg)

# If we get here, version check passed
return func(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
description = "A deep learning framework for AI-driven multi-physics systems"
readme = "README.md"
requires-python = ">=3.10"
license = "Apache-2.0"
license = {file="LICENSE.txt"}
dependencies = [
"certifi>=2023.7.22",
"fsspec>=2023.1.0",
Expand Down