aboutsummaryrefslogtreecommitdiff
path: root/crates/salsa/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/salsa/src/lib.rs')
-rw-r--r--crates/salsa/src/lib.rs293
1 files changed, 0 insertions, 293 deletions
diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs
deleted file mode 100644
index 35deed374..000000000
--- a/crates/salsa/src/lib.rs
+++ /dev/null
@@ -1,293 +0,0 @@
1extern crate im;
2extern crate parking_lot;
3
4use std::{
5 sync::Arc,
6 collections::{HashSet, HashMap},
7 cell::RefCell,
8};
9use parking_lot::Mutex;
10
11pub type GroundQueryFn<T, D> = Box<Fn(&T, &D) -> (D, OutputFingerprint) + Send + Sync + 'static>;
12pub type QueryFn<T, D> = Box<Fn(&QueryCtx<T, D>, &D) -> (D, OutputFingerprint) + Send + Sync + 'static>;
13
14#[derive(Debug)]
15pub struct Db<T, D> {
16 db: Arc<DbState<T, D>>,
17 query_config: Arc<QueryConfig<T, D>>,
18}
19
20pub struct QueryConfig<T, D> {
21 ground_fn: HashMap<QueryTypeId, GroundQueryFn<T, D>>,
22 query_fn: HashMap<QueryTypeId, QueryFn<T, D>>,
23}
24
25impl<T, D> ::std::fmt::Debug for QueryConfig<T, D> {
26 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
27 ::std::fmt::Display::fmt("QueryConfig { ... }", f)
28 }
29}
30
31#[derive(Debug)]
32struct DbState<T, D> {
33 ground_data: T,
34 gen: Gen,
35 graph: Mutex<im::HashMap<QueryId, (Gen, Arc<QueryRecord<D>>)>>,
36}
37
38#[derive(Debug)]
39struct QueryRecord<D> {
40 params: D,
41 output: D,
42 output_fingerprint: OutputFingerprint,
43 deps: Vec<(QueryId, OutputFingerprint)>,
44}
45
46impl<T, D> DbState<T, D> {
47 fn record(
48 &self,
49 query_id: QueryId,
50 params: D,
51 output: D,
52 output_fingerprint: OutputFingerprint,
53 deps: Vec<(QueryId, OutputFingerprint)>,
54 ) {
55 let gen = self.gen;
56 let record = QueryRecord {
57 params,
58 output,
59 output_fingerprint,
60 deps,
61 };
62 self.graph.lock().insert(query_id, (gen, Arc::new(record)));
63 }
64}
65
66impl<T, D> QueryConfig<T, D> {
67 pub fn new() -> Self {
68 QueryConfig {
69 ground_fn: HashMap::new(),
70 query_fn: HashMap::new(),
71 }
72 }
73 pub fn with_ground_query(
74 mut self,
75 query_type: QueryTypeId,
76 query_fn: GroundQueryFn<T, D>
77 ) -> Self {
78 let prev = self.ground_fn.insert(query_type, query_fn);
79 assert!(prev.is_none());
80 self
81 }
82 pub fn with_query(
83 mut self,
84 query_type: QueryTypeId,
85 query_fn: QueryFn<T, D>,
86 ) -> Self {
87 let prev = self.query_fn.insert(query_type, query_fn);
88 assert!(prev.is_none());
89 self
90 }
91}
92
93pub struct QueryCtx<T, D> {
94 db: Arc<DbState<T, D>>,
95 query_config: Arc<QueryConfig<T, D>>,
96 stack: RefCell<Vec<Vec<(QueryId, OutputFingerprint)>>>,
97 executed: RefCell<Vec<QueryTypeId>>,
98}
99
100impl<T, D> QueryCtx<T, D>
101where
102 D: Clone
103{
104 fn new(db: &Db<T, D>) -> QueryCtx<T, D> {
105 QueryCtx {
106 db: Arc::clone(&db.db),
107 query_config: Arc::clone(&db.query_config),
108 stack: RefCell::new(vec![Vec::new()]),
109 executed: RefCell::new(Vec::new()),
110 }
111 }
112 pub fn get(
113 &self,
114 query_id: QueryId,
115 params: D,
116 ) -> D {
117 let (res, output_fingerprint) = self.get_inner(query_id, params);
118 self.record_dep(query_id, output_fingerprint);
119 res
120 }
121 pub fn trace(&self) -> Vec<QueryTypeId> {
122 ::std::mem::replace(&mut *self.executed.borrow_mut(), Vec::new())
123 }
124
125 fn get_inner(
126 &self,
127 query_id: QueryId,
128 params: D,
129 ) -> (D, OutputFingerprint) {
130 let (gen, record) = {
131 let guard = self.db.graph.lock();
132 match guard.get(&query_id).map(|it| it.clone()){
133 None => {
134 drop(guard);
135 return self.force(query_id, params);
136 },
137 Some(it) => it,
138 }
139 };
140 if gen == self.db.gen {
141 return (record.output.clone(), record.output_fingerprint)
142 }
143 if self.query_config.ground_fn.contains_key(&query_id.0) {
144 let (invalidated, record) = {
145 let guard = self.db.graph.lock();
146 let (gen, ref record) = guard[&query_id];
147 (gen == INVALIDATED, record.clone())
148 };
149 if invalidated {
150 return self.force(query_id, params);
151 } else {
152 return (record.output.clone(), record.output_fingerprint);
153 }
154 }
155 for (dep_query_id, prev_fingerprint) in record.deps.iter().cloned() {
156 let dep_params: D = {
157 let guard = self.db.graph.lock();
158 guard[&dep_query_id]
159 .1
160 .params
161 .clone()
162 };
163 if prev_fingerprint != self.get_inner(dep_query_id, dep_params).1 {
164 return self.force(query_id, params)
165 }
166 }
167 let gen = self.db.gen;
168 {
169 let mut guard = self.db.graph.lock();
170 guard[&query_id].0 = gen;
171 }
172 (record.output.clone(), record.output_fingerprint)
173 }
174 fn force(
175 &self,
176 query_id: QueryId,
177 params: D,
178 ) -> (D, OutputFingerprint) {
179 self.executed.borrow_mut().push(query_id.0);
180 self.stack.borrow_mut().push(Vec::new());
181
182 let (res, output_fingerprint) = if let Some(f) = self.query_config.ground_fn.get(&query_id.0) {
183 f(&self.db.ground_data, &params)
184 } else if let Some(f) = self.query_config.query_fn.get(&query_id.0) {
185 f(self, &params)
186 } else {
187 panic!("unknown query type: {:?}", query_id.0);
188 };
189
190 let res: D = res.into();
191
192 let deps = self.stack.borrow_mut().pop().unwrap();
193 self.db.record(query_id, params, res.clone(), output_fingerprint, deps);
194 (res, output_fingerprint)
195 }
196 fn record_dep(
197 &self,
198 query_id: QueryId,
199 output_fingerprint: OutputFingerprint,
200 ) -> () {
201 let mut stack = self.stack.borrow_mut();
202 let deps = stack.last_mut().unwrap();
203 deps.push((query_id, output_fingerprint))
204 }
205}
206
207pub struct Invalidations {
208 types: HashSet<QueryTypeId>,
209 ids: Vec<QueryId>,
210}
211
212impl Invalidations {
213 pub fn new() -> Invalidations {
214 Invalidations {
215 types: HashSet::new(),
216 ids: Vec::new(),
217 }
218 }
219 pub fn invalidate(
220 &mut self,
221 query_type: QueryTypeId,
222 params: impl Iterator<Item=InputFingerprint>,
223 ) {
224 self.types.insert(query_type);
225 self.ids.extend(params.map(|it| QueryId(query_type, it)))
226 }
227}
228
229impl<T, D> Db<T, D>
230where
231 D: Clone
232{
233 pub fn new(query_config: QueryConfig<T, D>, ground_data: T) -> Db<T, D> {
234 Db {
235 db: Arc::new(DbState { ground_data, gen: Gen(0), graph: Default::default() }),
236 query_config: Arc::new(query_config),
237 }
238 }
239 pub fn ground_data(&self) -> &T {
240 &self.db.ground_data
241 }
242 pub fn with_ground_data(
243 &self,
244 ground_data: T,
245 invalidations: Invalidations,
246 ) -> Db<T, D> {
247 for id in self.query_config.ground_fn.keys() {
248 assert!(
249 invalidations.types.contains(id),
250 "all ground queries must be invalidated"
251 );
252 }
253
254 let gen = Gen(self.db.gen.0 + 1);
255 let mut graph = self.db.graph.lock().clone();
256 for id in invalidations.ids {
257 if let Some((gen, _)) = graph.get_mut(&id) {
258 *gen = INVALIDATED;
259 }
260 }
261 let graph = Mutex::new(graph);
262 Db {
263 db: Arc::new(DbState { ground_data, gen, graph }),
264 query_config: Arc::clone(&self.query_config)
265 }
266 }
267 pub fn query_ctx(&self) -> QueryCtx<T, D> {
268 QueryCtx::new(self)
269 }
270 pub fn get(
271 &self,
272 query_id: QueryId,
273 params: D,
274 ) -> (D, Vec<QueryTypeId>) {
275 let ctx = self.query_ctx();
276 let res = ctx.get(query_id, params.into());
277 let executed = ::std::mem::replace(&mut *ctx.executed.borrow_mut(), Vec::new());
278 (res, executed)
279 }
280}
281
282#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
283struct Gen(u64);
284const INVALIDATED: Gen = Gen(!0);
285#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
286pub struct InputFingerprint(pub u64);
287#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
288pub struct OutputFingerprint(pub u64);
289#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
290pub struct QueryTypeId(pub u16);
291#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
292pub struct QueryId(pub QueryTypeId, pub InputFingerprint);
293