aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAleksey Kladov <[email protected]>2018-09-12 19:50:15 +0100
committerAleksey Kladov <[email protected]>2018-09-15 22:00:05 +0100
commit8cf9c2719652d298006d51bc82a32908ab4e5335 (patch)
treec74d3c63b2b2d0463e557ce25dca9d0230f8f00e
parent0e493160c0cdbaa71f61af64fd7c439410e8c8b1 (diff)
generic salsa algo
-rw-r--r--crates/salsa/Cargo.toml8
-rw-r--r--crates/salsa/src/lib.rs238
-rw-r--r--crates/salsa/tests/integration.rs153
3 files changed, 399 insertions, 0 deletions
diff --git a/crates/salsa/Cargo.toml b/crates/salsa/Cargo.toml
new file mode 100644
index 000000000..9eb83234f
--- /dev/null
+++ b/crates/salsa/Cargo.toml
@@ -0,0 +1,8 @@
1[package]
2name = "salsa"
3version = "0.1.0"
4authors = ["Aleksey Kladov <[email protected]>"]
5
6[dependencies]
7parking_lot = "0.6.3"
8im = "12.0.0"
diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs
new file mode 100644
index 000000000..69c7b35fa
--- /dev/null
+++ b/crates/salsa/src/lib.rs
@@ -0,0 +1,238 @@
1extern crate im;
2extern crate parking_lot;
3
4use std::{
5 sync::Arc,
6 any::Any,
7 collections::HashMap,
8 cell::RefCell,
9};
10use parking_lot::Mutex;
11
12type GroundQueryFn<T> = fn(&T, &(Any + Send + Sync + 'static)) -> (Box<Any + Send + Sync + 'static>, OutputFingerprint);
13type QueryFn<T> = fn(&QueryCtx<T>, &(Any + Send + Sync + 'static)) -> (Box<Any + Send + Sync + 'static>, OutputFingerprint);
14
15#[derive(Debug)]
16pub struct Db<T> {
17 db: Arc<DbState<T>>,
18 query_config: Arc<QueryConfig<T>>,
19}
20
21pub struct QueryConfig<T> {
22 ground_fn: HashMap<QueryTypeId, GroundQueryFn<T>>,
23 query_fn: HashMap<QueryTypeId, QueryFn<T>>,
24}
25
26impl<T> ::std::fmt::Debug for QueryConfig<T> {
27 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
28 ::std::fmt::Display::fmt("QueryConfig { ... }", f)
29 }
30}
31
32#[derive(Debug)]
33struct DbState<T> {
34 ground_data: T,
35 gen: Gen,
36 graph: Mutex<im::HashMap<QueryId, (Gen, Arc<QueryRecord>)>>,
37}
38
39#[derive(Debug)]
40struct QueryRecord {
41 params: Arc<Any + Send + Sync + 'static>,
42 output: Arc<Any + Send + Sync + 'static>,
43 output_fingerprint: OutputFingerprint,
44 deps: Vec<(QueryId, OutputFingerprint)>,
45}
46
47impl<T> DbState<T> {
48 fn record(
49 &self,
50 query_id: QueryId,
51 params: Arc<Any + Send + Sync + 'static>,
52 output: Arc<Any + Send + Sync + 'static>,
53 output_fingerprint: OutputFingerprint,
54 deps: Vec<(QueryId, OutputFingerprint)>,
55 ) {
56 let gen = self.gen;
57 let record = QueryRecord {
58 params,
59 output,
60 output_fingerprint,
61 deps,
62 };
63 self.graph.lock().insert(query_id, (gen, Arc::new(record)));
64 }
65}
66
67impl<T> QueryConfig<T> {
68 pub fn new() -> Self {
69 QueryConfig {
70 ground_fn: HashMap::new(),
71 query_fn: HashMap::new(),
72 }
73 }
74 pub fn with_ground_query(
75 mut self,
76 query_type: QueryTypeId,
77 query_fn: GroundQueryFn<T>
78 ) -> Self {
79 let prev = self.ground_fn.insert(query_type, query_fn);
80 assert!(prev.is_none());
81 self
82 }
83 pub fn with_query(
84 mut self,
85 query_type: QueryTypeId,
86 query_fn: QueryFn<T>,
87 ) -> Self {
88 let prev = self.query_fn.insert(query_type, query_fn);
89 assert!(prev.is_none());
90 self
91 }
92}
93
94pub struct QueryCtx<T> {
95 db: Arc<DbState<T>>,
96 query_config: Arc<QueryConfig<T>>,
97 stack: RefCell<Vec<Vec<(QueryId, OutputFingerprint)>>>,
98 executed: RefCell<Vec<QueryTypeId>>,
99}
100
101impl<T> QueryCtx<T> {
102 fn new(db: &Db<T>) -> QueryCtx<T> {
103 QueryCtx {
104 db: Arc::clone(&db.db),
105 query_config: Arc::clone(&db.query_config),
106 stack: RefCell::new(vec![Vec::new()]),
107 executed: RefCell::new(Vec::new()),
108 }
109 }
110 pub fn get(
111 &self,
112 query_id: QueryId,
113 params: Arc<Any + Send + Sync + 'static>,
114 ) -> Arc<Any + Send + Sync + 'static> {
115 let (res, output_fingerprint) = self.get_inner(query_id, params);
116 self.record_dep(query_id, output_fingerprint);
117 res
118 }
119
120 pub fn get_inner(
121 &self,
122 query_id: QueryId,
123 params: Arc<Any + Send + Sync + 'static>,
124 ) -> (Arc<Any + Send + Sync + 'static>, OutputFingerprint) {
125 let (gen, record) = {
126 let guard = self.db.graph.lock();
127 match guard.get(&query_id).map(|it| it.clone()){
128 None => {
129 drop(guard);
130 return self.force(query_id, params);
131 },
132 Some(it) => it,
133 }
134 };
135 if gen == self.db.gen {
136 return (record.output.clone(), record.output_fingerprint)
137 }
138 if self.query_config.ground_fn.contains_key(&query_id.0) {
139 return self.force(query_id, params);
140 }
141 for (dep_query_id, prev_fingerprint) in record.deps.iter().cloned() {
142 let dep_params: Arc<Any + Send + Sync + 'static> = {
143 let guard = self.db.graph.lock();
144 guard[&dep_query_id]
145 .1
146 .params
147 .clone()
148 };
149 if prev_fingerprint != self.get_inner(dep_query_id, dep_params).1 {
150 return self.force(query_id, params)
151 }
152 }
153 let gen = self.db.gen;
154 {
155 let mut guard = self.db.graph.lock();
156 guard[&query_id].0 = gen;
157 }
158 (record.output.clone(), record.output_fingerprint)
159 }
160 fn force(
161 &self,
162 query_id: QueryId,
163 params: Arc<Any + Send + Sync + 'static>,
164 ) -> (Arc<Any + Send + Sync + 'static>, OutputFingerprint) {
165 self.executed.borrow_mut().push(query_id.0);
166 self.stack.borrow_mut().push(Vec::new());
167
168 let (res, output_fingerprint) = if let Some(f) = self.ground_query_fn_by_type(query_id.0) {
169 f(&self.db.ground_data, &*params)
170 } else if let Some(f) = self.query_fn_by_type(query_id.0) {
171 f(self, &*params)
172 } else {
173 panic!("unknown query type: {:?}", query_id.0);
174 };
175
176 let res: Arc<Any + Send + Sync + 'static> = res.into();
177
178 let deps = self.stack.borrow_mut().pop().unwrap();
179 self.db.record(query_id, params, res.clone(), output_fingerprint, deps);
180 (res, output_fingerprint)
181 }
182 fn ground_query_fn_by_type(&self, query_type: QueryTypeId) -> Option<GroundQueryFn<T>> {
183 self.query_config.ground_fn.get(&query_type).map(|&it| it)
184 }
185 fn query_fn_by_type(&self, query_type: QueryTypeId) -> Option<QueryFn<T>> {
186 self.query_config.query_fn.get(&query_type).map(|&it| it)
187 }
188 fn record_dep(
189 &self,
190 query_id: QueryId,
191 output_fingerprint: OutputFingerprint,
192 ) -> () {
193 let mut stack = self.stack.borrow_mut();
194 let deps = stack.last_mut().unwrap();
195 deps.push((query_id, output_fingerprint))
196 }
197}
198
199impl<T> Db<T> {
200 pub fn new(query_config: QueryConfig<T>, ground_data: T) -> Db<T> {
201 Db {
202 db: Arc::new(DbState { ground_data, gen: Gen(0), graph: Default::default() }),
203 query_config: Arc::new(query_config),
204 }
205 }
206
207 pub fn with_ground_data(&self, ground_data: T) -> Db<T> {
208 let gen = Gen(self.db.gen.0 + 1);
209 let graph = self.db.graph.lock().clone();
210 let graph = Mutex::new(graph);
211 Db {
212 db: Arc::new(DbState { ground_data, gen, graph }),
213 query_config: Arc::clone(&self.query_config)
214 }
215 }
216 pub fn get(
217 &self,
218 query_id: QueryId,
219 params: Box<Any + Send + Sync + 'static>,
220 ) -> (Arc<Any + Send + Sync + 'static>, Vec<QueryTypeId>) {
221 let ctx = QueryCtx::new(self);
222 let res = ctx.get(query_id, params.into());
223 let executed = ::std::mem::replace(&mut *ctx.executed.borrow_mut(), Vec::new());
224 (res, executed)
225 }
226}
227
228#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
229struct Gen(u64);
230#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
231pub struct InputFingerprint(pub u64);
232#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
233pub struct OutputFingerprint(pub u64);
234#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
235pub struct QueryTypeId(pub u16);
236#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
237pub struct QueryId(pub QueryTypeId, pub InputFingerprint);
238
diff --git a/crates/salsa/tests/integration.rs b/crates/salsa/tests/integration.rs
new file mode 100644
index 000000000..7241eca38
--- /dev/null
+++ b/crates/salsa/tests/integration.rs
@@ -0,0 +1,153 @@
1extern crate salsa;
2use std::{
3 sync::Arc,
4 collections::hash_map::{HashMap, DefaultHasher},
5 any::Any,
6 hash::{Hash, Hasher},
7};
8
9type State = HashMap<u32, String>;
10const GET_TEXT: salsa::QueryTypeId = salsa::QueryTypeId(1);
11const GET_FILES: salsa::QueryTypeId = salsa::QueryTypeId(2);
12const FILE_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(3);
13const TOTAL_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(4);
14
15fn mk_ground_query<T, R>(
16 state: &State,
17 params: &(Any + Send + Sync + 'static),
18 f: fn(&State, &T) -> R,
19) -> (Box<Any + Send + Sync + 'static>, salsa::OutputFingerprint)
20where
21 T: 'static,
22 R: Hash + Send + Sync + 'static,
23{
24 let params = params.downcast_ref().unwrap();
25 let result = f(state, params);
26 let fingerprint = o_print(&result);
27 (Box::new(result), fingerprint)
28}
29
30fn get<T, R>(db: &salsa::Db<State>, query_type: salsa::QueryTypeId, param: T) -> (Arc<R>, Vec<salsa::QueryTypeId>)
31where
32 T: Hash + Send + Sync + 'static,
33 R: Send + Sync + 'static,
34{
35 let i_print = i_print(&param);
36 let param = Box::new(param);
37 let (res, trace) = db.get(salsa::QueryId(query_type, i_print), param);
38 (res.downcast().unwrap(), trace)
39}
40
41struct QueryCtx<'a>(&'a salsa::QueryCtx<State>);
42
43impl<'a> QueryCtx<'a> {
44 fn get_text(&self, id: u32) -> Arc<String> {
45 let i_print = i_print(&id);
46 let text = self.0.get(salsa::QueryId(GET_TEXT, i_print), Arc::new(id));
47 text.downcast().unwrap()
48 }
49 fn get_files(&self) -> Arc<Vec<u32>> {
50 let i_print = i_print(&());
51 let files = self.0.get(salsa::QueryId(GET_FILES, i_print), Arc::new(()));
52 let res = files.downcast().unwrap();
53 res
54 }
55 fn get_n_lines(&self, id: u32) -> usize {
56 let i_print = i_print(&id);
57 let n_lines = self.0.get(salsa::QueryId(FILE_NEWLINES, i_print), Arc::new(id));
58 *n_lines.downcast().unwrap()
59 }
60}
61
62fn mk_query<T, R>(
63 query_ctx: &salsa::QueryCtx<State>,
64 params: &(Any + Send + Sync + 'static),
65 f: fn(QueryCtx, &T) -> R,
66) -> (Box<Any + Send + Sync + 'static>, salsa::OutputFingerprint)
67where
68 T: 'static,
69 R: Hash + Send + Sync + 'static,
70{
71 let params: &T = params.downcast_ref().unwrap();
72 let query_ctx = QueryCtx(query_ctx);
73 let result = f(query_ctx, params);
74 let fingerprint = o_print(&result);
75 (Box::new(result), fingerprint)
76}
77
78fn mk_queries() -> salsa::QueryConfig<State> {
79 salsa::QueryConfig::<State>::new()
80 .with_ground_query(GET_TEXT, |state, id| {
81 mk_ground_query::<u32, String>(state, id, |state, id| state[id].clone())
82 })
83 .with_ground_query(GET_FILES, |state, id| {
84 mk_ground_query::<(), Vec<u32>>(state, id, |state, &()| state.keys().cloned().collect())
85 })
86 .with_query(FILE_NEWLINES, |query_ctx, id| {
87 mk_query(query_ctx, id, |query_ctx, &id| {
88 let text = query_ctx.get_text(id);
89 text.lines().count()
90 })
91 })
92 .with_query(TOTAL_NEWLINES, |query_ctx, id| {
93 mk_query(query_ctx, id, |query_ctx, &()| {
94 let mut total = 0;
95 for &id in query_ctx.get_files().iter() {
96 total += query_ctx.get_n_lines(id)
97 }
98 total
99 })
100 })
101}
102
103#[test]
104fn test_number_of_lines() {
105 let mut state = State::new();
106 let db = salsa::Db::new(mk_queries(), state.clone());
107 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
108 assert_eq!(*newlines, 0);
109 assert_eq!(trace.len(), 2);
110 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
111 assert_eq!(*newlines, 0);
112 assert_eq!(trace.len(), 0);
113
114 state.insert(1, "hello\nworld".to_string());
115 let db = db.with_ground_data(state.clone());
116 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
117 assert_eq!(*newlines, 2);
118 assert_eq!(trace.len(), 4);
119
120 state.insert(2, "spam\neggs".to_string());
121 let db = db.with_ground_data(state.clone());
122 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
123 assert_eq!(*newlines, 4);
124 assert_eq!(trace.len(), 5);
125
126 for i in 0..10 {
127 state.insert(i + 10, "spam".to_string());
128 }
129 let db = db.with_ground_data(state.clone());
130 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
131 assert_eq!(*newlines, 14);
132 assert_eq!(trace.len(), 24);
133
134 state.insert(15, String::new());
135 let db = db.with_ground_data(state.clone());
136 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
137 assert_eq!(*newlines, 13);
138 assert_eq!(trace.len(), 15);
139}
140
141fn o_print<T: Hash>(x: &T) -> salsa::OutputFingerprint {
142 let mut hasher = DefaultHasher::new();
143 x.hash(&mut hasher);
144 let hash = hasher.finish();
145 salsa::OutputFingerprint(hash)
146}
147
148fn i_print<T: Hash>(x: &T) -> salsa::InputFingerprint {
149 let mut hasher = DefaultHasher::new();
150 x.hash(&mut hasher);
151 let hash = hasher.finish();
152 salsa::InputFingerprint(hash)
153}