Skip to content

✨ HasDType, Array #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,15 @@ jobs:
python-version: "3.11"
activate-environment: true

- name: get major numpy version
id: numpy-major
- name: get major.minor numpy version
id: numpy-version
run: |
version=$(echo ${{ matrix.numpy-version }} | cut -c 1)
echo "::set-output name=version::$version"
version="${{ matrix.numpy-version }}"
major=$(echo "$version" | cut -d. -f1)
minor=$(echo "$version" | cut -d. -f2)

echo "major=$major" >> $GITHUB_OUTPUT
echo "minor=$minor" >> $GITHUB_OUTPUT

- name: install deps
run: |
Expand All @@ -101,10 +105,29 @@ jobs:

# NOTE: `uv run --with=...` will be ignored by mypy (and `--isolated` does not help)
- name: mypy
run: >
uv run --no-sync --active
mypy --tb --no-incremental --cache-dir=/dev/null
tests/integration/test_numpy${{ steps.numpy-major.outputs.version }}.pyi
run: |
major="${{ steps.numpy-version.outputs.major }}"
minor="${{ steps.numpy-version.outputs.minor }}"

# Directory containing versioned test files
prefix="tests/integration"
files=""

# Find all test files matching the current major version
for path in $(find "$prefix" -name "test_numpy${major}p*.pyi"); do
# Extract file name
fname=$(basename "$path")
# Parse the minor version from the filename
fminor=$(echo "$fname" | sed -E "s/test_numpy${major}p([0-9]+)\.pyi/\1/")
# Include files where minor version ≤ NumPy's minor
if [ "$fminor" -le "$minor" ]; then
files="$files $path"
fi
done

uv run --no-sync --active \
mypy --tb --no-incremental --cache-dir=/dev/null \
$files

