Skip to content

Commit 280d810

Browse files
authored
Typed queries 2 (reduced) (#149)
This PR includes a reduced version of the changes proposed in #81 . I removed the plugin method hook and its tests to reduce the scope of the PR, to make it easier to manage and merge, as just the type annotations are a great improvement.
1 parent 1c36a4e commit 280d810

File tree

2 files changed

+46
-35
lines changed

2 files changed

+46
-35
lines changed

sqlalchemy-stubs/orm/query.pyi

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,44 @@
1-
from typing import Any, Optional, Union
1+
from typing import Any, Optional, Union, TypeVar, Generic, List, Iterator
22
from . import interfaces
33
from .base import InspectionAttr
4-
from ..sql.selectable import ForUpdateArg
4+
from ..sql.selectable import ForUpdateArg, Alias, CTE
5+
from ..sql.elements import Label
6+
from .session import Session
57

6-
class Query(object):
7-
session: Any = ...
8-
def __init__(self, entities, session: Optional[Any] = ...) -> None: ...
8+
9+
_T = TypeVar('_T')
10+
_Q = TypeVar('_Q', bound="Query")
11+
12+
13+
class Query(Generic[_T]):
14+
session: Session = ...
15+
def __init__(self, entities, session: Optional[Session] = ...) -> None: ...
16+
17+
# TODO: is "statement" always of type sqlalchemy.sql.selectable.Select ?
918
@property
1019
def statement(self): ...
11-
def subquery(self, name: Optional[Any] = ..., with_labels: bool = ..., reduce_columns: bool = ...): ...
12-
def cte(self, name: Optional[Any] = ..., recursive: bool = ...): ...
13-
def label(self, name): ...
20+
def subquery(self, name: Optional[str] = ..., with_labels: bool = ..., reduce_columns: bool = ...) -> Alias: ...
21+
def cte(self, name: Optional[str] = ..., recursive: bool = ...) -> CTE: ...
22+
def label(self, name: str) -> Label: ...
1423
def as_scalar(self): ...
1524
@property
1625
def selectable(self): ...
1726
def __clause_element__(self): ...
18-
def enable_eagerloads(self, value): ...
19-
def with_labels(self): ...
20-
def enable_assertions(self, value): ...
27+
def enable_eagerloads(self: _Q, value: bool) -> _Q: ...
28+
def with_labels(self: _Q) -> _Q: ...
29+
def enable_assertions(self: _Q, value: bool) -> _Q: ...
2130
@property
2231
def whereclause(self): ...
2332
def with_polymorphic(self, cls_or_mappers, selectable: Optional[Any] = ...,
2433
polymorphic_on: Optional[Any] = ...): ...
25-
def yield_per(self, count): ...
26-
def get(self, ident): ...
34+
def yield_per(self: _Q, count: int) -> _Q: ...
35+
def get(self, ident) -> Optional[_T]: ...
2736
def correlate(self, *args): ...
28-
def autoflush(self, setting): ...
29-
def populate_existing(self): ...
37+
def autoflush(self: _Q, setting: bool) -> _Q: ...
38+
def populate_existing(self: _Q) -> _Q: ...
3039
def with_parent(self, instance, property: Optional[Any] = ...): ...
3140
def add_entity(self, entity, alias: Optional[Any] = ...): ...
32-
def with_session(self, session): ...
41+
def with_session(self: _Q, session: Optional[Session]) -> _Q: ...
3342
def from_self(self, *entities): ...
3443
def values(self, *columns): ...
3544
def value(self, column): ...
@@ -42,14 +51,14 @@ class Query(object):
4251
def with_statement_hint(self, text, dialect_name: str = ...): ...
4352
def execution_options(self, **kwargs): ...
4453
def with_lockmode(self, mode): ...
45-
def with_for_update(self, read: bool = ..., nowait: bool = ..., of: Optional[Any] = ...,
46-
skip_locked: bool = ..., key_share: bool = ...): ...
47-
def params(self, *args, **kwargs): ...
48-
def filter(self, *criterion): ...
49-
def filter_by(self, **kwargs): ...
50-
def order_by(self, *criterion): ...
51-
def group_by(self, *criterion): ...
52-
def having(self, criterion): ...
54+
def with_for_update(self: _Q, read: bool = ..., nowait: bool = ..., of: Optional[Any] = ...,
55+
skip_locked: bool = ..., key_share: bool = ...) -> _Q: ...
56+
def params(self: _Q, *args, **kwargs) -> _Q: ...
57+
def filter(self: _Q, *criterion) -> _Q: ...
58+
def filter_by(self: _Q, **kwargs) -> _Q: ...
59+
def order_by(self: _Q, *criterion) -> _Q: ...
60+
def group_by(self: _Q, *criterion) -> _Q: ...
61+
def having(self: _Q, criterion) -> _Q: ...
5362
def union(self, *q): ...
5463
def union_all(self, *q): ...
5564
def intersect(self, *q): ...
@@ -62,26 +71,26 @@ class Query(object):
6271
def select_from(self, *from_obj): ...
6372
def select_entity_from(self, from_obj): ...
6473
def __getitem__(self, item): ...
65-
def slice(self, start, stop): ...
66-
def limit(self, limit): ...
67-
def offset(self, offset): ...
74+
def slice(self: _Q, start: Optional[int], stop: Optional[int]) -> _Q: ...
75+
def limit(self: _Q, limit: Optional[int]) -> _Q: ...
76+
def offset(self: _Q, offset: Optional[int]) -> _Q: ...
6877
def distinct(self, *criterion): ...
6978
def prefix_with(self, *prefixes): ...
7079
def suffix_with(self, *suffixes): ...
71-
def all(self): ...
80+
def all(self) -> List[_T]: ...
7281
def from_statement(self, statement): ...
73-
def first(self): ...
74-
def one_or_none(self): ...
75-
def one(self): ...
82+
def first(self) -> Optional[_T]: ...
83+
def one_or_none(self) -> Optional[_T]: ...
84+
def one(self) -> _T: ...
7685
def scalar(self): ...
77-
def __iter__(self): ...
86+
def __iter__(self) -> Iterator[_T]: ...
7887
@property
7988
def column_descriptions(self): ...
8089
def instances(self, cursor, __context: Optional[Any] = ...): ...
8190
def merge_result(self, iterator, load: bool = ...): ...
8291
def exists(self): ...
83-
def count(self): ...
84-
def delete(self, synchronize_session: Union[bool, str] = ...): ...
92+
def count(self) -> int: ...
93+
def delete(self, synchronize_session: Union[bool, str] = ...) -> int: ...
8594
def update(self, values, synchronize_session: Union[bool, str] = ..., update_args: Optional[Any] = ...): ...
8695

8796
class LockmodeArg(ForUpdateArg):

sqlalchemy-stubs/orm/session.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Any, Optional
22

3+
from sqlalchemy.orm.query import Query
4+
35
class _SessionClassMethods(object):
46
@classmethod
57
def close_all(cls): ...
@@ -60,7 +62,7 @@ class Session(_SessionClassMethods):
6062
def bind_mapper(self, mapper, bind): ...
6163
def bind_table(self, table, bind): ...
6264
def get_bind(self, mapper: Optional[Any] = ..., clause: Optional[Any] = ...): ...
63-
def query(self, *entities, **kwargs): ...
65+
def query(self, *entities, **kwargs) -> Query[Any]: ...
6466
@property
6567
def no_autoflush(self): ...
6668
def refresh(self, instance, attribute_names: Optional[Any] = ..., lockmode: Optional[Any] = ...): ...

0 commit comments

Comments
 (0)