@@ -35,11 +35,11 @@ use crate::types::generics::{
3535use crate :: types:: signatures:: { Parameter , ParameterForm , ParameterKind , Parameters } ;
3636use crate :: types:: tuple:: { TupleLength , TupleType } ;
3737use crate :: types:: {
38- BoundMethodType , ClassLiteral , DataclassFlags , DataclassParams , FieldInstance ,
39- KnownBoundMethodType , KnownClass , KnownInstanceType , MemberLookupPolicy , NominalInstanceType ,
40- PropertyInstanceType , SpecialFormType , TrackedConstraintSet , TypeAliasType , TypeContext ,
41- UnionBuilder , UnionType , WrapperDescriptorKind , enums, ide_support, infer_isolated_expression ,
42- todo_type,
38+ BoundMethodType , BoundTypeVarIdentity , ClassLiteral , DataclassFlags , DataclassParams ,
39+ FieldInstance , KnownBoundMethodType , KnownClass , KnownInstanceType , MemberLookupPolicy ,
40+ NominalInstanceType , PropertyInstanceType , SpecialFormType , TrackedConstraintSet ,
41+ TypeAliasType , TypeContext , UnionBuilder , UnionType , WrapperDescriptorKind , enums, ide_support,
42+ infer_isolated_expression , todo_type,
4343} ;
4444use ruff_db:: diagnostic:: { Annotation , Diagnostic , SubDiagnostic , SubDiagnosticSeverity } ;
4545use ruff_python_ast:: { self as ast, ArgOrKeyword , PythonVersion } ;
@@ -2718,9 +2718,25 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27182718 return ;
27192719 } ;
27202720
2721+ let return_with_tcx = self
2722+ . signature
2723+ . return_ty
2724+ . zip ( self . call_expression_tcx . annotation ) ;
2725+
27212726 self . inferable_typevars = generic_context. inferable_typevars ( self . db ) ;
27222727 let mut builder = SpecializationBuilder :: new ( self . db , self . inferable_typevars ) ;
27232728
2729+ // Prefer the declared type of generic classes.
2730+ let preferred_type_mappings = return_with_tcx. and_then ( |( return_ty, tcx) | {
2731+ let preferred_return_ty =
2732+ tcx. filter_union ( self . db , |ty| ty. class_specialization ( self . db ) . is_some ( ) ) ;
2733+ let return_ty =
2734+ return_ty. filter_union ( self . db , |ty| ty. class_specialization ( self . db ) . is_some ( ) ) ;
2735+
2736+ builder. infer ( return_ty, preferred_return_ty) . ok ( ) ?;
2737+ Some ( builder. type_mappings ( ) . clone ( ) )
2738+ } ) ;
2739+
27242740 let parameters = self . signature . parameters ( ) ;
27252741 for ( argument_index, adjusted_argument_index, _, argument_type) in
27262742 self . enumerate_argument_types ( )
@@ -2733,9 +2749,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27332749 continue ;
27342750 } ;
27352751
2736- if let Err ( error) = builder. infer (
2752+ let filter = |declared_ty : BoundTypeVarIdentity < ' _ > , inferred_ty : Type < ' _ > | {
2753+ // Avoid widening the inferred type if it is already assignable to the
2754+ // preferred declared type.
2755+ preferred_type_mappings
2756+ . as_ref ( )
2757+ . and_then ( |types| types. get ( & declared_ty) )
2758+ . is_none_or ( |preferred_ty| {
2759+ !inferred_ty. is_assignable_to ( self . db , * preferred_ty)
2760+ } )
2761+ } ;
2762+
2763+ if let Err ( error) = builder. infer_filter (
27372764 expected_type,
27382765 variadic_argument_type. unwrap_or ( argument_type) ,
2766+ filter,
27392767 ) {
27402768 self . errors . push ( BindingError :: SpecializationError {
27412769 error,
@@ -2745,15 +2773,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27452773 }
27462774 }
27472775
2748- // Build the specialization first without inferring the type context.
2776+ // Build the specialization first without inferring the complete type context.
27492777 let isolated_specialization = builder. build ( generic_context, * self . call_expression_tcx ) ;
27502778 let isolated_return_ty = self
27512779 . return_ty
27522780 . apply_specialization ( self . db , isolated_specialization) ;
27532781
27542782 let mut try_infer_tcx = || {
2755- let return_ty = self . signature . return_ty ?;
2756- let call_expression_tcx = self . call_expression_tcx . annotation ?;
2783+ let ( return_ty, call_expression_tcx) = return_with_tcx?;
27572784
27582785 // A type variable is not a useful type-context for expression inference, and applying it
27592786 // to the return type can lead to confusing unions in nested generic calls.
@@ -2762,7 +2789,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27622789 }
27632790
27642791 // If the return type is already assignable to the annotated type, we can ignore the
2765- // type context and prefer the narrower inferred type.
2792+ // rest of the type context and prefer the narrower inferred type.
27662793 if isolated_return_ty. is_assignable_to ( self . db , call_expression_tcx) {
27672794 return None ;
27682795 }
@@ -2771,7 +2798,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27712798 // annotated assignment, to closer match the order of any unions written in the type annotation.
27722799 builder. infer ( return_ty, call_expression_tcx) . ok ( ) ?;
27732800
2774- // Otherwise, build the specialization again after inferring the type context.
2801+ // Otherwise, build the specialization again after inferring the complete type context.
27752802 let specialization = builder. build ( generic_context, * self . call_expression_tcx ) ;
27762803 let return_ty = return_ty. apply_specialization ( self . db , specialization) ;
27772804
0 commit comments