Skip to content

Commit bacf0f9

Browse files
authored
✨ HasDType, Array (#48)
* 🚚 move HasArrayNamespace * ✨Array class * ✨HasDType Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 6d54ce7 commit bacf0f9

File tree

10 files changed

+248
-62
lines changed

10 files changed

+248
-62
lines changed

.github/workflows/ci.yml

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,15 @@ jobs:
8888
python-version: "3.11"
8989
activate-environment: true
9090

91-
- name: get major numpy version
92-
id: numpy-major
91+
- name: get major.minor numpy version
92+
id: numpy-version
9393
run: |
94-
version=$(echo ${{ matrix.numpy-version }} | cut -c 1)
95-
echo "::set-output name=version::$version"
94+
version="${{ matrix.numpy-version }}"
95+
major=$(echo "$version" | cut -d. -f1)
96+
minor=$(echo "$version" | cut -d. -f2)
97+
98+
echo "major=$major" >> $GITHUB_OUTPUT
99+
echo "minor=$minor" >> $GITHUB_OUTPUT
96100
97101
- name: install deps
98102
run: |
@@ -101,10 +105,29 @@ jobs:
101105
102106
# NOTE: `uv run --with=...` will be ignored by mypy (and `--isolated` does not help)
103107
- name: mypy
104-
run: >
105-
uv run --no-sync --active
106-
mypy --tb --no-incremental --cache-dir=/dev/null
107-
tests/integration/test_numpy${{ steps.numpy-major.outputs.version }}.pyi
108+
run: |
109+
major="${{ steps.numpy-version.outputs.major }}"
110+
minor="${{ steps.numpy-version.outputs.minor }}"
111+
112+
# Directory containing versioned test files
113+
prefix="tests/integration"
114+
files=""
115+
116+
# Find all test files matching the current major version
117+
for path in $(find "$prefix" -name "test_numpy${major}p*.pyi"); do
118+
# Extract file name
119+
fname=$(basename "$path")
120+
# Parse the minor version from the filename
121+
fminor=$(echo "$fname" | sed -E "s/test_numpy${major}p([0-9]+)\.pyi/\1/")
122+
# Include files where minor version ≤ NumPy's minor
123+
if [ "$fminor" -le "$minor" ]; then
124+
files="$files $path"
125+
fi
126+
done
127+
128+
uv run --no-sync --active \
129+
mypy --tb --no-incremental --cache-dir=/dev/null \
130+
$files
108131
109132
# TODO: (based)pyright
110133

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,16 @@ ignore = [
121121
"D107", # Missing docstring in __init__
122122
"D203", # 1 blank line required before class docstring
123123
"D213", # Multi-line docstring summary should start at the second line
124+
"D401", # First line should be in imperative mood
124125
"FBT", # flake8-boolean-trap
125126
"FIX", # flake8-fixme
126127
"ISC001", # Conflicts with formatter
128+
"PYI041", # Use `float` instead of `int | float`
127129
]
128130

131+
[tool.ruff.lint.pydocstyle]
132+
convention = "google"
133+
129134
[tool.ruff.lint.pylint]
130135
allow-dunder-method-names = [
131136
"__array_api_version__",

src/array_api_typing/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Static typing support for the array API standard."""
22

33
__all__ = (
4+
"Array",
45
"HasArrayNamespace",
6+
"HasDType",
57
"__version__",
68
"__version_tuple__",
79
)
810

9-
from ._namespace import HasArrayNamespace
11+
from ._array import Array, HasArrayNamespace, HasDType
1012
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
__all__ = (
2+
"Array",
3+
"HasArrayNamespace",
4+
)
5+
6+
from types import ModuleType
7+
from typing import Literal, Protocol
8+
from typing_extensions import TypeVar
9+
10+
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
11+
DTypeT_co = TypeVar("DTypeT_co", covariant=True)
12+
13+
14+
class HasArrayNamespace(Protocol[NamespaceT_co]):
15+
"""Protocol for classes that have an `__array_namespace__` method.
16+
17+
This `Protocol` is intended for use in static typing to ensure that an
18+
object has an `__array_namespace__` method that returns a namespace for
19+
array operations. This `Protocol` should not be used at runtime for type
20+
checking or as a base class.
21+
22+
Example:
23+
>>> import array_api_typing as xpt
24+
>>>
25+
>>> class MyArray:
26+
... def __array_namespace__(self):
27+
... return object()
28+
>>>
29+
>>> x = MyArray()
30+
>>> def has_array_namespace(x: xpt.HasArrayNamespace) -> bool:
31+
... return hasattr(x, "__array_namespace__")
32+
>>> has_array_namespace(x)
33+
True
34+
35+
"""
36+
37+
def __array_namespace__(
38+
self, /, *, api_version: Literal["2021.12"] | None = None
39+
) -> NamespaceT_co:
40+
"""Returns an object that has all the array API functions on it.
41+
42+
Args:
43+
api_version: string representing the version of the array API
44+
specification to be returned, in 'YYYY.MM' form, for example,
45+
'2020.10'. If it is `None`, it should return the namespace
46+
corresponding to latest version of the array API specification.
47+
If the given version is invalid or not implemented for the given
48+
module, an error should be raised. Default: `None`.
49+
50+
Returns:
51+
NamespaceT_co: An object representing the array API namespace. It
52+
should have every top-level function defined in the
53+
specification as an attribute. It may contain other public names
54+
as well, but it is recommended to only include those names that
55+
are part of the specification.
56+
57+
"""
58+
...
59+
60+
61+
class HasDType(Protocol[DTypeT_co]):
62+
"""Protocol for array classes that have a data type attribute."""
63+
64+
@property
65+
def dtype(self, /) -> DTypeT_co:
66+
"""Data type of the array elements."""
67+
...
68+
69+
70+
class Array(
71+
HasArrayNamespace[NamespaceT_co],
72+
# ------ Attributes -------
73+
HasDType[DTypeT_co],
74+
# -------------------------
75+
Protocol[DTypeT_co, NamespaceT_co],
76+
):
77+
"""Array API specification for array object attributes and methods.
78+
79+
The type is: ``Array[+DTypeT, +NamespaceT = ModuleType] = Array[DTypeT,
80+
NamespaceT]`` where:
81+
82+
- `DTypeT` is the data type of the array elements.
83+
- `NamespaceT` is the type of the array namespace. It defaults to
84+
`ModuleType`, which is the most common form of array namespace (e.g.,
85+
`numpy`, `cupy`, etc.). However, it can be any type, e.g. a
86+
`types.SimpleNamespace`, to allow for wrapper libraries to
87+
semi-dynamically define their own array namespaces based on the wrapped
88+
array type.
89+
90+
This type is intended for use in static typing to ensure that an object has
91+
the attributes and methods defined in the array API specification. It should
92+
not be used at runtime for type checking or as a base class.
93+
94+
"""

src/array_api_typing/_namespace.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

tests/integration/test_numpy1.pyi

Lines changed: 0 additions & 12 deletions
This file was deleted.

tests/integration/test_numpy1p0.pyi

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# mypy: disable-error-code="no-redef"
2+
3+
from types import ModuleType
4+
from typing import Any
5+
6+
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
7+
from numpy import dtype
8+
9+
import array_api_typing as xpt
10+
11+
# Define NDArrays against which we can test the protocols
12+
# Note that `np.array_api` doesn't support boolean arrays.
13+
nparr = np.eye(2)
14+
nparr_i32 = np.asarray([1], dtype=np.int32)
15+
nparr_f32 = np.asarray([1.0], dtype=np.float32)
16+
17+
# =========================================================
18+
# `xpt.HasArrayNamespace`
19+
20+
_: xpt.HasArrayNamespace[ModuleType] = nparr
21+
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
22+
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
23+
24+
# Check `__array_namespace__` method
25+
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
26+
ns: ModuleType = a_ns.__array_namespace__()
27+
28+
# Incorrect values are caught when using `__array_namespace__` and
29+
# backpropagated to the type of `a_ns`
30+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
31+
32+
# =========================================================
33+
# `xpt.HasDType`
34+
35+
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
36+
# type annotate specific dtypes like `np.float32` or `np.int32`.
37+
38+
_: xpt.HasDType[dtype[Any]] = nparr
39+
_: xpt.HasDType[dtype[Any]] = nparr_i32
40+
_: xpt.HasDType[dtype[Any]] = nparr_f32
41+
42+
# =========================================================
43+
# `xpt.Array`
44+
45+
# Check NamespaceT_co assignment
46+
a_ns: xpt.Array[Any, ModuleType] = nparr
47+
48+
# Check DTypeT_co assignment
49+
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
50+
# type annotate specific dtypes like `np.float32` or `np.int32`.
51+
_: xpt.Array[dtype[Any]] = nparr
52+
_: xpt.Array[dtype[Any]] = nparr_i32
53+
_: xpt.Array[dtype[Any]] = nparr_f32

tests/integration/test_numpy2.pyi

Lines changed: 0 additions & 11 deletions
This file was deleted.

tests/integration/test_numpy2p0.pyi

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# mypy: disable-error-code="no-redef"
2+
3+
from types import ModuleType
4+
from typing import Any, TypeAlias
5+
6+
import numpy as np
7+
import numpy.typing as npt
8+
9+
import array_api_typing as xpt
10+
11+
# DType aliases
12+
F32: TypeAlias = np.float32
13+
I32: TypeAlias = np.int32
14+
15+
# Define NDArrays against which we can test the protocols
16+
nparr: npt.NDArray[Any]
17+
nparr_i32: npt.NDArray[I32]
18+
nparr_f32: npt.NDArray[F32]
19+
nparr_b: npt.NDArray[np.bool_]
20+
21+
# =========================================================
22+
# `xpt.HasArrayNamespace`
23+
24+
# Check assignment
25+
_: xpt.HasArrayNamespace[ModuleType] = nparr
26+
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
27+
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
28+
_: xpt.HasArrayNamespace[ModuleType] = nparr_b
29+
30+
# Check `__array_namespace__` method
31+
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
32+
ns: ModuleType = a_ns.__array_namespace__()
33+
34+
# Incorrect values are caught when using `__array_namespace__` and
35+
# backpropagated to the type of `a_ns`
36+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
37+
38+
# =========================================================
39+
# `xpt.HasDType`
40+
41+
# Check DTypeT_co assignment
42+
_: xpt.HasDType[Any] = nparr
43+
_: xpt.HasDType[np.dtype[I32]] = nparr_i32
44+
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
45+
_: xpt.HasDType[np.dtype[np.bool_]] = nparr_b
46+
47+
# =========================================================
48+
# `xpt.Array`
49+
50+
# Check NamespaceT_co assignment
51+
a_ns: xpt.Array[Any, ModuleType] = nparr
52+
53+
# Check DTypeT_co assignment
54+
_: xpt.Array[Any] = nparr
55+
_: xpt.Array[np.dtype[I32]] = nparr_i32
56+
_: xpt.Array[np.dtype[F32]] = nparr_f32
57+
_: xpt.Array[np.dtype[np.bool_]] = nparr_b

tests/integration/test_numpy2p2.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from test_numpy2p0 import nparr
2+
3+
import array_api_typing as xpt
4+
5+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # type: ignore[assignment]

0 commit comments

Comments
 (0)