Skip to content

Commit 9eb95fb

Browse files
committed
add BaseManager.create() typechecking
1 parent 7aafca2 commit 9eb95fb

File tree

19 files changed

+557
-267
lines changed

19 files changed

+557
-267
lines changed

django-stubs/contrib/admin/models.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ CHANGE: int
1111
DELETION: int
1212
ACTION_FLAG_CHOICES: Any
1313

14-
class LogEntryManager(models.Manager):
14+
class LogEntryManager(models.Manager["LogEntry"]):
1515
def log_action(
1616
self,
1717
user_id: int,

django-stubs/contrib/auth/models.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Permission(models.Model):
1616
content_type_id: int
1717
name: models.CharField = ...
1818
content_type: models.ForeignKey[ContentType] = ...
19-
codename: str = ...
19+
codename: models.CharField = ...
2020
def natural_key(self) -> Tuple[str, str, str]: ...
2121

2222
class GroupManager(models.Manager):
Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
from typing import Any, Optional
2-
31
from django.db import models
42

53
class Redirect(models.Model):
6-
id: None
7-
site_id: int
8-
site: Any = ...
9-
old_path: str = ...
10-
new_path: str = ...
4+
site: models.ForeignKey = ...
5+
old_path: models.CharField = ...
6+
new_path: models.CharField = ...

django-stubs/contrib/sessions/backends/db.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from typing import Any, Dict, Optional, Type, Union
1+
from typing import Dict, Optional, Type, Union
22

33
from django.contrib.sessions.backends.base import SessionBase
44
from django.contrib.sessions.base_session import AbstractBaseSession
55
from django.contrib.sessions.models import Session
6+
from django.core.signing import Serializer
67
from django.db.models.base import Model
78

89
class SessionStore(SessionBase):
910
accessed: bool
10-
serializer: Type[django.core.signing.JSONSerializer]
11+
serializer: Type[Serializer]
1112
def __init__(self, session_key: Optional[str] = ...) -> None: ...
1213
@classmethod
1314
def get_model_class(cls) -> Type[Session]: ...

django-stubs/db/models/base.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, ClassVar, Sequence
1+
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, ClassVar, Sequence, Generic
22

33
from django.db.models.manager import Manager
44

@@ -10,9 +10,9 @@ class Model(metaclass=ModelBase):
1010
class DoesNotExist(Exception): ...
1111
class Meta: ...
1212
_meta: Any
13+
_default_manager: Manager[Model]
1314
pk: Any = ...
14-
objects: Manager[Model]
15-
def __init__(self, *args, **kwargs) -> None: ...
15+
def __init__(self: _Self, *args, **kwargs) -> None: ...
1616
def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ...
1717
def full_clean(self, exclude: Optional[List[str]] = ..., validate_unique: bool = ...) -> None: ...
1818
def clean_fields(self, exclude: List[str] = ...) -> None: ...

django-stubs/db/models/fields/__init__.pyi

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import uuid
22
from datetime import date, time, datetime, timedelta
3-
from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type, TypeVar
3+
from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type, TypeVar, Generic
44
import decimal
55

66
from typing_extensions import Literal
@@ -14,7 +14,7 @@ from django.forms import Widget, Field as FormField
1414
from .mixins import NOT_PROVIDED as NOT_PROVIDED
1515

1616
_Choice = Tuple[Any, Any]
17-
_ChoiceNamedGroup = Union[Tuple[str, Iterable[_Choice]], Tuple[str, Any]]
17+
_ChoiceNamedGroup = Tuple[str, Iterable[_Choice]]
1818
_FieldChoices = Iterable[Union[_Choice, _ChoiceNamedGroup]]
1919

2020
_ValidatorCallable = Callable[..., None]
@@ -76,7 +76,7 @@ class SmallIntegerField(IntegerField): ...
7676
class BigIntegerField(IntegerField): ...
7777

7878
class FloatField(Field):
79-
def __set__(self, instance, value: Union[float, int, Combinable]) -> float: ...
79+
def __set__(self, instance, value: Union[float, int, str, Combinable]) -> float: ...
8080
def __get__(self, instance, owner) -> float: ...
8181

8282
class DecimalField(Field):
@@ -102,7 +102,7 @@ class DecimalField(Field):
102102
validators: Iterable[_ValidatorCallable] = ...,
103103
error_messages: Optional[_ErrorMessagesToOverride] = ...,
104104
): ...
105-
def __set__(self, instance, value: Union[str, Combinable]) -> decimal.Decimal: ...
105+
def __set__(self, instance, value: Union[str, float, decimal.Decimal, Combinable]) -> decimal.Decimal: ...
106106
def __get__(self, instance, owner) -> decimal.Decimal: ...
107107

