diff options
Diffstat (limited to 'crates')
-rw-r--r-- | crates/salsa/Cargo.toml | 8 | ||||
-rw-r--r-- | crates/salsa/src/lib.rs | 238 | ||||
-rw-r--r-- | crates/salsa/tests/integration.rs | 153 |
3 files changed, 399 insertions, 0 deletions
diff --git a/crates/salsa/Cargo.toml b/crates/salsa/Cargo.toml new file mode 100644 index 000000000..9eb83234f --- /dev/null +++ b/crates/salsa/Cargo.toml | |||
@@ -0,0 +1,8 @@ | |||
1 | [package] | ||
2 | name = "salsa" | ||
3 | version = "0.1.0" | ||
4 | authors = ["Aleksey Kladov <[email protected]>"] | ||
5 | |||
6 | [dependencies] | ||
7 | parking_lot = "0.6.3" | ||
8 | im = "12.0.0" | ||
diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs new file mode 100644 index 000000000..69c7b35fa --- /dev/null +++ b/crates/salsa/src/lib.rs | |||
@@ -0,0 +1,238 @@ | |||
1 | extern crate im; | ||
2 | extern crate parking_lot; | ||
3 | |||
4 | use std::{ | ||
5 | sync::Arc, | ||
6 | any::Any, | ||
7 | collections::HashMap, | ||
8 | cell::RefCell, | ||
9 | }; | ||
10 | use parking_lot::Mutex; | ||
11 | |||
12 | type GroundQueryFn<T> = fn(&T, &(Any + Send + Sync + 'static)) -> (Box<Any + Send + Sync + 'static>, OutputFingerprint); | ||
13 | type QueryFn<T> = fn(&QueryCtx<T>, &(Any + Send + Sync + 'static)) -> (Box<Any + Send + Sync + 'static>, OutputFingerprint); | ||
14 | |||
15 | #[derive(Debug)] | ||
16 | pub struct Db<T> { | ||
17 | db: Arc<DbState<T>>, | ||
18 | query_config: Arc<QueryConfig<T>>, | ||
19 | } | ||
20 | |||
21 | pub struct QueryConfig<T> { | ||
22 | ground_fn: HashMap<QueryTypeId, GroundQueryFn<T>>, | ||
23 | query_fn: HashMap<QueryTypeId, QueryFn<T>>, | ||
24 | } | ||
25 | |||
26 | impl<T> ::std::fmt::Debug for QueryConfig<T> { | ||
27 | fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { | ||
28 | ::std::fmt::Display::fmt("QueryConfig { ... }", f) | ||
29 | } | ||
30 | } | ||
31 | |||
32 | #[derive(Debug)] | ||
33 | struct DbState<T> { | ||
34 | ground_data: T, | ||
35 | gen: Gen, | ||
36 | graph: Mutex<im::HashMap<QueryId, (Gen, Arc<QueryRecord>)>>, | ||
37 | } | ||
38 | |||
39 | #[derive(Debug)] | ||
40 | struct QueryRecord { | ||
41 | params: Arc<Any + Send + Sync + 'static>, | ||
42 | output: Arc<Any + Send + Sync + 'static>, | ||
43 | output_fingerprint: OutputFingerprint, | ||
44 | deps: Vec<(QueryId, OutputFingerprint)>, | ||
45 | } | ||
46 | |||
47 | impl<T> DbState<T> { | ||
48 | fn record( | ||
49 | &self, | ||
50 | query_id: QueryId, | ||
51 | params: Arc<Any + Send + Sync + 'static>, | ||
52 | output: Arc<Any + Send + Sync + 'static>, | ||
53 | output_fingerprint: OutputFingerprint, | ||
54 | deps: Vec<(QueryId, OutputFingerprint)>, | ||
55 | ) { | ||
56 | let gen = self.gen; | ||
57 | let record = QueryRecord { | ||
58 | params, | ||
59 | output, | ||
60 | output_fingerprint, | ||
61 | deps, | ||
62 | }; | ||
63 | self.graph.lock().insert(query_id, (gen, Arc::new(record))); | ||
64 | } | ||
65 | } | ||
66 | |||
67 | impl<T> QueryConfig<T> { | ||
68 | pub fn new() -> Self { | ||
69 | QueryConfig { | ||
70 | ground_fn: HashMap::new(), | ||
71 | query_fn: HashMap::new(), | ||
72 | } | ||
73 | } | ||
74 | pub fn with_ground_query( | ||
75 | mut self, | ||
76 | query_type: QueryTypeId, | ||
77 | query_fn: GroundQueryFn<T> | ||
78 | ) -> Self { | ||
79 | let prev = self.ground_fn.insert(query_type, query_fn); | ||
80 | assert!(prev.is_none()); | ||
81 | self | ||
82 | } | ||
83 | pub fn with_query( | ||
84 | mut self, | ||
85 | query_type: QueryTypeId, | ||
86 | query_fn: QueryFn<T>, | ||
87 | ) -> Self { | ||
88 | let prev = self.query_fn.insert(query_type, query_fn); | ||
89 | assert!(prev.is_none()); | ||
90 | self | ||
91 | } | ||
92 | } | ||
93 | |||
94 | pub struct QueryCtx<T> { | ||
95 | db: Arc<DbState<T>>, | ||
96 | query_config: Arc<QueryConfig<T>>, | ||
97 | stack: RefCell<Vec<Vec<(QueryId, OutputFingerprint)>>>, | ||
98 | executed: RefCell<Vec<QueryTypeId>>, | ||
99 | } | ||
100 | |||
101 | impl<T> QueryCtx<T> { | ||
102 | fn new(db: &Db<T>) -> QueryCtx<T> { | ||
103 | QueryCtx { | ||
104 | db: Arc::clone(&db.db), | ||
105 | query_config: Arc::clone(&db.query_config), | ||
106 | stack: RefCell::new(vec![Vec::new()]), | ||
107 | executed: RefCell::new(Vec::new()), | ||
108 | } | ||
109 | } | ||
110 | pub fn get( | ||
111 | &self, | ||
112 | query_id: QueryId, | ||
113 | params: Arc<Any + Send + Sync + 'static>, | ||
114 | ) -> Arc<Any + Send + Sync + 'static> { | ||
115 | let (res, output_fingerprint) = self.get_inner(query_id, params); | ||
116 | self.record_dep(query_id, output_fingerprint); | ||
117 | res | ||
118 | } | ||
119 | |||
120 | pub fn get_inner( | ||
121 | &self, | ||
122 | query_id: QueryId, | ||
123 | params: Arc<Any + Send + Sync + 'static>, | ||
124 | ) -> (Arc<Any + Send + Sync + 'static>, OutputFingerprint) { | ||
125 | let (gen, record) = { | ||
126 | let guard = self.db.graph.lock(); | ||
127 | match guard.get(&query_id).map(|it| it.clone()){ | ||
128 | None => { | ||
129 | drop(guard); | ||
130 | return self.force(query_id, params); | ||
131 | }, | ||
132 | Some(it) => it, | ||
133 | } | ||
134 | }; | ||
135 | if gen == self.db.gen { | ||
136 | return (record.output.clone(), record.output_fingerprint) | ||
137 | } | ||
138 | if self.query_config.ground_fn.contains_key(&query_id.0) { | ||
139 | return self.force(query_id, params); | ||
140 | } | ||
141 | for (dep_query_id, prev_fingerprint) in record.deps.iter().cloned() { | ||
142 | let dep_params: Arc<Any + Send + Sync + 'static> = { | ||
143 | let guard = self.db.graph.lock(); | ||
144 | guard[&dep_query_id] | ||
145 | .1 | ||
146 | .params | ||
147 | .clone() | ||
148 | }; | ||
149 | if prev_fingerprint != self.get_inner(dep_query_id, dep_params).1 { | ||
150 | return self.force(query_id, params) | ||
151 | } | ||
152 | } | ||
153 | let gen = self.db.gen; | ||
154 | { | ||
155 | let mut guard = self.db.graph.lock(); | ||
156 | guard[&query_id].0 = gen; | ||
157 | } | ||
158 | (record.output.clone(), record.output_fingerprint) | ||
159 | } | ||
160 | fn force( | ||
161 | &self, | ||
162 | query_id: QueryId, | ||
163 | params: Arc<Any + Send + Sync + 'static>, | ||
164 | ) -> (Arc<Any + Send + Sync + 'static>, OutputFingerprint) { | ||
165 | self.executed.borrow_mut().push(query_id.0); | ||
166 | self.stack.borrow_mut().push(Vec::new()); | ||
167 | |||
168 | let (res, output_fingerprint) = if let Some(f) = self.ground_query_fn_by_type(query_id.0) { | ||
169 | f(&self.db.ground_data, &*params) | ||
170 | } else if let Some(f) = self.query_fn_by_type(query_id.0) { | ||
171 | f(self, &*params) | ||
172 | } else { | ||
173 | panic!("unknown query type: {:?}", query_id.0); | ||
174 | }; | ||
175 | |||
176 | let res: Arc<Any + Send + Sync + 'static> = res.into(); | ||
177 | |||
178 | let deps = self.stack.borrow_mut().pop().unwrap(); | ||
179 | self.db.record(query_id, params, res.clone(), output_fingerprint, deps); | ||
180 | (res, output_fingerprint) | ||
181 | } | ||
182 | fn ground_query_fn_by_type(&self, query_type: QueryTypeId) -> Option<GroundQueryFn<T>> { | ||
183 | self.query_config.ground_fn.get(&query_type).map(|&it| it) | ||
184 | } | ||
185 | fn query_fn_by_type(&self, query_type: QueryTypeId) -> Option<QueryFn<T>> { | ||
186 | self.query_config.query_fn.get(&query_type).map(|&it| it) | ||
187 | } | ||
188 | fn record_dep( | ||
189 | &self, | ||
190 | query_id: QueryId, | ||
191 | output_fingerprint: OutputFingerprint, | ||
192 | ) -> () { | ||
193 | let mut stack = self.stack.borrow_mut(); | ||
194 | let deps = stack.last_mut().unwrap(); | ||
195 | deps.push((query_id, output_fingerprint)) | ||
196 | } | ||
197 | } | ||
198 | |||
199 | impl<T> Db<T> { | ||
200 | pub fn new(query_config: QueryConfig<T>, ground_data: T) -> Db<T> { | ||
201 | Db { | ||
202 | db: Arc::new(DbState { ground_data, gen: Gen(0), graph: Default::default() }), | ||
203 | query_config: Arc::new(query_config), | ||
204 | } | ||
205 | } | ||
206 | |||
207 | pub fn with_ground_data(&self, ground_data: T) -> Db<T> { | ||
208 | let gen = Gen(self.db.gen.0 + 1); | ||
209 | let graph = self.db.graph.lock().clone(); | ||
210 | let graph = Mutex::new(graph); | ||
211 | Db { | ||
212 | db: Arc::new(DbState { ground_data, gen, graph }), | ||
213 | query_config: Arc::clone(&self.query_config) | ||
214 | } | ||
215 | } | ||
216 | pub fn get( | ||
217 | &self, | ||
218 | query_id: QueryId, | ||
219 | params: Box<Any + Send + Sync + 'static>, | ||
220 | ) -> (Arc<Any + Send + Sync + 'static>, Vec<QueryTypeId>) { | ||
221 | let ctx = QueryCtx::new(self); | ||
222 | let res = ctx.get(query_id, params.into()); | ||
223 | let executed = ::std::mem::replace(&mut *ctx.executed.borrow_mut(), Vec::new()); | ||
224 | (res, executed) | ||
225 | } | ||
226 | } | ||
227 | |||
228 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
229 | struct Gen(u64); | ||
230 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
231 | pub struct InputFingerprint(pub u64); | ||
232 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
233 | pub struct OutputFingerprint(pub u64); | ||
234 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
235 | pub struct QueryTypeId(pub u16); | ||
236 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
237 | pub struct QueryId(pub QueryTypeId, pub InputFingerprint); | ||
238 | |||
diff --git a/crates/salsa/tests/integration.rs b/crates/salsa/tests/integration.rs new file mode 100644 index 000000000..7241eca38 --- /dev/null +++ b/crates/salsa/tests/integration.rs | |||
@@ -0,0 +1,153 @@ | |||
1 | extern crate salsa; | ||
2 | use std::{ | ||
3 | sync::Arc, | ||
4 | collections::hash_map::{HashMap, DefaultHasher}, | ||
5 | any::Any, | ||
6 | hash::{Hash, Hasher}, | ||
7 | }; | ||
8 | |||
9 | type State = HashMap<u32, String>; | ||
10 | const GET_TEXT: salsa::QueryTypeId = salsa::QueryTypeId(1); | ||
11 | const GET_FILES: salsa::QueryTypeId = salsa::QueryTypeId(2); | ||
12 | const FILE_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(3); | ||
13 | const TOTAL_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(4); | ||
14 | |||
15 | fn mk_ground_query<T, R>( | ||
16 | state: &State, | ||
17 | params: &(Any + Send + Sync + 'static), | ||
18 | f: fn(&State, &T) -> R, | ||
19 | ) -> (Box<Any + Send + Sync + 'static>, salsa::OutputFingerprint) | ||
20 | where | ||
21 | T: 'static, | ||
22 | R: Hash + Send + Sync + 'static, | ||
23 | { | ||
24 | let params = params.downcast_ref().unwrap(); | ||
25 | let result = f(state, params); | ||
26 | let fingerprint = o_print(&result); | ||
27 | (Box::new(result), fingerprint) | ||
28 | } | ||
29 | |||
30 | fn get<T, R>(db: &salsa::Db<State>, query_type: salsa::QueryTypeId, param: T) -> (Arc<R>, Vec<salsa::QueryTypeId>) | ||
31 | where | ||
32 | T: Hash + Send + Sync + 'static, | ||
33 | R: Send + Sync + 'static, | ||
34 | { | ||
35 | let i_print = i_print(¶m); | ||
36 | let param = Box::new(param); | ||
37 | let (res, trace) = db.get(salsa::QueryId(query_type, i_print), param); | ||
38 | (res.downcast().unwrap(), trace) | ||
39 | } | ||
40 | |||
41 | struct QueryCtx<'a>(&'a salsa::QueryCtx<State>); | ||
42 | |||
43 | impl<'a> QueryCtx<'a> { | ||
44 | fn get_text(&self, id: u32) -> Arc<String> { | ||
45 | let i_print = i_print(&id); | ||
46 | let text = self.0.get(salsa::QueryId(GET_TEXT, i_print), Arc::new(id)); | ||
47 | text.downcast().unwrap() | ||
48 | } | ||
49 | fn get_files(&self) -> Arc<Vec<u32>> { | ||
50 | let i_print = i_print(&()); | ||
51 | let files = self.0.get(salsa::QueryId(GET_FILES, i_print), Arc::new(())); | ||
52 | let res = files.downcast().unwrap(); | ||
53 | res | ||
54 | } | ||
55 | fn get_n_lines(&self, id: u32) -> usize { | ||
56 | let i_print = i_print(&id); | ||
57 | let n_lines = self.0.get(salsa::QueryId(FILE_NEWLINES, i_print), Arc::new(id)); | ||
58 | *n_lines.downcast().unwrap() | ||
59 | } | ||
60 | } | ||
61 | |||
62 | fn mk_query<T, R>( | ||
63 | query_ctx: &salsa::QueryCtx<State>, | ||
64 | params: &(Any + Send + Sync + 'static), | ||
65 | f: fn(QueryCtx, &T) -> R, | ||
66 | ) -> (Box<Any + Send + Sync + 'static>, salsa::OutputFingerprint) | ||
67 | where | ||
68 | T: 'static, | ||
69 | R: Hash + Send + Sync + 'static, | ||
70 | { | ||
71 | let params: &T = params.downcast_ref().unwrap(); | ||
72 | let query_ctx = QueryCtx(query_ctx); | ||
73 | let result = f(query_ctx, params); | ||
74 | let fingerprint = o_print(&result); | ||
75 | (Box::new(result), fingerprint) | ||
76 | } | ||
77 | |||
78 | fn mk_queries() -> salsa::QueryConfig<State> { | ||
79 | salsa::QueryConfig::<State>::new() | ||
80 | .with_ground_query(GET_TEXT, |state, id| { | ||
81 | mk_ground_query::<u32, String>(state, id, |state, id| state[id].clone()) | ||
82 | }) | ||
83 | .with_ground_query(GET_FILES, |state, id| { | ||
84 | mk_ground_query::<(), Vec<u32>>(state, id, |state, &()| state.keys().cloned().collect()) | ||
85 | }) | ||
86 | .with_query(FILE_NEWLINES, |query_ctx, id| { | ||
87 | mk_query(query_ctx, id, |query_ctx, &id| { | ||
88 | let text = query_ctx.get_text(id); | ||
89 | text.lines().count() | ||
90 | }) | ||
91 | }) | ||
92 | .with_query(TOTAL_NEWLINES, |query_ctx, id| { | ||
93 | mk_query(query_ctx, id, |query_ctx, &()| { | ||
94 | let mut total = 0; | ||
95 | for &id in query_ctx.get_files().iter() { | ||
96 | total += query_ctx.get_n_lines(id) | ||
97 | } | ||
98 | total | ||
99 | }) | ||
100 | }) | ||
101 | } | ||
102 | |||
103 | #[test] | ||
104 | fn test_number_of_lines() { | ||
105 | let mut state = State::new(); | ||
106 | let db = salsa::Db::new(mk_queries(), state.clone()); | ||
107 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | ||
108 | assert_eq!(*newlines, 0); | ||
109 | assert_eq!(trace.len(), 2); | ||
110 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | ||
111 | assert_eq!(*newlines, 0); | ||
112 | assert_eq!(trace.len(), 0); | ||
113 | |||
114 | state.insert(1, "hello\nworld".to_string()); | ||
115 | let db = db.with_ground_data(state.clone()); | ||
116 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | ||
117 | assert_eq!(*newlines, 2); | ||
118 | assert_eq!(trace.len(), 4); | ||
119 | |||
120 | state.insert(2, "spam\neggs".to_string()); | ||
121 | let db = db.with_ground_data(state.clone()); | ||
122 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | ||
123 | assert_eq!(*newlines, 4); | ||
124 | assert_eq!(trace.len(), 5); | ||
125 | |||
126 | for i in 0..10 { | ||
127 | state.insert(i + 10, "spam".to_string()); | ||
128 | } | ||
129 | let db = db.with_ground_data(state.clone()); | ||
130 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | ||
131 | assert_eq!(*newlines, 14); | ||
132 | assert_eq!(trace.len(), 24); | ||
133 | |||
134 | state.insert(15, String::new()); | ||
135 | let db = db.with_ground_data(state.clone()); | ||
136 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | ||
137 | assert_eq!(*newlines, 13); | ||
138 | assert_eq!(trace.len(), 15); | ||
139 | } | ||
140 | |||
141 | fn o_print<T: Hash>(x: &T) -> salsa::OutputFingerprint { | ||
142 | let mut hasher = DefaultHasher::new(); | ||
143 | x.hash(&mut hasher); | ||
144 | let hash = hasher.finish(); | ||
145 | salsa::OutputFingerprint(hash) | ||
146 | } | ||
147 | |||
148 | fn i_print<T: Hash>(x: &T) -> salsa::InputFingerprint { | ||
149 | let mut hasher = DefaultHasher::new(); | ||
150 | x.hash(&mut hasher); | ||
151 | let hash = hasher.finish(); | ||
152 | salsa::InputFingerprint(hash) | ||
153 | } | ||