aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAleksey Kladov <[email protected]>2018-09-12 20:11:26 +0100
committerAleksey Kladov <[email protected]>2018-09-15 22:00:05 +0100
commitcecc7ad5b20e693cb8d962187bd83b9ac234de97 (patch)
tree4bb388adae050ba0457cb0ca3b2bfcc9af353871
parent8cf9c2719652d298006d51bc82a32908ab4e5335 (diff)
be generic over data
-rw-r--r--crates/salsa/src/lib.rs91
-rw-r--r--crates/salsa/tests/integration.rs25
2 files changed, 61 insertions, 55 deletions
diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs
index 69c7b35fa..a54f2a06f 100644
--- a/crates/salsa/src/lib.rs
+++ b/crates/salsa/src/lib.rs
@@ -3,53 +3,52 @@ extern crate parking_lot;
3 3
4use std::{ 4use std::{
5 sync::Arc, 5 sync::Arc,
6 any::Any,
7 collections::HashMap, 6 collections::HashMap,
8 cell::RefCell, 7 cell::RefCell,
9}; 8};
10use parking_lot::Mutex; 9use parking_lot::Mutex;
11 10
12type GroundQueryFn<T> = fn(&T, &(Any + Send + Sync + 'static)) -> (Box<Any + Send + Sync + 'static>, OutputFingerprint); 11type GroundQueryFn<T, D> = fn(&T, &D) -> (D, OutputFingerprint);
13type QueryFn<T> = fn(&QueryCtx<T>, &(Any + Send + Sync + 'static)) -> (Box<Any + Send + Sync + 'static>, OutputFingerprint); 12type QueryFn<T, D> = fn(&QueryCtx<T, D>, &D) -> (D, OutputFingerprint);
14 13
15#[derive(Debug)] 14#[derive(Debug)]
16pub struct Db<T> { 15pub struct Db<T, D> {
17 db: Arc<DbState<T>>, 16 db: Arc<DbState<T, D>>,
18 query_config: Arc<QueryConfig<T>>, 17 query_config: Arc<QueryConfig<T, D>>,
19} 18}
20 19
21pub struct QueryConfig<T> { 20pub struct QueryConfig<T, D> {
22 ground_fn: HashMap<QueryTypeId, GroundQueryFn<T>>, 21 ground_fn: HashMap<QueryTypeId, GroundQueryFn<T, D>>,
23 query_fn: HashMap<QueryTypeId, QueryFn<T>>, 22 query_fn: HashMap<QueryTypeId, QueryFn<T, D>>,
24} 23}
25 24
26impl<T> ::std::fmt::Debug for QueryConfig<T> { 25impl<T, D> ::std::fmt::Debug for QueryConfig<T, D> {
27 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { 26 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
28 ::std::fmt::Display::fmt("QueryConfig { ... }", f) 27 ::std::fmt::Display::fmt("QueryConfig { ... }", f)
29 } 28 }
30} 29}
31 30
32#[derive(Debug)] 31#[derive(Debug)]
33struct DbState<T> { 32struct DbState<T, D> {
34 ground_data: T, 33 ground_data: T,
35 gen: Gen, 34 gen: Gen,
36 graph: Mutex<im::HashMap<QueryId, (Gen, Arc<QueryRecord>)>>, 35 graph: Mutex<im::HashMap<QueryId, (Gen, Arc<QueryRecord<D>>)>>,
37} 36}
38 37
39#[derive(Debug)] 38#[derive(Debug)]
40struct QueryRecord { 39struct QueryRecord<D> {
41 params: Arc<Any + Send + Sync + 'static>, 40 params: D,
42 output: Arc<Any + Send + Sync + 'static>, 41 output: D,
43 output_fingerprint: OutputFingerprint, 42 output_fingerprint: OutputFingerprint,
44 deps: Vec<(QueryId, OutputFingerprint)>, 43 deps: Vec<(QueryId, OutputFingerprint)>,
45} 44}
46 45
47impl<T> DbState<T> { 46impl<T, D> DbState<T, D> {
48 fn record( 47 fn record(
49 &self, 48 &self,
50 query_id: QueryId, 49 query_id: QueryId,
51 params: Arc<Any + Send + Sync + 'static>, 50 params: D,
52 output: Arc<Any + Send + Sync + 'static>, 51 output: D,
53 output_fingerprint: OutputFingerprint, 52 output_fingerprint: OutputFingerprint,
54 deps: Vec<(QueryId, OutputFingerprint)>, 53 deps: Vec<(QueryId, OutputFingerprint)>,
55 ) { 54 ) {
@@ -64,7 +63,7 @@ impl<T> DbState<T> {
64 } 63 }
65} 64}
66 65
67impl<T> QueryConfig<T> { 66impl<T, D> QueryConfig<T, D> {
68 pub fn new() -> Self { 67 pub fn new() -> Self {
69 QueryConfig { 68 QueryConfig {
70 ground_fn: HashMap::new(), 69 ground_fn: HashMap::new(),
@@ -74,7 +73,7 @@ impl<T> QueryConfig<T> {
74 pub fn with_ground_query( 73 pub fn with_ground_query(
75 mut self, 74 mut self,
76 query_type: QueryTypeId, 75 query_type: QueryTypeId,
77 query_fn: GroundQueryFn<T> 76 query_fn: GroundQueryFn<T, D>
78 ) -> Self { 77 ) -> Self {
79 let prev = self.ground_fn.insert(query_type, query_fn); 78 let prev = self.ground_fn.insert(query_type, query_fn);
80 assert!(prev.is_none()); 79 assert!(prev.is_none());
@@ -83,7 +82,7 @@ impl<T> QueryConfig<T> {
83 pub fn with_query( 82 pub fn with_query(
84 mut self, 83 mut self,
85 query_type: QueryTypeId, 84 query_type: QueryTypeId,
86 query_fn: QueryFn<T>, 85 query_fn: QueryFn<T, D>,
87 ) -> Self { 86 ) -> Self {
88 let prev = self.query_fn.insert(query_type, query_fn); 87 let prev = self.query_fn.insert(query_type, query_fn);
89 assert!(prev.is_none()); 88 assert!(prev.is_none());
@@ -91,15 +90,18 @@ impl<T> QueryConfig<T> {
91 } 90 }
92} 91}
93 92
94pub struct QueryCtx<T> { 93pub struct QueryCtx<T, D> {
95 db: Arc<DbState<T>>, 94 db: Arc<DbState<T, D>>,
96 query_config: Arc<QueryConfig<T>>, 95 query_config: Arc<QueryConfig<T, D>>,
97 stack: RefCell<Vec<Vec<(QueryId, OutputFingerprint)>>>, 96 stack: RefCell<Vec<Vec<(QueryId, OutputFingerprint)>>>,
98 executed: RefCell<Vec<QueryTypeId>>, 97 executed: RefCell<Vec<QueryTypeId>>,
99} 98}
100 99
101impl<T> QueryCtx<T> { 100impl<T, D> QueryCtx<T, D>
102 fn new(db: &Db<T>) -> QueryCtx<T> { 101where
102 D: Clone
103{
104 fn new(db: &Db<T, D>) -> QueryCtx<T, D> {
103 QueryCtx { 105 QueryCtx {
104 db: Arc::clone(&db.db), 106 db: Arc::clone(&db.db),
105 query_config: Arc::clone(&db.query_config), 107 query_config: Arc::clone(&db.query_config),
@@ -110,8 +112,8 @@ impl<T> QueryCtx<T> {
110 pub fn get( 112 pub fn get(
111 &self, 113 &self,
112 query_id: QueryId, 114 query_id: QueryId,
113 params: Arc<Any + Send + Sync + 'static>, 115 params: D,
114 ) -> Arc<Any + Send + Sync + 'static> { 116 ) -> D {
115 let (res, output_fingerprint) = self.get_inner(query_id, params); 117 let (res, output_fingerprint) = self.get_inner(query_id, params);
116 self.record_dep(query_id, output_fingerprint); 118 self.record_dep(query_id, output_fingerprint);
117 res 119 res
@@ -120,8 +122,8 @@ impl<T> QueryCtx<T> {
120 pub fn get_inner( 122 pub fn get_inner(
121 &self, 123 &self,
122 query_id: QueryId, 124 query_id: QueryId,
123 params: Arc<Any + Send + Sync + 'static>, 125 params: D,
124 ) -> (Arc<Any + Send + Sync + 'static>, OutputFingerprint) { 126 ) -> (D, OutputFingerprint) {
125 let (gen, record) = { 127 let (gen, record) = {
126 let guard = self.db.graph.lock(); 128 let guard = self.db.graph.lock();
127 match guard.get(&query_id).map(|it| it.clone()){ 129 match guard.get(&query_id).map(|it| it.clone()){
@@ -139,7 +141,7 @@ impl<T> QueryCtx<T> {
139 return self.force(query_id, params); 141 return self.force(query_id, params);
140 } 142 }
141 for (dep_query_id, prev_fingerprint) in record.deps.iter().cloned() { 143 for (dep_query_id, prev_fingerprint) in record.deps.iter().cloned() {
142 let dep_params: Arc<Any + Send + Sync + 'static> = { 144 let dep_params: D = {
143 let guard = self.db.graph.lock(); 145 let guard = self.db.graph.lock();
144 guard[&dep_query_id] 146 guard[&dep_query_id]
145 .1 147 .1
@@ -160,29 +162,29 @@ impl<T> QueryCtx<T> {
160 fn force( 162 fn force(
161 &self, 163 &self,
162 query_id: QueryId, 164 query_id: QueryId,
163 params: Arc<Any + Send + Sync + 'static>, 165 params: D,
164 ) -> (Arc<Any + Send + Sync + 'static>, OutputFingerprint) { 166 ) -> (D, OutputFingerprint) {
165 self.executed.borrow_mut().push(query_id.0); 167 self.executed.borrow_mut().push(query_id.0);
166 self.stack.borrow_mut().push(Vec::new()); 168 self.stack.borrow_mut().push(Vec::new());
167 169
168 let (res, output_fingerprint) = if let Some(f) = self.ground_query_fn_by_type(query_id.0) { 170 let (res, output_fingerprint) = if let Some(f) = self.ground_query_fn_by_type(query_id.0) {
169 f(&self.db.ground_data, &*params) 171 f(&self.db.ground_data, &params)
170 } else if let Some(f) = self.query_fn_by_type(query_id.0) { 172 } else if let Some(f) = self.query_fn_by_type(query_id.0) {
171 f(self, &*params) 173 f(self, &params)
172 } else { 174 } else {
173 panic!("unknown query type: {:?}", query_id.0); 175 panic!("unknown query type: {:?}", query_id.0);
174 }; 176 };
175 177
176 let res: Arc<Any + Send + Sync + 'static> = res.into(); 178 let res: D = res.into();
177 179
178 let deps = self.stack.borrow_mut().pop().unwrap(); 180 let deps = self.stack.borrow_mut().pop().unwrap();
179 self.db.record(query_id, params, res.clone(), output_fingerprint, deps); 181 self.db.record(query_id, params, res.clone(), output_fingerprint, deps);
180 (res, output_fingerprint) 182 (res, output_fingerprint)
181 } 183 }
182 fn ground_query_fn_by_type(&self, query_type: QueryTypeId) -> Option<GroundQueryFn<T>> { 184 fn ground_query_fn_by_type(&self, query_type: QueryTypeId) -> Option<GroundQueryFn<T, D>> {
183 self.query_config.ground_fn.get(&query_type).map(|&it| it) 185 self.query_config.ground_fn.get(&query_type).map(|&it| it)
184 } 186 }
185 fn query_fn_by_type(&self, query_type: QueryTypeId) -> Option<QueryFn<T>> { 187 fn query_fn_by_type(&self, query_type: QueryTypeId) -> Option<QueryFn<T, D>> {
186 self.query_config.query_fn.get(&query_type).map(|&it| it) 188 self.query_config.query_fn.get(&query_type).map(|&it| it)
187 } 189 }
188 fn record_dep( 190 fn record_dep(
@@ -196,15 +198,18 @@ impl<T> QueryCtx<T> {
196 } 198 }
197} 199}
198 200
199impl<T> Db<T> { 201impl<T, D> Db<T, D>
200 pub fn new(query_config: QueryConfig<T>, ground_data: T) -> Db<T> { 202where
203 D: Clone
204{
205 pub fn new(query_config: QueryConfig<T, D>, ground_data: T) -> Db<T, D> {
201 Db { 206 Db {
202 db: Arc::new(DbState { ground_data, gen: Gen(0), graph: Default::default() }), 207 db: Arc::new(DbState { ground_data, gen: Gen(0), graph: Default::default() }),
203 query_config: Arc::new(query_config), 208 query_config: Arc::new(query_config),
204 } 209 }
205 } 210 }
206 211
207 pub fn with_ground_data(&self, ground_data: T) -> Db<T> { 212 pub fn with_ground_data(&self, ground_data: T) -> Db<T, D> {
208 let gen = Gen(self.db.gen.0 + 1); 213 let gen = Gen(self.db.gen.0 + 1);
209 let graph = self.db.graph.lock().clone(); 214 let graph = self.db.graph.lock().clone();
210 let graph = Mutex::new(graph); 215 let graph = Mutex::new(graph);
@@ -216,8 +221,8 @@ impl<T> Db<T> {
216 pub fn get( 221 pub fn get(
217 &self, 222 &self,
218 query_id: QueryId, 223 query_id: QueryId,
219 params: Box<Any + Send + Sync + 'static>, 224 params: D,
220 ) -> (Arc<Any + Send + Sync + 'static>, Vec<QueryTypeId>) { 225 ) -> (D, Vec<QueryTypeId>) {
221 let ctx = QueryCtx::new(self); 226 let ctx = QueryCtx::new(self);
222 let res = ctx.get(query_id, params.into()); 227 let res = ctx.get(query_id, params.into());
223 let executed = ::std::mem::replace(&mut *ctx.executed.borrow_mut(), Vec::new()); 228 let executed = ::std::mem::replace(&mut *ctx.executed.borrow_mut(), Vec::new());
diff --git a/crates/salsa/tests/integration.rs b/crates/salsa/tests/integration.rs
index 7241eca38..2872d3913 100644
--- a/crates/salsa/tests/integration.rs
+++ b/crates/salsa/tests/integration.rs
@@ -7,6 +7,7 @@ use std::{
7}; 7};
8 8
9type State = HashMap<u32, String>; 9type State = HashMap<u32, String>;
10type Data = Arc<Any + Send + Sync + 'static>;
10const GET_TEXT: salsa::QueryTypeId = salsa::QueryTypeId(1); 11const GET_TEXT: salsa::QueryTypeId = salsa::QueryTypeId(1);
11const GET_FILES: salsa::QueryTypeId = salsa::QueryTypeId(2); 12const GET_FILES: salsa::QueryTypeId = salsa::QueryTypeId(2);
12const FILE_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(3); 13const FILE_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(3);
@@ -14,9 +15,9 @@ const TOTAL_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(4);
14 15
15fn mk_ground_query<T, R>( 16fn mk_ground_query<T, R>(
16 state: &State, 17 state: &State,
17 params: &(Any + Send + Sync + 'static), 18 params: &Data,
18 f: fn(&State, &T) -> R, 19 f: fn(&State, &T) -> R,
19) -> (Box<Any + Send + Sync + 'static>, salsa::OutputFingerprint) 20) -> (Data, salsa::OutputFingerprint)
20where 21where
21 T: 'static, 22 T: 'static,
22 R: Hash + Send + Sync + 'static, 23 R: Hash + Send + Sync + 'static,
@@ -24,21 +25,21 @@ where
24 let params = params.downcast_ref().unwrap(); 25 let params = params.downcast_ref().unwrap();
25 let result = f(state, params); 26 let result = f(state, params);
26 let fingerprint = o_print(&result); 27 let fingerprint = o_print(&result);
27 (Box::new(result), fingerprint) 28 (Arc::new(result), fingerprint)
28} 29}
29 30
30fn get<T, R>(db: &salsa::Db<State>, query_type: salsa::QueryTypeId, param: T) -> (Arc<R>, Vec<salsa::QueryTypeId>) 31fn get<T, R>(db: &salsa::Db<State, Data>, query_type: salsa::QueryTypeId, param: T) -> (Arc<R>, Vec<salsa::QueryTypeId>)
31where 32where
32 T: Hash + Send + Sync + 'static, 33 T: Hash + Send + Sync + 'static,
33 R: Send + Sync + 'static, 34 R: Send + Sync + 'static,
34{ 35{
35 let i_print = i_print(&param); 36 let i_print = i_print(&param);
36 let param = Box::new(param); 37 let param = Arc::new(param);
37 let (res, trace) = db.get(salsa::QueryId(query_type, i_print), param); 38 let (res, trace) = db.get(salsa::QueryId(query_type, i_print), param);
38 (res.downcast().unwrap(), trace) 39 (res.downcast().unwrap(), trace)
39} 40}
40 41
41struct QueryCtx<'a>(&'a salsa::QueryCtx<State>); 42struct QueryCtx<'a>(&'a salsa::QueryCtx<State, Data>);
42 43
43impl<'a> QueryCtx<'a> { 44impl<'a> QueryCtx<'a> {
44 fn get_text(&self, id: u32) -> Arc<String> { 45 fn get_text(&self, id: u32) -> Arc<String> {
@@ -60,10 +61,10 @@ impl<'a> QueryCtx<'a> {
60} 61}
61 62
62fn mk_query<T, R>( 63fn mk_query<T, R>(
63 query_ctx: &salsa::QueryCtx<State>, 64 query_ctx: &salsa::QueryCtx<State, Data>,
64 params: &(Any + Send + Sync + 'static), 65 params: &Data,
65 f: fn(QueryCtx, &T) -> R, 66 f: fn(QueryCtx, &T) -> R,
66) -> (Box<Any + Send + Sync + 'static>, salsa::OutputFingerprint) 67) -> (Data, salsa::OutputFingerprint)
67where 68where
68 T: 'static, 69 T: 'static,
69 R: Hash + Send + Sync + 'static, 70 R: Hash + Send + Sync + 'static,
@@ -72,11 +73,11 @@ where
72 let query_ctx = QueryCtx(query_ctx); 73 let query_ctx = QueryCtx(query_ctx);
73 let result = f(query_ctx, params); 74 let result = f(query_ctx, params);
74 let fingerprint = o_print(&result); 75 let fingerprint = o_print(&result);
75 (Box::new(result), fingerprint) 76 (Arc::new(result), fingerprint)
76} 77}
77 78
78fn mk_queries() -> salsa::QueryConfig<State> { 79fn mk_queries() -> salsa::QueryConfig<State, Data> {
79 salsa::QueryConfig::<State>::new() 80 salsa::QueryConfig::<State, Data>::new()
80 .with_ground_query(GET_TEXT, |state, id| { 81 .with_ground_query(GET_TEXT, |state, id| {
81 mk_ground_query::<u32, String>(state, id, |state, id| state[id].clone()) 82 mk_ground_query::<u32, String>(state, id, |state, id| state[id].clone())
82 }) 83 })