Skip to content

Commit 304ac22

Browse files
authored
[ty] Use constructor parameter types as type context (#21054)
## Summary Resolves astral-sh/ty#1408.
1 parent c3de884 commit 304ac22

File tree

4 files changed

+161
-24
lines changed

4 files changed

+161
-24
lines changed

crates/ty_python_semantic/resources/mdtest/bidirectional.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,84 @@ def h[T](x: T, cond: bool) -> T | list[T]:
145145
def i[T](x: T, cond: bool) -> T | list[T]:
146146
return x if cond else [x]
147147
```
148+
149+
## Type context sources
150+
151+
Type context is sourced from various places, including annotated assignments:
152+
153+
```py
154+
from typing import Literal
155+
156+
a: list[Literal[1]] = [1]
157+
```
158+
159+
Function parameter annotations:
160+
161+
```py
162+
def b(x: list[Literal[1]]): ...
163+
164+
b([1])
165+
```
166+
167+
Bound method parameter annotations:
168+
169+
```py
170+
class C:
171+
def __init__(self, x: list[Literal[1]]): ...
172+
def foo(self, x: list[Literal[1]]): ...
173+
174+
C([1]).foo([1])
175+
```
176+
177+
Declared variable types:
178+
179+
```py
180+
d: list[Literal[1]]
181+
d = [1]
182+
```
183+
184+
Declared attribute types:
185+
186+
```py
187+
class E:
188+
e: list[Literal[1]]
189+
190+
def _(e: E):
191+
# TODO: Implement attribute type context.
192+
# error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to attribute `e` of type `list[Literal[1]]`"
193+
e.e = [1]
194+
```
195+
196+
Function return types:
197+
198+
```py
199+
def f() -> list[Literal[1]]:
200+
return [1]
201+
```
202+
203+
## Class constructor parameters
204+
205+
```toml
206+
[environment]
207+
python-version = "3.12"
208+
```
209+
210+
The parameters of both `__init__` and `__new__` are used as type context sources for constructor
211+
calls:
212+
213+
```py
214+
def f[T](x: T) -> list[T]:
215+
return [x]
216+
217+
class A:
218+
def __new__(cls, value: list[int | str]):
219+
return super().__new__(cls, value)
220+
221+
def __init__(self, value: list[int | None]): ...
222+
223+
A(f(1))
224+
225+
# error: [invalid-argument-type] "Argument to function `__new__` is incorrect: Expected `list[int | str]`, found `list[list[Unknown]]`"
226+
# error: [invalid-argument-type] "Argument to bound method `__init__` is incorrect: Expected `list[int | None]`, found `list[list[Unknown]]`"
227+
A(f([]))
228+
```

crates/ty_python_semantic/src/types.rs

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6007,6 +6007,9 @@ impl<'db> Type<'db> {
60076007
/// Given a class literal or non-dynamic `SubclassOf` type, try calling it (creating an instance)
60086008
/// and return the resulting instance type.
60096009
///
6010+
/// The `infer_argument_types` closure should be invoked with the signatures of `__new__` and
6011+
/// `__init__`, such that the argument types can be inferred with the correct type context.
6012+
///
60106013
/// Models `type.__call__` behavior.
60116014
/// TODO: model metaclass `__call__`.
60126015
///
@@ -6017,10 +6020,10 @@ impl<'db> Type<'db> {
60176020
///
60186021
/// Foo()
60196022
/// ```
6020-
fn try_call_constructor(
6023+
fn try_call_constructor<'ast>(
60216024
self,
60226025
db: &'db dyn Db,
6023-
argument_types: CallArguments<'_, 'db>,
6026+
infer_argument_types: impl FnOnce(Option<Bindings<'db>>) -> CallArguments<'ast, 'db>,
60246027
tcx: TypeContext<'db>,
60256028
) -> Result<Type<'db>, ConstructorCallError<'db>> {
60266029
debug_assert!(matches!(
@@ -6076,11 +6079,63 @@ impl<'db> Type<'db> {
60766079
// easy to check if that's the one we found?
60776080
// Note that `__new__` is a static method, so we must inject the `cls` argument.
60786081
let new_method = self_type.lookup_dunder_new(db, ());
6082+
6083+
// Construct an instance type that we can use to look up the `__init__` instance method.
6084+
// This performs the same logic as `Type::to_instance`, except for generic class literals.
6085+
// TODO: we should use the actual return type of `__new__` to determine the instance type
6086+
let init_ty = self_type
6087+
.to_instance(db)
6088+
.expect("type should be convertible to instance type");
6089+
6090+
// Lookup the `__init__` instance method in the MRO.
6091+
let init_method = init_ty.member_lookup_with_policy(
6092+
db,
6093+
"__init__".into(),
6094+
MemberLookupPolicy::NO_INSTANCE_FALLBACK | MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK,
6095+
);
6096+
6097+
// Infer the call argument types, using both `__new__` and `__init__` for type-context.
6098+
let bindings = match (
6099+
new_method.as_ref().map(|method| &method.place),
6100+
&init_method.place,
6101+
) {
6102+
(Some(Place::Defined(new_method, ..)), Place::Undefined) => Some(
6103+
new_method
6104+
.bindings(db)
6105+
.map(|binding| binding.with_bound_type(self_type)),
6106+
),
6107+
6108+
(Some(Place::Undefined) | None, Place::Defined(init_method, ..)) => {
6109+
Some(init_method.bindings(db))
6110+
}
6111+
6112+
(Some(Place::Defined(new_method, ..)), Place::Defined(init_method, ..)) => {
6113+
let callable = UnionBuilder::new(db)
6114+
.add(*new_method)
6115+
.add(*init_method)
6116+
.build();
6117+
6118+
let new_method_bindings = new_method
6119+
.bindings(db)
6120+
.map(|binding| binding.with_bound_type(self_type));
6121+
6122+
Some(Bindings::from_union(
6123+
callable,
6124+
[new_method_bindings, init_method.bindings(db)],
6125+
))
6126+
}
6127+
6128+
_ => None,
6129+
};
6130+
6131+
let argument_types = infer_argument_types(bindings);
6132+
60796133
let new_call_outcome = new_method.and_then(|new_method| {
60806134
match new_method.place.try_call_dunder_get(db, self_type) {
60816135
Place::Defined(new_method, _, boundness) => {
60826136
let result =
60836137
new_method.try_call(db, argument_types.with_self(Some(self_type)).as_ref());
6138+
60846139
if boundness == Definedness::PossiblyUndefined {
60856140
Some(Err(DunderNewCallError::PossiblyUnbound(result.err())))
60866141
} else {
@@ -6091,24 +6146,7 @@ impl<'db> Type<'db> {
60916146
}
60926147
});
60936148

6094-
// Construct an instance type that we can use to look up the `__init__` instance method.
6095-
// This performs the same logic as `Type::to_instance`, except for generic class literals.
6096-
// TODO: we should use the actual return type of `__new__` to determine the instance type
6097-
let init_ty = self_type
6098-
.to_instance(db)
6099-
.expect("type should be convertible to instance type");
6100-
6101-
let init_call_outcome = if new_call_outcome.is_none()
6102-
|| !init_ty
6103-
.member_lookup_with_policy(
6104-
db,
6105-
"__init__".into(),
6106-
MemberLookupPolicy::NO_INSTANCE_FALLBACK
6107-
| MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK,
6108-
)
6109-
.place
6110-
.is_undefined()
6111-
{
6149+
let init_call_outcome = if new_call_outcome.is_none() || !init_method.is_undefined() {
61126150
Some(init_ty.try_call_dunder(db, "__init__", argument_types, tcx))
61136151
} else {
61146152
None

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ impl<'db> Bindings<'db> {
100100
self.elements.iter()
101101
}
102102

103+
pub(crate) fn map(self, f: impl Fn(CallableBinding<'db>) -> CallableBinding<'db>) -> Self {
104+
Self {
105+
callable_type: self.callable_type,
106+
argument_forms: self.argument_forms,
107+
elements: self.elements.into_iter().map(f).collect(),
108+
}
109+
}
110+
103111
/// Match the arguments of a call site against the parameters of a collection of possibly
104112
/// unioned, possibly overloaded signatures.
105113
///

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6798,9 +6798,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
67986798
.to_class_type(self.db())
67996799
.is_none_or(|enum_class| !class.is_subclass_of(self.db(), enum_class))
68006800
{
6801-
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
6802-
self.infer_argument_types(arguments, &mut call_arguments, &argument_forms);
6803-
68046801
if matches!(
68056802
class.known(self.db()),
68066803
Some(KnownClass::TypeVar | KnownClass::ExtensionsTypeVar)
@@ -6819,8 +6816,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
68196816
}
68206817
}
68216818

6819+
let db = self.db();
6820+
let infer_call_arguments = |bindings: Option<Bindings<'db>>| {
6821+
if let Some(bindings) = bindings {
6822+
let bindings = bindings.match_parameters(self.db(), &call_arguments);
6823+
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
6824+
} else {
6825+
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
6826+
self.infer_argument_types(arguments, &mut call_arguments, &argument_forms);
6827+
}
6828+
6829+
call_arguments
6830+
};
6831+
68226832
return callable_type
6823-
.try_call_constructor(self.db(), call_arguments, tcx)
6833+
.try_call_constructor(db, infer_call_arguments, tcx)
68246834
.unwrap_or_else(|err| {
68256835
err.report_diagnostic(&self.context, callable_type, call_expression.into());
68266836
err.return_type()

0 commit comments

Comments
 (0)