diff options
Diffstat (limited to 'crates/salsa')
-rw-r--r-- | crates/salsa/src/lib.rs | 91 | ||||
-rw-r--r-- | crates/salsa/tests/integration.rs | 25 |
2 files changed, 61 insertions, 55 deletions
diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs index 69c7b35fa..a54f2a06f 100644 --- a/crates/salsa/src/lib.rs +++ b/crates/salsa/src/lib.rs | |||
@@ -3,53 +3,52 @@ extern crate parking_lot; | |||
3 | 3 | ||
4 | use std::{ | 4 | use std::{ |
5 | sync::Arc, | 5 | sync::Arc, |
6 | any::Any, | ||
7 | collections::HashMap, | 6 | collections::HashMap, |
8 | cell::RefCell, | 7 | cell::RefCell, |
9 | }; | 8 | }; |
10 | use parking_lot::Mutex; | 9 | use parking_lot::Mutex; |
11 | 10 | ||
12 | type GroundQueryFn<T> = fn(&T, &(Any + Send + Sync + 'static)) -> (Box<Any + Send + Sync + 'static>, OutputFingerprint); | 11 | type GroundQueryFn<T, D> = fn(&T, &D) -> (D, OutputFingerprint); |
13 | type QueryFn<T> = fn(&QueryCtx<T>, &(Any + Send + Sync + 'static)) -> (Box<Any + Send + Sync + 'static>, OutputFingerprint); | 12 | type QueryFn<T, D> = fn(&QueryCtx<T, D>, &D) -> (D, OutputFingerprint); |
14 | 13 | ||
15 | #[derive(Debug)] | 14 | #[derive(Debug)] |
16 | pub struct Db<T> { | 15 | pub struct Db<T, D> { |
17 | db: Arc<DbState<T>>, | 16 | db: Arc<DbState<T, D>>, |
18 | query_config: Arc<QueryConfig<T>>, | 17 | query_config: Arc<QueryConfig<T, D>>, |
19 | } | 18 | } |
20 | 19 | ||
21 | pub struct QueryConfig<T> { | 20 | pub struct QueryConfig<T, D> { |
22 | ground_fn: HashMap<QueryTypeId, GroundQueryFn<T>>, | 21 | ground_fn: HashMap<QueryTypeId, GroundQueryFn<T, D>>, |
23 | query_fn: HashMap<QueryTypeId, QueryFn<T>>, | 22 | query_fn: HashMap<QueryTypeId, QueryFn<T, D>>, |
24 | } | 23 | } |
25 | 24 | ||
26 | impl<T> ::std::fmt::Debug for QueryConfig<T> { | 25 | impl<T, D> ::std::fmt::Debug for QueryConfig<T, D> { |
27 | fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { | 26 | fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { |
28 | ::std::fmt::Display::fmt("QueryConfig { ... }", f) | 27 | ::std::fmt::Display::fmt("QueryConfig { ... }", f) |
29 | } | 28 | } |
30 | } | 29 | } |
31 | 30 | ||
32 | #[derive(Debug)] | 31 | #[derive(Debug)] |
33 | struct DbState<T> { | 32 | struct DbState<T, D> { |
34 | ground_data: T, | 33 | ground_data: T, |
35 | gen: Gen, | 34 | gen: Gen, |
36 | graph: Mutex<im::HashMap<QueryId, (Gen, Arc<QueryRecord>)>>, | 35 | graph: Mutex<im::HashMap<QueryId, (Gen, Arc<QueryRecord<D>>)>>, |
37 | } | 36 | } |
38 | 37 | ||
39 | #[derive(Debug)] | 38 | #[derive(Debug)] |
40 | struct QueryRecord { | 39 | struct QueryRecord<D> { |
41 | params: Arc<Any + Send + Sync + 'static>, | 40 | params: D, |
42 | output: Arc<Any + Send + Sync + 'static>, | 41 | output: D, |
43 | output_fingerprint: OutputFingerprint, | 42 | output_fingerprint: OutputFingerprint, |
44 | deps: Vec<(QueryId, OutputFingerprint)>, | 43 | deps: Vec<(QueryId, OutputFingerprint)>, |
45 | } | 44 | } |
46 | 45 | ||
47 | impl<T> DbState<T> { | 46 | impl<T, D> DbState<T, D> { |
48 | fn record( | 47 | fn record( |
49 | &self, | 48 | &self, |
50 | query_id: QueryId, | 49 | query_id: QueryId, |
51 | params: Arc<Any + Send + Sync + 'static>, | 50 | params: D, |
52 | output: Arc<Any + Send + Sync + 'static>, | 51 | output: D, |
53 | output_fingerprint: OutputFingerprint, | 52 | output_fingerprint: OutputFingerprint, |
54 | deps: Vec<(QueryId, OutputFingerprint)>, | 53 | deps: Vec<(QueryId, OutputFingerprint)>, |
55 | ) { | 54 | ) { |
@@ -64,7 +63,7 @@ impl<T> DbState<T> { | |||
64 | } | 63 | } |
65 | } | 64 | } |
66 | 65 | ||
67 | impl<T> QueryConfig<T> { | 66 | impl<T, D> QueryConfig<T, D> { |
68 | pub fn new() -> Self { | 67 | pub fn new() -> Self { |
69 | QueryConfig { | 68 | QueryConfig { |
70 | ground_fn: HashMap::new(), | 69 | ground_fn: HashMap::new(), |
@@ -74,7 +73,7 @@ impl<T> QueryConfig<T> { | |||
74 | pub fn with_ground_query( | 73 | pub fn with_ground_query( |
75 | mut self, | 74 | mut self, |
76 | query_type: QueryTypeId, | 75 | query_type: QueryTypeId, |
77 | query_fn: GroundQueryFn<T> | 76 | query_fn: GroundQueryFn<T, D> |
78 | ) -> Self { | 77 | ) -> Self { |
79 | let prev = self.ground_fn.insert(query_type, query_fn); | 78 | let prev = self.ground_fn.insert(query_type, query_fn); |
80 | assert!(prev.is_none()); | 79 | assert!(prev.is_none()); |
@@ -83,7 +82,7 @@ impl<T> QueryConfig<T> { | |||
83 | pub fn with_query( | 82 | pub fn with_query( |
84 | mut self, | 83 | mut self, |
85 | query_type: QueryTypeId, | 84 | query_type: QueryTypeId, |
86 | query_fn: QueryFn<T>, | 85 | query_fn: QueryFn<T, D>, |
87 | ) -> Self { | 86 | ) -> Self { |
88 | let prev = self.query_fn.insert(query_type, query_fn); | 87 | let prev = self.query_fn.insert(query_type, query_fn); |
89 | assert!(prev.is_none()); | 88 | assert!(prev.is_none()); |
@@ -91,15 +90,18 @@ impl<T> QueryConfig<T> { | |||
91 | } | 90 | } |
92 | } | 91 | } |
93 | 92 | ||
94 | pub struct QueryCtx<T> { | 93 | pub struct QueryCtx<T, D> { |
95 | db: Arc<DbState<T>>, | 94 | db: Arc<DbState<T, D>>, |
96 | query_config: Arc<QueryConfig<T>>, | 95 | query_config: Arc<QueryConfig<T, D>>, |
97 | stack: RefCell<Vec<Vec<(QueryId, OutputFingerprint)>>>, | 96 | stack: RefCell<Vec<Vec<(QueryId, OutputFingerprint)>>>, |
98 | executed: RefCell<Vec<QueryTypeId>>, | 97 | executed: RefCell<Vec<QueryTypeId>>, |
99 | } | 98 | } |
100 | 99 | ||
101 | impl<T> QueryCtx<T> { | 100 | impl<T, D> QueryCtx<T, D> |
102 | fn new(db: &Db<T>) -> QueryCtx<T> { | 101 | where |
102 | D: Clone | ||
103 | { | ||
104 | fn new(db: &Db<T, D>) -> QueryCtx<T, D> { | ||
103 | QueryCtx { | 105 | QueryCtx { |
104 | db: Arc::clone(&db.db), | 106 | db: Arc::clone(&db.db), |
105 | query_config: Arc::clone(&db.query_config), | 107 | query_config: Arc::clone(&db.query_config), |
@@ -110,8 +112,8 @@ impl<T> QueryCtx<T> { | |||
110 | pub fn get( | 112 | pub fn get( |
111 | &self, | 113 | &self, |
112 | query_id: QueryId, | 114 | query_id: QueryId, |
113 | params: Arc<Any + Send + Sync + 'static>, | 115 | params: D, |
114 | ) -> Arc<Any + Send + Sync + 'static> { | 116 | ) -> D { |
115 | let (res, output_fingerprint) = self.get_inner(query_id, params); | 117 | let (res, output_fingerprint) = self.get_inner(query_id, params); |
116 | self.record_dep(query_id, output_fingerprint); | 118 | self.record_dep(query_id, output_fingerprint); |
117 | res | 119 | res |
@@ -120,8 +122,8 @@ impl<T> QueryCtx<T> { | |||
120 | pub fn get_inner( | 122 | pub fn get_inner( |
121 | &self, | 123 | &self, |
122 | query_id: QueryId, | 124 | query_id: QueryId, |
123 | params: Arc<Any + Send + Sync + 'static>, | 125 | params: D, |
124 | ) -> (Arc<Any + Send + Sync + 'static>, OutputFingerprint) { | 126 | ) -> (D, OutputFingerprint) { |
125 | let (gen, record) = { | 127 | let (gen, record) = { |
126 | let guard = self.db.graph.lock(); | 128 | let guard = self.db.graph.lock(); |
127 | match guard.get(&query_id).map(|it| it.clone()){ | 129 | match guard.get(&query_id).map(|it| it.clone()){ |
@@ -139,7 +141,7 @@ impl<T> QueryCtx<T> { | |||
139 | return self.force(query_id, params); | 141 | return self.force(query_id, params); |
140 | } | 142 | } |
141 | for (dep_query_id, prev_fingerprint) in record.deps.iter().cloned() { | 143 | for (dep_query_id, prev_fingerprint) in record.deps.iter().cloned() { |
142 | let dep_params: Arc<Any + Send + Sync + 'static> = { | 144 | let dep_params: D = { |
143 | let guard = self.db.graph.lock(); | 145 | let guard = self.db.graph.lock(); |
144 | guard[&dep_query_id] | 146 | guard[&dep_query_id] |
145 | .1 | 147 | .1 |
@@ -160,29 +162,29 @@ impl<T> QueryCtx<T> { | |||
160 | fn force( | 162 | fn force( |
161 | &self, | 163 | &self, |
162 | query_id: QueryId, | 164 | query_id: QueryId, |
163 | params: Arc<Any + Send + Sync + 'static>, | 165 | params: D, |
164 | ) -> (Arc<Any + Send + Sync + 'static>, OutputFingerprint) { | 166 | ) -> (D, OutputFingerprint) { |
165 | self.executed.borrow_mut().push(query_id.0); | 167 | self.executed.borrow_mut().push(query_id.0); |
166 | self.stack.borrow_mut().push(Vec::new()); | 168 | self.stack.borrow_mut().push(Vec::new()); |
167 | 169 | ||
168 | let (res, output_fingerprint) = if let Some(f) = self.ground_query_fn_by_type(query_id.0) { | 170 | let (res, output_fingerprint) = if let Some(f) = self.ground_query_fn_by_type(query_id.0) { |
169 | f(&self.db.ground_data, &*params) | 171 | f(&self.db.ground_data, ¶ms) |
170 | } else if let Some(f) = self.query_fn_by_type(query_id.0) { | 172 | } else if let Some(f) = self.query_fn_by_type(query_id.0) { |
171 | f(self, &*params) | 173 | f(self, ¶ms) |
172 | } else { | 174 | } else { |
173 | panic!("unknown query type: {:?}", query_id.0); | 175 | panic!("unknown query type: {:?}", query_id.0); |
174 | }; | 176 | }; |
175 | 177 | ||
176 | let res: Arc<Any + Send + Sync + 'static> = res.into(); | 178 | let res: D = res.into(); |
177 | 179 | ||
178 | let deps = self.stack.borrow_mut().pop().unwrap(); | 180 | let deps = self.stack.borrow_mut().pop().unwrap(); |
179 | self.db.record(query_id, params, res.clone(), output_fingerprint, deps); | 181 | self.db.record(query_id, params, res.clone(), output_fingerprint, deps); |
180 | (res, output_fingerprint) | 182 | (res, output_fingerprint) |
181 | } | 183 | } |
182 | fn ground_query_fn_by_type(&self, query_type: QueryTypeId) -> Option<GroundQueryFn<T>> { | 184 | fn ground_query_fn_by_type(&self, query_type: QueryTypeId) -> Option<GroundQueryFn<T, D>> { |
183 | self.query_config.ground_fn.get(&query_type).map(|&it| it) | 185 | self.query_config.ground_fn.get(&query_type).map(|&it| it) |
184 | } | 186 | } |
185 | fn query_fn_by_type(&self, query_type: QueryTypeId) -> Option<QueryFn<T>> { | 187 | fn query_fn_by_type(&self, query_type: QueryTypeId) -> Option<QueryFn<T, D>> { |
186 | self.query_config.query_fn.get(&query_type).map(|&it| it) | 188 | self.query_config.query_fn.get(&query_type).map(|&it| it) |
187 | } | 189 | } |
188 | fn record_dep( | 190 | fn record_dep( |
@@ -196,15 +198,18 @@ impl<T> QueryCtx<T> { | |||
196 | } | 198 | } |
197 | } | 199 | } |
198 | 200 | ||
199 | impl<T> Db<T> { | 201 | impl<T, D> Db<T, D> |
200 | pub fn new(query_config: QueryConfig<T>, ground_data: T) -> Db<T> { | 202 | where |
203 | D: Clone | ||
204 | { | ||
205 | pub fn new(query_config: QueryConfig<T, D>, ground_data: T) -> Db<T, D> { | ||
201 | Db { | 206 | Db { |
202 | db: Arc::new(DbState { ground_data, gen: Gen(0), graph: Default::default() }), | 207 | db: Arc::new(DbState { ground_data, gen: Gen(0), graph: Default::default() }), |
203 | query_config: Arc::new(query_config), | 208 | query_config: Arc::new(query_config), |
204 | } | 209 | } |
205 | } | 210 | } |
206 | 211 | ||
207 | pub fn with_ground_data(&self, ground_data: T) -> Db<T> { | 212 | pub fn with_ground_data(&self, ground_data: T) -> Db<T, D> { |
208 | let gen = Gen(self.db.gen.0 + 1); | 213 | let gen = Gen(self.db.gen.0 + 1); |
209 | let graph = self.db.graph.lock().clone(); | 214 | let graph = self.db.graph.lock().clone(); |
210 | let graph = Mutex::new(graph); | 215 | let graph = Mutex::new(graph); |
@@ -216,8 +221,8 @@ impl<T> Db<T> { | |||
216 | pub fn get( | 221 | pub fn get( |
217 | &self, | 222 | &self, |
218 | query_id: QueryId, | 223 | query_id: QueryId, |
219 | params: Box<Any + Send + Sync + 'static>, | 224 | params: D, |
220 | ) -> (Arc<Any + Send + Sync + 'static>, Vec<QueryTypeId>) { | 225 | ) -> (D, Vec<QueryTypeId>) { |
221 | let ctx = QueryCtx::new(self); | 226 | let ctx = QueryCtx::new(self); |
222 | let res = ctx.get(query_id, params.into()); | 227 | let res = ctx.get(query_id, params.into()); |
223 | let executed = ::std::mem::replace(&mut *ctx.executed.borrow_mut(), Vec::new()); | 228 | let executed = ::std::mem::replace(&mut *ctx.executed.borrow_mut(), Vec::new()); |
diff --git a/crates/salsa/tests/integration.rs b/crates/salsa/tests/integration.rs index 7241eca38..2872d3913 100644 --- a/crates/salsa/tests/integration.rs +++ b/crates/salsa/tests/integration.rs | |||
@@ -7,6 +7,7 @@ use std::{ | |||
7 | }; | 7 | }; |
8 | 8 | ||
9 | type State = HashMap<u32, String>; | 9 | type State = HashMap<u32, String>; |
10 | type Data = Arc<Any + Send + Sync + 'static>; | ||
10 | const GET_TEXT: salsa::QueryTypeId = salsa::QueryTypeId(1); | 11 | const GET_TEXT: salsa::QueryTypeId = salsa::QueryTypeId(1); |
11 | const GET_FILES: salsa::QueryTypeId = salsa::QueryTypeId(2); | 12 | const GET_FILES: salsa::QueryTypeId = salsa::QueryTypeId(2); |
12 | const FILE_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(3); | 13 | const FILE_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(3); |
@@ -14,9 +15,9 @@ const TOTAL_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(4); | |||
14 | 15 | ||
15 | fn mk_ground_query<T, R>( | 16 | fn mk_ground_query<T, R>( |
16 | state: &State, | 17 | state: &State, |
17 | params: &(Any + Send + Sync + 'static), | 18 | params: &Data, |
18 | f: fn(&State, &T) -> R, | 19 | f: fn(&State, &T) -> R, |
19 | ) -> (Box<Any + Send + Sync + 'static>, salsa::OutputFingerprint) | 20 | ) -> (Data, salsa::OutputFingerprint) |
20 | where | 21 | where |
21 | T: 'static, | 22 | T: 'static, |
22 | R: Hash + Send + Sync + 'static, | 23 | R: Hash + Send + Sync + 'static, |
@@ -24,21 +25,21 @@ where | |||
24 | let params = params.downcast_ref().unwrap(); | 25 | let params = params.downcast_ref().unwrap(); |
25 | let result = f(state, params); | 26 | let result = f(state, params); |
26 | let fingerprint = o_print(&result); | 27 | let fingerprint = o_print(&result); |
27 | (Box::new(result), fingerprint) | 28 | (Arc::new(result), fingerprint) |
28 | } | 29 | } |
29 | 30 | ||
30 | fn get<T, R>(db: &salsa::Db<State>, query_type: salsa::QueryTypeId, param: T) -> (Arc<R>, Vec<salsa::QueryTypeId>) | 31 | fn get<T, R>(db: &salsa::Db<State, Data>, query_type: salsa::QueryTypeId, param: T) -> (Arc<R>, Vec<salsa::QueryTypeId>) |
31 | where | 32 | where |
32 | T: Hash + Send + Sync + 'static, | 33 | T: Hash + Send + Sync + 'static, |
33 | R: Send + Sync + 'static, | 34 | R: Send + Sync + 'static, |
34 | { | 35 | { |
35 | let i_print = i_print(¶m); | 36 | let i_print = i_print(¶m); |
36 | let param = Box::new(param); | 37 | let param = Arc::new(param); |
37 | let (res, trace) = db.get(salsa::QueryId(query_type, i_print), param); | 38 | let (res, trace) = db.get(salsa::QueryId(query_type, i_print), param); |
38 | (res.downcast().unwrap(), trace) | 39 | (res.downcast().unwrap(), trace) |
39 | } | 40 | } |
40 | 41 | ||
41 | struct QueryCtx<'a>(&'a salsa::QueryCtx<State>); | 42 | struct QueryCtx<'a>(&'a salsa::QueryCtx<State, Data>); |
42 | 43 | ||
43 | impl<'a> QueryCtx<'a> { | 44 | impl<'a> QueryCtx<'a> { |
44 | fn get_text(&self, id: u32) -> Arc<String> { | 45 | fn get_text(&self, id: u32) -> Arc<String> { |
@@ -60,10 +61,10 @@ impl<'a> QueryCtx<'a> { | |||
60 | } | 61 | } |
61 | 62 | ||
62 | fn mk_query<T, R>( | 63 | fn mk_query<T, R>( |
63 | query_ctx: &salsa::QueryCtx<State>, | 64 | query_ctx: &salsa::QueryCtx<State, Data>, |
64 | params: &(Any + Send + Sync + 'static), | 65 | params: &Data, |
65 | f: fn(QueryCtx, &T) -> R, | 66 | f: fn(QueryCtx, &T) -> R, |
66 | ) -> (Box<Any + Send + Sync + 'static>, salsa::OutputFingerprint) | 67 | ) -> (Data, salsa::OutputFingerprint) |
67 | where | 68 | where |
68 | T: 'static, | 69 | T: 'static, |
69 | R: Hash + Send + Sync + 'static, | 70 | R: Hash + Send + Sync + 'static, |
@@ -72,11 +73,11 @@ where | |||
72 | let query_ctx = QueryCtx(query_ctx); | 73 | let query_ctx = QueryCtx(query_ctx); |
73 | let result = f(query_ctx, params); | 74 | let result = f(query_ctx, params); |
74 | let fingerprint = o_print(&result); | 75 | let fingerprint = o_print(&result); |
75 | (Box::new(result), fingerprint) | 76 | (Arc::new(result), fingerprint) |
76 | } | 77 | } |
77 | 78 | ||
78 | fn mk_queries() -> salsa::QueryConfig<State> { | 79 | fn mk_queries() -> salsa::QueryConfig<State, Data> { |
79 | salsa::QueryConfig::<State>::new() | 80 | salsa::QueryConfig::<State, Data>::new() |
80 | .with_ground_query(GET_TEXT, |state, id| { | 81 | .with_ground_query(GET_TEXT, |state, id| { |
81 | mk_ground_query::<u32, String>(state, id, |state, id| state[id].clone()) | 82 | mk_ground_query::<u32, String>(state, id, |state, id| state[id].clone()) |
82 | }) | 83 | }) |