diff --git a/crates/hir-def/src/hir/type_ref.rs b/crates/hir-def/src/hir/type_ref.rs index da0f058a9cb5..3eb67e89da12 100644 --- a/crates/hir-def/src/hir/type_ref.rs +++ b/crates/hir-def/src/hir/type_ref.rs @@ -195,12 +195,12 @@ impl TypeRef { TypeRef::Tuple(ThinVec::new()) } - pub fn walk(this: TypeRefId, map: &ExpressionStore, f: &mut impl FnMut(&TypeRef)) { + pub fn walk(this: TypeRefId, map: &ExpressionStore, f: &mut impl FnMut(TypeRefId, &TypeRef)) { go(this, f, map); - fn go(type_ref: TypeRefId, f: &mut impl FnMut(&TypeRef), map: &ExpressionStore) { - let type_ref = &map[type_ref]; - f(type_ref); + fn go(type_ref_id: TypeRefId, f: &mut impl FnMut(TypeRefId, &TypeRef), map: &ExpressionStore) { + let type_ref = &map[type_ref_id]; + f(type_ref_id, type_ref); match type_ref { TypeRef::Fn(fn_) => { fn_.params.iter().for_each(|&(_, param_type)| go(param_type, f, map)) @@ -224,7 +224,7 @@ impl TypeRef { }; } - fn go_path(path: &Path, f: &mut impl FnMut(&TypeRef), map: &ExpressionStore) { + fn go_path(path: &Path, f: &mut impl FnMut(TypeRefId, &TypeRef), map: &ExpressionStore) { if let Some(type_ref) = path.type_anchor() { go(type_ref, f, map); } diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index d778cc0e30ed..db2537345cc6 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -44,7 +44,7 @@ use hir_def::{ layout::Integer, resolver::{HasResolver, ResolveValueResult, Resolver, TypeNs, ValueNs}, signatures::{ConstSignature, StaticSignature}, - type_ref::{ConstRef, LifetimeRefId, TypeRefId}, + type_ref::{ConstRef, LifetimeRefId, TypeRef, TypeRefId}, }; use hir_expand::{mod_path::ModPath, name::Name}; use indexmap::IndexSet; @@ -59,6 +59,7 @@ use crate::{ ImplTraitIdx, InEnvironment, IncorrectGenericsLenKind, Interner, Lifetime, OpaqueTyId, ParamLoweringMode, PathLoweringDiagnostic, ProjectionTy, Substitution, TraitEnvironment, Ty, TyBuilder, TyExt, + collect_type_inference_vars, db::HirDatabase, fold_tys, generics::Generics, @@ -473,6 +474,7 @@ pub struct InferenceResult { /// unresolved or missing subpatterns or subpatterns of mismatched types. pub(crate) type_of_pat: ArenaMap, pub(crate) type_of_binding: ArenaMap, + pub(crate) type_of_type_placeholder: ArenaMap, pub(crate) type_of_rpit: ArenaMap, type_mismatches: FxHashMap, /// Whether there are any type-mismatching errors in the result. @@ -555,6 +557,15 @@ impl InferenceResult { _ => None, }) } + + pub fn placeholder_types(&self) -> impl Iterator { + self.type_of_type_placeholder.iter() + } + + pub fn type_of_type_placeholder(&self, type_ref: TypeRefId) -> &Ty { + self.type_of_type_placeholder.get(type_ref).unwrap_or(&self.standard_types.unknown) + } + pub fn closure_info(&self, closure: &ClosureId) -> &(Vec, FnTrait) { self.closure_info.get(closure).unwrap() } @@ -818,6 +829,7 @@ impl<'db> InferenceContext<'db> { type_of_expr, type_of_pat, type_of_binding, + type_of_type_placeholder, type_of_rpit, type_mismatches, has_errors, @@ -873,6 +885,11 @@ impl<'db> InferenceContext<'db> { *has_errors = *has_errors || ty.contains_unknown(); } type_of_binding.shrink_to_fit(); + for ty in type_of_type_placeholder.values_mut() { + *ty = table.resolve_completely(ty.clone()); + *has_errors = *has_errors || ty.contains_unknown(); + } + type_of_type_placeholder.shrink_to_fit(); for ty in type_of_rpit.values_mut() { *ty = table.resolve_completely(ty.clone()); *has_errors = *has_errors || ty.contains_unknown(); @@ -1371,6 +1388,10 @@ impl<'db> InferenceContext<'db> { self.result.type_of_pat.insert(pat, ty); } + fn write_type_placeholder_ty(&mut self, type_ref: TypeRefId, ty: Ty) { + self.result.type_of_type_placeholder.insert(type_ref, ty); + } + fn write_binding_ty(&mut self, id: BindingId, ty: Ty) { self.result.type_of_binding.insert(id, ty); } @@ -1417,7 +1438,22 @@ impl<'db> InferenceContext<'db> { let ty = self .with_ty_lowering(store, type_source, lifetime_elision, |ctx| ctx.lower_ty(type_ref)); let ty = self.insert_type_vars(ty); - self.normalize_associated_types_in(ty) + let ty = self.normalize_associated_types_in(ty); + + // Record the association from placeholders' TypeRefId to type variables + // Current assumptions: + // - same number of type variables in ty and type placeholders in type_ref + // - same order + // Better way to achieve this? + let type_variables = collect_type_inference_vars(&ty); + let mut type_variables_iter = type_variables.into_iter(); + TypeRef::walk(type_ref, store, &mut |type_ref_id, type_ref| { + if matches!(type_ref, TypeRef::Placeholder) { + self.write_type_placeholder_ty(type_ref_id, type_variables_iter.next().unwrap()); + } + }); + + ty } fn make_body_ty(&mut self, type_ref: TypeRefId) -> Ty { diff --git a/crates/hir-ty/src/lib.rs b/crates/hir-ty/src/lib.rs index 7fdfb205721e..eb32987acd0e 100644 --- a/crates/hir-ty/src/lib.rs +++ b/crates/hir-ty/src/lib.rs @@ -720,7 +720,7 @@ pub(crate) fn fold_free_vars + TypeFoldable< t.fold_with(&mut FreeVarFolder(for_ty, for_const), DebruijnIndex::INNERMOST) } -pub(crate) fn fold_tys + TypeFoldable>( +pub fn fold_tys + TypeFoldable>( t: T, mut for_ty: impl FnMut(Ty, DebruijnIndex) -> Ty, binders: DebruijnIndex, @@ -1065,6 +1065,48 @@ where collector.placeholders.into_iter().collect() } +struct TypeInferenceVarCollector { + type_inference_vars: Vec, +} + +impl TypeVisitor for TypeInferenceVarCollector { + type BreakTy = (); + + fn as_dyn(&mut self) -> &mut dyn TypeVisitor { + self + } + + fn interner(&self) -> Interner { + Interner + } + + fn visit_ty( + &mut self, + ty: &Ty, + outer_binder: DebruijnIndex, + ) -> std::ops::ControlFlow { + if ty.is_ty_var() { + self.type_inference_vars.push(ty.clone()); + } else if ty.data(Interner).flags.intersects(TypeFlags::HAS_TY_INFER) { + return ty.super_visit_with(self, outer_binder); + } else { + // Fast path: don't visit inner types (e.g. generic arguments) when `flags` indicate + // that there are no placeholders. + } + + std::ops::ControlFlow::Continue(()) + } +} + +pub fn collect_type_inference_vars(value: &T) -> Vec +where + T: ?Sized + TypeVisitable, +{ + let mut collector = TypeInferenceVarCollector { type_inference_vars: vec![] }; + _ = value.visit_with(&mut collector, DebruijnIndex::INNERMOST); + collector.type_inference_vars +} + pub fn known_const_to_ast( konst: &Const, db: &dyn HirDatabase, diff --git a/crates/hir-ty/src/tests.rs b/crates/hir-ty/src/tests.rs index 1c3da438cb36..88019ab77750 100644 --- a/crates/hir-ty/src/tests.rs +++ b/crates/hir-ty/src/tests.rs @@ -23,6 +23,7 @@ use hir_def::{ item_scope::ItemScope, nameres::DefMap, src::HasSource, + type_ref::TypeRefId, }; use hir_expand::{FileRange, InFile, db::ExpandDatabase}; use itertools::Itertools; @@ -219,6 +220,24 @@ fn check_impl( None => format_to!(unexpected_type_mismatches, "{:?}: {}\n", range.range, actual), } } + + for (type_ref, ty) in inference_result.placeholder_types() { + let node = match type_node(&body_source_map, type_ref, &db) { + Some(value) => value, + None => continue, + }; + let range = node.as_ref().original_file_range_rooted(&db); + if let Some(expected) = types.remove(&range) { + let actual = salsa::attach(&db, || { + if display_source { + ty.display_source_code(&db, def.module(&db), true).unwrap() + } else { + ty.display_test(&db, display_target).to_string() + } + }); + assert_eq!(actual, expected, "type annotation differs at {:#?}", range.range); + } + } } let mut buf = String::new(); @@ -274,6 +293,20 @@ fn pat_node( }) } +fn type_node( + body_source_map: &BodySourceMap, + type_ref: TypeRefId, + db: &TestDB, +) -> Option> { + Some(match body_source_map.type_syntax(type_ref) { + Ok(sp) => { + let root = db.parse_or_expand(sp.file_id); + sp.map(|ptr| ptr.to_node(&root).syntax().clone()) + } + Err(SyntheticSyntax) => return None, + }) +} + fn infer(#[rust_analyzer::rust_fixture] ra_fixture: &str) -> String { infer_with_mismatches(ra_fixture, false) } diff --git a/crates/hir-ty/src/tests/display_source_code.rs b/crates/hir-ty/src/tests/display_source_code.rs index 6e3faa05a629..420d90c0c5f1 100644 --- a/crates/hir-ty/src/tests/display_source_code.rs +++ b/crates/hir-ty/src/tests/display_source_code.rs @@ -246,3 +246,22 @@ fn test() { "#, ); } + +#[test] +fn type_placeholder_type() { + check_types_source_code( + r#" +struct S(T); +fn test() { + let f: S<_> = S(3); + //^ i32 + let f: [_; _] = [4_u32, 5, 6]; + //^ u32 + let f: (_, _, _) = (1_u32, 1_i32, false); + //^ u32 + //^ i32 + //^ bool +} +"#, + ); +} diff --git a/crates/hir/src/source_analyzer.rs b/crates/hir/src/source_analyzer.rs index e3070c5f74ce..803561499adc 100644 --- a/crates/hir/src/source_analyzer.rs +++ b/crates/hir/src/source_analyzer.rs @@ -29,7 +29,7 @@ use hir_def::{ lang_item::LangItem, nameres::MacroSubNs, resolver::{HasResolver, Resolver, TypeNs, ValueNs, resolver_for_scope}, - type_ref::{Mutability, TypeRefId}, + type_ref::{Mutability, TypeRef, TypeRefId}, }; use hir_expand::{ HirFileId, InFile, @@ -37,7 +37,7 @@ use hir_expand::{ name::{AsName, Name}, }; use hir_ty::{ - Adjustment, AliasTy, InferenceResult, Interner, LifetimeElisionKind, ProjectionTy, + Adjustment, AliasTy, DebruijnIndex, InferenceResult, Interner, LifetimeElisionKind, ProjectionTy, Substitution, ToChalk, TraitEnvironment, Ty, TyExt, TyKind, TyLoweringContext, diagnostics::{ InsideUnsafeBlock, record_literal_missing_fields, record_pattern_missing_fields, @@ -260,17 +260,43 @@ impl<'db> SourceAnalyzer<'db> { ty: &ast::Type, ) -> Option> { let type_ref = self.type_id(ty)?; + let ty = TyLoweringContext::new( - db, - &self.resolver, - self.store()?, - self.resolver.generic_def()?, - // FIXME: Is this correct here? Anyway that should impact mostly diagnostics, which we don't emit here - // (this can impact the lifetimes generated, e.g. in `const` they won't be `'static`, but this seems like a - // small problem). - LifetimeElisionKind::Infer, - ) - .lower_ty(type_ref); + db, + &self.resolver, + self.store()?, + self.resolver.generic_def()?, + // FIXME: Is this correct here? Anyway that should impact mostly diagnostics, which we don't emit here + // (this can impact the lifetimes generated, e.g. in `const` they won't be `'static`, but this seems like a + // small problem). + LifetimeElisionKind::Infer, + ) + .lower_ty(type_ref); + + let infer = self.infer(); + let mut infered_types = vec![]; + let ty = if let Some(infer) = infer && let Some(store) = self.store() { + TypeRef::walk(type_ref, store, &mut |type_ref_id, type_ref| { + if matches!(type_ref, TypeRef::Placeholder) { + infered_types.push(infer.type_of_type_placeholder(type_ref_id).clone()); + } + }); + + let mut infered_types = infered_types.into_iter(); + + hir_ty::fold_tys( + ty, + |ty, _| if ty.is_unknown() { + infered_types.next().unwrap() + } else { + ty + }, + DebruijnIndex::INNERMOST, + ) + } else { + ty + }; + Some(Type::new_with_resolver(db, &self.resolver, ty)) } diff --git a/crates/ide-assists/src/handlers/extract_type_alias.rs b/crates/ide-assists/src/handlers/extract_type_alias.rs index 7f93506685e1..252b897014a7 100644 --- a/crates/ide-assists/src/handlers/extract_type_alias.rs +++ b/crates/ide-assists/src/handlers/extract_type_alias.rs @@ -1,4 +1,5 @@ use either::Either; +use hir::HirDisplay; use ide_db::syntax_helpers::node_ext::walk_ty; use syntax::{ ast::{self, AstNode, HasGenericArgs, HasGenericParams, HasName, edit::IndentLevel, make}, @@ -39,6 +40,15 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> ); let target = ty.syntax().text_range(); + let module = ctx.sema.scope(ty.syntax())?.module(); + let resolved_ty = ctx.sema.resolve_type(&ty)?; + let resolved_ty = if !resolved_ty.contains_unknown() { + let resolved_ty = resolved_ty.display_source_code(ctx.db(), module.into(), false).ok()?; + make::ty(&resolved_ty) + } else { + ty.clone() + }; + acc.add( AssistId::refactor_extract("extract_type_alias"), "Extract type as type alias", @@ -70,7 +80,7 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> // Insert new alias let ty_alias = - make::ty_alias(None, "Type", generic_params, None, None, Some((ty, None))) + make::ty_alias(None, "Type", generic_params, None, None, Some((resolved_ty, None))) .clone_for_update(); if let Some(cap) = ctx.config.snippet_cap @@ -389,4 +399,50 @@ where "#, ); } + + #[test] + fn inferred_generic_type_parameter() { + check_assist( + extract_type_alias, + r#" +struct Wrap(T); + +fn main() { + let wrap: $0Wrap<_>$0 = Wrap::<_>(3i32); +} + "#, + r#" +struct Wrap(T); + +type $0Type = Wrap; + +fn main() { + let wrap: Type = Wrap::<_>(3i32); +} + "#, + ) + } + + #[test] + fn inferred_type() { + check_assist( + extract_type_alias, + r#" +struct Wrap(T); + +fn main() { + let wrap: Wrap<$0_$0> = Wrap::<_>(3i32); +} + "#, + r#" +struct Wrap(T); + +type $0Type = i32; + +fn main() { + let wrap: Wrap = Wrap::<_>(3i32); +} + "#, + ) + } } diff --git a/crates/ide/src/inlay_hints.rs b/crates/ide/src/inlay_hints.rs index d7d317166494..3472f5fa35e7 100644 --- a/crates/ide/src/inlay_hints.rs +++ b/crates/ide/src/inlay_hints.rs @@ -288,6 +288,25 @@ fn hints( implied_dyn_trait::hints(hints, famous_defs, config, Either::Right(dyn_)); Some(()) }, + ast::Type::InferType(ref infer) => { + (|| { + let module= sema.scope(infer.syntax())?.module(); + let ty = sema.resolve_type(&ty)?; + let label = ty.display_source_code(sema.db, module.into(), false).ok()?; + hints.push(InlayHint { + range: infer.syntax().text_range(), + kind: InlayKind::Type, + label: format!("= {label}").into(), + text_edit: None, + position: InlayHintPosition::After, + pad_left: true, + pad_right: false, + resolve_parent: None, + }); + Some(()) + })(); + Some(()) + }, _ => Some(()), }, ast::GenericParamList(it) => bounds::hints(hints, famous_defs, config, it), @@ -1101,4 +1120,21 @@ where "#, ); } + + #[test] + fn inferred_types() { + check( + r#" +struct S(T); + +fn foo() { + let t: (_, _, [_; _]) = (1_u32, S(2), [false] as _); + //^ = u32 + //^ = S + //^ = bool + //^ = [bool; 1] +} +"#, + ); + } } diff --git a/crates/ide/src/inlay_hints/adjustment.rs b/crates/ide/src/inlay_hints/adjustment.rs index e39a5f8889a2..48178ebac6c7 100644 --- a/crates/ide/src/inlay_hints/adjustment.rs +++ b/crates/ide/src/inlay_hints/adjustment.rs @@ -366,6 +366,7 @@ fn main() { //^^^^^^^^^^^^^^^^^^^^^) //^^^^^^^^^&raw mut * let _: &mut [_] = &mut [0; 0]; + //^ = i32 //^^^^^^^^^^^&mut * Struct.consume(); diff --git a/crates/ide/src/inlay_hints/param_name.rs b/crates/ide/src/inlay_hints/param_name.rs index 754707784055..673528be2641 100644 --- a/crates/ide/src/inlay_hints/param_name.rs +++ b/crates/ide/src/inlay_hints/param_name.rs @@ -568,6 +568,7 @@ fn main() { let param = 0; foo(param); foo(param as _); + //^ = u32 let param_end = 0; foo(param_end); let start_param = 0;