Skip to content

Commit eaa42ce

Browse files
committed
✨Array class
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 255f453 commit eaa42ce

File tree

4 files changed

+25
-2
lines changed

4 files changed

+25
-2
lines changed

src/array_api_typing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Static typing support for the array API standard."""
22

33
__all__ = (
4+
"Array",
45
"HasArrayNamespace",
56
"__version__",
67
"__version_tuple__",
78
)
89

9-
from ._array import HasArrayNamespace
10+
from ._array import Array, HasArrayNamespace
1011
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
__all__ = ("HasArrayNamespace",)
1+
__all__ = (
2+
"Array",
3+
"HasArrayNamespace",
4+
)
25

36
from types import ModuleType
47
from typing import Literal, Protocol
@@ -52,3 +55,10 @@ def __array_namespace__(
5255
5356
"""
5457
...
58+
59+
60+
class Array(
61+
HasArrayNamespace[NamespaceT_co],
62+
Protocol[NamespaceT_co],
63+
):
64+
"""Array API specification for array object attributes and methods."""

tests/integration/test_numpy1p0.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,9 @@ ns: ModuleType = a_ns.__array_namespace__()
2727
# Incorrect values are caught when using `__array_namespace__` and
2828
# backpropagated to the type of `a_ns`
2929
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
30+
31+
# =========================================================
32+
# `xpt.Array`
33+
34+
# Check NamespaceT_co assignment
35+
a_ns: xpt.Array[ModuleType] = nparr

tests/integration/test_numpy2p0.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,9 @@ ns: ModuleType = a_ns.__array_namespace__()
3434
# Incorrect values are caught when using `__array_namespace__` and
3535
# backpropagated to the type of `a_ns`
3636
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
37+
38+
# =========================================================
39+
# `xpt.Array`
40+
41+
# Check NamespaceT_co assignment
42+
a_ns: xpt.Array[ModuleType] = nparr

0 commit comments

Comments
 (0)