# TODO: (based)pyright

Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,16 @@ ignore = [
"D107", # Missing docstring in __init__
"D203", # 1 blank line required before class docstring
"D213", # Multi-line docstring summary should start at the second line
"D401", # First line should be in imperative mood
"FBT", # flake8-boolean-trap
"FIX", # flake8-fixme
"ISC001", # Conflicts with formatter
"PYI041", # Use `float` instead of `int | float`
]

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.lint.pylint]
allow-dunder-method-names = [
"__array_api_version__",
Expand Down
4 changes: 3 additions & 1 deletion src/array_api_typing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Static typing support for the array API standard."""

__all__ = (
"Array",
"HasArrayNamespace",
"HasDType",
"__version__",
"__version_tuple__",
)

from ._namespace import HasArrayNamespace
from ._array import Array, HasArrayNamespace, HasDType
from ._version import version as __version__, version_tuple as __version_tuple__
94 changes: 94 additions & 0 deletions src/array_api_typing/_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
__all__ = (
"Array",
"HasArrayNamespace",
)

from types import ModuleType
from typing import Literal, Protocol
from typing_extensions import TypeVar

NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
DTypeT_co = TypeVar("DTypeT_co", covariant=True)


class HasArrayNamespace(Protocol[NamespaceT_co]):
"""Protocol for classes that have an `__array_namespace__` method.

This `Protocol` is intended for use in static typing to ensure that an
object has an `__array_namespace__` method that returns a namespace for
array operations. This `Protocol` should not be used at runtime for type
checking or as a base class.

Example:
>>> import array_api_typing as xpt
>>>
>>> class MyArray:
... def __array_namespace__(self):
... return object()
>>>
>>> x = MyArray()
>>> def has_array_namespace(x: xpt.HasArrayNamespace) -> bool:
... return hasattr(x, "__array_namespace__")
>>> has_array_namespace(x)
True

"""

def __array_namespace__(
self, /, *, api_version: Literal["2021.12"] | None = None
) -> NamespaceT_co:
"""Returns an object that has all the array API functions on it.

Args:
api_version: string representing the version of the array API
specification to be returned, in 'YYYY.MM' form, for example,
'2020.10'. If it is `None`, it should return the namespace
corresponding to latest version of the array API specification.
If the given version is invalid or not implemented for the given
module, an error should be raised. Default: `None`.

Returns:
NamespaceT_co: An object representing the array API namespace. It
should have every top-level function defined in the
specification as an attribute. It may contain other public names
as well, but it is recommended to only include those names that
are part of the specification.

"""
...


class HasDType(Protocol[DTypeT_co]):
"""Protocol for array classes that have a data type attribute."""

@property
def dtype(self, /) -> DTypeT_co:
"""Data type of the array elements."""
...


class Array(
HasArrayNamespace[NamespaceT_co],
# ------ Attributes -------
HasDType[DTypeT_co],
# -------------------------
Protocol[DTypeT_co, NamespaceT_co],
):
"""Array API specification for array object attributes and methods.

The type is: ``Array[+DTypeT, +NamespaceT = ModuleType] = Array[DTypeT,
NamespaceT]`` where:

- `DTypeT` is the data type of the array elements.
- `NamespaceT` is the type of the array namespace. It defaults to
`ModuleType`, which is the most common form of array namespace (e.g.,
`numpy`, `cupy`, etc.). However, it can be any type, e.g. a
`types.SimpleNamespace`, to allow for wrapper libraries to
semi-dynamically define their own array namespaces based on the wrapped
array type.

This type is intended for use in static typing to ensure that an object has
the attributes and methods defined in the array API specification. It should
not be used at runtime for type checking or as a base class.

"""
30 changes: 0 additions & 30 deletions src/array_api_typing/_namespace.py

This file was deleted.

12 changes: 0 additions & 12 deletions tests/integration/test_numpy1.pyi

This file was deleted.

53 changes: 53 additions & 0 deletions tests/integration/test_numpy1p0.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# mypy: disable-error-code="no-redef"

from types import ModuleType
from typing import Any

import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
from numpy import dtype

import array_api_typing as xpt

# Define NDArrays against which we can test the protocols
# Note that `np.array_api` doesn't support boolean arrays.
nparr = np.eye(2)
nparr_i32 = np.asarray([1], dtype=np.int32)
nparr_f32 = np.asarray([1.0], dtype=np.float32)

# =========================================================
# `xpt.HasArrayNamespace`

_: xpt.HasArrayNamespace[ModuleType] = nparr
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32

# Check `__array_namespace__` method
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
ns: ModuleType = a_ns.__array_namespace__()

# Incorrect values are caught when using `__array_namespace__` and
# backpropagated to the type of `a_ns`
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught

# =========================================================
# `xpt.HasDType`

# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
# type annotate specific dtypes like `np.float32` or `np.int32`.

_: xpt.HasDType[dtype[Any]] = nparr
_: xpt.HasDType[dtype[Any]] = nparr_i32
_: xpt.HasDType[dtype[Any]] = nparr_f32

# =========================================================
# `xpt.Array`

# Check NamespaceT_co assignment
a_ns: xpt.Array[Any, ModuleType] = nparr

# Check DTypeT_co assignment
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
# type annotate specific dtypes like `np.float32` or `np.int32`.
_: xpt.Array[dtype[Any]] = nparr
_: xpt.Array[dtype[Any]] = nparr_i32
_: xpt.Array[dtype[Any]] = nparr_f32
11 changes: 0 additions & 11 deletions tests/integration/test_numpy2.pyi

This file was deleted.

57 changes: 57 additions & 0 deletions tests/integration/test_numpy2p0.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# mypy: disable-error-code="no-redef"

from types import ModuleType
from typing import Any, TypeAlias

import numpy as np
import numpy.typing as npt

import array_api_typing as xpt

# DType aliases
F32: TypeAlias = np.float32
I32: TypeAlias = np.int32

# Define NDArrays against which we can test the protocols
nparr: npt.NDArray[Any]
nparr_i32: npt.NDArray[I32]
nparr_f32: npt.NDArray[F32]
nparr_b: npt.NDArray[np.bool_]

# =========================================================
# `xpt.HasArrayNamespace`

# Check assignment
_: xpt.HasArrayNamespace[ModuleType] = nparr
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
_: xpt.HasArrayNamespace[ModuleType] = nparr_b

# Check `__array_namespace__` method
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
ns: ModuleType = a_ns.__array_namespace__()

# Incorrect values are caught when using `__array_namespace__` and
# backpropagated to the type of `a_ns`
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught

# =========================================================
# `xpt.HasDType`

# Check DTypeT_co assignment
_: xpt.HasDType[Any] = nparr
_: xpt.HasDType[np.dtype[I32]] = nparr_i32
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
_: xpt.HasDType[np.dtype[np.bool_]] = nparr_b

# =========================================================
# `xpt.Array`

# Check NamespaceT_co assignment
a_ns: xpt.Array[Any, ModuleType] = nparr

# Check DTypeT_co assignment
_: xpt.Array[Any] = nparr
_: xpt.Array[np.dtype[I32]] = nparr_i32
_: xpt.Array[np.dtype[F32]] = nparr_f32
_: xpt.Array[np.dtype[np.bool_]] = nparr_b
5 changes: 5 additions & 0 deletions tests/integration/test_numpy2p2.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from test_numpy2p0 import nparr

import array_api_typing as xpt

_: xpt.HasArrayNamespace[dict[str, int]] = nparr # type: ignore[assignment]
Loading