aboutsummaryrefslogtreecommitdiff
path: root/crates/salsa
diff options
context:
space:
mode:
authorbors[bot] <bors[bot]@users.noreply.github.com>2018-09-15 22:11:25 +0100
committerbors[bot] <bors[bot]@users.noreply.github.com>2018-09-15 22:11:25 +0100
commit3993bb4de95af407e5edc1fe551bec0f001a3f0f (patch)
tree31893552cd739187080048df24a629d416174305 /crates/salsa
parent2a56b5c4f096736d6795eecb835cc2dc14b00107 (diff)
parentfcdf3a52b4b61a39474950486ea0edf5ebf33bea (diff)
Merge #67
67: Salsa r=matklad a=matklad The aim of this PR is to transition from rather ad-hock FileData and ModuleMap caching strategy to something resembling a general-purpose red-green engine. Ideally, we shouldn't recompute ModuleMap at all, unless the set of mod decls or files changes. Co-authored-by: Aleksey Kladov <[email protected]>
Diffstat (limited to 'crates/salsa')
-rw-r--r--crates/salsa/Cargo.toml8
-rw-r--r--crates/salsa/src/lib.rs293
-rw-r--r--crates/salsa/tests/integration.rs170
3 files changed, 471 insertions, 0 deletions
diff --git a/crates/salsa/Cargo.toml b/crates/salsa/Cargo.toml
new file mode 100644
index 000000000..9eb83234f
--- /dev/null
+++ b/crates/salsa/Cargo.toml
@@ -0,0 +1,8 @@
1[package]
2name = "salsa"
3version = "0.1.0"
4authors = ["Aleksey Kladov <[email protected]>"]
5
6[dependencies]
7parking_lot = "0.6.3"
8im = "12.0.0"
diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs
new file mode 100644
index 000000000..35deed374
--- /dev/null
+++ b/crates/salsa/src/lib.rs
@@ -0,0 +1,293 @@
1extern crate im;
2extern crate parking_lot;
3
4use std::{
5 sync::Arc,
6 collections::{HashSet, HashMap},
7 cell::RefCell,
8};
9use parking_lot::Mutex;
10
11pub type GroundQueryFn<T, D> = Box<Fn(&T, &D) -> (D, OutputFingerprint) + Send + Sync + 'static>;
12pub type QueryFn<T, D> = Box<Fn(&QueryCtx<T, D>, &D) -> (D, OutputFingerprint) + Send + Sync + 'static>;
13
14#[derive(Debug)]
15pub struct Db<T, D> {
16 db: Arc<DbState<T, D>>,
17 query_config: Arc<QueryConfig<T, D>>,
18}
19
20pub struct QueryConfig<T, D> {
21 ground_fn: HashMap<QueryTypeId, GroundQueryFn<T, D>>,
22 query_fn: HashMap<QueryTypeId, QueryFn<T, D>>,
23}
24
25impl<T, D> ::std::fmt::Debug for QueryConfig<T, D> {
26 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
27 ::std::fmt::Display::fmt("QueryConfig { ... }", f)
28 }
29}
30
31#[derive(Debug)]
32struct DbState<T, D> {
33 ground_data: T,
34 gen: Gen,
35 graph: Mutex<im::HashMap<QueryId, (Gen, Arc<QueryRecord<D>>)>>,
36}
37
38#[derive(Debug)]
39struct QueryRecord<D> {
40 params: D,
41 output: D,
42 output_fingerprint: OutputFingerprint,
43 deps: Vec<(QueryId, OutputFingerprint)>,
44}
45
46impl<T, D> DbState<T, D> {
47 fn record(
48 &self,
49 query_id: QueryId,
50 params: D,
51 output: D,
52 output_fingerprint: OutputFingerprint,
53 deps: Vec<(QueryId, OutputFingerprint)>,
54 ) {
55 let gen = self.gen;
56 let record = QueryRecord {
57 params,
58 output,
59 output_fingerprint,
60 deps,
61 };
62 self.graph.lock().insert(query_id, (gen, Arc::new(record)));
63 }
64}
65
66impl<T, D> QueryConfig<T, D> {
67 pub fn new() -> Self {
68 QueryConfig {
69 ground_fn: HashMap::new(),
70 query_fn: HashMap::new(),
71 }
72 }
73 pub fn with_ground_query(
74 mut self,
75 query_type: QueryTypeId,
76 query_fn: GroundQueryFn<T, D>
77 ) -> Self {
78 let prev = self.ground_fn.insert(query_type, query_fn);
79 assert!(prev.is_none());
80 self
81 }
82 pub fn with_query(
83 mut self,
84 query_type: QueryTypeId,
85 query_fn: QueryFn<T, D>,
86 ) -> Self {
87 let prev = self.query_fn.insert(query_type, query_fn);
88 assert!(prev.is_none());
89 self
90 }
91}
92
93pub struct QueryCtx<T, D> {
94 db: Arc<DbState<T, D>>,
95 query_config: Arc<QueryConfig<T, D>>,
96 stack: RefCell<Vec<Vec<(QueryId, OutputFingerprint)>>>,
97 executed: RefCell<Vec<QueryTypeId>>,
98}
99
100impl<T, D> QueryCtx<T, D>
101where
102 D: Clone
103{
104 fn new(db: &Db<T, D>) -> QueryCtx<T, D> {
105 QueryCtx {
106 db: Arc::clone(&db.db),
107 query_config: Arc::clone(&db.query_config),
108 stack: RefCell::new(vec![Vec::new()]),
109 executed: RefCell::new(Vec::new()),
110 }
111 }
112 pub fn get(
113 &self,
114 query_id: QueryId,
115 params: D,
116 ) -> D {
117 let (res, output_fingerprint) = self.get_inner(query_id, params);
118 self.record_dep(query_id, output_fingerprint);
119 res
120 }
121 pub fn trace(&self) -> Vec<QueryTypeId> {
122 ::std::mem::replace(&mut *self.executed.borrow_mut(), Vec::new())
123 }
124
125 fn get_inner(
126 &self,
127 query_id: QueryId,
128 params: D,
129 ) -> (D, OutputFingerprint) {
130 let (gen, record) = {
131 let guard = self.db.graph.lock();
132 match guard.get(&query_id).map(|it| it.clone()){
133 None => {
134 drop(guard);
135 return self.force(query_id, params);
136 },
137 Some(it) => it,
138 }
139 };
140 if gen == self.db.gen {
141 return (record.output.clone(), record.output_fingerprint)
142 }
143 if self.query_config.ground_fn.contains_key(&query_id.0) {
144 let (invalidated, record) = {
145 let guard = self.db.graph.lock();
146 let (gen, ref record) = guard[&query_id];
147 (gen == INVALIDATED, record.clone())
148 };
149 if invalidated {
150 return self.force(query_id, params);
151 } else {
152 return (record.output.clone(), record.output_fingerprint);
153 }
154 }
155 for (dep_query_id, prev_fingerprint) in record.deps.iter().cloned() {
156 let dep_params: D = {
157 let guard = self.db.graph.lock();
158 guard[&dep_query_id]
159 .1
160 .params
161 .clone()
162 };
163 if prev_fingerprint != self.get_inner(dep_query_id, dep_params).1 {
164 return self.force(query_id, params)
165 }
166 }
167 let gen = self.db.gen;
168 {
169 let mut guard = self.db.graph.lock();
170 guard[&query_id].0 = gen;
171 }
172 (record.output.clone(), record.output_fingerprint)
173 }
174 fn force(
175 &self,
176 query_id: QueryId,
177 params: D,
178 ) -> (D, OutputFingerprint) {
179 self.executed.borrow_mut().push(query_id.0);
180 self.stack.borrow_mut().push(Vec::new());
181
182 let (res, output_fingerprint) = if let Some(f) = self.query_config.ground_fn.get(&query_id.0) {
183 f(&self.db.ground_data, &params)
184 } else if let Some(f) = self.query_config.query_fn.get(&query_id.0) {
185 f(self, &params)
186 } else {
187 panic!("unknown query type: {:?}", query_id.0);
188 };
189
190 let res: D = res.into();
191
192 let deps = self.stack.borrow_mut().pop().unwrap();
193 self.db.record(query_id, params, res.clone(), output_fingerprint, deps);
194 (res, output_fingerprint)
195 }
196 fn record_dep(
197 &self,
198 query_id: QueryId,
199 output_fingerprint: OutputFingerprint,
200 ) -> () {
201 let mut stack = self.stack.borrow_mut();
202 let deps = stack.last_mut().unwrap();
203 deps.push((query_id, output_fingerprint))
204 }
205}
206
207pub struct Invalidations {
208 types: HashSet<QueryTypeId>,
209 ids: Vec<QueryId>,
210}
211
212impl Invalidations {
213 pub fn new() -> Invalidations {
214 Invalidations {
215 types: HashSet::new(),
216 ids: Vec::new(),
217 }
218 }
219 pub fn invalidate(
220 &mut self,
221 query_type: QueryTypeId,
222 params: impl Iterator<Item=InputFingerprint>,
223 ) {
224 self.types.insert(query_type);
225 self.ids.extend(params.map(|it| QueryId(query_type, it)))
226 }
227}
228
229impl<T, D> Db<T, D>
230where
231 D: Clone
232{
233 pub fn new(query_config: QueryConfig<T, D>, ground_data: T) -> Db<T, D> {
234 Db {
235 db: Arc::new(DbState { ground_data, gen: Gen(0), graph: Default::default() }),
236 query_config: Arc::new(query_config),
237 }
238 }
239 pub fn ground_data(&self) -> &T {
240 &self.db.ground_data
241 }
242 pub fn with_ground_data(
243 &self,
244 ground_data: T,
245 invalidations: Invalidations,
246 ) -> Db<T, D> {
247 for id in self.query_config.ground_fn.keys() {
248 assert!(
249 invalidations.types.contains(id),
250 "all ground queries must be invalidated"
251 );
252 }
253
254 let gen = Gen(self.db.gen.0 + 1);
255 let mut graph = self.db.graph.lock().clone();
256 for id in invalidations.ids {
257 if let Some((gen, _)) = graph.get_mut(&id) {
258 *gen = INVALIDATED;
259 }
260 }
261 let graph = Mutex::new(graph);
262 Db {
263 db: Arc::new(DbState { ground_data, gen, graph }),
264 query_config: Arc::clone(&self.query_config)
265 }
266 }
267 pub fn query_ctx(&self) -> QueryCtx<T, D> {
268 QueryCtx::new(self)
269 }
270 pub fn get(
271 &self,
272 query_id: QueryId,
273 params: D,
274 ) -> (D, Vec<QueryTypeId>) {
275 let ctx = self.query_ctx();
276 let res = ctx.get(query_id, params.into());
277 let executed = ::std::mem::replace(&mut *ctx.executed.borrow_mut(), Vec::new());
278 (res, executed)
279 }
280}
281
282#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
283struct Gen(u64);
284const INVALIDATED: Gen = Gen(!0);
285#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
286pub struct InputFingerprint(pub u64);
287#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
288pub struct OutputFingerprint(pub u64);
289#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
290pub struct QueryTypeId(pub u16);
291#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
292pub struct QueryId(pub QueryTypeId, pub InputFingerprint);
293
diff --git a/crates/salsa/tests/integration.rs b/crates/salsa/tests/integration.rs
new file mode 100644
index 000000000..aed9219be
--- /dev/null
+++ b/crates/salsa/tests/integration.rs
@@ -0,0 +1,170 @@
1extern crate salsa;
2use std::{
3 iter::once,
4 sync::Arc,
5 collections::hash_map::{HashMap, DefaultHasher},
6 any::Any,
7 hash::{Hash, Hasher},
8};
9
10type State = HashMap<u32, String>;
11type Data = Arc<Any + Send + Sync + 'static>;
12const GET_TEXT: salsa::QueryTypeId = salsa::QueryTypeId(1);
13const GET_FILES: salsa::QueryTypeId = salsa::QueryTypeId(2);
14const FILE_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(3);
15const TOTAL_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(4);
16
17fn mk_ground_query<T, R>(
18 state: &State,
19 params: &Data,
20 f: fn(&State, &T) -> R,
21) -> (Data, salsa::OutputFingerprint)
22where
23 T: 'static,
24 R: Hash + Send + Sync + 'static,
25{
26 let params = params.downcast_ref().unwrap();
27 let result = f(state, params);
28 let fingerprint = o_print(&result);
29 (Arc::new(result), fingerprint)
30}
31
32fn get<T, R>(db: &salsa::Db<State, Data>, query_type: salsa::QueryTypeId, param: T) -> (Arc<R>, Vec<salsa::QueryTypeId>)
33where
34 T: Hash + Send + Sync + 'static,
35 R: Send + Sync + 'static,
36{
37 let i_print = i_print(&param);
38 let param = Arc::new(param);
39 let (res, trace) = db.get(salsa::QueryId(query_type, i_print), param);
40 (res.downcast().unwrap(), trace)
41}
42
43struct QueryCtx<'a>(&'a salsa::QueryCtx<State, Data>);
44
45impl<'a> QueryCtx<'a> {
46 fn get_text(&self, id: u32) -> Arc<String> {
47 let i_print = i_print(&id);
48 let text = self.0.get(salsa::QueryId(GET_TEXT, i_print), Arc::new(id));
49 text.downcast().unwrap()
50 }
51 fn get_files(&self) -> Arc<Vec<u32>> {
52 let i_print = i_print(&());
53 let files = self.0.get(salsa::QueryId(GET_FILES, i_print), Arc::new(()));
54 let res = files.downcast().unwrap();
55 res
56 }
57 fn get_n_lines(&self, id: u32) -> usize {
58 let i_print = i_print(&id);
59 let n_lines = self.0.get(salsa::QueryId(FILE_NEWLINES, i_print), Arc::new(id));
60 *n_lines.downcast().unwrap()
61 }
62}
63
64fn mk_query<T, R>(
65 query_ctx: &salsa::QueryCtx<State, Data>,
66 params: &Data,
67 f: fn(QueryCtx, &T) -> R,
68) -> (Data, salsa::OutputFingerprint)
69where
70 T: 'static,
71 R: Hash + Send + Sync + 'static,
72{
73 let params: &T = params.downcast_ref().unwrap();
74 let query_ctx = QueryCtx(query_ctx);
75 let result = f(query_ctx, params);
76 let fingerprint = o_print(&result);
77 (Arc::new(result), fingerprint)
78}
79
80fn mk_queries() -> salsa::QueryConfig<State, Data> {
81 salsa::QueryConfig::<State, Data>::new()
82 .with_ground_query(GET_TEXT, Box::new(|state, id| {
83 mk_ground_query::<u32, String>(state, id, |state, id| state[id].clone())
84 }))
85 .with_ground_query(GET_FILES, Box::new(|state, id| {
86 mk_ground_query::<(), Vec<u32>>(state, id, |state, &()| state.keys().cloned().collect())
87 }))
88 .with_query(FILE_NEWLINES, Box::new(|query_ctx, id| {
89 mk_query(query_ctx, id, |query_ctx, &id| {
90 let text = query_ctx.get_text(id);
91 text.lines().count()
92 })
93 }))
94 .with_query(TOTAL_NEWLINES, Box::new(|query_ctx, id| {
95 mk_query(query_ctx, id, |query_ctx, &()| {
96 let mut total = 0;
97 for &id in query_ctx.get_files().iter() {
98 total += query_ctx.get_n_lines(id)
99 }
100 total
101 })
102 }))
103}
104
105#[test]
106fn test_number_of_lines() {
107 let mut state = State::new();
108 let db = salsa::Db::new(mk_queries(), state.clone());
109 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
110 assert_eq!(*newlines, 0);
111 assert_eq!(trace.len(), 2);
112 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
113 assert_eq!(*newlines, 0);
114 assert_eq!(trace.len(), 0);
115
116 state.insert(1, "hello\nworld".to_string());
117 let mut inv = salsa::Invalidations::new();
118 inv.invalidate(GET_TEXT, once(i_print(&1u32)));
119 inv.invalidate(GET_FILES, once(i_print(&())));
120 let db = db.with_ground_data(state.clone(), inv);
121 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
122 assert_eq!(*newlines, 2);
123 assert_eq!(trace.len(), 4);
124
125 state.insert(2, "spam\neggs".to_string());
126 let mut inv = salsa::Invalidations::new();
127 inv.invalidate(GET_TEXT, once(i_print(&2u32)));
128 inv.invalidate(GET_FILES, once(i_print(&())));
129 let db = db.with_ground_data(state.clone(), inv);
130 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
131 assert_eq!(*newlines, 4);
132 assert_eq!(trace.len(), 4);
133
134 let mut invs = vec![];
135 for i in 0..10 {
136 let id = i + 10;
137 invs.push(i_print(&id));
138 state.insert(id, "spam".to_string());
139 }
140 let mut inv = salsa::Invalidations::new();
141 inv.invalidate(GET_TEXT, invs.into_iter());
142 inv.invalidate(GET_FILES, once(i_print(&())));
143 let db = db.with_ground_data(state.clone(), inv);
144 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
145 assert_eq!(*newlines, 14);
146 assert_eq!(trace.len(), 22);
147
148 state.insert(15, String::new());
149 let mut inv = salsa::Invalidations::new();
150 inv.invalidate(GET_TEXT, once(i_print(&15u32)));
151 inv.invalidate(GET_FILES, once(i_print(&())));
152 let db = db.with_ground_data(state.clone(), inv);
153 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
154 assert_eq!(*newlines, 13);
155 assert_eq!(trace.len(), 4);
156}
157
158fn o_print<T: Hash>(x: &T) -> salsa::OutputFingerprint {
159 let mut hasher = DefaultHasher::new();
160 x.hash(&mut hasher);
161 let hash = hasher.finish();
162 salsa::OutputFingerprint(hash)
163}
164
165fn i_print<T: Hash>(x: &T) -> salsa::InputFingerprint {
166 let mut hasher = DefaultHasher::new();
167 x.hash(&mut hasher);
168 let hash = hasher.finish();
169 salsa::InputFingerprint(hash)
170}