Skip to content

Commit e967176

Browse files
committed
improve generic call expression inference
1 parent 2aa777e commit e967176

File tree

10 files changed

+564
-141
lines changed

10 files changed

+564
-141
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 154 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -427,14 +427,13 @@ a = f("a")
427427
reveal_type(a) # revealed: list[Literal["a"]]
428428

429429
b: list[int | Literal["a"]] = f("a")
430-
reveal_type(b) # revealed: list[Literal["a"] | int]
430+
reveal_type(b) # revealed: list[int | Literal["a"]]
431431

432432
c: list[int | str] = f("a")
433-
reveal_type(c) # revealed: list[str | int]
433+
reveal_type(c) # revealed: list[int | str]
434434

435435
d: list[int | tuple[int, int]] = f((1, 2))
436-
# TODO: We could avoid reordering the union elements here.
437-
reveal_type(d) # revealed: list[tuple[int, int] | int]
436+
reveal_type(d) # revealed: list[int | tuple[int, int]]
438437

439438
e: list[int] = f(True)
440439
reveal_type(e) # revealed: list[int]
@@ -455,7 +454,120 @@ j: int | str = f2(True)
455454
reveal_type(j) # revealed: Literal[True]
456455
```
457456

458-
Types are not widened unnecessarily:
457+
The function arguments are also inferred using the type context:
458+
459+
```py
460+
from typing import TypedDict
461+
462+
class TD(TypedDict):
463+
x: int
464+
465+
def f[T](x: list[T]) -> T:
466+
return x[0]
467+
468+
a: TD = f([{"x": 0}, {"x": 1}])
469+
reveal_type(a) # revealed: TD
470+
471+
b: TD | None = f([{"x": 0}, {"x": 1}])
472+
reveal_type(b) # revealed: TD
473+
474+
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
475+
# error: [invalid-key] "Invalid key for TypedDict `TD`: Unknown key "y""
476+
# error: [invalid-assignment] "Object of type `Unknown | dict[Unknown | str, Unknown | int]` is not assignable to `TD`"
477+
c: TD = f([{"y": 0}, {"x": 1}])
478+
479+
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
480+
# error: [invalid-key] "Invalid key for TypedDict `TD`: Unknown key "y""
481+
# error: [invalid-assignment] "Object of type `Unknown | dict[Unknown | str, Unknown | int]` is not assignable to `TD | None`"
482+
c: TD | None = f([{"y": 0}, {"x": 1}])
483+
```
484+
485+
## Prefer the declared type of generic classes
486+
487+
```toml
488+
[environment]
489+
python-version = "3.12"
490+
```
491+
492+
```py
493+
from typing import Any
494+
495+
def f[T](x: T) -> list[T]:
496+
return [x]
497+
498+
def f2[T](x: T) -> list[T] | None:
499+
return [x]
500+
501+
def f3[T](x: T) -> list[T] | dict[T, T]:
502+
return [x]
503+
504+
a = f(1)
505+
reveal_type(a) # revealed: list[Literal[1]]
506+
507+
b: list[Any] = f(1)
508+
reveal_type(b) # revealed: list[Any]
509+
510+
c: list[Any] = [1]
511+
reveal_type(c) # revealed: list[Any]
512+
513+
d: list[Any] | None = f(1)
514+
reveal_type(d) # revealed: list[Any]
515+
516+
e: list[Any] | None = [1]
517+
reveal_type(e) # revealed: list[Any]
518+
519+
f: list[Any] | None = f2(1)
520+
# TODO: Better constraint solver.
521+
reveal_type(f) # revealed: list[Literal[1]] | None
522+
523+
g: list[Any] | dict[Any, Any] = f3(1)
524+
# TODO: Better constraint solver.
525+
reveal_type(g) # revealed: list[Literal[1]] | dict[Literal[1], Literal[1]]
526+
```
527+
528+
## Narrow generic unions
529+
530+
```toml
531+
[environment]
532+
python-version = "3.12"
533+
```
534+
535+
```py
536+
from typing import reveal_type, TypedDict
537+
538+
def id[T](x: T) -> T:
539+
return x
540+
541+
def _(narrow: dict[str, str], target: list[str] | dict[str, str] | None):
542+
target = id(narrow)
543+
reveal_type(target) # revealed: dict[str, str]
544+
545+
def _(narrow: list[str], target: list[str] | dict[str, str] | None):
546+
target = id(narrow)
547+
reveal_type(target) # revealed: list[str]
548+
549+
def _(narrow: list[str] | dict[str, str], target: list[str] | dict[str, str] | None):
550+
target = id(narrow)
551+
reveal_type(target) # revealed: list[str] | dict[str, str]
552+
553+
class TD(TypedDict):
554+
x: int
555+
556+
def _(target: list[TD] | dict[str, TD] | None):
557+
target = id([{"x": 1}])
558+
reveal_type(target) # revealed: list[TD]
559+
560+
def _(target: list[TD] | dict[str, TD] | None):
561+
target = id({"x": {"x": 1}})
562+
reveal_type(target) # revealed: dict[str, TD]
563+
```
564+
565+
## Prefer the inferred type of non-generic classes
566+
567+
```toml
568+
[environment]
569+
python-version = "3.12"
570+
```
459571

460572
```py
461573
def id[T](x: T) -> T:
@@ -476,10 +588,8 @@ def _(i: int):
476588
b: list[int | None] | None = id([i])
477589
c: list[int | None] | int | None = id([i])
478590
reveal_type(a) # revealed: list[int | None]
479-
# TODO: these should reveal `list[int | None]`
480-
# we currently do not use the call expression annotation as type context for argument inference
481-
reveal_type(b) # revealed: list[Unknown | int]
482-
reveal_type(c) # revealed: list[Unknown | int]
591+
reveal_type(b) # revealed: list[int | None]
592+
reveal_type(c) # revealed: list[int | None]
483593

