diff --git a/Cargo.lock b/Cargo.lock index d2091aebb..70ed98f0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -563,6 +563,41 @@ dependencies = [ "typenum", ] +[[package]] +name = "darling" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6b136475da5ef7b6ac596c0e956e37bad51b85b987ff3d5e230e964936736b2" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b44ad32f92b75fb438b04b68547e521a548be8acc339a6dacc4a7121488f53e6" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.98", +] + +[[package]] +name = "darling_macro" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b5be8a7a562d315a5b92a630c30cec6bcf663e6673f00fbb69cca66a6f521b9" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.98", +] + [[package]] name = "dashmap" version = "6.1.0" @@ -1082,6 +1117,12 @@ dependencies = [ "syn 2.0.98", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "1.0.3" @@ -1861,6 +1902,7 @@ dependencies = [ name = "pico_macros" version = "0.3.1" dependencies = [ + "darling", "proc-macro2", "quote", "syn 2.0.98", diff --git a/Cargo.toml b/Cargo.toml index 20e6d50fd..ab0aecda6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ clap = { version = "4.5.18", features = ["derive"] } colored = "2.0.4" colorize = "0.1.0" crossbeam = "0.8" +darling = "0.21.1" dashmap = "6.0.1" lazy_static = "1.4" log = { version = "0.4.17", features = ["kv_unstable", "kv_unstable_std"] } diff --git a/crates/pico/tests/memo_on_db_method.rs b/crates/pico/tests/memo_on_db_method.rs new file mode 100644 index 000000000..a06fe911b --- /dev/null +++ b/crates/pico/tests/memo_on_db_method.rs @@ -0,0 +1,48 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; + +use pico::{Database, SourceId, Storage}; +use pico_macros::{memo, Db, Source}; + +static FIRST_LETTER_COUNTER: AtomicUsize = AtomicUsize::new(0); + +#[derive(Db, Default)] +struct TestDatabase { + pub storage: Storage, +} + +impl TestDatabase { + #[memo] + fn first_letter(&self, input_id: SourceId) -> char { + FIRST_LETTER_COUNTER.fetch_add(1, Ordering::SeqCst); + let input = self.get(input_id); + input.value.chars().next().unwrap() + } +} + +#[test] +fn memo_on_db_method() { + let mut db = TestDatabase::default(); + + let input_id = db.set(Input { + key: "key", + value: "asdf".to_string(), + }); + + assert_eq!(*db.first_letter(input_id), 'a'); + assert_eq!(FIRST_LETTER_COUNTER.load(Ordering::SeqCst), 1); + + db.set(Input { + key: "key", + value: "qwer".to_string(), + }); + + assert_eq!(*db.first_letter(input_id), 'q'); + assert_eq!(FIRST_LETTER_COUNTER.load(Ordering::SeqCst), 2); +} + +#[derive(Debug, Clone, PartialEq, Eq, Source)] +struct Input { + #[key] + pub key: &'static str, + pub value: String, +} diff --git a/crates/pico/tests/memo_on_struct_method.rs b/crates/pico/tests/memo_on_struct_method.rs new file mode 100644 index 000000000..7ad1a3210 --- /dev/null +++ b/crates/pico/tests/memo_on_struct_method.rs @@ -0,0 +1,53 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; + +use pico::{Database, SourceId, Storage}; +use pico_macros::{memo, Db, Source}; + +static FIRST_LETTER_COUNTER: AtomicUsize = AtomicUsize::new(0); + +#[derive(Db, Default)] +struct TestDatabase { + pub storage: Storage, +} + +#[derive(Clone, PartialEq, Eq, Hash)] +struct TestStruct; + +impl TestStruct { + #[memo(db = test_db)] + fn first_letter(&self, test_db: &TestDatabase, input_id: SourceId) -> char { + FIRST_LETTER_COUNTER.fetch_add(1, Ordering::SeqCst); + let input = test_db.get(input_id); + input.value.chars().next().unwrap() + } +} + +#[test] +fn memo_on_struct_method() { + let mut db = TestDatabase::default(); + + let test_struct = TestStruct {}; + + let input_id = db.set(Input { + key: "key", + value: "asdf".to_string(), + }); + + assert_eq!(*test_struct.first_letter(&db, input_id), 'a'); + assert_eq!(FIRST_LETTER_COUNTER.load(Ordering::SeqCst), 1); + + db.set(Input { + key: "key", + value: "qwer".to_string(), + }); + + assert_eq!(*test_struct.first_letter(&db, input_id), 'q'); + assert_eq!(FIRST_LETTER_COUNTER.load(Ordering::SeqCst), 2); +} + +#[derive(Debug, Clone, PartialEq, Eq, Source)] +struct Input { + #[key] + pub key: &'static str, + pub value: String, +} diff --git a/crates/pico/tests/params/memo_ref_never_cloned.rs b/crates/pico/tests/params/memo_ref_never_cloned.rs index 36c2fa96e..e36b8ebc4 100644 --- a/crates/pico/tests/params/memo_ref_never_cloned.rs +++ b/crates/pico/tests/params/memo_ref_never_cloned.rs @@ -39,4 +39,4 @@ fn get_output(_db: &TestDatabase) -> Output { } #[memo] -fn consume_output(db: &TestDatabase, _output: MemoRef) {} +fn consume_output(_db: &TestDatabase, _output: MemoRef) {} diff --git a/crates/pico_macros/Cargo.toml b/crates/pico_macros/Cargo.toml index c470d23cd..523ea4c5d 100644 --- a/crates/pico_macros/Cargo.toml +++ b/crates/pico_macros/Cargo.toml @@ -5,6 +5,7 @@ edition.workspace = true license.workspace = true [dependencies] +darling = { workspace = true } proc-macro2 = { workspace = true } quote = { workspace = true } syn = { workspace = true } diff --git a/crates/pico_macros/src/memo_macro.rs b/crates/pico_macros/src/memo_macro.rs index 91d664ee1..045b22efb 100644 --- a/crates/pico_macros/src/memo_macro.rs +++ b/crates/pico_macros/src/memo_macro.rs @@ -1,14 +1,73 @@ use std::hash::{DefaultHasher, Hash, Hasher}; +use darling::{Error as DarlingError, FromMeta}; use proc_macro::TokenStream; +use proc_macro2::Span; use quote::{quote, ToTokens}; -use syn::{parse_macro_input, parse_quote, Error, FnArg, ItemFn, PatType, ReturnType, Signature}; +use syn::{ + parse, parse_macro_input, parse_quote, visit_mut::VisitMut, Error, Expr, FnArg, Ident, ItemFn, + Lit, Meta, Pat, PatIdent, PatType, ReturnType, Signature, Type, +}; + +#[derive(Debug)] +struct DbArg(pub Ident); + +impl FromMeta for DbArg { + fn from_meta(item: &Meta) -> darling::Result { + match item { + Meta::Path(path) => { + // bare identifier + if let Some(ident) = path.get_ident() { + Ok(DbArg(ident.clone())) + } else { + Err(DarlingError::custom("Expected identifier").with_span(path)) + } + } + Meta::NameValue(nv) => match &nv.value { + Expr::Lit(expr_lit) => { + if let Lit::Str(litstr) = &expr_lit.lit { + let ident = Ident::new(&litstr.value(), litstr.span()); + Ok(DbArg(ident)) + } else { + Err(DarlingError::custom("Expected string literal") + .with_span(&expr_lit.lit)) + } + } + Expr::Path(expr_path) => { + if let Some(segment) = expr_path.path.segments.last() { + Ok(DbArg(segment.ident.clone())) + } else { + Err(DarlingError::custom("Empty path for db").with_span(&expr_path.path)) + } + } + other => { + Err(DarlingError::custom("Unsupported expression for db").with_span(other)) + } + }, + _ => Err(DarlingError::custom("Unsupported meta for db").with_span(item)), + } + } +} + +#[derive(Debug, FromMeta)] +#[darling(derive_syn_parse)] +struct MemoArgs { + #[darling(default)] + db: Option, +} + +pub(crate) fn memo_macro(attr: TokenStream, item: TokenStream) -> TokenStream { + let args_: MemoArgs = match parse(attr) { + Ok(v) => v, + Err(e) => { + return e.to_compile_error().into(); + } + }; -pub(crate) fn memo_macro(_args: TokenStream, item: TokenStream) -> TokenStream { let ItemFn { - sig, + mut sig, vis, - block, + mut block, attrs, } = parse_macro_input!(item as ItemFn); @@ -17,35 +76,80 @@ pub(crate) fn memo_macro(_args: TokenStream, item: TokenStream) -> TokenStream { if sig.inputs.is_empty() { return Error::new_spanned( &sig, - "Memoized function must have at least one argument (&Database)", + "Memoized function must have at least one argument (db or &self)", ) .to_compile_error() .into(); } - let db_arg = match &sig.inputs[0] { - FnArg::Typed(PatType { pat, .. }) => pat, - _ => unreachable!(), + let db_pos = get_db_position(&sig, &args_); + + let (db_arg, closure_db_arg) = match &sig.inputs[db_pos] { + FnArg::Receiver(rcv) => { + if rcv.reference.is_none() { + return Error::new_spanned(rcv, "Receiver must be a reference") + .to_compile_error() + .into(); + } + if rcv.mutability.is_some() { + return Error::new_spanned(rcv, "Receiver should not be mutable") + .to_compile_error() + .into(); + } + (quote!(self), quote!(__self)) + } + FnArg::Typed(PatType { pat, .. }) => { + let tok = pat.to_token_stream(); + (tok.clone(), tok) + } }; - let args = sig.inputs.iter().skip(1).map(|arg| match arg { - FnArg::Typed(PatType { pat, ty, .. }) => (pat, ty), - _ => unreachable!(), - }); + let args = sig + .inputs + .iter() + .cloned() + .enumerate() + .filter_map(|(i, arg)| if db_pos == i { None } else { Some(arg) }) + .map(|arg| match arg { + FnArg::Typed(PatType { pat, ty, .. }) => (pat, ty), + // hack to transform `self`` to fake `__self: &Self`` argument and use it as regular parameter + FnArg::Receiver(_) => { + let pat_ident = Pat::Ident(PatIdent { + attrs: Vec::new(), + by_ref: None, + mutability: None, + ident: Ident::new("__self", Span::call_site()), + subpat: None, + }); + (Box::new(pat_ident), Box::new(parse_quote!(&Self))) + } + }); - let param_ids_blocks = args.clone().map(|(arg, ty)| match ArgType::parse(ty) { + let param_ids_blocks = args.clone().map(|(arg, ty)| match ArgType::parse(&ty) { ArgType::Source | ArgType::MemoRef => { - let param_arg = match **ty { - syn::Type::Reference(_) => quote!((*(#arg))), + let param_arg = match *ty { + Type::Reference(_) => quote!((*(#arg))), _ => quote!(#arg), }; quote! { param_ids.push(#param_arg.into()); } } + ArgType::Receiver => { + let intern_param = match *ty { + Type::Reference(_) => { + quote!(::pico::macro_fns::intern_borrowed_param(#db_arg, self)) + } + _ => unreachable!(), + }; + quote! { + let param_id = #intern_param; + param_ids.push(param_id); + } + } ArgType::Other => { - let intern_param = match **ty { - syn::Type::Reference(_) => { + let intern_param = match *ty { + Type::Reference(_) => { quote!(::pico::macro_fns::intern_borrowed_param(#db_arg, #arg)) } _ => quote!(::pico::macro_fns::intern_owned_param(#db_arg, #arg)), @@ -62,8 +166,7 @@ pub(crate) fn memo_macro(_args: TokenStream, item: TokenStream) -> TokenStream { ReturnType::Default => parse_quote!(()), }; - let mut new_sig = sig.clone(); - new_sig.output = ReturnType::Type( + sig.output = ReturnType::Type( parse_quote!(->), Box::new(parse_quote!(::pico::MemoRef<#return_type>)), ); @@ -71,10 +174,10 @@ pub(crate) fn memo_macro(_args: TokenStream, item: TokenStream) -> TokenStream { let extract_parameters = args .enumerate() .map(|(i, (arg, ty))| { - match ArgType::parse(ty) { + match ArgType::parse(&ty) { ArgType::Source => { - let binding_expr = match **ty { - syn::Type::Reference(_) => quote!(¶m_id.into()), + let binding_expr = match *ty { + Type::Reference(_) => quote!(¶m_id.into()), _ => quote!(param_id.into()), }; quote! { @@ -85,9 +188,9 @@ pub(crate) fn memo_macro(_args: TokenStream, item: TokenStream) -> TokenStream { } } ArgType::MemoRef => { - let binding_expr = match **ty { - syn::Type::Reference(_) => quote!(&::pico::MemoRef::new(#db_arg, param_id.into())), - _ => quote!(::pico::MemoRef::new(#db_arg, param_id.into())), + let binding_expr = match *ty { + Type::Reference(_) => quote!(&::pico::MemoRef::new(#db_arg, param_id.into())), + _ => quote!(::pico::MemoRef::new(#closure_db_arg, param_id.into())), }; quote! { let #arg: #ty = { @@ -96,14 +199,14 @@ pub(crate) fn memo_macro(_args: TokenStream, item: TokenStream) -> TokenStream { }; } } - ArgType::Other => { - let (target_type, binding_expr) = match **ty { - syn::Type::Reference(ref reference) => (&reference.elem, quote!(inner)), - _ => (ty, quote!(inner.clone())), + ArgType::Other | ArgType::Receiver => { + let (target_type, binding_expr) = match *ty { + Type::Reference(ref reference) => (&reference.elem, quote!(inner)), + _ => (&ty, quote!(inner.clone())), }; quote! { let #arg: #ty = { - let param_ref = ::pico::macro_fns::get_param(#db_arg, derived_node_id.params[#i])?; + let param_ref = ::pico::macro_fns::get_param(#closure_db_arg, derived_node_id.params[#i])?; let inner = param_ref .downcast_ref::<#target_type>() .expect("Unexpected param type. This is indicative of a bug in Pico."); @@ -114,9 +217,12 @@ pub(crate) fn memo_macro(_args: TokenStream, item: TokenStream) -> TokenStream { } }); + let mut replacer = IdentReplacer::new("self", "__self"); + replacer.visit_block_mut(&mut block); + let output = quote! { #(#attrs)* - #vis #new_sig { + #vis #sig { let mut param_ids = ::pico::macro_fns::init_param_vec(); #( #param_ids_blocks @@ -125,7 +231,7 @@ pub(crate) fn memo_macro(_args: TokenStream, item: TokenStream) -> TokenStream { let did_recalculate = ::pico::execute_memoized_function( #db_arg, derived_node_id, - ::pico::InnerFn::new(|#db_arg, derived_node_id| { + ::pico::InnerFn::new(|#closure_db_arg, derived_node_id| { use ::pico::Database; #( #extract_parameters @@ -154,30 +260,67 @@ fn hash(input: &Signature) -> u64 { enum ArgType { Source, MemoRef, + Receiver, Other, } impl ArgType { - pub fn parse(ty: &syn::Type) -> Self { + pub fn parse(ty: &Type) -> Self { if type_is(ty, "SourceId") { return ArgType::Source; } if type_is(ty, "MemoRef") { return ArgType::MemoRef; } + if type_is(ty, "Self") { + return ArgType::Receiver; + } ArgType::Other } } -fn type_is(ty: &syn::Type, target: &'static str) -> bool { +fn type_is(ty: &Type, target: &'static str) -> bool { let inner = match ty { - syn::Type::Reference(r) => &*r.elem, + Type::Reference(r) => &*r.elem, _ => ty, }; - if let syn::Type::Path(type_path) = inner { + if let Type::Path(type_path) = inner { if let Some(segment) = type_path.path.segments.last() { return segment.ident == target; } } false } + +fn get_db_position(sig: &Signature, args: &MemoArgs) -> usize { + args.db + .as_ref() + .and_then(|db_arg| { + sig.inputs.iter().position(|arg| match arg { + FnArg::Typed(PatType { pat, .. }) => { + matches!(&**pat, Pat::Ident(pi) if pi.ident == db_arg.0) + } + _ => false, + }) + }) + .unwrap_or(0) +} + +struct IdentReplacer { + pub from: &'static str, + pub to: &'static str, +} + +impl IdentReplacer { + pub fn new(from: &'static str, to: &'static str) -> Self { + Self { from, to } + } +} + +impl VisitMut for IdentReplacer { + fn visit_ident_mut(&mut self, ident: &mut Ident) { + if ident == self.from { + *ident = Ident::new(self.to, ident.span()); + } + } +}