diff options
Diffstat (limited to 'crates/salsa')
-rw-r--r-- | crates/salsa/Cargo.toml | 8 | ||||
-rw-r--r-- | crates/salsa/src/lib.rs | 293 | ||||
-rw-r--r-- | crates/salsa/tests/integration.rs | 170 |
3 files changed, 471 insertions, 0 deletions
diff --git a/crates/salsa/Cargo.toml b/crates/salsa/Cargo.toml new file mode 100644 index 000000000..9eb83234f --- /dev/null +++ b/crates/salsa/Cargo.toml | |||
@@ -0,0 +1,8 @@ | |||
1 | [package] | ||
2 | name = "salsa" | ||
3 | version = "0.1.0" | ||
4 | authors = ["Aleksey Kladov <[email protected]>"] | ||
5 | |||
6 | [dependencies] | ||
7 | parking_lot = "0.6.3" | ||
8 | im = "12.0.0" | ||
diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs new file mode 100644 index 000000000..35deed374 --- /dev/null +++ b/crates/salsa/src/lib.rs | |||
@@ -0,0 +1,293 @@ | |||
1 | extern crate im; | ||
2 | extern crate parking_lot; | ||
3 | |||
4 | use std::{ | ||
5 | sync::Arc, | ||
6 | collections::{HashSet, HashMap}, | ||
7 | cell::RefCell, | ||
8 | }; | ||
9 | use parking_lot::Mutex; | ||
10 | |||
11 | pub type GroundQueryFn<T, D> = Box<Fn(&T, &D) -> (D, OutputFingerprint) + Send + Sync + 'static>; | ||
12 | pub type QueryFn<T, D> = Box<Fn(&QueryCtx<T, D>, &D) -> (D, OutputFingerprint) + Send + Sync + 'static>; | ||
13 | |||
14 | #[derive(Debug)] | ||
15 | pub struct Db<T, D> { | ||
16 | db: Arc<DbState<T, D>>, | ||
17 | query_config: Arc<QueryConfig<T, D>>, | ||
18 | } | ||
19 | |||
20 | pub struct QueryConfig<T, D> { | ||
21 | ground_fn: HashMap<QueryTypeId, GroundQueryFn<T, D>>, | ||
22 | query_fn: HashMap<QueryTypeId, QueryFn<T, D>>, | ||
23 | } | ||
24 | |||
25 | impl<T, D> ::std::fmt::Debug for QueryConfig<T, D> { | ||
26 | fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { | ||
27 | ::std::fmt::Display::fmt("QueryConfig { ... }", f) | ||
28 | } | ||
29 | } | ||
30 | |||
31 | #[derive(Debug)] | ||
32 | struct DbState<T, D> { | ||
33 | ground_data: T, | ||
34 | gen: Gen, | ||
35 | graph: Mutex<im::HashMap<QueryId, (Gen, Arc<QueryRecord<D>>)>>, | ||
36 | } | ||
37 | |||
38 | #[derive(Debug)] | ||
39 | struct QueryRecord<D> { | ||
40 | params: D, | ||
41 | output: D, | ||
42 | output_fingerprint: OutputFingerprint, | ||
43 | deps: Vec<(QueryId, OutputFingerprint)>, | ||
44 | } | ||
45 | |||
46 | impl<T, D> DbState<T, D> { | ||
47 | fn record( | ||
48 | &self, | ||
49 | query_id: QueryId, | ||
50 | params: D, | ||
51 | output: D, | ||
52 | output_fingerprint: OutputFingerprint, | ||
53 | deps: Vec<(QueryId, OutputFingerprint)>, | ||
54 | ) { | ||
55 | let gen = self.gen; | ||
56 | let record = QueryRecord { | ||
57 | params, | ||
58 | output, | ||
59 | output_fingerprint, | ||
60 | deps, | ||
61 | }; | ||
62 | self.graph.lock().insert(query_id, (gen, Arc::new(record))); | ||
63 | } | ||
64 | } | ||
65 | |||
66 | impl<T, D> QueryConfig<T, D> { | ||
67 | pub fn new() -> Self { | ||
68 | QueryConfig { | ||
69 | ground_fn: HashMap::new(), | ||
70 | query_fn: HashMap::new(), | ||
71 | } | ||
72 | } | ||
73 | pub fn with_ground_query( | ||
74 | mut self, | ||
75 | query_type: QueryTypeId, | ||
76 | query_fn: GroundQueryFn<T, D> | ||
77 | ) -> Self { | ||
78 | let prev = self.ground_fn.insert(query_type, query_fn); | ||
79 | assert!(prev.is_none()); | ||
80 | self | ||
81 | } | ||
82 | pub fn with_query( | ||
83 | mut self, | ||
84 | query_type: QueryTypeId, | ||
85 | query_fn: QueryFn<T, D>, | ||
86 | ) -> Self { | ||
87 | let prev = self.query_fn.insert(query_type, query_fn); | ||
88 | assert!(prev.is_none()); | ||
89 | self | ||
90 | } | ||
91 | } | ||
92 | |||
93 | pub struct QueryCtx<T, D> { | ||
94 | db: Arc<DbState<T, D>>, | ||
95 | query_config: Arc<QueryConfig<T, D>>, | ||
96 | stack: RefCell<Vec<Vec<(QueryId, OutputFingerprint)>>>, | ||
97 | executed: RefCell<Vec<QueryTypeId>>, | ||
98 | } | ||
99 | |||
100 | impl<T, D> QueryCtx<T, D> | ||
101 | where | ||
102 | D: Clone | ||
103 | { | ||
104 | fn new(db: &Db<T, D>) -> QueryCtx<T, D> { | ||
105 | QueryCtx { | ||
106 | db: Arc::clone(&db.db), | ||
107 | query_config: Arc::clone(&db.query_config), | ||
108 | stack: RefCell::new(vec![Vec::new()]), | ||
109 | executed: RefCell::new(Vec::new()), | ||
110 | } | ||
111 | } | ||
112 | pub fn get( | ||
113 | &self, | ||
114 | query_id: QueryId, | ||
115 | params: D, | ||
116 | ) -> D { | ||
117 | let (res, output_fingerprint) = self.get_inner(query_id, params); | ||
118 | self.record_dep(query_id, output_fingerprint); | ||
119 | res | ||
120 | } | ||
121 | pub fn trace(&self) -> Vec<QueryTypeId> { | ||
122 | ::std::mem::replace(&mut *self.executed.borrow_mut(), Vec::new()) | ||
123 | } | ||
124 | |||
125 | fn get_inner( | ||
126 | &self, | ||
127 | query_id: QueryId, | ||
128 | params: D, | ||
129 | ) -> (D, OutputFingerprint) { | ||
130 | let (gen, record) = { | ||
131 | let guard = self.db.graph.lock(); | ||
132 | match guard.get(&query_id).map(|it| it.clone()){ | ||
133 | None => { | ||
134 | drop(guard); | ||
135 | return self.force(query_id, params); | ||
136 | }, | ||
137 | Some(it) => it, | ||
138 | } | ||
139 | }; | ||
140 | if gen == self.db.gen { | ||
141 | return (record.output.clone(), record.output_fingerprint) | ||
142 | } | ||
143 | if self.query_config.ground_fn.contains_key(&query_id.0) { | ||
144 | let (invalidated, record) = { | ||
145 | let guard = self.db.graph.lock(); | ||
146 | let (gen, ref record) = guard[&query_id]; | ||
147 | (gen == INVALIDATED, record.clone()) | ||
148 | }; | ||
149 | if invalidated { | ||
150 | return self.force(query_id, params); | ||
151 | } else { | ||
152 | return (record.output.clone(), record.output_fingerprint); | ||
153 | } | ||
154 | } | ||
155 | for (dep_query_id, prev_fingerprint) in record.deps.iter().cloned() { | ||
156 | let dep_params: D = { | ||
157 | let guard = self.db.graph.lock(); | ||
158 | guard[&dep_query_id] | ||
159 | .1 | ||
160 | .params | ||
161 | .clone() | ||
162 | }; | ||
163 | if prev_fingerprint != self.get_inner(dep_query_id, dep_params).1 { | ||
164 | return self.force(query_id, params) | ||
165 | } | ||
166 | } | ||
167 | let gen = self.db.gen; | ||
168 | { | ||
169 | let mut guard = self.db.graph.lock(); | ||
170 | guard[&query_id].0 = gen; | ||
171 | } | ||
172 | (record.output.clone(), record.output_fingerprint) | ||
173 | } | ||
174 | fn force( | ||
175 | &self, | ||
176 | query_id: QueryId, | ||
177 | params: D, | ||
178 | ) -> (D, OutputFingerprint) { | ||
179 | self.executed.borrow_mut().push(query_id.0); | ||
180 | self.stack.borrow_mut().push(Vec::new()); | ||
181 | |||
182 | let (res, output_fingerprint) = if let Some(f) = self.query_config.ground_fn.get(&query_id.0) { | ||
183 | f(&self.db.ground_data, ¶ms) | ||
184 | } else if let Some(f) = self.query_config.query_fn.get(&query_id.0) { | ||
185 | f(self, ¶ms) | ||
186 | } else { | ||
187 | panic!("unknown query type: {:?}", query_id.0); | ||
188 | }; | ||
189 | |||
190 | let res: D = res.into(); | ||
191 | |||
192 | let deps = self.stack.borrow_mut().pop().unwrap(); | ||
193 | self.db.record(query_id, params, res.clone(), output_fingerprint, deps); | ||
194 | (res, output_fingerprint) | ||
195 | } | ||
196 | fn record_dep( | ||
197 | &self, | ||
198 | query_id: QueryId, | ||
199 | output_fingerprint: OutputFingerprint, | ||
200 | ) -> () { | ||
201 | let mut stack = self.stack.borrow_mut(); | ||
202 | let deps = stack.last_mut().unwrap(); | ||
203 | deps.push((query_id, output_fingerprint)) | ||
204 | } | ||
205 | } | ||
206 | |||
207 | pub struct Invalidations { | ||
208 | types: HashSet<QueryTypeId>, | ||
209 | ids: Vec<QueryId>, | ||
210 | } | ||
211 | |||
212 | impl Invalidations { | ||
213 | pub fn new() -> Invalidations { | ||
214 | Invalidations { | ||
215 | types: HashSet::new(), | ||
216 | ids: Vec::new(), | ||
217 | } | ||
218 | } | ||
219 | pub fn invalidate( | ||
220 | &mut self, | ||
221 | query_type: QueryTypeId, | ||
222 | params: impl Iterator<Item=InputFingerprint>, | ||
223 | ) { | ||
224 | self.types.insert(query_type); | ||
225 | self.ids.extend(params.map(|it| QueryId(query_type, it))) | ||
226 | } | ||
227 | } | ||
228 | |||
229 | impl<T, D> Db<T, D> | ||
230 | where | ||
231 | D: Clone | ||
232 | { | ||
233 | pub fn new(query_config: QueryConfig<T, D>, ground_data: T) -> Db<T, D> { | ||
234 | Db { | ||
235 | db: Arc::new(DbState { ground_data, gen: Gen(0), graph: Default::default() }), | ||
236 | query_config: Arc::new(query_config), | ||
237 | } | ||
238 | } | ||
239 | pub fn ground_data(&self) -> &T { | ||
240 | &self.db.ground_data | ||
241 | } | ||
242 | pub fn with_ground_data( | ||
243 | &self, | ||
244 | ground_data: T, | ||
245 | invalidations: Invalidations, | ||
246 | ) -> Db<T, D> { | ||
247 | for id in self.query_config.ground_fn.keys() { | ||
248 | assert!( | ||
249 | invalidations.types.contains(id), | ||
250 | "all ground queries must be invalidated" | ||
251 | ); | ||
252 | } | ||
253 | |||
254 | let gen = Gen(self.db.gen.0 + 1); | ||
255 | let mut graph = self.db.graph.lock().clone(); | ||
256 | for id in invalidations.ids { | ||
257 | if let Some((gen, _)) = graph.get_mut(&id) { | ||
258 | *gen = INVALIDATED; | ||
259 | } | ||
260 | } | ||
261 | let graph = Mutex::new(graph); | ||
262 | Db { | ||
263 | db: Arc::new(DbState { ground_data, gen, graph }), | ||
264 | query_config: Arc::clone(&self.query_config) | ||
265 | } | ||
266 | } | ||
267 | pub fn query_ctx(&self) -> QueryCtx<T, D> { | ||
268 | QueryCtx::new(self) | ||
269 | } | ||
270 | pub fn get( | ||
271 | &self, | ||
272 | query_id: QueryId, | ||
273 | params: D, | ||
274 | ) -> (D, Vec<QueryTypeId>) { | ||
275 | let ctx = self.query_ctx(); | ||
276 | let res = ctx.get(query_id, params.into()); | ||
277 | let executed = ::std::mem::replace(&mut *ctx.executed.borrow_mut(), Vec::new()); | ||
278 | (res, executed) | ||
279 | } | ||
280 | } | ||
281 | |||
282 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
283 | struct Gen(u64); | ||
284 | const INVALIDATED: Gen = Gen(!0); | ||
285 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
286 | pub struct InputFingerprint(pub u64); | ||
287 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
288 | pub struct OutputFingerprint(pub u64); | ||
289 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
290 | pub struct QueryTypeId(pub u16); | ||
291 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
292 | pub struct QueryId(pub QueryTypeId, pub InputFingerprint); | ||
293 | |||
diff --git a/crates/salsa/tests/integration.rs b/crates/salsa/tests/integration.rs new file mode 100644 index 000000000..aed9219be --- /dev/null +++ b/crates/salsa/tests/integration.rs | |||
@@ -0,0 +1,170 @@ | |||
1 | extern crate salsa; | ||
2 | use std::{ | ||
3 | iter::once, | ||
4 | sync::Arc, | ||
5 | collections::hash_map::{HashMap, DefaultHasher}, | ||
6 | any::Any, | ||
7 | hash::{Hash, Hasher}, | ||
8 | }; | ||
9 | |||
10 | type State = HashMap<u32, String>; | ||
11 | type Data = Arc<Any + Send + Sync + 'static>; | ||
12 | const GET_TEXT: salsa::QueryTypeId = salsa::QueryTypeId(1); | ||
13 | const GET_FILES: salsa::QueryTypeId = salsa::QueryTypeId(2); | ||
14 | const FILE_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(3); | ||
15 | const TOTAL_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(4); | ||
16 | |||
17 | fn mk_ground_query<T, R>( | ||
18 | state: &State, | ||
19 | params: &Data, | ||
20 | f: fn(&State, &T) -> R, | ||
21 | ) -> (Data, salsa::OutputFingerprint) | ||
22 | where | ||
23 | T: 'static, | ||
24 | R: Hash + Send + Sync + 'static, | ||
25 | { | ||
26 | let params = params.downcast_ref().unwrap(); | ||
27 | let result = f(state, params); | ||
28 | let fingerprint = o_print(&result); | ||
29 | (Arc::new(result), fingerprint) | ||
30 | } | ||
31 | |||
32 | fn get<T, R>(db: &salsa::Db<State, Data>, query_type: salsa::QueryTypeId, param: T) -> (Arc<R>, Vec<salsa::QueryTypeId>) | ||
33 | where | ||
34 | T: Hash + Send + Sync + 'static, | ||
35 | R: Send + Sync + 'static, | ||
36 | { | ||
37 | let i_print = i_print(¶m); | ||
38 | let param = Arc::new(param); | ||
39 | let (res, trace) = db.get(salsa::QueryId(query_type, i_print), param); | ||
40 | (res.downcast().unwrap(), trace) | ||
41 | } | ||
42 | |||
43 | struct QueryCtx<'a>(&'a salsa::QueryCtx<State, Data>); | ||
44 | |||
45 | impl<'a> QueryCtx<'a> { | ||
46 | fn get_text(&self, id: u32) -> Arc<String> { | ||
47 | let i_print = i_print(&id); | ||
48 | let text = self.0.get(salsa::QueryId(GET_TEXT, i_print), Arc::new(id)); | ||
49 | text.downcast().unwrap() | ||
50 | } | ||
51 | fn get_files(&self) -> Arc<Vec<u32>> { | ||
52 | let i_print = i_print(&()); | ||
53 | let files = self.0.get(salsa::QueryId(GET_FILES, i_print), Arc::new(())); | ||
54 | let res = files.downcast().unwrap(); | ||
55 | res | ||
56 | } | ||
57 | fn get_n_lines(&self, id: u32) -> usize { | ||
58 | let i_print = i_print(&id); | ||
59 | let n_lines = self.0.get(salsa::QueryId(FILE_NEWLINES, i_print), Arc::new(id)); | ||
60 | *n_lines.downcast().unwrap() | ||
61 | } | ||
62 | } | ||
63 | |||
64 | fn mk_query<T, R>( | ||
65 | query_ctx: &salsa::QueryCtx<State, Data>, | ||
66 | params: &Data, | ||
67 | f: fn(QueryCtx, &T) -> R, | ||
68 | ) -> (Data, salsa::OutputFingerprint) | ||
69 | where | ||
70 | T: 'static, | ||
71 | R: Hash + Send + Sync + 'static, | ||
72 | { | ||
73 | let params: &T = params.downcast_ref().unwrap(); | ||
74 | let query_ctx = QueryCtx(query_ctx); | ||
75 | let result = f(query_ctx, params); | ||
76 | let fingerprint = o_print(&result); | ||
77 | (Arc::new(result), fingerprint) | ||
78 | } | ||
79 | |||
80 | fn mk_queries() -> salsa::QueryConfig<State, Data> { | ||
81 | salsa::QueryConfig::<State, Data>::new() | ||
82 | .with_ground_query(GET_TEXT, Box::new(|state, id| { | ||
83 | mk_ground_query::<u32, String>(state, id, |state, id| state[id].clone()) | ||
84 | })) | ||
85 | .with_ground_query(GET_FILES, Box::new(|state, id| { | ||
86 | mk_ground_query::<(), Vec<u32>>(state, id, |state, &()| state.keys().cloned().collect()) | ||
87 | })) | ||
88 | .with_query(FILE_NEWLINES, Box::new(|query_ctx, id| { | ||
89 | mk_query(query_ctx, id, |query_ctx, &id| { | ||
90 | let text = query_ctx.get_text(id); | ||
91 | text.lines().count() | ||
92 | }) | ||
93 | })) | ||
94 | .with_query(TOTAL_NEWLINES, Box::new(|query_ctx, id| { | ||
95 | mk_query(query_ctx, id, |query_ctx, &()| { | ||
96 | let mut total = 0; | ||
97 | for &id in query_ctx.get_files().iter() { | ||
98 | total += query_ctx.get_n_lines(id) | ||
99 | } | ||
100 | total | ||
101 | }) | ||
102 | })) | ||
103 | } | ||
104 | |||
105 | #[test] | ||
106 | fn test_number_of_lines() { | ||
107 | let mut state = State::new(); | ||
108 | let db = salsa::Db::new(mk_queries(), state.clone()); | ||
109 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | ||
110 | assert_eq!(*newlines, 0); | ||
111 | assert_eq!(trace.len(), 2); | ||
112 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | ||
113 | assert_eq!(*newlines, 0); | ||
114 | assert_eq!(trace.len(), 0); | ||
115 | |||
116 | state.insert(1, "hello\nworld".to_string()); | ||
117 | let mut inv = salsa::Invalidations::new(); | ||
118 | inv.invalidate(GET_TEXT, once(i_print(&1u32))); | ||
119 | inv.invalidate(GET_FILES, once(i_print(&()))); | ||
120 | let db = db.with_ground_data(state.clone(), inv); | ||
121 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | ||
122 | assert_eq!(*newlines, 2); | ||
123 | assert_eq!(trace.len(), 4); | ||
124 | |||
125 | state.insert(2, "spam\neggs".to_string()); | ||
126 | let mut inv = salsa::Invalidations::new(); | ||
127 | inv.invalidate(GET_TEXT, once(i_print(&2u32))); | ||
128 | inv.invalidate(GET_FILES, once(i_print(&()))); | ||
129 | let db = db.with_ground_data(state.clone(), inv); | ||
130 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | ||
131 | assert_eq!(*newlines, 4); | ||
132 | assert_eq!(trace.len(), 4); | ||
133 | |||
134 | let mut invs = vec![]; | ||
135 | for i in 0..10 { | ||
136 | let id = i + 10; | ||
137 | invs.push(i_print(&id)); | ||
138 | state.insert(id, "spam".to_string()); | ||
139 | } | ||
140 | let mut inv = salsa::Invalidations::new(); | ||
141 | inv.invalidate(GET_TEXT, invs.into_iter()); | ||
142 | inv.invalidate(GET_FILES, once(i_print(&()))); | ||
143 | let db = db.with_ground_data(state.clone(), inv); | ||
144 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | ||
145 | assert_eq!(*newlines, 14); | ||
146 | assert_eq!(trace.len(), 22); | ||
147 | |||
148 | state.insert(15, String::new()); | ||
149 | let mut inv = salsa::Invalidations::new(); | ||
150 | inv.invalidate(GET_TEXT, once(i_print(&15u32))); | ||
151 | inv.invalidate(GET_FILES, once(i_print(&()))); | ||
152 | let db = db.with_ground_data(state.clone(), inv); | ||
153 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | ||
154 | assert_eq!(*newlines, 13); | ||
155 | assert_eq!(trace.len(), 4); | ||
156 | } | ||
157 | |||
158 | fn o_print<T: Hash>(x: &T) -> salsa::OutputFingerprint { | ||
159 | let mut hasher = DefaultHasher::new(); | ||
160 | x.hash(&mut hasher); | ||
161 | let hash = hasher.finish(); | ||
162 | salsa::OutputFingerprint(hash) | ||
163 | } | ||
164 | |||
165 | fn i_print<T: Hash>(x: &T) -> salsa::InputFingerprint { | ||
166 | let mut hasher = DefaultHasher::new(); | ||
167 | x.hash(&mut hasher); | ||
168 | let hash = hasher.finish(); | ||
169 | salsa::InputFingerprint(hash) | ||
170 | } | ||