Skip to content

Commit 0e90681

Browse files
committed
✨HasDType
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 65256af commit 0e90681

File tree

4 files changed

+54
-6
lines changed

4 files changed

+54
-6
lines changed

src/array_api_typing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
__all__ = (
44
"Array",
55
"HasArrayNamespace",
6+
"HasDType",
67
"__version__",
78
"__version_tuple__",
89
)
910

10-
from ._array import Array, HasArrayNamespace
11+
from ._array import Array, HasArrayNamespace, HasDType
1112
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing_extensions import TypeVar
99

1010
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
11+
DTypeT_co = TypeVar("DTypeT_co", covariant=True)
1112

1213

1314
class HasArrayNamespace(Protocol[NamespaceT_co]):
@@ -57,15 +58,28 @@ def __array_namespace__(
5758
...
5859

5960

61+
class HasDType(Protocol[DTypeT_co]):
62+
"""Protocol for array classes that have a data type attribute."""
63+
64+
@property
65+
def dtype(self, /) -> DTypeT_co:
66+
"""Data type of the array elements."""
67+
...
68+
69+
6070
class Array(
6171
HasArrayNamespace[NamespaceT_co],
72+
# ------ Attributes -------
73+
HasDType[DTypeT_co],
6274
# -------------------------
63-
Protocol[NamespaceT_co],
75+
Protocol[DTypeT_co, NamespaceT_co],
6476
):
6577
"""Array API specification for array object attributes and methods.
6678
67-
The type is: ``Array[+NamespaceT = ModuleType] = Array[NamespaceT]`` where:
79+
The type is: ``Array[+DTypeT, +NamespaceT = ModuleType] = Array[DTypeT,
80+
NamespaceT]`` where:
6881
82+
- `DTypeT` is the data type of the array elements.
6983
- `NamespaceT` is the type of the array namespace. It defaults to
7084
`ModuleType`, which is the most common form of array namespace (e.g.,
7185
`numpy`, `cupy`, etc.). However, it can be any type, e.g. a

tests/integration/test_numpy1p0.pyi

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# mypy: disable-error-code="no-redef"
22

33
from types import ModuleType
4-
from typing import TypeAlias
4+
from typing import Any
55

66
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
7+
from numpy import dtype
78

89
import array_api_typing as xpt
910

@@ -28,8 +29,25 @@ ns: ModuleType = a_ns.__array_namespace__()
2829
# backpropagated to the type of `a_ns`
2930
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
3031

32+
# =========================================================
33+
# `xpt.HasDType`
34+
35+
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
36+
# type annotate specific dtypes like `np.float32` or `np.int32`.
37+
38+
_: xpt.HasDType[dtype[Any]] = nparr
39+
_: xpt.HasDType[dtype[Any]] = nparr_i32
40+
_: xpt.HasDType[dtype[Any]] = nparr_f32
41+
3142
# =========================================================
3243
# `xpt.Array`
3344

3445
# Check NamespaceT_co assignment
35-
a_ns: xpt.Array[ModuleType] = nparr
46+
a_ns: xpt.Array[Any, ModuleType] = nparr
47+
48+
# Check DTypeT_co assignment
49+
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
50+
# type annotate specific dtypes like `np.float32` or `np.int32`.
51+
_: xpt.Array[dtype[Any]] = nparr
52+
_: xpt.Array[dtype[Any]] = nparr_i32
53+
_: xpt.Array[dtype[Any]] = nparr_f32

tests/integration/test_numpy2p0.pyi

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,23 @@ ns: ModuleType = a_ns.__array_namespace__()
3535
# backpropagated to the type of `a_ns`
3636
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
3737

38+
# =========================================================
39+
# `xpt.HasDType`
40+
41+
# Check DTypeT_co assignment
42+
_: xpt.HasDType[Any] = nparr
43+
_: xpt.HasDType[np.dtype[I32]] = nparr_i32
44+
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
45+
_: xpt.HasDType[np.dtype[np.bool_]] = nparr_b
46+
3847
# =========================================================
3948
# `xpt.Array`
4049

4150
# Check NamespaceT_co assignment
42-
a_ns: xpt.Array[ModuleType] = nparr
51+
a_ns: xpt.Array[Any, ModuleType] = nparr
52+
53+
# Check DTypeT_co assignment
54+
_: xpt.Array[Any] = nparr
55+
_: xpt.Array[np.dtype[I32]] = nparr_i32
56+
_: xpt.Array[np.dtype[F32]] = nparr_f32
57+
_: xpt.Array[np.dtype[np.bool_]] = nparr_b

0 commit comments

Comments
 (0)