diff --git a/crates/ide-assists/src/handlers/extract_to_default_generic.rs b/crates/ide-assists/src/handlers/extract_to_default_generic.rs new file mode 100644 index 000000000000..d47d85c9b357 --- /dev/null +++ b/crates/ide-assists/src/handlers/extract_to_default_generic.rs @@ -0,0 +1,245 @@ +use ast::Name; +use either::Either::{self, Left, Right}; +use ide_db::{source_change::SourceChangeBuilder, syntax_helpers::suggest_name::NameGenerator}; +use syntax::{ + ast::{self, AstNode, HasGenericParams, HasName, make}, + syntax_editor::{Position, SyntaxEditor}, +}; + +use crate::{AssistContext, AssistId, Assists}; + +// Assist: extract_to_default_generic +// +// Extracts selected type to default generic parameter. +// +// ``` +// struct Foo(u32, $0String$0); +// ``` +// -> +// ``` +// struct Foo(u32, T); +// ``` +pub(crate) fn extract_to_default_generic(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + if ctx.has_empty_selection() { + return None; + } + + let ty: Either = ctx.find_node_at_range()?; + let adt: Either, ast::Fn> = + ty.syntax().ancestors().find_map(AstNode::cast)?; + + extract_to_default_generic_impl(acc, ctx, adt, ty) +} + +fn extract_to_default_generic_impl( + acc: &mut Assists, + ctx: &AssistContext<'_>, + adt: Either, + ty: Either, +) -> Option<()> { + let name = adt.name()?; + let has_default = adt.is_left(); + + let label = if has_default { + "Extract type as default generic parameter" + } else { + "Extract type as generic parameter" + }; + let target = ty.syntax().text_range(); + acc.add(AssistId::refactor_extract("extract_to_default_generic"), label, target, |edit| { + let mut editor = edit.make_editor(adt.syntax()); + let generic_list = get_or_create_generic_param_list(&name, &adt, &mut editor, edit); + + let generic_name = generic_name(&generic_list, ty.is_right()); + + editor.replace(ty.syntax(), generic_name.syntax()); + + match ty { + Left(ty) => { + let param = if has_default { + make::type_default_param(generic_name, None, ty) + } else { + make::type_param(generic_name, None) + } + .clone_for_update(); + + generic_list.add_generic_param(param.into()); + } + Right(n) => { + let param = if has_default { + make::const_default_param(generic_name, const_ty(ctx, &n), n) + } else { + make::const_param(generic_name, const_ty(ctx, &n)) + } + .clone_for_update(); + + generic_list.add_generic_param(param.into()); + } + } + + if let Some(cap) = ctx.config.snippet_cap + && let Some(last) = generic_list.generic_params().last() + { + if let ast::GenericParam::TypeParam(param) = &last + && let Some(name) = param.name() + { + let annotation = edit.make_tabstop_after(cap); + editor.add_annotation(name.syntax(), annotation); + } else if let ast::GenericParam::ConstParam(param) = &last + && let Some(ast::Type::InferType(ty)) = param.ty() + { + let annotation = edit.make_placeholder_snippet(cap); + editor.add_annotation(ty.syntax(), annotation); + } + } + + edit.add_file_edits(ctx.vfs_file_id(), editor); + }) +} + +fn array_index_type(n: &ast::ConstArg) -> Option { + let kind = n.syntax().parent()?.kind(); + + if ast::ArrayType::can_cast(kind) || ast::ArrayExpr::can_cast(kind) { + Some(make::ty("usize")) + } else { + None + } +} + +fn generic_name(generic_list: &ast::GenericParamList, is_const_param: bool) -> Name { + let exist_names = generic_list + .generic_params() + .filter_map(|it| match it { + ast::GenericParam::ConstParam(const_param) => const_param.name(), + ast::GenericParam::TypeParam(type_param) => type_param.name(), + ast::GenericParam::LifetimeParam(_) => None, + }) + .map(|name| name.to_string()) + .collect::>(); + + let mut name_gen = NameGenerator::new_with_names(exist_names.iter().map(|name| name.as_str())); + + make::name(&if is_const_param { + name_gen.suggest_name("N") + } else { + name_gen.suggest_name("T") + }) + .clone_for_update() +} + +fn const_ty(ctx: &AssistContext<'_>, n: &ast::ConstArg) -> ast::Type { + if let Some(expr) = n.expr() + && let Some(ty_info) = ctx.sema.type_of_expr(&expr) + && let Some(builtin) = ty_info.adjusted().as_builtin() + { + make::ty(builtin.name().as_str()) + } else if let Some(array_index_ty) = array_index_type(n) { + array_index_ty + } else { + make::ty_placeholder() + } +} + +fn get_or_create_generic_param_list( + name: &ast::Name, + adt: &impl HasGenericParams, + editor: &mut SyntaxEditor, + edit: &mut SourceChangeBuilder, +) -> ast::GenericParamList { + if let Some(list) = adt.generic_param_list() { + edit.make_mut(list) + } else { + let generic = make::generic_param_list([]).clone_for_update(); + editor.insert(Position::after(name.syntax()), generic.syntax()); + generic + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::check_assist; + + #[test] + fn test_extract_to_default_generic() { + check_assist( + extract_to_default_generic, + r#"type X = ($0i32$0, i64);"#, + r#"type X = (T, i64);"#, + ); + + check_assist( + extract_to_default_generic, + r#"type X = ($0i32$0, T);"#, + r#"type X = (T1, T);"#, + ); + } + + #[test] + fn test_extract_to_default_generic_on_adt() { + check_assist( + extract_to_default_generic, + r#"struct Foo($0i32$0);"#, + r#"struct Foo(T);"#, + ); + + check_assist( + extract_to_default_generic, + r#"struct Foo(T, $0i32$0);"#, + r#"struct Foo(T, T1);"#, + ); + + check_assist( + extract_to_default_generic, + r#"enum Foo { A($0i32$0), B, C(i64) };"#, + r#"enum Foo { A(T), B, C(i64) };"#, + ); + } + + #[test] + fn test_extract_to_generic_on_fn() { + check_assist( + extract_to_default_generic, + r#"fn foo(x: $0i32$0) {}"#, + r#"fn foo(x: T) {}"#, + ); + + check_assist( + extract_to_default_generic, + r#"fn foo(x: [i32; $02$0]) {}"#, + r#"fn foo(x: [i32; N]) {}"#, + ); + } + + #[test] + fn test_extract_to_default_generic_const() { + check_assist( + extract_to_default_generic, + r#"type A = [i32; $08$0];"#, + r#"type A = [i32; N];"#, + ); + + check_assist( + extract_to_default_generic, + r#"type A = [T; $08$0];"#, + r#"type A = [T; N];"#, + ); + } + + #[test] + fn test_extract_to_default_generic_const_non_array() { + check_assist( + extract_to_default_generic, + r#" + struct Foo([(); N]); + type A = Foo<$08$0>; + "#, + r#" + struct Foo([(); N]); + type A = Foo; + "#, + ); + } +} diff --git a/crates/ide-assists/src/lib.rs b/crates/ide-assists/src/lib.rs index 4682c0473238..1e6ab33a4a02 100644 --- a/crates/ide-assists/src/lib.rs +++ b/crates/ide-assists/src/lib.rs @@ -146,6 +146,7 @@ mod handlers { mod extract_function; mod extract_module; mod extract_struct_from_enum_variant; + mod extract_to_default_generic; mod extract_type_alias; mod extract_variable; mod fix_visibility; @@ -281,6 +282,7 @@ mod handlers { extract_expressions_from_format_string::extract_expressions_from_format_string, extract_struct_from_enum_variant::extract_struct_from_enum_variant, extract_type_alias::extract_type_alias, + extract_to_default_generic::extract_to_default_generic, fix_visibility::fix_visibility, flip_binexpr::flip_binexpr, flip_comma::flip_comma, diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs index 91348be97eb7..267e13daf750 100644 --- a/crates/ide-assists/src/tests/generated.rs +++ b/crates/ide-assists/src/tests/generated.rs @@ -1187,6 +1187,19 @@ enum A { One(One) } ) } +#[test] +fn doctest_extract_to_default_generic() { + check_doc_test( + "extract_to_default_generic", + r#####" +struct Foo(u32, $0String$0); +"#####, + r#####" +struct Foo(u32, T); +"#####, + ) +} + #[test] fn doctest_extract_type_alias() { check_doc_test( diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index c5ca60976018..f29490344fe0 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs @@ -1085,10 +1085,27 @@ pub fn type_param(name: ast::Name, bounds: Option) -> ast::T ast_from_text(&format!("fn f<{name}{bounds}>() {{ }}")) } +pub fn type_default_param( + name: ast::Name, + bounds: Option, + default: ast::Type, +) -> ast::TypeParam { + let bounds = bounds.map_or_else(String::new, |it| format!(": {it}")); + ast_from_text(&format!("fn f<{name}{bounds} = {default}>() {{ }}")) +} + pub fn const_param(name: ast::Name, ty: ast::Type) -> ast::ConstParam { ast_from_text(&format!("fn f() {{ }}")) } +pub fn const_default_param( + name: ast::Name, + ty: ast::Type, + default: ast::ConstArg, +) -> ast::ConstParam { + ast_from_text(&format!("fn f() {{ }}")) +} + pub fn lifetime_param(lifetime: ast::Lifetime) -> ast::LifetimeParam { ast_from_text(&format!("fn f<{lifetime}>() {{ }}")) } diff --git a/crates/syntax/src/ast/traits.rs b/crates/syntax/src/ast/traits.rs index 5290f32dd27d..057accec2f2e 100644 --- a/crates/syntax/src/ast/traits.rs +++ b/crates/syntax/src/ast/traits.rs @@ -161,3 +161,4 @@ impl Iterator for AttrDocCommentIter { } impl HasName for Either {} +impl HasGenericParams for Either {}