Skip to content

Commit a32d5b8

Browse files
authored
[ty] Improve exhaustiveness analysis for type variables with bounds or constraints (#21172)
1 parent 6337e22 commit a32d5b8

File tree

2 files changed

+85
-31
lines changed

2 files changed

+85
-31
lines changed

crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,55 @@ class Answer(Enum):
417417
case Answer.NO:
418418
return False
419419
```
420+
421+
## Exhaustiveness checking for type variables with bounds or constraints
422+
423+
```toml
424+
[environment]
425+
python-version = "3.12"
426+
```
427+
428+
```py
429+
from typing import assert_never, Literal
430+
431+
def f[T: bool](x: T) -> T:
432+
match x:
433+
case True:
434+
return x
435+
case False:
436+
return x
437+
case _:
438+
reveal_type(x) # revealed: Never
439+
assert_never(x)
440+
441+
def g[T: Literal["foo", "bar"]](x: T) -> T:
442+
match x:
443+
case "foo":
444+
return x
445+
case "bar":
446+
return x
447+
case _:
448+
reveal_type(x) # revealed: Never
449+
assert_never(x)
450+
451+
def h[T: int | str](x: T) -> T:
452+
if isinstance(x, int):
453+
return x
454+
elif isinstance(x, str):
455+
return x
456+
else:
457+
reveal_type(x) # revealed: Never
458+
assert_never(x)
459+
460+
def i[T: (int, str)](x: T) -> T:
461+
match x:
462+
case int():
463+
pass
464+
case str():
465+
pass
466+
case _:
467+
reveal_type(x) # revealed: Never
468+
assert_never(x)
469+
470+
return x
471+
```

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -781,37 +781,6 @@ impl<'db> IntersectionBuilder<'db> {
781781
seen_aliases,
782782
)
783783
}
784-
Type::EnumLiteral(enum_literal) => {
785-
let enum_class = enum_literal.enum_class(self.db);
786-
let metadata =
787-
enum_metadata(self.db, enum_class).expect("Class of enum literal is an enum");
788-
789-
let enum_members_in_negative_part = self
790-
.intersections
791-
.iter()
792-
.flat_map(|intersection| &intersection.negative)
793-
.filter_map(|ty| ty.as_enum_literal())
794-
.filter(|lit| lit.enum_class(self.db) == enum_class)
795-
.map(|lit| lit.name(self.db))
796-
.chain(std::iter::once(enum_literal.name(self.db)))
797-
.collect::<FxHashSet<_>>();
798-
799-
let all_members_are_in_negative_part = metadata
800-
.members
801-
.keys()
802-
.all(|name| enum_members_in_negative_part.contains(name));
803-
804-
if all_members_are_in_negative_part {
805-
for inner in &mut self.intersections {
806-
inner.add_negative(self.db, enum_literal.enum_class_instance(self.db));
807-
}
808-
} else {
809-
for inner in &mut self.intersections {
810-
inner.add_negative(self.db, ty);
811-
}
812-
}
813-
self
814-
}
815784
_ => {
816785
for inner in &mut self.intersections {
817786
inner.add_negative(self.db, ty);
@@ -1177,6 +1146,39 @@ impl<'db> InnerIntersectionBuilder<'db> {
11771146

11781147
fn build(mut self, db: &'db dyn Db) -> Type<'db> {
11791148
self.simplify_constrained_typevars(db);
1149+
1150+
// If any typevars are in `self.positive`, speculatively solve all bounded type variables
1151+
// to their upper bound and all constrained type variables to the union of their constraints.
1152+
// If that speculative intersection simplifies to `Never`, this intersection must also simplify
1153+
// to `Never`.
1154+
if self.positive.iter().any(|ty| ty.is_type_var()) {
1155+
let mut speculative = IntersectionBuilder::new(db);
1156+
for pos in &self.positive {
1157+
match pos {
1158+
Type::TypeVar(type_var) => {
1159+
match type_var.typevar(db).bound_or_constraints(db) {
1160+
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
1161+
speculative = speculative.add_positive(bound);
1162+
}
1163+
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
1164+
speculative = speculative.add_positive(Type::Union(constraints));
1165+
}
1166+
// TypeVars without a bound or constraint implicitly have `object` as their
1167+
// upper bound, and it is always a no-op to add `object` to an intersection.
1168+
None => {}
1169+
}
1170+
}
1171+
_ => speculative = speculative.add_positive(*pos),
1172+
}
1173+
}
1174+
for neg in &self.negative {
1175+
speculative = speculative.add_negative(*neg);
1176+
}
1177+
if speculative.build().is_never() {
1178+
return Type::Never;
1179+
}
1180+
}
1181+
11801182
match (self.positive.len(), self.negative.len()) {
11811183
(0, 0) => Type::object(),
11821184
(1, 0) => self.positive[0],

0 commit comments

Comments
 (0)