Skip to content

Commit 78239df

Browse files
committed
fix: fix mypy errors and add mypy as pre-commit
1 parent 1810177 commit 78239df

File tree

6 files changed

+34
-12
lines changed

6 files changed

+34
-12
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,5 @@ tmp/
3535
*.egg
3636
dist/
3737
.DS_STORE
38+
venv
39+
.venv

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,8 @@ repos:
4848
rev: 23.7.0
4949
hooks:
5050
- id: black
51+
- repo: https://github.com/pre-commit/mirrors-mypy
52+
rev: v1.13.0
53+
hooks:
54+
- id: mypy
55+
additional_dependencies: []

pyproject.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,17 @@ build-backend = "setuptools.build_meta"
3232

3333
[tool.black]
3434
line-length = 88
35+
36+
[tool.mypy]
37+
mypy_path = "src/array_api_stubs/_draft"
38+
exclude = [
39+
"docs/",
40+
"spec/",
41+
"venv/",
42+
".venv/",
43+
"src/array_api_stubs/_2021_12/",
44+
"src/array_api_stubs/_2022_12/",
45+
"src/array_api_stubs/_2023_12/",
46+
"src/array_api_conf.py"
47+
]
48+
disable_error_code = "empty-body,type-var"

src/array_api_stubs/_draft/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Function stubs and API documentation for the array API standard."""
22

3-
from .array_object import *
43
from .constants import *
54
from .creation_functions import *
65
from .data_type_functions import *

src/array_api_stubs/_draft/_types.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@
4646
)
4747
from enum import Enum
4848

49-
array = TypeVar("array", bound="array_")
49+
array = TypeVar("array", bound="_array")
5050
device = TypeVar("device")
5151
dtype = TypeVar("dtype")
52+
Device = TypeVar("Device")
53+
Dtype = TypeVar("Dtype")
5254
SupportsDLPack = TypeVar("SupportsDLPack")
5355
SupportsBufferProtocol = TypeVar("SupportsBufferProtocol")
5456
PyCapsule = TypeVar("PyCapsule")
@@ -88,7 +90,7 @@ def __len__(self, /) -> int:
8890
...
8991

9092

91-
class Info(Protocol):
93+
class Info(Protocol[device]):
9294
"""Namespace returned by `__array_namespace_info__`."""
9395

9496
def capabilities(self) -> Capabilities:
@@ -147,12 +149,12 @@ def dtypes(
147149
)
148150

149151

150-
class _array(Protocol[array, dtype, device]):
152+
class _array(Protocol[array, Dtype, Device, PyCapsule]): # type: ignore
151153
def __init__(self: array) -> None:
152154
"""Initialize the attributes for the array object class."""
153155

154156
@property
155-
def dtype(self: array) -> dtype:
157+
def dtype(self: array) -> Dtype:
156158
"""
157159
Data type of the array elements.
158160
@@ -163,7 +165,7 @@ def dtype(self: array) -> dtype:
163165
"""
164166

165167
@property
166-
def device(self: array) -> device:
168+
def device(self: array) -> Device:
167169
"""
168170
Hardware device the array data resides on.
169171
@@ -625,7 +627,7 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
625627
ONE_API = 14
626628
"""
627629

628-
def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:
630+
def __eq__(self: array, other: Union[int, float, bool, array], /) -> array: # type: ignore
629631
r"""
630632
Computes the truth value of ``self_i == other_i`` for each element of an array instance with the respective element of the array ``other``.
631633
@@ -1072,7 +1074,7 @@ def __mul__(self: array, other: Union[int, float, array], /) -> array:
10721074
Added complex data type support.
10731075
"""
10741076

1075-
def __ne__(self: array, other: Union[int, float, bool, array], /) -> array:
1077+
def __ne__(self: array, other: Union[int, float, bool, array], /) -> array: # type: ignore
10761078
"""
10771079
Computes the truth value of ``self_i != other_i`` for each element of an array instance with the respective element of the array ``other``.
10781080
@@ -1342,7 +1344,7 @@ def __xor__(self: array, other: Union[int, bool, array], /) -> array:
13421344
"""
13431345

13441346
def to_device(
1345-
self: array, device: device, /, *, stream: Optional[Union[int, Any]] = None
1347+
self: array, device: Device, /, *, stream: Optional[Union[int, Any]] = None
13461348
) -> array:
13471349
"""
13481350
Copy the array from the device on which it currently resides to the specified ``device``.

src/array_api_stubs/_draft/linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def matrix_norm(
301301
/,
302302
*,
303303
keepdims: bool = False,
304-
ord: Optional[Union[int, float, Literal[inf, -inf, "fro", "nuc"]]] = "fro",
304+
ord: Optional[Union[int, float, Literal[inf, -inf, "fro", "nuc"]]] = "fro", # type: ignore
305305
) -> array:
306306
"""
307307
Computes the matrix norm of a matrix (or a stack of matrices) ``x``.
@@ -781,7 +781,7 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> arr
781781
"""
782782

783783

784-
def vecdot(x1: array, x2: array, /, *, axis: int = None) -> array:
784+
def vecdot(x1: array, x2: array, /, *, axis: int | None = None) -> array:
785785
"""Alias for :func:`~array_api.vecdot`."""
786786

787787

@@ -791,7 +791,7 @@ def vector_norm(
791791
*,
792792
axis: Optional[Union[int, Tuple[int, ...]]] = None,
793793
keepdims: bool = False,
794-
ord: Union[int, float, Literal[inf, -inf]] = 2,
794+
ord: Union[int, float, Literal[inf, -inf]] = 2, # type: ignore
795795
) -> array:
796796
r"""
797797
Computes the vector norm of a vector (or batch of vectors) ``x``.

0 commit comments

Comments
 (0)