Skip to content

Commit c563650

Browse files
committed
prefer declared type of generic classes
1 parent 304ac22 commit c563650

File tree

7 files changed

+158
-39
lines changed

7 files changed

+158
-39
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -427,14 +427,13 @@ a = f("a")
427427
reveal_type(a) # revealed: list[Literal["a"]]
428428

429429
b: list[int | Literal["a"]] = f("a")
430-
reveal_type(b) # revealed: list[Literal["a"] | int]
430+
reveal_type(b) # revealed: list[int | Literal["a"]]
431431

432432
c: list[int | str] = f("a")
433-
reveal_type(c) # revealed: list[str | int]
433+
reveal_type(c) # revealed: list[int | str]
434434

435435
d: list[int | tuple[int, int]] = f((1, 2))
436-
# TODO: We could avoid reordering the union elements here.
437-
reveal_type(d) # revealed: list[tuple[int, int] | int]
436+
reveal_type(d) # revealed: list[int | tuple[int, int]]
438437

439438
e: list[int] = f(True)
440439
reveal_type(e) # revealed: list[int]
@@ -455,7 +454,54 @@ j: int | str = f2(True)
455454
reveal_type(j) # revealed: Literal[True]
456455
```
457456

458-
Types are not widened unnecessarily:
457+
## Prefer the declared type of generic classes
458+
459+
```toml
460+
[environment]
461+
python-version = "3.12"
462+
```
463+
464+
```py
465+
from typing import Any
466+
467+
def f[T](x: T) -> list[T]:
468+
return [x]
469+
470+
def f2[T](x: T) -> list[T] | None:
471+
return [x]
472+
473+
def f3[T](x: T) -> list[T] | dict[T, T]:
474+
return [x]
475+
476+
a = f(1)
477+
reveal_type(a) # revealed: list[Literal[1]]
478+
479+
b: list[Any] = f(1)
480+
reveal_type(b) # revealed: list[Any]
481+
482+
c: list[Any] = [1]
483+
reveal_type(c) # revealed: list[Any]
484+
485+
d: list[Any] | None = f(1)
486+
reveal_type(d) # revealed: list[Any]
487+
488+
e: list[Any] | None = [1]
489+
reveal_type(e) # revealed: list[Any]
490+
491+
f: list[Any] | None = f2(1)
492+
reveal_type(f) # revealed: list[Any] | None
493+
494+
g: list[Any] | dict[Any, Any] = f3(1)
495+
# TODO: Better constraint solver.
496+
reveal_type(g) # revealed: list[Literal[1]] | dict[Literal[1], Literal[1]]
497+
```
498+
499+
## Prefer the inferred type of non-generic classes
500+
501+
```toml
502+
[environment]
503+
python-version = "3.12"
504+
```
459505

460506
```py
461507
def id[T](x: T) -> T:

crates/ty_python_semantic/resources/mdtest/bidirectional.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ def _(l: list[int] | None = None):
5050
def f[T](x: T, cond: bool) -> T | list[T]:
5151
return x if cond else [x]
5252

