diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c25779f..4d75210 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -88,11 +88,15 @@ jobs: python-version: "3.11" activate-environment: true - - name: get major numpy version - id: numpy-major + - name: get major.minor numpy version + id: numpy-version run: | - version=$(echo ${{ matrix.numpy-version }} | cut -c 1) - echo "::set-output name=version::$version" + version="${{ matrix.numpy-version }}" + major=$(echo "$version" | cut -d. -f1) + minor=$(echo "$version" | cut -d. -f2) + + echo "major=$major" >> $GITHUB_OUTPUT + echo "minor=$minor" >> $GITHUB_OUTPUT - name: install deps run: | @@ -101,10 +105,29 @@ jobs: # NOTE: `uv run --with=...` will be ignored by mypy (and `--isolated` does not help) - name: mypy - run: > - uv run --no-sync --active - mypy --tb --no-incremental --cache-dir=/dev/null - tests/integration/test_numpy${{ steps.numpy-major.outputs.version }}.pyi + run: | + major="${{ steps.numpy-version.outputs.major }}" + minor="${{ steps.numpy-version.outputs.minor }}" + + # Directory containing versioned test files + prefix="tests/integration" + files="" + + # Find all test files matching the current major version + for path in $(find "$prefix" -name "test_numpy${major}p*.pyi"); do + # Extract file name + fname=$(basename "$path") + # Parse the minor version from the filename + fminor=$(echo "$fname" | sed -E "s/test_numpy${major}p([0-9]+)\.pyi/\1/") + # Include files where minor version ≤ NumPy's minor + if [ "$fminor" -le "$minor" ]; then + files="$files $path" + fi + done + + uv run --no-sync --active \ + mypy --tb --no-incremental --cache-dir=/dev/null \ + $files # TODO: (based)pyright diff --git a/pyproject.toml b/pyproject.toml index 02118fe..2f9f0bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,11 +121,16 @@ ignore = [ "D107", # Missing docstring in __init__ "D203", # 1 blank line required before class docstring "D213", # Multi-line docstring summary should start at the second line + "D401", # First line should be in imperative mood "FBT", # flake8-boolean-trap "FIX", # flake8-fixme "ISC001", # Conflicts with formatter + "PYI041", # Use `float` instead of `int | float` ] +[tool.ruff.lint.pydocstyle] +convention = "google" + [tool.ruff.lint.pylint] allow-dunder-method-names = [ "__array_api_version__", diff --git a/src/array_api_typing/__init__.py b/src/array_api_typing/__init__.py index 3532743..aea21af 100644 --- a/src/array_api_typing/__init__.py +++ b/src/array_api_typing/__init__.py @@ -1,10 +1,12 @@ """Static typing support for the array API standard.""" __all__ = ( + "Array", "HasArrayNamespace", + "HasDType", "__version__", "__version_tuple__", ) -from ._namespace import HasArrayNamespace +from ._array import Array, HasArrayNamespace, HasDType from ._version import version as __version__, version_tuple as __version_tuple__ diff --git a/src/array_api_typing/_array.py b/src/array_api_typing/_array.py new file mode 100644 index 0000000..f4614f2 --- /dev/null +++ b/src/array_api_typing/_array.py @@ -0,0 +1,94 @@ +__all__ = ( + "Array", + "HasArrayNamespace", +) + +from types import ModuleType +from typing import Literal, Protocol +from typing_extensions import TypeVar + +NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType) +DTypeT_co = TypeVar("DTypeT_co", covariant=True) + + +class HasArrayNamespace(Protocol[NamespaceT_co]): + """Protocol for classes that have an `__array_namespace__` method. + + This `Protocol` is intended for use in static typing to ensure that an + object has an `__array_namespace__` method that returns a namespace for + array operations. This `Protocol` should not be used at runtime for type + checking or as a base class. + + Example: + >>> import array_api_typing as xpt + >>> + >>> class MyArray: + ... def __array_namespace__(self): + ... return object() + >>> + >>> x = MyArray() + >>> def has_array_namespace(x: xpt.HasArrayNamespace) -> bool: + ... return hasattr(x, "__array_namespace__") + >>> has_array_namespace(x) + True + + """ + + def __array_namespace__( + self, /, *, api_version: Literal["2021.12"] | None = None + ) -> NamespaceT_co: + """Returns an object that has all the array API functions on it. + + Args: + api_version: string representing the version of the array API + specification to be returned, in 'YYYY.MM' form, for example, + '2020.10'. If it is `None`, it should return the namespace + corresponding to latest version of the array API specification. + If the given version is invalid or not implemented for the given + module, an error should be raised. Default: `None`. + + Returns: + NamespaceT_co: An object representing the array API namespace. It + should have every top-level function defined in the + specification as an attribute. It may contain other public names + as well, but it is recommended to only include those names that + are part of the specification. + + """ + ... + + +class HasDType(Protocol[DTypeT_co]): + """Protocol for array classes that have a data type attribute.""" + + @property + def dtype(self, /) -> DTypeT_co: + """Data type of the array elements.""" + ... + + +class Array( + HasArrayNamespace[NamespaceT_co], + # ------ Attributes ------- + HasDType[DTypeT_co], + # ------------------------- + Protocol[DTypeT_co, NamespaceT_co], +): + """Array API specification for array object attributes and methods. + + The type is: ``Array[+DTypeT, +NamespaceT = ModuleType] = Array[DTypeT, + NamespaceT]`` where: + + - `DTypeT` is the data type of the array elements. + - `NamespaceT` is the type of the array namespace. It defaults to + `ModuleType`, which is the most common form of array namespace (e.g., + `numpy`, `cupy`, etc.). However, it can be any type, e.g. a + `types.SimpleNamespace`, to allow for wrapper libraries to + semi-dynamically define their own array namespaces based on the wrapped + array type. + + This type is intended for use in static typing to ensure that an object has + the attributes and methods defined in the array API specification. It should + not be used at runtime for type checking or as a base class. + + """ diff --git a/src/array_api_typing/_namespace.py b/src/array_api_typing/_namespace.py deleted file mode 100644 index 2074f4e..0000000 --- a/src/array_api_typing/_namespace.py +++ /dev/null @@ -1,30 +0,0 @@ -__all__ = ("HasArrayNamespace",) - -from types import ModuleType -from typing import Literal, Protocol -from typing_extensions import TypeVar - -T_co = TypeVar("T_co", covariant=True, default=ModuleType) - - -class HasArrayNamespace(Protocol[T_co]): - """Protocol for classes that have an `__array_namespace__` method. - - Example: - >>> import array_api_typing as xpt - >>> - >>> class MyArray: - ... def __array_namespace__(self): - ... return object() - >>> - >>> x = MyArray() - >>> def has_array_namespace(x: xpt.HasArrayNamespace) -> bool: - ... return hasattr(x, "__array_namespace__") - >>> has_array_namespace(x) - True - - """ - - def __array_namespace__( - self, /, *, api_version: Literal["2021.12"] | None = None - ) -> T_co: ... diff --git a/tests/integration/test_numpy1.pyi b/tests/integration/test_numpy1.pyi deleted file mode 100644 index 9367379..0000000 --- a/tests/integration/test_numpy1.pyi +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Any - -# requires numpy < 2 -import numpy.array_api as np - -import array_api_typing as xpt - -### -# Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`. - -arr = np.eye(2) -arr_namespace: xpt.HasArrayNamespace[Any] = arr diff --git a/tests/integration/test_numpy1p0.pyi b/tests/integration/test_numpy1p0.pyi new file mode 100644 index 0000000..efa859c --- /dev/null +++ b/tests/integration/test_numpy1p0.pyi @@ -0,0 +1,53 @@ +# mypy: disable-error-code="no-redef" + +from types import ModuleType +from typing import Any + +import numpy.array_api as np # type: ignore[import-not-found, unused-ignore] +from numpy import dtype + +import array_api_typing as xpt + +# Define NDArrays against which we can test the protocols +# Note that `np.array_api` doesn't support boolean arrays. +nparr = np.eye(2) +nparr_i32 = np.asarray([1], dtype=np.int32) +nparr_f32 = np.asarray([1.0], dtype=np.float32) + +# ========================================================= +# `xpt.HasArrayNamespace` + +_: xpt.HasArrayNamespace[ModuleType] = nparr +_: xpt.HasArrayNamespace[ModuleType] = nparr_i32 +_: xpt.HasArrayNamespace[ModuleType] = nparr_f32 + +# Check `__array_namespace__` method +a_ns: xpt.HasArrayNamespace[ModuleType] = nparr +ns: ModuleType = a_ns.__array_namespace__() + +# Incorrect values are caught when using `__array_namespace__` and +# backpropagated to the type of `a_ns` +_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught + +# ========================================================= +# `xpt.HasDType` + +# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't +# type annotate specific dtypes like `np.float32` or `np.int32`. + +_: xpt.HasDType[dtype[Any]] = nparr +_: xpt.HasDType[dtype[Any]] = nparr_i32 +_: xpt.HasDType[dtype[Any]] = nparr_f32 + +# ========================================================= +# `xpt.Array` + +# Check NamespaceT_co assignment +a_ns: xpt.Array[Any, ModuleType] = nparr + +# Check DTypeT_co assignment +# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't +# type annotate specific dtypes like `np.float32` or `np.int32`. +_: xpt.Array[dtype[Any]] = nparr +_: xpt.Array[dtype[Any]] = nparr_i32 +_: xpt.Array[dtype[Any]] = nparr_f32 diff --git a/tests/integration/test_numpy2.pyi b/tests/integration/test_numpy2.pyi deleted file mode 100644 index 64bed8c..0000000 --- a/tests/integration/test_numpy2.pyi +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Any - -import numpy.typing as npt - -import array_api_typing as xpt - -### -# Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`. - -arr: npt.NDArray[Any] -arr_namespace: xpt.HasArrayNamespace[Any] = arr diff --git a/tests/integration/test_numpy2p0.pyi b/tests/integration/test_numpy2p0.pyi new file mode 100644 index 0000000..4f2ec16 --- /dev/null +++ b/tests/integration/test_numpy2p0.pyi @@ -0,0 +1,57 @@ +# mypy: disable-error-code="no-redef" + +from types import ModuleType +from typing import Any, TypeAlias + +import numpy as np +import numpy.typing as npt + +import array_api_typing as xpt + +# DType aliases +F32: TypeAlias = np.float32 +I32: TypeAlias = np.int32 + +# Define NDArrays against which we can test the protocols +nparr: npt.NDArray[Any] +nparr_i32: npt.NDArray[I32] +nparr_f32: npt.NDArray[F32] +nparr_b: npt.NDArray[np.bool_] + +# ========================================================= +# `xpt.HasArrayNamespace` + +# Check assignment +_: xpt.HasArrayNamespace[ModuleType] = nparr +_: xpt.HasArrayNamespace[ModuleType] = nparr_i32 +_: xpt.HasArrayNamespace[ModuleType] = nparr_f32 +_: xpt.HasArrayNamespace[ModuleType] = nparr_b + +# Check `__array_namespace__` method +a_ns: xpt.HasArrayNamespace[ModuleType] = nparr +ns: ModuleType = a_ns.__array_namespace__() + +# Incorrect values are caught when using `__array_namespace__` and +# backpropagated to the type of `a_ns` +_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught + +# ========================================================= +# `xpt.HasDType` + +# Check DTypeT_co assignment +_: xpt.HasDType[Any] = nparr +_: xpt.HasDType[np.dtype[I32]] = nparr_i32 +_: xpt.HasDType[np.dtype[F32]] = nparr_f32 +_: xpt.HasDType[np.dtype[np.bool_]] = nparr_b + +# ========================================================= +# `xpt.Array` + +# Check NamespaceT_co assignment +a_ns: xpt.Array[Any, ModuleType] = nparr + +# Check DTypeT_co assignment +_: xpt.Array[Any] = nparr +_: xpt.Array[np.dtype[I32]] = nparr_i32 +_: xpt.Array[np.dtype[F32]] = nparr_f32 +_: xpt.Array[np.dtype[np.bool_]] = nparr_b diff --git a/tests/integration/test_numpy2p2.pyi b/tests/integration/test_numpy2p2.pyi new file mode 100644 index 0000000..e667f6e --- /dev/null +++ b/tests/integration/test_numpy2p2.pyi @@ -0,0 +1,5 @@ +from test_numpy2p0 import nparr + +import array_api_typing as xpt + +_: xpt.HasArrayNamespace[dict[str, int]] = nparr # type: ignore[assignment]