484594
a: list[int | None] | None = [i]
485595
b: list[int | None] | None = lst(i)
@@ -494,4 +604,39 @@ def _(i: int):
494604
reveal_type(a) # revealed: list[Unknown]
495605
reveal_type(b) # revealed: list[Unknown]
496606
reveal_type(c) # revealed: list[Unknown]
607+
608+
def f[T](x: list[T]) -> T:
609+
return x[0]
610+
611+
def _(a: int, b: str, c: int | str):
612+
x1: int = f(lst(a))
613+
reveal_type(x1) # revealed: int
614+
615+
x2: int | str = f(lst(a))
616+
reveal_type(x2) # revealed: int
617+
618+
x3: int | None = f(lst(a))
619+
reveal_type(x3) # revealed: int
620+
621+
x4: str = f(lst(b))
622+
reveal_type(x4) # revealed: str
623+
624+
x5: int | str = f(lst(b))
625+
reveal_type(x5) # revealed: str
626+
627+
x6: str | None = f(lst(b))
628+
reveal_type(x6) # revealed: str
629+
630+
x7: int | str = f(lst(c))
631+
reveal_type(x7) # revealed: int | str
632+
633+
x8: int | str = f(lst(c))
634+
reveal_type(x8) # revealed: int | str
635+
636+
# TODO: Ideally this would reveal `int | str`. This is a known limitation of our
637+
# call inference solver, and would # require an extra inference attempt without type
638+
# context, or with type context # of subsets of the union, both of which are impractical
639+
# for performance reasons.
640+
x9: int | str | None = f(lst(c))
641+
reveal_type(x9) # revealed: int | str | None
497642
```

crates/ty_python_semantic/resources/mdtest/bidirectional.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def _(l: list[int] | None = None):
5050
def f[T](x: T, cond: bool) -> T | list[T]:
5151
return x if cond else [x]
5252

53-
# TODO: no error
54-
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
53+
# TODO: Better constraint solver.
54+
# error: [invalid-assignment]
5555
l5: int | list[int] = f(1, True)
5656
```
5757

crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class Data:
3737
content: list[int] = field(default_factory=list)
3838
timestamp: datetime = field(default_factory=datetime.now, init=False)
3939

40-
# revealed: (self: Data, content: list[int] = Unknown) -> None
40+
# revealed: (self: Data, content: list[int] = list[int]) -> None
4141
reveal_type(Data.__init__)
4242

4343
data = Data([1, 2, 3])
@@ -63,7 +63,6 @@ class Person:
6363
age: int | None = field(default=None, kw_only=True)
6464
role: str = field(default="user", kw_only=True)
6565

66-
# TODO: this would ideally show a default value of `None` for `age`
6766
# revealed: (self: Person, name: str, *, age: int | None = None, role: str = Literal["user"]) -> None
6867
reveal_type(Person.__init__)
6968

crates/ty_python_semantic/src/types.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -883,21 +883,31 @@ impl<'db> Type<'db> {
883883
_ => false,
884884
}
885885
}
886-
887886
// If the type is a specialized instance of the given `KnownClass`, returns the specialization.
888887
pub(crate) fn known_specialization(
889888
&self,
890889
db: &'db dyn Db,
891890
known_class: KnownClass,
892891
) -> Option<Specialization<'db>> {
893892
let class_literal = known_class.try_to_class_literal(db)?;
894-
self.specialization_of(db, Some(class_literal))
893+
self.specialization_of(db, class_literal)
894+
}
895+
896+
// If this type is a class instance, returns its specialization.
897+
pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option<Specialization<'db>> {
898+
self.specialization_of_optional(db, None)
895899
}
896900

897901
// If the type is a specialized instance of the given class, returns the specialization.
898-
//
899-
// If no class is provided, returns the specialization of any class instance.
900902
pub(crate) fn specialization_of(
903+
self,
904+
db: &'db dyn Db,
905+
expected_class: ClassLiteral<'_>,
906+
) -> Option<Specialization<'db>> {
907+
self.specialization_of_optional(db, Some(expected_class))
908+
}
909+
910+
fn specialization_of_optional(
901911
self,
902912
db: &'db dyn Db,
903913
expected_class: Option<ClassLiteral<'_>>,
@@ -5551,7 +5561,7 @@ impl<'db> Type<'db> {
55515561
) -> Result<Bindings<'db>, CallError<'db>> {
55525562
self.bindings(db)
55535563
.match_parameters(db, argument_types)
5554-
.check_types(db, argument_types, &TypeContext::default(), &[])
5564+
.check_types(db, argument_types, TypeContext::default(), &[])
55555565
}
55565566

55575567
/// Look up a dunder method on the meta-type of `self` and call it.
@@ -5603,7 +5613,8 @@ impl<'db> Type<'db> {
56035613
let bindings = dunder_callable
56045614
.bindings(db)
56055615
.match_parameters(db, argument_types)
5606-
.check_types(db, argument_types, &tcx, &[])?;
5616+
.check_types(db, argument_types, tcx, &[])?;
5617+
56075618
if boundness == Definedness::PossiblyUndefined {
56085619
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
56095620
}

0 commit comments

Comments
 (0)