Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -427,14 +427,13 @@ a = f("a")
reveal_type(a) # revealed: list[Literal["a"]]

b: list[int | Literal["a"]] = f("a")
reveal_type(b) # revealed: list[Literal["a"] | int]
reveal_type(b) # revealed: list[int | Literal["a"]]

c: list[int | str] = f("a")
reveal_type(c) # revealed: list[str | int]
reveal_type(c) # revealed: list[int | str]

d: list[int | tuple[int, int]] = f((1, 2))
# TODO: We could avoid reordering the union elements here.
reveal_type(d) # revealed: list[tuple[int, int] | int]
reveal_type(d) # revealed: list[int | tuple[int, int]]

e: list[int] = f(True)
reveal_type(e) # revealed: list[int]
Expand All @@ -455,7 +454,54 @@ j: int | str = f2(True)
reveal_type(j) # revealed: Literal[True]
```

Types are not widened unnecessarily:
## Prefer the declared type of generic classes

```toml
[environment]
python-version = "3.12"
```

```py
from typing import Any

def f[T](x: T) -> list[T]:
return [x]

def f2[T](x: T) -> list[T] | None:
return [x]

def f3[T](x: T) -> list[T] | dict[T, T]:
return [x]

a = f(1)
reveal_type(a) # revealed: list[Literal[1]]

b: list[Any] = f(1)
reveal_type(b) # revealed: list[Any]

c: list[Any] = [1]
reveal_type(c) # revealed: list[Any]

d: list[Any] | None = f(1)
reveal_type(d) # revealed: list[Any]

e: list[Any] | None = [1]
reveal_type(e) # revealed: list[Any]

f: list[Any] | None = f2(1)
reveal_type(f) # revealed: list[Any] | None

g: list[Any] | dict[Any, Any] = f3(1)
# TODO: Better constraint solver.
reveal_type(g) # revealed: list[Literal[1]] | dict[Literal[1], Literal[1]]
```

## Prefer the inferred type of non-generic classes

```toml
[environment]
python-version = "3.12"
```

```py
def id[T](x: T) -> T:
Expand Down
2 changes: 0 additions & 2 deletions crates/ty_python_semantic/resources/mdtest/bidirectional.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def _(l: list[int] | None = None):
def f[T](x: T, cond: bool) -> T | list[T]:
return x if cond else [x]

