diff options
author | Aleksey Kladov <[email protected]> | 2018-09-12 19:50:15 +0100 |
---|---|---|
committer | Aleksey Kladov <[email protected]> | 2018-09-15 22:00:05 +0100 |
commit | 8cf9c2719652d298006d51bc82a32908ab4e5335 (patch) | |
tree | c74d3c63b2b2d0463e557ce25dca9d0230f8f00e /crates/salsa/src | |
parent | 0e493160c0cdbaa71f61af64fd7c439410e8c8b1 (diff) |
generic salsa algo
Diffstat (limited to 'crates/salsa/src')
-rw-r--r-- | crates/salsa/src/lib.rs | 238 |
1 files changed, 238 insertions, 0 deletions
diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs new file mode 100644 index 000000000..69c7b35fa --- /dev/null +++ b/crates/salsa/src/lib.rs | |||
@@ -0,0 +1,238 @@ | |||
1 | extern crate im; | ||
2 | extern crate parking_lot; | ||
3 | |||
4 | use std::{ | ||
5 | sync::Arc, | ||
6 | any::Any, | ||
7 | collections::HashMap, | ||
8 | cell::RefCell, | ||
9 | }; | ||
10 | use parking_lot::Mutex; | ||
11 | |||
12 | type GroundQueryFn<T> = fn(&T, &(Any + Send + Sync + 'static)) -> (Box<Any + Send + Sync + 'static>, OutputFingerprint); | ||
13 | type QueryFn<T> = fn(&QueryCtx<T>, &(Any + Send + Sync + 'static)) -> (Box<Any + Send + Sync + 'static>, OutputFingerprint); | ||
14 | |||
15 | #[derive(Debug)] | ||
16 | pub struct Db<T> { | ||
17 | db: Arc<DbState<T>>, | ||
18 | query_config: Arc<QueryConfig<T>>, | ||
19 | } | ||
20 | |||
21 | pub struct QueryConfig<T> { | ||
22 | ground_fn: HashMap<QueryTypeId, GroundQueryFn<T>>, | ||
23 | query_fn: HashMap<QueryTypeId, QueryFn<T>>, | ||
24 | } | ||
25 | |||
26 | impl<T> ::std::fmt::Debug for QueryConfig<T> { | ||
27 | fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { | ||
28 | ::std::fmt::Display::fmt("QueryConfig { ... }", f) | ||
29 | } | ||
30 | } | ||
31 | |||
32 | #[derive(Debug)] | ||
33 | struct DbState<T> { | ||
34 | ground_data: T, | ||
35 | gen: Gen, | ||
36 | graph: Mutex<im::HashMap<QueryId, (Gen, Arc<QueryRecord>)>>, | ||
37 | } | ||
38 | |||
39 | #[derive(Debug)] | ||
40 | struct QueryRecord { | ||
41 | params: Arc<Any + Send + Sync + 'static>, | ||
42 | output: Arc<Any + Send + Sync + 'static>, | ||
43 | output_fingerprint: OutputFingerprint, | ||
44 | deps: Vec<(QueryId, OutputFingerprint)>, | ||
45 | } | ||
46 | |||
47 | impl<T> DbState<T> { | ||
48 | fn record( | ||
49 | &self, | ||
50 | query_id: QueryId, | ||
51 | params: Arc<Any + Send + Sync + 'static>, | ||
52 | output: Arc<Any + Send + Sync + 'static>, | ||
53 | output_fingerprint: OutputFingerprint, | ||
54 | deps: Vec<(QueryId, OutputFingerprint)>, | ||
55 | ) { | ||
56 | let gen = self.gen; | ||
57 | let record = QueryRecord { | ||
58 | params, | ||
59 | output, | ||
60 | output_fingerprint, | ||
61 | deps, | ||
62 | }; | ||
63 | self.graph.lock().insert(query_id, (gen, Arc::new(record))); | ||
64 | } | ||
65 | } | ||
66 | |||
67 | impl<T> QueryConfig<T> { | ||
68 | pub fn new() -> Self { | ||
69 | QueryConfig { | ||
70 | ground_fn: HashMap::new(), | ||
71 | query_fn: HashMap::new(), | ||
72 | } | ||
73 | } | ||
74 | pub fn with_ground_query( | ||
75 | mut self, | ||
76 | query_type: QueryTypeId, | ||
77 | query_fn: GroundQueryFn<T> | ||
78 | ) -> Self { | ||
79 | let prev = self.ground_fn.insert(query_type, query_fn); | ||
80 | assert!(prev.is_none()); | ||
81 | self | ||
82 | } | ||
83 | pub fn with_query( | ||
84 | mut self, | ||
85 | query_type: QueryTypeId, | ||
86 | query_fn: QueryFn<T>, | ||
87 | ) -> Self { | ||
88 | let prev = self.query_fn.insert(query_type, query_fn); | ||
89 | assert!(prev.is_none()); | ||
90 | self | ||
91 | } | ||
92 | } | ||
93 | |||
94 | pub struct QueryCtx<T> { | ||
95 | db: Arc<DbState<T>>, | ||
96 | query_config: Arc<QueryConfig<T>>, | ||
97 | stack: RefCell<Vec<Vec<(QueryId, OutputFingerprint)>>>, | ||
98 | executed: RefCell<Vec<QueryTypeId>>, | ||
99 | } | ||
100 | |||
101 | impl<T> QueryCtx<T> { | ||
102 | fn new(db: &Db<T>) -> QueryCtx<T> { | ||
103 | QueryCtx { | ||
104 | db: Arc::clone(&db.db), | ||
105 | query_config: Arc::clone(&db.query_config), | ||
106 | stack: RefCell::new(vec![Vec::new()]), | ||
107 | executed: RefCell::new(Vec::new()), | ||
108 | } | ||
109 | } | ||
110 | pub fn get( | ||
111 | &self, | ||
112 | query_id: QueryId, | ||
113 | params: Arc<Any + Send + Sync + 'static>, | ||
114 | ) -> Arc<Any + Send + Sync + 'static> { | ||
115 | let (res, output_fingerprint) = self.get_inner(query_id, params); | ||
116 | self.record_dep(query_id, output_fingerprint); | ||
117 | res | ||
118 | } | ||
119 | |||
120 | pub fn get_inner( | ||
121 | &self, | ||
122 | query_id: QueryId, | ||
123 | params: Arc<Any + Send + Sync + 'static>, | ||
124 | ) -> (Arc<Any + Send + Sync + 'static>, OutputFingerprint) { | ||
125 | let (gen, record) = { | ||
126 | let guard = self.db.graph.lock(); | ||
127 | match guard.get(&query_id).map(|it| it.clone()){ | ||
128 | None => { | ||
129 | drop(guard); | ||
130 | return self.force(query_id, params); | ||
131 | }, | ||
132 | Some(it) => it, | ||
133 | } | ||
134 | }; | ||
135 | if gen == self.db.gen { | ||
136 | return (record.output.clone(), record.output_fingerprint) | ||
137 | } | ||
138 | if self.query_config.ground_fn.contains_key(&query_id.0) { | ||
139 | return self.force(query_id, params); | ||
140 | } | ||
141 | for (dep_query_id, prev_fingerprint) in record.deps.iter().cloned() { | ||
142 | let dep_params: Arc<Any + Send + Sync + 'static> = { | ||
143 | let guard = self.db.graph.lock(); | ||
144 | guard[&dep_query_id] | ||
145 | .1 | ||
146 | .params | ||
147 | .clone() | ||
148 | }; | ||
149 | if prev_fingerprint != self.get_inner(dep_query_id, dep_params).1 { | ||
150 | return self.force(query_id, params) | ||
151 | } | ||
152 | } | ||
153 | let gen = self.db.gen; | ||
154 | { | ||
155 | let mut guard = self.db.graph.lock(); | ||
156 | guard[&query_id].0 = gen; | ||
157 | } | ||
158 | (record.output.clone(), record.output_fingerprint) | ||
159 | } | ||
160 | fn force( | ||
161 | &self, | ||
162 | query_id: QueryId, | ||
163 | params: Arc<Any + Send + Sync + 'static>, | ||
164 | ) -> (Arc<Any + Send + Sync + 'static>, OutputFingerprint) { | ||
165 | self.executed.borrow_mut().push(query_id.0); | ||
166 | self.stack.borrow_mut().push(Vec::new()); | ||
167 | |||
168 | let (res, output_fingerprint) = if let Some(f) = self.ground_query_fn_by_type(query_id.0) { | ||
169 | f(&self.db.ground_data, &*params) | ||
170 | } else if let Some(f) = self.query_fn_by_type(query_id.0) { | ||
171 | f(self, &*params) | ||
172 | } else { | ||
173 | panic!("unknown query type: {:?}", query_id.0); | ||
174 | }; | ||
175 | |||
176 | let res: Arc<Any + Send + Sync + 'static> = res.into(); | ||
177 | |||
178 | let deps = self.stack.borrow_mut().pop().unwrap(); | ||
179 | self.db.record(query_id, params, res.clone(), output_fingerprint, deps); | ||
180 | (res, output_fingerprint) | ||
181 | } | ||
182 | fn ground_query_fn_by_type(&self, query_type: QueryTypeId) -> Option<GroundQueryFn<T>> { | ||
183 | self.query_config.ground_fn.get(&query_type).map(|&it| it) | ||
184 | } | ||
185 | fn query_fn_by_type(&self, query_type: QueryTypeId) -> Option<QueryFn<T>> { | ||
186 | self.query_config.query_fn.get(&query_type).map(|&it| it) | ||
187 | } | ||
188 | fn record_dep( | ||
189 | &self, | ||
190 | query_id: QueryId, | ||
191 | output_fingerprint: OutputFingerprint, | ||
192 | ) -> () { | ||
193 | let mut stack = self.stack.borrow_mut(); | ||
194 | let deps = stack.last_mut().unwrap(); | ||
195 | deps.push((query_id, output_fingerprint)) | ||
196 | } | ||
197 | } | ||
198 | |||
199 | impl<T> Db<T> { | ||
200 | pub fn new(query_config: QueryConfig<T>, ground_data: T) -> Db<T> { | ||
201 | Db { | ||
202 | db: Arc::new(DbState { ground_data, gen: Gen(0), graph: Default::default() }), | ||
203 | query_config: Arc::new(query_config), | ||
204 | } | ||
205 | } | ||
206 | |||
207 | pub fn with_ground_data(&self, ground_data: T) -> Db<T> { | ||
208 | let gen = Gen(self.db.gen.0 + 1); | ||
209 | let graph = self.db.graph.lock().clone(); | ||
210 | let graph = Mutex::new(graph); | ||
211 | Db { | ||
212 | db: Arc::new(DbState { ground_data, gen, graph }), | ||
213 | query_config: Arc::clone(&self.query_config) | ||
214 | } | ||
215 | } | ||
216 | pub fn get( | ||
217 | &self, | ||
218 | query_id: QueryId, | ||
219 | params: Box<Any + Send + Sync + 'static>, | ||
220 | ) -> (Arc<Any + Send + Sync + 'static>, Vec<QueryTypeId>) { | ||
221 | let ctx = QueryCtx::new(self); | ||
222 | let res = ctx.get(query_id, params.into()); | ||
223 | let executed = ::std::mem::replace(&mut *ctx.executed.borrow_mut(), Vec::new()); | ||
224 | (res, executed) | ||
225 | } | ||
226 | } | ||
227 | |||
228 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
229 | struct Gen(u64); | ||
230 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
231 | pub struct InputFingerprint(pub u64); | ||
232 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
233 | pub struct OutputFingerprint(pub u64); | ||
234 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
235 | pub struct QueryTypeId(pub u16); | ||
236 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
237 | pub struct QueryId(pub QueryTypeId, pub InputFingerprint); | ||
238 | |||