Skip to content

Commit 2f4dcbf

Browse files
committed
prefer declared type of generic classes
1 parent 1d6ae85 commit 2f4dcbf

File tree

7 files changed

+160
-41
lines changed

7 files changed

+160
-41
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -427,14 +427,13 @@ a = f("a")
427427
reveal_type(a) # revealed: list[Literal["a"]]
428428

429429
b: list[int | Literal["a"]] = f("a")
430-
reveal_type(b) # revealed: list[Literal["a"] | int]
430+
reveal_type(b) # revealed: list[int | Literal["a"]]
431431

432432
c: list[int | str] = f("a")
433-
reveal_type(c) # revealed: list[str | int]
433+
reveal_type(c) # revealed: list[int | str]
434434

435435
d: list[int | tuple[int, int]] = f((1, 2))
436-
# TODO: We could avoid reordering the union elements here.
437-
reveal_type(d) # revealed: list[tuple[int, int] | int]
436+
reveal_type(d) # revealed: list[int | tuple[int, int]]
438437

439438
e: list[int] = f(True)
440439
reveal_type(e) # revealed: list[int]
@@ -455,7 +454,54 @@ j: int | str = f2(True)
455454
reveal_type(j) # revealed: Literal[True]
456455
```
457456

458-
Types are not widened unnecessarily:
457+
## Prefer the declared type of generic classes
458+
459+
```toml
460+
[environment]
461+
python-version = "3.12"
462+
```
463+
464+
```py
465+
from typing import Any
466+
467+
def f[T](x: T) -> list[T]:
468+
return [x]
469+
470+
def f2[T](x: T) -> list[T] | None:
471+
return [x]
472+
473+
def f3[T](x: T) -> list[T] | dict[T, T]:
474+
return [x]
475+
476+
a = f(1)
477+
reveal_type(a) # revealed: list[Literal[1]]
478+
479+
b: list[Any] = f(1)
480+
reveal_type(b) # revealed: list[Any]
481+
482+
c: list[Any] = [1]
483+
reveal_type(c) # revealed: list[Any]
484+
485+
d: list[Any] | None = f(1)
486+
reveal_type(d) # revealed: list[Any]
487+
488+
e: list[Any] | None = [1]
489+
reveal_type(e) # revealed: list[Any]
490+
491+
f: list[Any] | None = f2(1)
492+
reveal_type(f) # revealed: list[Any] | None
493+
494+
g: list[Any] | dict[Any, Any] = f3(1)
495+
# TODO: Better constraint solver.
496+
reveal_type(g) # revealed: list[Literal[1]] | dict[Literal[1], Literal[1]]
497+
```
498+
499+
## Prefer the inferred type of non-generic classes
500+
501+
```toml
502+
[environment]
503+
python-version = "3.12"
504+
```
459505

460506
```py
461507
def id[T](x: T) -> T:

crates/ty_python_semantic/resources/mdtest/bidirectional.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ def _(l: list[int] | None = None):
5050
def f[T](x: T, cond: bool) -> T | list[T]:
5151
return x if cond else [x]
5252

