diff options
Diffstat (limited to 'crates/salsa/src/lib.rs')
-rw-r--r-- | crates/salsa/src/lib.rs | 293 |
1 files changed, 0 insertions, 293 deletions
diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs deleted file mode 100644 index 35deed374..000000000 --- a/crates/salsa/src/lib.rs +++ /dev/null | |||
@@ -1,293 +0,0 @@ | |||
1 | extern crate im; | ||
2 | extern crate parking_lot; | ||
3 | |||
4 | use std::{ | ||
5 | sync::Arc, | ||
6 | collections::{HashSet, HashMap}, | ||
7 | cell::RefCell, | ||
8 | }; | ||
9 | use parking_lot::Mutex; | ||
10 | |||
11 | pub type GroundQueryFn<T, D> = Box<Fn(&T, &D) -> (D, OutputFingerprint) + Send + Sync + 'static>; | ||
12 | pub type QueryFn<T, D> = Box<Fn(&QueryCtx<T, D>, &D) -> (D, OutputFingerprint) + Send + Sync + 'static>; | ||
13 | |||
14 | #[derive(Debug)] | ||
15 | pub struct Db<T, D> { | ||
16 | db: Arc<DbState<T, D>>, | ||
17 | query_config: Arc<QueryConfig<T, D>>, | ||
18 | } | ||
19 | |||
20 | pub struct QueryConfig<T, D> { | ||
21 | ground_fn: HashMap<QueryTypeId, GroundQueryFn<T, D>>, | ||
22 | query_fn: HashMap<QueryTypeId, QueryFn<T, D>>, | ||
23 | } | ||
24 | |||
25 | impl<T, D> ::std::fmt::Debug for QueryConfig<T, D> { | ||
26 | fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { | ||
27 | ::std::fmt::Display::fmt("QueryConfig { ... }", f) | ||
28 | } | ||
29 | } | ||
30 | |||
31 | #[derive(Debug)] | ||
32 | struct DbState<T, D> { | ||
33 | ground_data: T, | ||
34 | gen: Gen, | ||
35 | graph: Mutex<im::HashMap<QueryId, (Gen, Arc<QueryRecord<D>>)>>, | ||
36 | } | ||
37 | |||
38 | #[derive(Debug)] | ||
39 | struct QueryRecord<D> { | ||
40 | params: D, | ||
41 | output: D, | ||
42 | output_fingerprint: OutputFingerprint, | ||
43 | deps: Vec<(QueryId, OutputFingerprint)>, | ||
44 | } | ||
45 | |||
46 | impl<T, D> DbState<T, D> { | ||
47 | fn record( | ||
48 | &self, | ||
49 | query_id: QueryId, | ||
50 | params: D, | ||
51 | output: D, | ||
52 | output_fingerprint: OutputFingerprint, | ||
53 | deps: Vec<(QueryId, OutputFingerprint)>, | ||
54 | ) { | ||
55 | let gen = self.gen; | ||
56 | let record = QueryRecord { | ||
57 | params, | ||
58 | output, | ||
59 | output_fingerprint, | ||
60 | deps, | ||
61 | }; | ||
62 | self.graph.lock().insert(query_id, (gen, Arc::new(record))); | ||
63 | } | ||
64 | } | ||
65 | |||
66 | impl<T, D> QueryConfig<T, D> { | ||
67 | pub fn new() -> Self { | ||
68 | QueryConfig { | ||
69 | ground_fn: HashMap::new(), | ||
70 | query_fn: HashMap::new(), | ||
71 | } | ||
72 | } | ||
73 | pub fn with_ground_query( | ||
74 | mut self, | ||
75 | query_type: QueryTypeId, | ||
76 | query_fn: GroundQueryFn<T, D> | ||
77 | ) -> Self { | ||
78 | let prev = self.ground_fn.insert(query_type, query_fn); | ||
79 | assert!(prev.is_none()); | ||
80 | self | ||
81 | } | ||
82 | pub fn with_query( | ||
83 | mut self, | ||
84 | query_type: QueryTypeId, | ||
85 | query_fn: QueryFn<T, D>, | ||
86 | ) -> Self { | ||
87 | let prev = self.query_fn.insert(query_type, query_fn); | ||
88 | assert!(prev.is_none()); | ||
89 | self | ||
90 | } | ||
91 | } | ||
92 | |||
93 | pub struct QueryCtx<T, D> { | ||
94 | db: Arc<DbState<T, D>>, | ||
95 | query_config: Arc<QueryConfig<T, D>>, | ||
96 | stack: RefCell<Vec<Vec<(QueryId, OutputFingerprint)>>>, | ||
97 | executed: RefCell<Vec<QueryTypeId>>, | ||
98 | } | ||
99 | |||
100 | impl<T, D> QueryCtx<T, D> | ||
101 | where | ||
102 | D: Clone | ||
103 | { | ||
104 | fn new(db: &Db<T, D>) -> QueryCtx<T, D> { | ||
105 | QueryCtx { | ||
106 | db: Arc::clone(&db.db), | ||
107 | query_config: Arc::clone(&db.query_config), | ||
108 | stack: RefCell::new(vec![Vec::new()]), | ||
109 | executed: RefCell::new(Vec::new()), | ||
110 | } | ||
111 | } | ||
112 | pub fn get( | ||
113 | &self, | ||
114 | query_id: QueryId, | ||
115 | params: D, | ||
116 | ) -> D { | ||
117 | let (res, output_fingerprint) = self.get_inner(query_id, params); | ||
118 | self.record_dep(query_id, output_fingerprint); | ||
119 | res | ||
120 | } | ||
121 | pub fn trace(&self) -> Vec<QueryTypeId> { | ||
122 | ::std::mem::replace(&mut *self.executed.borrow_mut(), Vec::new()) | ||
123 | } | ||
124 | |||
125 | fn get_inner( | ||
126 | &self, | ||
127 | query_id: QueryId, | ||
128 | params: D, | ||
129 | ) -> (D, OutputFingerprint) { | ||
130 | let (gen, record) = { | ||
131 | let guard = self.db.graph.lock(); | ||
132 | match guard.get(&query_id).map(|it| it.clone()){ | ||
133 | None => { | ||
134 | drop(guard); | ||
135 | return self.force(query_id, params); | ||
136 | }, | ||
137 | Some(it) => it, | ||
138 | } | ||
139 | }; | ||
140 | if gen == self.db.gen { | ||
141 | return (record.output.clone(), record.output_fingerprint) | ||
142 | } | ||
143 | if self.query_config.ground_fn.contains_key(&query_id.0) { | ||
144 | let (invalidated, record) = { | ||
145 | let guard = self.db.graph.lock(); | ||
146 | let (gen, ref record) = guard[&query_id]; | ||
147 | (gen == INVALIDATED, record.clone()) | ||
148 | }; | ||
149 | if invalidated { | ||
150 | return self.force(query_id, params); | ||
151 | } else { | ||
152 | return (record.output.clone(), record.output_fingerprint); | ||
153 | } | ||
154 | } | ||
155 | for (dep_query_id, prev_fingerprint) in record.deps.iter().cloned() { | ||
156 | let dep_params: D = { | ||
157 | let guard = self.db.graph.lock(); | ||
158 | guard[&dep_query_id] | ||
159 | .1 | ||
160 | .params | ||
161 | .clone() | ||
162 | }; | ||
163 | if prev_fingerprint != self.get_inner(dep_query_id, dep_params).1 { | ||
164 | return self.force(query_id, params) | ||
165 | } | ||
166 | } | ||
167 | let gen = self.db.gen; | ||
168 | { | ||
169 | let mut guard = self.db.graph.lock(); | ||
170 | guard[&query_id].0 = gen; | ||
171 | } | ||
172 | (record.output.clone(), record.output_fingerprint) | ||
173 | } | ||
174 | fn force( | ||
175 | &self, | ||
176 | query_id: QueryId, | ||
177 | params: D, | ||
178 | ) -> (D, OutputFingerprint) { | ||
179 | self.executed.borrow_mut().push(query_id.0); | ||
180 | self.stack.borrow_mut().push(Vec::new()); | ||
181 | |||
182 | let (res, output_fingerprint) = if let Some(f) = self.query_config.ground_fn.get(&query_id.0) { | ||
183 | f(&self.db.ground_data, ¶ms) | ||
184 | } else if let Some(f) = self.query_config.query_fn.get(&query_id.0) { | ||
185 | f(self, ¶ms) | ||
186 | } else { | ||
187 | panic!("unknown query type: {:?}", query_id.0); | ||
188 | }; | ||
189 | |||
190 | let res: D = res.into(); | ||
191 | |||
192 | let deps = self.stack.borrow_mut().pop().unwrap(); | ||
193 | self.db.record(query_id, params, res.clone(), output_fingerprint, deps); | ||
194 | (res, output_fingerprint) | ||
195 | } | ||
196 | fn record_dep( | ||
197 | &self, | ||
198 | query_id: QueryId, | ||
199 | output_fingerprint: OutputFingerprint, | ||
200 | ) -> () { | ||
201 | let mut stack = self.stack.borrow_mut(); | ||
202 | let deps = stack.last_mut().unwrap(); | ||
203 | deps.push((query_id, output_fingerprint)) | ||
204 | } | ||
205 | } | ||
206 | |||
207 | pub struct Invalidations { | ||
208 | types: HashSet<QueryTypeId>, | ||
209 | ids: Vec<QueryId>, | ||
210 | } | ||
211 | |||
212 | impl Invalidations { | ||
213 | pub fn new() -> Invalidations { | ||
214 | Invalidations { | ||
215 | types: HashSet::new(), | ||
216 | ids: Vec::new(), | ||
217 | } | ||
218 | } | ||
219 | pub fn invalidate( | ||
220 | &mut self, | ||
221 | query_type: QueryTypeId, | ||
222 | params: impl Iterator<Item=InputFingerprint>, | ||
223 | ) { | ||
224 | self.types.insert(query_type); | ||
225 | self.ids.extend(params.map(|it| QueryId(query_type, it))) | ||
226 | } | ||
227 | } | ||
228 | |||
229 | impl<T, D> Db<T, D> | ||
230 | where | ||
231 | D: Clone | ||
232 | { | ||
233 | pub fn new(query_config: QueryConfig<T, D>, ground_data: T) -> Db<T, D> { | ||
234 | Db { | ||
235 | db: Arc::new(DbState { ground_data, gen: Gen(0), graph: Default::default() }), | ||
236 | query_config: Arc::new(query_config), | ||
237 | } | ||
238 | } | ||
239 | pub fn ground_data(&self) -> &T { | ||
240 | &self.db.ground_data | ||
241 | } | ||
242 | pub fn with_ground_data( | ||
243 | &self, | ||
244 | ground_data: T, | ||
245 | invalidations: Invalidations, | ||
246 | ) -> Db<T, D> { | ||
247 | for id in self.query_config.ground_fn.keys() { | ||
248 | assert!( | ||
249 | invalidations.types.contains(id), | ||
250 | "all ground queries must be invalidated" | ||
251 | ); | ||
252 | } | ||
253 | |||
254 | let gen = Gen(self.db.gen.0 + 1); | ||
255 | let mut graph = self.db.graph.lock().clone(); | ||
256 | for id in invalidations.ids { | ||
257 | if let Some((gen, _)) = graph.get_mut(&id) { | ||
258 | *gen = INVALIDATED; | ||
259 | } | ||
260 | } | ||
261 | let graph = Mutex::new(graph); | ||
262 | Db { | ||
263 | db: Arc::new(DbState { ground_data, gen, graph }), | ||
264 | query_config: Arc::clone(&self.query_config) | ||
265 | } | ||
266 | } | ||
267 | pub fn query_ctx(&self) -> QueryCtx<T, D> { | ||
268 | QueryCtx::new(self) | ||
269 | } | ||
270 | pub fn get( | ||
271 | &self, | ||
272 | query_id: QueryId, | ||
273 | params: D, | ||
274 | ) -> (D, Vec<QueryTypeId>) { | ||
275 | let ctx = self.query_ctx(); | ||
276 | let res = ctx.get(query_id, params.into()); | ||
277 | let executed = ::std::mem::replace(&mut *ctx.executed.borrow_mut(), Vec::new()); | ||
278 | (res, executed) | ||
279 | } | ||
280 | } | ||
281 | |||
282 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
283 | struct Gen(u64); | ||
284 | const INVALIDATED: Gen = Gen(!0); | ||
285 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
286 | pub struct InputFingerprint(pub u64); | ||
287 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
288 | pub struct OutputFingerprint(pub u64); | ||
289 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
290 | pub struct QueryTypeId(pub u16); | ||
291 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | ||
292 | pub struct QueryId(pub QueryTypeId, pub InputFingerprint); | ||
293 | |||