# TODO: no error
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
l5: int | list[int] = f(1, True)
```

Expand Down
31 changes: 24 additions & 7 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -891,13 +891,24 @@ impl<'db> Type<'db> {
known_class: KnownClass,
) -> Option<Specialization<'db>> {
let class_literal = known_class.try_to_class_literal(db)?;
self.specialization_of(db, Some(class_literal))
self.specialization_of(db, class_literal)
}

// If this type is a class instance, returns its specialization.
pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option<Specialization<'db>> {
self.specialization_of_optional(db, None)
}

// If the type is a specialized instance of the given class, returns the specialization.
//
// If no class is provided, returns the specialization of any class instance.
pub(crate) fn specialization_of(
self,
db: &'db dyn Db,
expected_class: ClassLiteral<'_>,
) -> Option<Specialization<'db>> {
self.specialization_of_optional(db, Some(expected_class))
}

fn specialization_of_optional(
self,
db: &'db dyn Db,
expected_class: Option<ClassLiteral<'_>>,
Expand Down Expand Up @@ -1214,22 +1225,28 @@ impl<'db> Type<'db> {

/// If the type is a union, filters union elements based on the provided predicate.
///
/// Otherwise, returns the type unchanged.
/// Otherwise, considers the type to be the sole inhabitant of a single-valued union,
/// and filters it, returning `Never` if the predicate returns `false`, or the type
/// unchanged if `true`.
pub(crate) fn filter_union(
self,
db: &'db dyn Db,
f: impl FnMut(&Type<'db>) -> bool,
mut f: impl FnMut(&Type<'db>) -> bool,
) -> Type<'db> {
if let Type::Union(union) = self {
union.filter(db, f)
} else {
} else if f(&self) {
self
} else {
Type::Never
}
}

/// If the type is a union, removes union elements that are disjoint from `target`.
///
/// Otherwise, returns the type unchanged.
/// Otherwise, considers the type to be the sole inhabitant of a single-valued union,
/// and filters it, returning `Never` if it is disjoint from `target`, or the type
/// unchanged if `true`.
pub(crate) fn filter_disjoint_elements(
self,
db: &'db dyn Db,
Expand Down
49 changes: 38 additions & 11 deletions crates/ty_python_semantic/src/types/call/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ use crate::types::generics::{
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
use crate::types::tuple::{TupleLength, TupleType};
use crate::types::{
BoundMethodType, ClassLiteral, DataclassFlags, DataclassParams, FieldInstance,
KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, NominalInstanceType,
PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext,
UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, infer_isolated_expression,
todo_type,
BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams,
FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
TypeAliasType, TypeContext, UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support,
infer_isolated_expression, todo_type,
};
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
Expand Down Expand Up @@ -2718,9 +2718,25 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
return;
};

let return_with_tcx = self
.signature
.return_ty
.zip(self.call_expression_tcx.annotation);

self.inferable_typevars = generic_context.inferable_typevars(self.db);
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);

// Prefer the declared type of generic classes.
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
let preferred_return_ty =
tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
let return_ty =
return_ty.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());

builder.infer(return_ty, preferred_return_ty).ok()?;
Some(builder.type_mappings().clone())
});

let parameters = self.signature.parameters();
for (argument_index, adjusted_argument_index, _, argument_type) in
self.enumerate_argument_types()
Expand All @@ -2733,9 +2749,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
continue;
};

if let Err(error) = builder.infer(
let filter = |declared_ty: BoundTypeVarIdentity<'_>, inferred_ty: Type<'_>| {
// Avoid widening the inferred type if it is already assignable to the
// preferred declared type.
preferred_type_mappings
.as_ref()
.and_then(|types| types.get(&declared_ty))
.is_none_or(|preferred_ty| {
!inferred_ty.is_assignable_to(self.db, *preferred_ty)
})
};

if let Err(error) = builder.infer_filter(
expected_type,
variadic_argument_type.unwrap_or(argument_type),
filter,
) {
self.errors.push(BindingError::SpecializationError {
error,
Expand All @@ -2745,15 +2773,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
}
}

// Build the specialization first without inferring the type context.
// Build the specialization first without inferring the complete type context.
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
let isolated_return_ty = self
.return_ty
.apply_specialization(self.db, isolated_specialization);

let mut try_infer_tcx = || {
let return_ty = self.signature.return_ty?;
let call_expression_tcx = self.call_expression_tcx.annotation?;
let (return_ty, call_expression_tcx) = return_with_tcx?;

// A type variable is not a useful type-context for expression inference, and applying it
// to the return type can lead to confusing unions in nested generic calls.
Expand All @@ -2762,7 +2789,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
}

// If the return type is already assignable to the annotated type, we can ignore the
// type context and prefer the narrower inferred type.
// rest of the type context and prefer the narrower inferred type.
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
return None;
}
Expand All @@ -2771,7 +2798,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
// annotated assignment, to closer match the order of any unions written in the type annotation.
builder.infer(return_ty, call_expression_tcx).ok()?;

// Otherwise, build the specialization again after inferring the type context.
// Otherwise, build the specialization again after inferring the complete type context.
let specialization = builder.build(generic_context, *self.call_expression_tcx);
let return_ty = return_ty.apply_specialization(self.db, specialization);

Expand Down
2 changes: 1 addition & 1 deletion crates/ty_python_semantic/src/types/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ impl<'db> GenericAlias<'db> {
) -> Self {
let tcx = tcx
.annotation
.and_then(|ty| ty.specialization_of(db, Some(self.origin(db))))
.and_then(|ty| ty.specialization_of(db, self.origin(db)))
.map(|specialization| specialization.types(db))
.unwrap_or(&[]);

Expand Down
58 changes: 44 additions & 14 deletions crates/ty_python_semantic/src/types/generics.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::fmt::Display;

use itertools::Itertools;
Expand Down Expand Up @@ -1319,14 +1320,19 @@ impl<'db> SpecializationBuilder<'db> {
}
}

/// Returns the current set of type mappings for this specialization.
pub(crate) fn type_mappings(&self) -> &FxHashMap<BoundTypeVarIdentity<'db>, Type<'db>> {
&self.types
}

pub(crate) fn build(
&mut self,
generic_context: GenericContext<'db>,
tcx: TypeContext<'db>,
) -> Specialization<'db> {
let tcx_specialization = tcx
.annotation
.and_then(|annotation| annotation.specialization_of(self.db, None));
.and_then(|annotation| annotation.class_specialization(self.db));

let types =
(generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| {
Expand All @@ -1349,19 +1355,43 @@ impl<'db> SpecializationBuilder<'db> {
generic_context.specialize_partial(self.db, types)
}

fn add_type_mapping(&mut self, bound_typevar: BoundTypeVarInstance<'db>, ty: Type<'db>) {
self.types
.entry(bound_typevar.identity(self.db))
.and_modify(|existing| {
*existing = UnionType::from_elements(self.db, [*existing, ty]);
})
.or_insert(ty);
fn add_type_mapping(
&mut self,
bound_typevar: BoundTypeVarInstance<'db>,
ty: Type<'db>,
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
) {
let identity = bound_typevar.identity(self.db);
match self.types.entry(identity) {
Entry::Occupied(mut entry) => {
if filter(identity, ty) {
*entry.get_mut() = UnionType::from_elements(self.db, [*entry.get(), ty]);
}
}
Entry::Vacant(entry) => {
entry.insert(ty);
}
}
}

/// Infer type mappings for the specialization based on a given type and its declared type.
pub(crate) fn infer(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
) -> Result<(), SpecializationError<'db>> {
self.infer_filter(formal, actual, |_, _| true)
}

/// Infer type mappings for the specialization based on a given type and its declared type.
///
/// The filter predicate is provided with a type variable and the type being mapped to it. Type
/// mappings to which the predicate returns `false` will be ignored.
pub(crate) fn infer_filter(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
) -> Result<(), SpecializationError<'db>> {
if formal == actual {
return Ok(());
Expand Down Expand Up @@ -1442,7 +1472,7 @@ impl<'db> SpecializationBuilder<'db> {
if remaining_actual.is_never() {
return Ok(());
}
self.add_type_mapping(*formal_bound_typevar, remaining_actual);
self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter);
}
(Type::Union(formal), _) => {
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not
Expand All @@ -1452,7 +1482,7 @@ impl<'db> SpecializationBuilder<'db> {
let bound_typevars =
(formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
if let Ok(bound_typevar) = bound_typevars.exactly_one() {
self.add_type_mapping(bound_typevar, actual);
self.add_type_mapping(bound_typevar, actual, filter);
}
}

Expand Down Expand Up @@ -1480,13 +1510,13 @@ impl<'db> SpecializationBuilder<'db> {
argument: ty,
});
}
self.add_type_mapping(bound_typevar, ty);
self.add_type_mapping(bound_typevar, ty, filter);
}
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
// Prefer an exact match first.
for constraint in constraints.elements(self.db) {
if ty == *constraint {
self.add_type_mapping(bound_typevar, ty);
self.add_type_mapping(bound_typevar, ty, filter);
return Ok(());
}
}
Expand All @@ -1496,7 +1526,7 @@ impl<'db> SpecializationBuilder<'db> {
.when_assignable_to(self.db, *constraint, self.inferable)
.is_always_satisfied(self.db)
{
self.add_type_mapping(bound_typevar, *constraint);
self.add_type_mapping(bound_typevar, *constraint, filter);
return Ok(());
}
}
Expand All @@ -1506,7 +1536,7 @@ impl<'db> SpecializationBuilder<'db> {
});
}
_ => {
self.add_type_mapping(bound_typevar, ty);
self.add_type_mapping(bound_typevar, ty, filter);
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion crates/ty_python_semantic/src/types/infer/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6260,7 +6260,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {

let inferred_elt_ty = self.get_or_infer_expression(elt, elt_tcx);

// Simplify the inference based on the declared type of the element.
// Avoid widening the inferred type if it is already assignable to the preferred
// declared type.
if let Some(elt_tcx) = elt_tcx.annotation {
if inferred_elt_ty.is_assignable_to(self.db(), elt_tcx) {
continue;
Expand Down
Loading