Skip to content

Commit 18564b2

Browse files
committed
add methods to project MemoRef
1 parent 74bb1a2 commit 18564b2

File tree

11 files changed

+376
-50
lines changed

11 files changed

+376
-50
lines changed

crates/graphql_network_protocol/src/graphql_network_protocol.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ impl NetworkProtocol for GraphQLNetworkProtocol {
7474
let mut graphql_root_types = None;
7575

7676
let (type_system_document, type_system_extension_documents) =
77-
parse_graphql_schema(db).to_owned()?;
77+
parse_graphql_schema(db).try_ok()?.split();
7878

7979
let (mut result, mut directives, mut refetch_fields) =
8080
process_graphql_type_system_document(

crates/graphql_network_protocol/src/read_schema.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub fn parse_graphql_schema<TNetworkProtocol: NetworkProtocol + 'static>(
1313
db: &IsographDatabase<TNetworkProtocol>,
1414
) -> Result<
1515
(
16-
MemoRef<GraphQLTypeSystemDocument>,
16+
GraphQLTypeSystemDocument,
1717
BTreeMap<RelativePathToSourceFile, MemoRef<GraphQLTypeSystemExtensionDocument>>,
1818
),
1919
WithLocation<SchemaParseError>,
@@ -35,18 +35,18 @@ pub fn parse_graphql_schema<TNetworkProtocol: NetworkProtocol + 'static>(
3535
.iter()
3636
{
3737
let extensions_document =
38-
parse_schema_extensions_file(db, *schema_extension_source_id).to_owned()?;
38+
parse_schema_extensions_file(db, *schema_extension_source_id).try_ok()?;
3939
schema_extensions.insert(*relative_path, extensions_document);
4040
}
4141

42-
Ok((db.intern(schema), schema_extensions))
42+
Ok((schema, schema_extensions))
4343
}
4444

4545
#[memo]
4646
pub fn parse_schema_extensions_file<TNetworkProtocol: NetworkProtocol + 'static>(
4747
db: &IsographDatabase<TNetworkProtocol>,
4848
schema_extension_source_id: SourceId<SchemaSource>,
49-
) -> Result<MemoRef<GraphQLTypeSystemExtensionDocument>, WithLocation<SchemaParseError>> {
49+
) -> Result<GraphQLTypeSystemExtensionDocument, WithLocation<SchemaParseError>> {
5050
let SchemaSource {
5151
content,
5252
text_source,
@@ -55,5 +55,5 @@ pub fn parse_schema_extensions_file<TNetworkProtocol: NetworkProtocol + 'static>
5555
let schema_extensions = parse_schema_extensions(content, *text_source)
5656
.map_err(|with_span| with_span.to_with_location(*text_source))?;
5757

58-
Ok(db.intern(schema_extensions))
58+
Ok(schema_extensions)
5959
}

crates/pico/src/view.rs renamed to crates/pico/src/field_view.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@ pub trait Counter: Singleton + Default + Copy + Eq + 'static {
66
fn increment(self) -> Self;
77
}
88

9-
pub type Projector<Db, T> = for<'a> fn(&'a Db) -> &'a T;
9+
pub type FieldProjector<Db, T> = for<'a> fn(&'a Db) -> &'a T;
1010

11-
pub struct View<'a, Db: Database, T, C: Counter> {
11+
pub struct FieldView<'a, Db: Database, T, C: Counter> {
1212
db: &'a Db,
13-
projector: Projector<Db, T>,
13+
projector: FieldProjector<Db, T>,
1414
phantom: PhantomData<C>,
1515
}
1616