53-
# TODO: no error
54-
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
5553
l5: int | list[int] = f(1, True)
5654
```
5755

crates/ty_python_semantic/src/types.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -886,13 +886,24 @@ impl<'db> Type<'db> {
886886
known_class: KnownClass,
887887
) -> Option<Specialization<'db>> {
888888
let class_literal = known_class.try_to_class_literal(db)?;
889-
self.specialization_of(db, Some(class_literal))
889+
self.specialization_of(db, class_literal)
890+
}
891+
892+
// If this type is a class instance, returns its specialization.
893+
pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option<Specialization<'db>> {
894+
self.specialization_of_optional(db, None)
890895
}
891896

892897
// If the type is a specialized instance of the given class, returns the specialization.
893-
//
894-
// If no class is provided, returns the specialization of any class instance.
895898
pub(crate) fn specialization_of(
899+
self,
900+
db: &'db dyn Db,
901+
expected_class: ClassLiteral<'_>,
902+
) -> Option<Specialization<'db>> {
903+
self.specialization_of_optional(db, Some(expected_class))
904+
}
905+
906+
fn specialization_of_optional(
896907
self,
897908
db: &'db dyn Db,
898909
expected_class: Option<ClassLiteral<'_>>,
@@ -1209,22 +1220,28 @@ impl<'db> Type<'db> {
12091220

12101221
/// If the type is a union, filters union elements based on the provided predicate.
12111222
///
1212-
/// Otherwise, returns the type unchanged.
1223+
/// Otherwise, considers the type to be the sole inhabitant of a single-valued union,
1224+
/// and filters it, returning `Never` if the predicate returns `false`, or the type
1225+
/// unchanged if `true`.
12131226
pub(crate) fn filter_union(
12141227
self,
12151228
db: &'db dyn Db,
1216-
f: impl FnMut(&Type<'db>) -> bool,
1229+
mut f: impl FnMut(&Type<'db>) -> bool,
12171230
) -> Type<'db> {
12181231
if let Type::Union(union) = self {
12191232
union.filter(db, f)
1220-
} else {
1233+
} else if f(&self) {
12211234
self
1235+
} else {
1236+
Type::Never
12221237
}
12231238
}
12241239

12251240
/// If the type is a union, removes union elements that are disjoint from `target`.
12261241
///
1227-
/// Otherwise, returns the type unchanged.
1242+
/// Otherwise, considers the type to be the sole inhabitant of a single-valued union,
1243+
/// and filters it, returning `Never` if it is disjoint from `target`, or the type
1244+
/// unchanged if `true`.
12281245
pub(crate) fn filter_disjoint_elements(
12291246
self,
12301247
db: &'db dyn Db,

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ use crate::types::generics::{
3333
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
3434
use crate::types::tuple::{TupleLength, TupleType};
3535
use crate::types::{
36-
BoundMethodType, ClassLiteral, DataclassFlags, DataclassParams, FieldInstance,
37-
KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType,
38-
SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext, UnionBuilder, UnionType,
39-
WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, todo_type,
36+
BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams,
37+
FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
38+
PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext,
39+
UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, infer_isolated_expression,
40+
todo_type,
4041
};
4142
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
4243
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
@@ -2627,9 +2628,25 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26272628
return;
26282629
};
26292630

2631+
let return_with_tcx = self
2632+
.signature
2633+
.return_ty
2634+
.zip(self.call_expression_tcx.annotation);
2635+
26302636
self.inferable_typevars = generic_context.inferable_typevars(self.db);
26312637
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
26322638

2639+
// Prefer the declared type of generic classes.
2640+
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
2641+
let preferred_return_ty =
2642+
tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
2643+
let return_ty =
2644+
return_ty.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
2645+
2646+
builder.infer(return_ty, preferred_return_ty).ok()?;
2647+
Some(builder.type_mappings().clone())
2648+
});
2649+
26332650
let parameters = self.signature.parameters();
26342651
for (argument_index, adjusted_argument_index, _, argument_type) in
26352652
self.enumerate_argument_types()
@@ -2642,9 +2659,20 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26422659
continue;
26432660
};
26442661

2645-
if let Err(error) = builder.infer(
2662+
let filter = |declared: BoundTypeVarIdentity<'_>, new: Type<'_>| {
2663+
// Avoid widening the inferred type if it is already assignable to the
2664+
// preferred declared type.
2665+
preferred_type_mappings
2666+
.as_ref()
2667+
.is_none_or(|preferred_type_mappings| {
2668+
!new.is_assignable_to(self.db, preferred_type_mappings[&declared])
2669+
})
2670+
};
2671+
2672+
if let Err(error) = builder.infer_filter(
26462673
expected_type,
26472674
variadic_argument_type.unwrap_or(argument_type),
2675+
filter,
26482676
) {
26492677
self.errors.push(BindingError::SpecializationError {
26502678
error,
@@ -2654,15 +2682,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26542682
}
26552683
}
26562684

2657-
// Build the specialization first without inferring the type context.
2685+
// Build the specialization first without inferring the complete type context.
26582686
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
26592687
let isolated_return_ty = self
26602688
.return_ty
26612689
.apply_specialization(self.db, isolated_specialization);
26622690

26632691
let mut try_infer_tcx = || {
2664-
let return_ty = self.signature.return_ty?;
2665-
let call_expression_tcx = self.call_expression_tcx.annotation?;
2692+
let (return_ty, call_expression_tcx) = return_with_tcx?;
26662693

26672694
// A type variable is not a useful type-context for expression inference, and applying it
26682695
// to the return type can lead to confusing unions in nested generic calls.
@@ -2671,7 +2698,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26712698
}
26722699

26732700
// If the return type is already assignable to the annotated type, we can ignore the
2674-
// type context and prefer the narrower inferred type.
2701+
// rest of the type context and prefer the narrower inferred type.
26752702
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
26762703
return None;
26772704
}
@@ -2680,7 +2707,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26802707
// annotated assignment, to closer match the order of any unions written in the type annotation.
26812708
builder.infer(return_ty, call_expression_tcx).ok()?;
26822709

2683-
// Otherwise, build the specialization again after inferring the type context.
2710+
// Otherwise, build the specialization again after inferring the complete type context.
26842711
let specialization = builder.build(generic_context, *self.call_expression_tcx);
26852712
let return_ty = return_ty.apply_specialization(self.db, specialization);
26862713

crates/ty_python_semantic/src/types/class.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ impl<'db> GenericAlias<'db> {
248248
) -> Self {
249249
let tcx = tcx
250250
.annotation
251-
.and_then(|ty| ty.specialization_of(db, Some(self.origin(db))))
251+
.and_then(|ty| ty.specialization_of(db, self.origin(db)))
252252
.map(|specialization| specialization.types(db))
253253
.unwrap_or(&[]);
254254

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::cell::RefCell;
2+
use std::collections::hash_map::Entry;
23
use std::fmt::Display;
34

45
use itertools::Itertools;
@@ -1318,14 +1319,19 @@ impl<'db> SpecializationBuilder<'db> {
13181319
}
13191320
}
13201321

1322+
/// Returns the current set of type mappings for this specialization.
1323+
pub(crate) fn type_mappings(&self) -> &FxHashMap<BoundTypeVarIdentity<'db>, Type<'db>> {
1324+
&self.types
1325+
}
1326+
13211327
pub(crate) fn build(
13221328
&mut self,
13231329
generic_context: GenericContext<'db>,
13241330
tcx: TypeContext<'db>,
13251331
) -> Specialization<'db> {
13261332
let tcx_specialization = tcx
13271333
.annotation
1328-
.and_then(|annotation| annotation.specialization_of(self.db, None));
1334+
.and_then(|annotation| annotation.class_specialization(self.db));
13291335

13301336
let types =
13311337
(generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| {
@@ -1348,19 +1354,43 @@ impl<'db> SpecializationBuilder<'db> {
13481354
generic_context.specialize_partial(self.db, types)
13491355
}
13501356

1351-
fn add_type_mapping(&mut self, bound_typevar: BoundTypeVarInstance<'db>, ty: Type<'db>) {
1352-
self.types
1353-
.entry(bound_typevar.identity(self.db))
1354-
.and_modify(|existing| {
1355-
*existing = UnionType::from_elements(self.db, [*existing, ty]);
1356-
})
1357-
.or_insert(ty);
1357+
fn add_type_mapping(
1358+
&mut self,
1359+
bound_typevar: BoundTypeVarInstance<'db>,
1360+
ty: Type<'db>,
1361+
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
1362+
) {
1363+
let identity = bound_typevar.identity(self.db);
1364+
match self.types.entry(identity) {
1365+
Entry::Occupied(mut entry) => {
1366+
if filter(identity, ty) {
1367+
*entry.get_mut() = UnionType::from_elements(self.db, [*entry.get(), ty]);
1368+
}
1369+
}
1370+
Entry::Vacant(entry) => {
1371+
entry.insert(ty);
1372+
}
1373+
}
13581374
}
13591375

1376+
/// Infer type mappings for the specialization based on a given type and its declared type.
13601377
pub(crate) fn infer(
13611378
&mut self,
13621379
formal: Type<'db>,
13631380
actual: Type<'db>,
1381+
) -> Result<(), SpecializationError<'db>> {
1382+
self.infer_filter(formal, actual, |_, _| true)
1383+
}
1384+
1385+
/// Infer type mappings for the specialization based on a given type and its declared type.
1386+
///
1387+
/// The filter predicate is provided with a type variable and the type being mapped to it. Type
1388+
/// mappings to which the predicate returns `false` will be ignored.
1389+
pub(crate) fn infer_filter(
1390+
&mut self,
1391+
formal: Type<'db>,
1392+
actual: Type<'db>,
1393+
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
13641394
) -> Result<(), SpecializationError<'db>> {
13651395
if formal == actual {
13661396
return Ok(());
@@ -1441,7 +1471,7 @@ impl<'db> SpecializationBuilder<'db> {
14411471
if remaining_actual.is_never() {
14421472
return Ok(());
14431473
}
1444-
self.add_type_mapping(*formal_bound_typevar, remaining_actual);
1474+
self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter);
14451475
}
14461476
(Type::Union(formal), _) => {
14471477
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not
@@ -1451,7 +1481,7 @@ impl<'db> SpecializationBuilder<'db> {
14511481
let bound_typevars =
14521482
(formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
14531483
if let Ok(bound_typevar) = bound_typevars.exactly_one() {
1454-
self.add_type_mapping(bound_typevar, actual);
1484+
self.add_type_mapping(bound_typevar, actual, filter);
14551485
}
14561486
}
14571487

@@ -1479,15 +1509,15 @@ impl<'db> SpecializationBuilder<'db> {
14791509
argument: ty,
14801510
});
14811511
}
1482-
self.add_type_mapping(bound_typevar, ty);
1512+
self.add_type_mapping(bound_typevar, ty, filter);
14831513
}
14841514
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
14851515
for constraint in constraints.elements(self.db) {
14861516
if ty
14871517
.when_assignable_to(self.db, *constraint, self.inferable)
14881518
.is_always_satisfied(self.db)
14891519
{
1490-
self.add_type_mapping(bound_typevar, *constraint);
1520+
self.add_type_mapping(bound_typevar, *constraint, filter);
14911521
return Ok(());
14921522
}
14931523
}
@@ -1497,7 +1527,7 @@ impl<'db> SpecializationBuilder<'db> {
14971527
});
14981528
}
14991529
_ => {
1500-
self.add_type_mapping(bound_typevar, ty);
1530+
self.add_type_mapping(bound_typevar, ty, filter);
15011531
}
15021532
}
15031533
}

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6250,7 +6250,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
62506250

62516251
let inferred_elt_ty = self.get_or_infer_expression(elt, elt_tcx);
62526252

6253-
// Simplify the inference based on the declared type of the element.
6253+
// Avoid widening the inferred type if it is already assignable to the preferred
6254+
// declared type.
62546255
if let Some(elt_tcx) = elt_tcx.annotation {
62556256
if inferred_elt_ty.is_assignable_to(self.db(), elt_tcx) {
62566257
continue;

0 commit comments

Comments
 (0)