108108
class AutoField(Field):
@@ -167,15 +167,15 @@ class EmailField(CharField): ...
167167
class URLField(CharField): ...
168168

169169
class TextField(Field):
170-
def __set__(self, instance, value: str) -> None: ...
170+
def __set__(self, instance, value: Union[str, Combinable]) -> None: ...
171171
def __get__(self, instance, owner) -> str: ...
172172

173173
class BooleanField(Field):
174-
def __set__(self, instance, value: bool) -> None: ...
174+
def __set__(self, instance, value: Union[bool, Combinable]) -> None: ...
175175
def __get__(self, instance, owner) -> bool: ...
176176

177177
class NullBooleanField(Field):
178-
def __set__(self, instance, value: Optional[bool]) -> None: ...
178+
def __set__(self, instance, value: Optional[Union[bool, Combinable]]) -> None: ...
179179
def __get__(self, instance, owner) -> Optional[bool]: ...
180180

181181
class IPAddressField(Field):

django-stubs/views/debug.pyi

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from importlib.abc import SourceLoader
2-
from typing import Any, Callable, Dict, List, Optional, Type, Union
32
from types import TracebackType
3+
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Type, Union
44

5-
from django.core.handlers.wsgi import WSGIRequest
6-
from django.http.request import QueryDict
5+
from django.http.request import HttpRequest, QueryDict
76
from django.http.response import Http404, HttpResponse
87
from django.utils.safestring import SafeText
98

@@ -19,21 +18,21 @@ def cleanse_setting(key: Union[int, str], value: Any) -> Any: ...
1918
def get_safe_settings() -> Dict[str, Any]: ...
2019
def technical_500_response(request: Any, exc_type: Any, exc_value: Any, tb: Any, status_code: int = ...): ...
2120
def get_default_exception_reporter_filter() -> ExceptionReporterFilter: ...
22-
def get_exception_reporter_filter(request: Optional[WSGIRequest]) -> ExceptionReporterFilter: ...
21+
def get_exception_reporter_filter(request: Optional[HttpRequest]) -> ExceptionReporterFilter: ...
2322

2423
class ExceptionReporterFilter:
2524
def get_post_parameters(self, request: Any): ...
2625
def get_traceback_frame_variables(self, request: Any, tb_frame: Any): ...
2726

2827
class SafeExceptionReporterFilter(ExceptionReporterFilter):
29-
def is_active(self, request: Optional[WSGIRequest]) -> bool: ...
30-
def get_cleansed_multivaluedict(self, request: WSGIRequest, multivaluedict: QueryDict) -> QueryDict: ...
31-
def get_post_parameters(self, request: Optional[WSGIRequest]) -> Union[Dict[Any, Any], QueryDict]: ...
32-
def cleanse_special_types(self, request: Optional[WSGIRequest], value: Any) -> Any: ...
28+
def is_active(self, request: Optional[HttpRequest]) -> bool: ...
29+
def get_cleansed_multivaluedict(self, request: HttpRequest, multivaluedict: QueryDict) -> QueryDict: ...
30+
def get_post_parameters(self, request: Optional[HttpRequest]) -> MutableMapping[str, Any]: ...
31+
def cleanse_special_types(self, request: Optional[HttpRequest], value: Any) -> Any: ...
3332
def get_traceback_frame_variables(self, request: Any, tb_frame: Any): ...
3433

3534
class ExceptionReporter:
36-
request: Optional[WSGIRequest] = ...
35+
request: Optional[HttpRequest] = ...
3736
filter: ExceptionReporterFilter = ...
3837
exc_type: None = ...
3938
exc_value: Optional[str] = ...
@@ -44,7 +43,7 @@ class ExceptionReporter:
4443
postmortem: None = ...
4544
def __init__(
4645
self,
47-
request: Optional[WSGIRequest],
46+
request: Optional[HttpRequest],
4847
exc_type: Optional[Type[BaseException]],
4948
exc_value: Optional[Union[str, BaseException]],
5049
tb: Optional[TracebackType],
@@ -63,5 +62,5 @@ class ExceptionReporter:
6362
module_name: Optional[str] = None,
6463
): ...
6564

