@@ -33,10 +33,11 @@ use crate::types::generics::{
3333use crate :: types:: signatures:: { Parameter , ParameterForm , ParameterKind , Parameters } ;
3434use crate :: types:: tuple:: { TupleLength , TupleType } ;
3535use crate :: types:: {
36- BoundMethodType , ClassLiteral , DataclassFlags , DataclassParams , FieldInstance ,
37- KnownBoundMethodType , KnownClass , KnownInstanceType , MemberLookupPolicy , PropertyInstanceType ,
38- SpecialFormType , TrackedConstraintSet , TypeAliasType , TypeContext , UnionBuilder , UnionType ,
39- WrapperDescriptorKind , enums, ide_support, infer_isolated_expression, todo_type,
36+ BoundMethodType , BoundTypeVarIdentity , ClassLiteral , DataclassFlags , DataclassParams ,
37+ FieldInstance , KnownBoundMethodType , KnownClass , KnownInstanceType , MemberLookupPolicy ,
38+ PropertyInstanceType , SpecialFormType , TrackedConstraintSet , TypeAliasType , TypeContext ,
39+ UnionBuilder , UnionType , WrapperDescriptorKind , enums, ide_support, infer_isolated_expression,
40+ todo_type,
4041} ;
4142use ruff_db:: diagnostic:: { Annotation , Diagnostic , SubDiagnostic , SubDiagnosticSeverity } ;
4243use ruff_python_ast:: { self as ast, ArgOrKeyword , PythonVersion } ;
@@ -2627,9 +2628,25 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26272628 return ;
26282629 } ;
26292630
2631+ let return_with_tcx = self
2632+ . signature
2633+ . return_ty
2634+ . zip ( self . call_expression_tcx . annotation ) ;
2635+
26302636 self . inferable_typevars = generic_context. inferable_typevars ( self . db ) ;
26312637 let mut builder = SpecializationBuilder :: new ( self . db , self . inferable_typevars ) ;
26322638
2639+ // Prefer the declared type of generic classes.
2640+ let preferred_type_mappings = return_with_tcx. and_then ( |( return_ty, tcx) | {
2641+ let preferred_return_ty =
2642+ tcx. filter_union ( self . db , |ty| ty. class_specialization ( self . db ) . is_some ( ) ) ;
2643+ let return_ty =
2644+ return_ty. filter_union ( self . db , |ty| ty. class_specialization ( self . db ) . is_some ( ) ) ;
2645+
2646+ builder. infer ( return_ty, preferred_return_ty) . ok ( ) ?;
2647+ Some ( builder. type_mappings ( ) . clone ( ) )
2648+ } ) ;
2649+
26332650 let parameters = self . signature . parameters ( ) ;
26342651 for ( argument_index, adjusted_argument_index, _, argument_type) in
26352652 self . enumerate_argument_types ( )
@@ -2642,9 +2659,20 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26422659 continue ;
26432660 } ;
26442661
2645- if let Err ( error) = builder. infer (
2662+ let filter = |declared : BoundTypeVarIdentity < ' _ > , new : Type < ' _ > | {
2663+ // Avoid widening the inferred type if it is already assignable to the
2664+ // preferred declared type.
2665+ preferred_type_mappings
2666+ . as_ref ( )
2667+ . is_none_or ( |preferred_type_mappings| {
2668+ !new. is_assignable_to ( self . db , preferred_type_mappings[ & declared] )
2669+ } )
2670+ } ;
2671+
2672+ if let Err ( error) = builder. infer_filter (
26462673 expected_type,
26472674 variadic_argument_type. unwrap_or ( argument_type) ,
2675+ filter,
26482676 ) {
26492677 self . errors . push ( BindingError :: SpecializationError {
26502678 error,
@@ -2654,15 +2682,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26542682 }
26552683 }
26562684
2657- // Build the specialization first without inferring the type context.
2685+ // Build the specialization first without inferring the complete type context.
26582686 let isolated_specialization = builder. build ( generic_context, * self . call_expression_tcx ) ;
26592687 let isolated_return_ty = self
26602688 . return_ty
26612689 . apply_specialization ( self . db , isolated_specialization) ;
26622690
26632691 let mut try_infer_tcx = || {
2664- let return_ty = self . signature . return_ty ?;
2665- let call_expression_tcx = self . call_expression_tcx . annotation ?;
2692+ let ( return_ty, call_expression_tcx) = return_with_tcx?;
26662693
26672694 // A type variable is not a useful type-context for expression inference, and applying it
26682695 // to the return type can lead to confusing unions in nested generic calls.
@@ -2671,7 +2698,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26712698 }
26722699
26732700 // If the return type is already assignable to the annotated type, we can ignore the
2674- // type context and prefer the narrower inferred type.
2701+ // rest of the type context and prefer the narrower inferred type.
26752702 if isolated_return_ty. is_assignable_to ( self . db , call_expression_tcx) {
26762703 return None ;
26772704 }
@@ -2680,7 +2707,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26802707 // annotated assignment, to closer match the order of any unions written in the type annotation.
26812708 builder. infer ( return_ty, call_expression_tcx) . ok ( ) ?;
26822709
2683- // Otherwise, build the specialization again after inferring the type context.
2710+ // Otherwise, build the specialization again after inferring the complete type context.
26842711 let specialization = builder. build ( generic_context, * self . call_expression_tcx ) ;
26852712 let return_ty = return_ty. apply_specialization ( self . db , specialization) ;
26862713
0 commit comments