53-
# TODO: no error
54-
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
5553
l5: int | list[int] = f(1, True)
5654
```
5755

crates/ty_python_semantic/src/types.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -891,13 +891,24 @@ impl<'db> Type<'db> {
891891
known_class: KnownClass,
892892
) -> Option<Specialization<'db>> {
893893
let class_literal = known_class.try_to_class_literal(db)?;
894-
self.specialization_of(db, Some(class_literal))
894+
self.specialization_of(db, class_literal)
895+
}
896+
897+
// If this type is a class instance, returns its specialization.
898+
pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option<Specialization<'db>> {
899+
self.specialization_of_optional(db, None)
895900
}
896901

897902
// If the type is a specialized instance of the given class, returns the specialization.
898-
//
899-
// If no class is provided, returns the specialization of any class instance.
900903
pub(crate) fn specialization_of(
904+
self,
905+
db: &'db dyn Db,
906+
expected_class: ClassLiteral<'_>,
907+
) -> Option<Specialization<'db>> {
908+
self.specialization_of_optional(db, Some(expected_class))
909+
}
910+
911+
fn specialization_of_optional(
901912
self,
902913
db: &'db dyn Db,
903914
expected_class: Option<ClassLiteral<'_>>,
@@ -1214,22 +1225,28 @@ impl<'db> Type<'db> {
12141225

12151226
/// If the type is a union, filters union elements based on the provided predicate.
12161227
///
1217-
/// Otherwise, returns the type unchanged.
1228+
/// Otherwise, considers the type to be the sole inhabitant of a single-valued union,
1229+
/// and filters it, returning `Never` if the predicate returns `false`, or the type
1230+
/// unchanged if `true`.
12181231
pub(crate) fn filter_union(
12191232
self,
12201233
db: &'db dyn Db,
1221-
f: impl FnMut(&Type<'db>) -> bool,
1234+
mut f: impl FnMut(&Type<'db>) -> bool,
12221235
) -> Type<'db> {
12231236
if let Type::Union(union) = self {
12241237
union.filter(db, f)
1225-
} else {
1238+
} else if f(&self) {
12261239
self
1240+
} else {
1241+
Type::Never
12271242
}
12281243
}
12291244

12301245
/// If the type is a union, removes union elements that are disjoint from `target`.
12311246
///
1232-
/// Otherwise, returns the type unchanged.
1247+
/// Otherwise, considers the type to be the sole inhabitant of a single-valued union,
1248+
/// and filters it, returning `Never` if it is disjoint from `target`, or the type
1249+
/// unchanged if `true`.
12331250
pub(crate) fn filter_disjoint_elements(
12341251
self,
12351252
db: &'db dyn Db,

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ use crate::types::generics::{
3535
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
3636
use crate::types::tuple::{TupleLength, TupleType};
3737
use crate::types::{
38-
BoundMethodType, ClassLiteral, DataclassFlags, DataclassParams, FieldInstance,
39-
KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, NominalInstanceType,
40-
PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext,
41-
UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, infer_isolated_expression,
42-
todo_type,
38+
BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams,
39+
FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
40+
NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
41+
TypeAliasType, TypeContext, UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support,
42+
infer_isolated_expression, todo_type,
4343
};
4444
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
4545
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
@@ -2718,9 +2718,25 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27182718
return;
27192719
};
27202720

2721+
let return_with_tcx = self
2722+
.signature
2723+
.return_ty
2724+
.zip(self.call_expression_tcx.annotation);
2725+
27212726
self.inferable_typevars = generic_context.inferable_typevars(self.db);
27222727
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
27232728

2729+
// Prefer the declared type of generic classes.
2730+
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
2731+
let preferred_return_ty =
2732+
tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
2733+
let return_ty =
2734+
return_ty.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
2735+
2736+
builder.infer(return_ty, preferred_return_ty).ok()?;
2737+
Some(builder.type_mappings().clone())
2738+
});
2739+
27242740
let parameters = self.signature.parameters();
27252741
for (argument_index, adjusted_argument_index, _, argument_type) in
27262742
self.enumerate_argument_types()
@@ -2733,9 +2749,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27332749
continue;
27342750
};
27352751

2736-
if let Err(error) = builder.infer(
2752+
let filter = |declared_ty: BoundTypeVarIdentity<'_>, inferred_ty: Type<'_>| {
2753+
// Avoid widening the inferred type if it is already assignable to the
2754+
// preferred declared type.
2755+
preferred_type_mappings
2756+
.as_ref()
2757+
.and_then(|types| types.get(&declared_ty))
2758+
.is_none_or(|preferred_ty| {
2759+
!inferred_ty.is_assignable_to(self.db, *preferred_ty)
2760+
})
2761+
};
2762+
2763+
if let Err(error) = builder.infer_filter(
27372764
expected_type,
27382765
variadic_argument_type.unwrap_or(argument_type),
2766+
filter,
27392767
) {
27402768
self.errors.push(BindingError::SpecializationError {
27412769
error,
@@ -2745,15 +2773,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27452773
}
27462774
}
27472775

2748-
// Build the specialization first without inferring the type context.
2776+
// Build the specialization first without inferring the complete type context.
27492777
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
27502778
let isolated_return_ty = self
27512779
.return_ty
27522780
.apply_specialization(self.db, isolated_specialization);
27532781

27542782
let mut try_infer_tcx = || {
2755-
let return_ty = self.signature.return_ty?;
2756-
let call_expression_tcx = self.call_expression_tcx.annotation?;
2783+
let (return_ty, call_expression_tcx) = return_with_tcx?;
27572784

27582785
// A type variable is not a useful type-context for expression inference, and applying it
27592786
// to the return type can lead to confusing unions in nested generic calls.
@@ -2762,7 +2789,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27622789
}
27632790

27642791
// If the return type is already assignable to the annotated type, we can ignore the
2765-
// type context and prefer the narrower inferred type.
2792+
// rest of the type context and prefer the narrower inferred type.
27662793
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
27672794
return None;
27682795
}
@@ -2771,7 +2798,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27712798
// annotated assignment, to closer match the order of any unions written in the type annotation.
27722799
builder.infer(return_ty, call_expression_tcx).ok()?;
27732800

2774-
// Otherwise, build the specialization again after inferring the type context.
2801+
// Otherwise, build the specialization again after inferring the complete type context.
27752802
let specialization = builder.build(generic_context, *self.call_expression_tcx);
27762803
let return_ty = return_ty.apply_specialization(self.db, specialization);
27772804

crates/ty_python_semantic/src/types/class.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ impl<'db> GenericAlias<'db> {
258258
) -> Self {
259259
let tcx = tcx
260260
.annotation
261-
.and_then(|ty| ty.specialization_of(db, Some(self.origin(db))))
261+
.and_then(|ty| ty.specialization_of(db, self.origin(db)))
262262
.map(|specialization| specialization.types(db))
263263
.unwrap_or(&[]);
264264

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::cell::RefCell;
2+
use std::collections::hash_map::Entry;
23
use std::fmt::Display;
34

45
use itertools::Itertools;
@@ -1319,14 +1320,19 @@ impl<'db> SpecializationBuilder<'db> {
13191320
}
13201321
}
13211322

1323+
/// Returns the current set of type mappings for this specialization.
1324+
pub(crate) fn type_mappings(&self) -> &FxHashMap<BoundTypeVarIdentity<'db>, Type<'db>> {
1325+
&self.types
1326+
}
1327+
13221328
pub(crate) fn build(
13231329
&mut self,
13241330
generic_context: GenericContext<'db>,
13251331
tcx: TypeContext<'db>,
13261332
) -> Specialization<'db> {
13271333
let tcx_specialization = tcx
13281334
.annotation
1329-
.and_then(|annotation| annotation.specialization_of(self.db, None));
1335+
.and_then(|annotation| annotation.class_specialization(self.db));
13301336

13311337
let types =
13321338
(generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| {
@@ -1349,19 +1355,43 @@ impl<'db> SpecializationBuilder<'db> {
13491355
generic_context.specialize_partial(self.db, types)
13501356
}
13511357

1352-
fn add_type_mapping(&mut self, bound_typevar: BoundTypeVarInstance<'db>, ty: Type<'db>) {
1353-
self.types
1354-
.entry(bound_typevar.identity(self.db))
1355-
.and_modify(|existing| {
1356-
*existing = UnionType::from_elements(self.db, [*existing, ty]);
1357-
})
1358-
.or_insert(ty);
1358+
fn add_type_mapping(
1359+
&mut self,
1360+
bound_typevar: BoundTypeVarInstance<'db>,
1361+
ty: Type<'db>,
1362+
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
1363+
) {
1364+
let identity = bound_typevar.identity(self.db);
1365+
match self.types.entry(identity) {
1366+
Entry::Occupied(mut entry) => {
1367+
if filter(identity, ty) {
1368+
*entry.get_mut() = UnionType::from_elements(self.db, [*entry.get(), ty]);
1369+
}
1370+
}
1371+
Entry::Vacant(entry) => {
1372+
entry.insert(ty);
1373+
}
1374+
}
13591375
}
13601376

1377+
/// Infer type mappings for the specialization based on a given type and its declared type.
13611378
pub(crate) fn infer(
13621379
&mut self,
13631380
formal: Type<'db>,
13641381
actual: Type<'db>,
1382+
) -> Result<(), SpecializationError<'db>> {
1383+
self.infer_filter(formal, actual, |_, _| true)
1384+
}
1385+
1386+
/// Infer type mappings for the specialization based on a given type and its declared type.
1387+
///
1388+
/// The filter predicate is provided with a type variable and the type being mapped to it. Type
1389+
/// mappings to which the predicate returns `false` will be ignored.
1390+
pub(crate) fn infer_filter(
1391+
&mut self,
1392+
formal: Type<'db>,
1393+
actual: Type<'db>,
1394+
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
13651395
) -> Result<(), SpecializationError<'db>> {
13661396
if formal == actual {
13671397
return Ok(());
@@ -1442,7 +1472,7 @@ impl<'db> SpecializationBuilder<'db> {
14421472
if remaining_actual.is_never() {
14431473
return Ok(());
14441474
}
1445-
self.add_type_mapping(*formal_bound_typevar, remaining_actual);
1475+
self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter);
14461476
}
14471477
(Type::Union(formal), _) => {
14481478
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not
@@ -1452,7 +1482,7 @@ impl<'db> SpecializationBuilder<'db> {
14521482
let bound_typevars =
14531483
(formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
14541484
if let Ok(bound_typevar) = bound_typevars.exactly_one() {
1455-
self.add_type_mapping(bound_typevar, actual);
1485+
self.add_type_mapping(bound_typevar, actual, filter);
14561486
}
14571487
}
14581488

@@ -1480,13 +1510,13 @@ impl<'db> SpecializationBuilder<'db> {
14801510
argument: ty,
14811511
});
14821512
}
1483-
self.add_type_mapping(bound_typevar, ty);
1513+
self.add_type_mapping(bound_typevar, ty, filter);
14841514
}
14851515
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
14861516
// Prefer an exact match first.
14871517
for constraint in constraints.elements(self.db) {
14881518
if ty == *constraint {
1489-
self.add_type_mapping(bound_typevar, ty);
1519+
self.add_type_mapping(bound_typevar, ty, filter);
14901520
return Ok(());
14911521
}
14921522
}
@@ -1496,7 +1526,7 @@ impl<'db> SpecializationBuilder<'db> {
14961526
.when_assignable_to(self.db, *constraint, self.inferable)
14971527
.is_always_satisfied(self.db)
14981528
{
1499-
self.add_type_mapping(bound_typevar, *constraint);
1529+
self.add_type_mapping(bound_typevar, *constraint, filter);
15001530
return Ok(());
15011531
}
15021532
}
@@ -1506,7 +1536,7 @@ impl<'db> SpecializationBuilder<'db> {
15061536
});
15071537
}
15081538
_ => {
1509-
self.add_type_mapping(bound_typevar, ty);
1539+
self.add_type_mapping(bound_typevar, ty, filter);
15101540
}
15111541
}
15121542
}

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6260,7 +6260,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
62606260

62616261
let inferred_elt_ty = self.get_or_infer_expression(elt, elt_tcx);
62626262

6263-
// Simplify the inference based on the declared type of the element.
6263+
// Avoid widening the inferred type if it is already assignable to the preferred
6264+
// declared type.
62646265
if let Some(elt_tcx) = elt_tcx.annotation {
62656266
if inferred_elt_ty.is_assignable_to(self.db(), elt_tcx) {
62666267
continue;

0 commit comments

Comments
 (0)