17-
impl<'a, Db: Database, T, C: Counter> View<'a, Db, T, C> {
18-
pub fn new(db: &'a Db, projector: Projector<Db, T>) -> Self {
17+
impl<'a, Db: Database, T, C: Counter> FieldView<'a, Db, T, C> {
18+
pub fn new(db: &'a Db, projector: FieldProjector<Db, T>) -> Self {
1919
Self {
2020
db,
2121
projector,
@@ -44,16 +44,16 @@ impl<'a, Db: Database, T, C: Counter> View<'a, Db, T, C> {
4444
}
4545
}
4646

47-
pub type ProjectorMut<Db, T> = for<'a> fn(&'a mut Db) -> &'a mut T;
47+
pub type FieldProjectorMut<Db, T> = for<'a> fn(&'a mut Db) -> &'a mut T;
4848

49-
pub struct MutView<'a, Db: Database, T, C: Counter> {
49+
pub struct FieldViewMut<'a, Db: Database, T, C: Counter> {
5050
db: &'a mut Db,
51-
projector: ProjectorMut<Db, T>,
51+
projector: FieldProjectorMut<Db, T>,
5252
phantom: PhantomData<C>,
5353
}
5454

55-
impl<'a, Db: Database, T, C: Counter> MutView<'a, Db, T, C> {
56-
pub fn new(db: &'a mut Db, projector: ProjectorMut<Db, T>) -> Self {
55+
impl<'a, Db: Database, T, C: Counter> FieldViewMut<'a, Db, T, C> {
56+
pub fn new(db: &'a mut Db, projector: FieldProjectorMut<Db, T>) -> Self {
5757
Self {
5858
db,
5959
projector,

crates/pico/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@ mod derived_node;
44
mod dyn_eq;
55
mod epoch;
66
mod execute_memoized_function;
7+
mod field_view;
78
mod garbage_collection;
89
mod index;
910
mod intern;
1011
pub mod macro_fns;
1112
mod memo_ref;
1213
mod retained_query;
1314
mod source;
14-
mod view;
1515

1616
pub use database::*;
1717
pub use derived_node::*;
1818
pub use dyn_eq::*;
1919
pub use execute_memoized_function::*;
20+
pub use field_view::*;
2021
pub use intern::*;
2122
pub use memo_ref::*;
2223
pub use retained_query::*;
2324
pub use source::*;
24-
pub use view::*;

crates/pico/src/memo_ref.rs

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,65 @@
1-
use std::{marker::PhantomData, ops::Deref};
1+
use std::{
2+
any::Any,
3+
hash::{Hash, Hasher},
4+
marker::PhantomData,
5+
ops::Deref,
6+
};
27

8+
use crate::{DatabaseDyn, DerivedNodeId, ParamId, dependency::NodeKind};
39
use intern::InternId;
410

5-
use crate::{DatabaseDyn, DerivedNodeId, ParamId, dependency::NodeKind};
11+
type MemoRefProjector<T> = for<'a> fn(&'a dyn Any) -> &'a T;
12+
13+
#[inline(always)]
14+
fn project_identity<T: 'static>(value: &dyn Any) -> &T {
15+
value
16+
.downcast_ref::<T>()
17+
.expect("MemoRef: underlying value has unexpected type")
18+
}
19+
20+
#[inline(always)]
21+
fn project_result_ok<T: 'static, E: 'static>(value: &dyn Any) -> &T {
22+
match value
23+
.downcast_ref::<Result<T, E>>()
24+
.expect("MemoRef<Result<..>>: underlying value has unexpected type")
25+
{
26+
Ok(t) => t,
27+
Err(_) => unreachable!("Ok projection used only after Ok check"),
28+
}
29+
}
30+
31+
#[inline(always)]
32+
fn project_option_some<T: 'static>(value: &dyn Any) -> &T {
33+
match value
34+
.downcast_ref::<Option<T>>()
35+
.expect("MemoRef<Option<..>>: underlying value has unexpected type")
36+
{
37+
Some(t) => t,
38+
None => unreachable!("Some projection used only after Some check"),
39+
}
40+
}
41+
42+
#[inline(always)]
43+
fn project_tuple_0<T0: 'static, T1: 'static>(value: &dyn Any) -> &T0 {
44+
let (t0, _t1) = value
45+
.downcast_ref::<(T0, T1)>()
46+
.expect("MemoRef<(..)>: underlying value has unexpected type");
47+
t0
48+
}
49+
50+
#[inline(always)]
51+
fn project_tuple_1<T0: 'static, T1: 'static>(value: &dyn Any) -> &T1 {
52+
let (_t0, t1) = value
53+
.downcast_ref::<(T0, T1)>()
54+
.expect("MemoRef<(..)>: underlying value has unexpected type");
55+
t1
56+
}
657

758
#[derive(Debug)]
859
pub struct MemoRef<T> {
960
pub(crate) db: *const dyn DatabaseDyn,
1061
pub(crate) derived_node_id: DerivedNodeId,
62+
projector: MemoRefProjector<T>,
1163
phantom: PhantomData<T>,
1264
}
1365

@@ -27,12 +79,22 @@ impl<T> PartialEq for MemoRef<T> {
2779

2880
impl<T> Eq for MemoRef<T> {}
2981

82+
#[allow(clippy::unnecessary_cast)]
83+
impl<T> Hash for MemoRef<T> {
84+
fn hash<H: Hasher>(&self, state: &mut H) {
85+
let data_ptr = self.db as *const dyn DatabaseDyn as *const ();
86+
data_ptr.hash(state);
87+
self.derived_node_id.hash(state);
88+
}
89+
}
90+
3091
#[allow(clippy::unnecessary_cast)]
3192
impl<T: 'static + Clone> MemoRef<T> {
3293
pub fn new(db: &dyn DatabaseDyn, derived_node_id: DerivedNodeId) -> Self {
3394
Self {
3495
db: db as *const _ as *const dyn DatabaseDyn,
3596
derived_node_id,
97+
projector: project_identity::<T>,
3698
phantom: PhantomData,
3799
}
38100
}
@@ -63,6 +125,55 @@ impl<T: 'static> Deref for MemoRef<T> {
63125
NodeKind::Derived(self.derived_node_id),
64126
revision.time_updated,
65127
);
66-
value.downcast_ref::<T>().unwrap()
128+
(self.projector)(value)
129+
}
130+
}
131+
132+
impl<T: 'static, E: 'static + Clone> MemoRef<Result<T, E>> {
133+
pub fn try_ok(self) -> Result<MemoRef<T>, E> {
134+
match self.deref() {
135+
Ok(_) => Ok(MemoRef {
136+
db: self.db,
137+
derived_node_id: self.derived_node_id,
138+
projector: project_result_ok::<T, E>,
139+
phantom: PhantomData,
140+
}),
141+
Err(err) => Err(err.clone()),
142+
}
143+
}
144+
}
145+
146+
impl<T: 'static> MemoRef<Option<T>> {
147+
pub fn try_some(self) -> Option<MemoRef<T>> {
148+
match self.deref() {
149+
Some(_) => Some(MemoRef {
150+
db: self.db,
151+
derived_node_id: self.derived_node_id,
152+
projector: project_option_some::<T>,
153+
phantom: PhantomData,
154+
}),
155+
None => None,
156+
}
157+
}
158+
}
159+
160+
impl<T0: 'static, T1: 'static> MemoRef<(T0, T1)> {
161+
/// Splits a `MemoRef<(T0, T1)>` into memo references for each element
162+
/// without cloning the underlying tuple elements.
163+
pub fn split(self) -> (MemoRef<T0>, MemoRef<T1>) {
164+
(
165+
MemoRef {
166+
db: self.db,
167+
derived_node_id: self.derived_node_id,
168+
projector: project_tuple_0::<T0, T1>,
169+
phantom: PhantomData,
170+
},
171+
MemoRef {
172+
db: self.db,
173+
derived_node_id: self.derived_node_id,
174+
projector: project_tuple_1::<T0, T1>,
175+
phantom: PhantomData,
176+
},
177+
)
67178
}
68179
}

crates/pico/tests/intern.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,19 @@ enum ProcessInputError {
3939
}
4040

4141
#[memo]
42-
fn first_letter(
43-
db: &TestDatabase,
44-
input_id: SourceId<Input>,
45-
) -> Result<MemoRef<char>, FirstLetterError> {
42+
fn first_letter(db: &TestDatabase, input_id: SourceId<Input>) -> Result<char, FirstLetterError> {
4643
db.get(input_id)
4744
.value
4845
.chars()
4946
.next()
5047
.ok_or(FirstLetterError::EmptyString)
51-
.map(|v| db.intern(v))
5248
}
5349

5450
#[memo]
5551
fn process_input(
5652
db: &TestDatabase,
5753
input_id: SourceId<Input>,
5854
) -> Result<MemoRef<char>, ProcessInputError> {
59-
let result = first_letter(db, input_id).to_owned()?;
55+
let result = first_letter(db, input_id).try_ok()?;
6056
Ok(result)
6157
}

0 commit comments

Comments
 (0)