66-
def technical_404_response(request: WSGIRequest, exception: Http404) -> HttpResponse: ...
67-
def default_urlconf(request: WSGIRequest) -> HttpResponse: ...
65+
def technical_404_response(request: HttpRequest, exception: Http404) -> HttpResponse: ...
66+
def default_urlconf(request: HttpRequest) -> HttpResponse: ...

mypy_django_plugin/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from configparser import ConfigParser
2+
from typing import Optional
3+
4+
from dataclasses import dataclass
5+
6+
7+
@dataclass
8+
class Config:
9+
django_settings_module: Optional[str] = None
10+
ignore_missing_settings: bool = False
11+
12+
@classmethod
13+
def from_config_file(self, fpath: str) -> 'Config':
14+
ini_config = ConfigParser()
15+
ini_config.read(fpath)
16+
if not ini_config.has_section('mypy_django_plugin'):
17+
raise ValueError('Invalid config file: no [mypy_django_plugin] section')
18+
return Config(django_settings_module=ini_config.get('mypy_django_plugin', 'django_settings',
19+
fallback=None),
20+
ignore_missing_settings=ini_config.get('mypy_django_plugin', 'ignore_missing_settings',
21+
fallback=False))

mypy_django_plugin/helpers.py

Lines changed: 2 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import typing
22
from typing import Dict, Optional
33

4-
from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var, AssignmentStmt, \
5-
CallExpr
4+
from mypy.nodes import Expression, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo
65
from mypy.plugin import FunctionContext
7-
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType
6+
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType
87

98
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
109
FIELD_FULLNAME = 'django.db.models.fields.Field'
@@ -119,74 +118,6 @@ def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance:
119118
return reparametrize_with(type_to_fill, typevar_values)
120119

121120

122-
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
123-
if tp.type.has_base(FIELD_FULLNAME):
124-
set_method = tp.type.get_method('__set__')
125-
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
126-
if 'value' in set_method.type.arg_names:
127-
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
128-
if isinstance(set_value_type, Instance):
129-
set_value_type = fill_typevars(tp, set_value_type)
130-
return set_value_type
131-
elif isinstance(set_value_type, UnionType):
132-
items_no_typevars = []
133-
for item in set_value_type.items:
134-
if isinstance(item, Instance):
135-
item = fill_typevars(tp, item)
136-
items_no_typevars.append(item)
137-
return UnionType(items_no_typevars)
138-
139-
get_method = tp.type.get_method('__get__')
140-
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
141-
return get_method.type.ret_type
142-
# GenericForeignKey
143-
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
144-
return AnyType(TypeOfAny.special_form)
145-
return None
146-
147-
148-
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
149-
# only primary keys defined in current class for now
150-
for stmt in model.defn.defs.body:
151-
if isinstance(stmt, AssignmentStmt) and isinstance(stmt.rvalue, CallExpr):
152-
name_expr = stmt.lvalues[0]
153-
if isinstance(name_expr, NameExpr):
154-
name = name_expr.name
155-
if 'primary_key' in stmt.rvalue.arg_names:
156-
is_primary_key = stmt.rvalue.args[stmt.rvalue.arg_names.index('primary_key')]
157-
if is_primary_key:
158-
return extract_field_setter_type(model.names[name].type)
159-
return None
160-
161-
162-
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
163-
expected_types: Dict[str, Type] = {}
164-
165-
primary_key_type = extract_primary_key_type(model)
166-
if not primary_key_type:
167-
# no explicit primary key, set pk to Any and add id
168-
primary_key_type = AnyType(TypeOfAny.special_form)
169-
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
170-
171-
expected_types['pk'] = primary_key_type
172-
173-
for base in model.mro:
174-
for name, sym in base.names.items():
175-
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
176-
tp = sym.node.type
177-
field_type = extract_field_setter_type(tp)
178-
if tp.type.fullname() in {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME}:
179-
ref_to_model = tp.args[0]
180-
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME):
181-
primary_key_type = extract_primary_key_type(ref_to_model.type)
182-
if not primary_key_type:
183-
primary_key_type = AnyType(TypeOfAny.special_form)
184-
expected_types[name + '_id'] = primary_key_type
185-
if field_type:
186-
expected_types[name] = field_type
187-
return expected_types
188-
189-
190121
def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]:
191122
"""Return the expression for the specific argument.
192123

0 commit comments

Comments
 (0)