diff options
-rw-r--r-- | crates/salsa/src/lib.rs | 56 | ||||
-rw-r--r-- | crates/salsa/tests/integration.rs | 32 |
2 files changed, 76 insertions, 12 deletions
diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs index a54f2a06f..5de3c7774 100644 --- a/crates/salsa/src/lib.rs +++ b/crates/salsa/src/lib.rs | |||
@@ -3,7 +3,7 @@ extern crate parking_lot; | |||
3 | 3 | ||
4 | use std::{ | 4 | use std::{ |
5 | sync::Arc, | 5 | sync::Arc, |
6 | collections::HashMap, | 6 | collections::{HashSet, HashMap}, |
7 | cell::RefCell, | 7 | cell::RefCell, |
8 | }; | 8 | }; |
9 | use parking_lot::Mutex; | 9 | use parking_lot::Mutex; |
@@ -138,7 +138,16 @@ where | |||
138 | return (record.output.clone(), record.output_fingerprint) | 138 | return (record.output.clone(), record.output_fingerprint) |
139 | } | 139 | } |
140 | if self.query_config.ground_fn.contains_key(&query_id.0) { | 140 | if self.query_config.ground_fn.contains_key(&query_id.0) { |
141 | return self.force(query_id, params); | 141 | let (invalidated, record) = { |
142 | let guard = self.db.graph.lock(); | ||
143 | let (gen, ref record) = guard[&query_id]; | ||
144 | (gen == INVALIDATED, record.clone()) | ||
145 | }; | ||
146 | if invalidated { | ||
147 | return self.force(query_id, params); | ||
148 | } else { | ||
149 | return (record.output.clone(), record.output_fingerprint); | ||
150 | } | ||
142 | } | 151 | } |
143 | for (dep_query_id, prev_fingerprint) in record.deps.iter().cloned() { | 152 | for (dep_query_id, prev_fingerprint) in record.deps.iter().cloned() { |
144 | let dep_params: D = { | 153 | let dep_params: D = { |
@@ -198,6 +207,28 @@ where | |||
198 | } | 207 | } |
199 | } | 208 | } |
200 | 209 | ||
210 | pub struct Invalidations { | ||
211 | types: HashSet<QueryTypeId>, | ||
212 | ids: Vec<QueryId>, | ||
213 | } | ||
214 | |||
215 | impl Invalidations { | ||
216 | pub fn new() -> Invalidations { | ||
217 | Invalidations { | ||
218 | types: HashSet::new(), | ||
219 | ids: Vec::new(), | ||
220 | } | ||
221 | } | ||
222 | pub fn invalidate( | ||
223 | &mut self, | ||
224 | query_type: QueryTypeId, | ||
225 | params: impl Iterator<Item=InputFingerprint>, | ||
226 | ) { | ||
227 | self.types.insert(query_type); | ||
228 | self.ids.extend(params.map(|it| QueryId(query_type, it))) | ||
229 | } | ||
230 | } | ||
231 | |||
201 | impl<T, D> Db<T, D> | 232 | impl<T, D> Db<T, D> |
202 | where | 233 | where |
203 | D: Clone | 234 | D: Clone |
@@ -209,9 +240,25 @@ where | |||
209 | } | 240 | } |
210 | } | 241 | } |
211 | 242 | ||
212 | pub fn with_ground_data(&self, ground_data: T) -> Db<T, D> { | 243 | pub fn with_ground_data( |
244 | &self, | ||
245 | ground_data: T, | ||
246 | invalidations: Invalidations, | ||
247 | ) -> Db<T, D> { | ||
248 | for id in self.query_config.ground_fn.keys() { | ||
249 | assert!( | ||
250 | invalidations.types.contains(id), | ||
251 | "all ground queries must be invalidated" | ||
252 | ); | ||
253 | } | ||
254 | |||
213 | let gen = Gen(self.db.gen.0 + 1); | 255 | let gen = Gen(self.db.gen.0 + 1); |
214 | let graph = self.db.graph.lock().clone(); | 256 | let mut graph = self.db.graph.lock().clone(); |
257 | for id in invalidations.ids { | ||
258 | if let Some((gen, _)) = graph.get_mut(&id) { | ||
259 | *gen = INVALIDATED; | ||
260 | } | ||
261 | } | ||
215 | let graph = Mutex::new(graph); | 262 | let graph = Mutex::new(graph); |
216 | Db { | 263 | Db { |
217 | db: Arc::new(DbState { ground_data, gen, graph }), | 264 | db: Arc::new(DbState { ground_data, gen, graph }), |
@@ -232,6 +279,7 @@ where | |||
232 | 279 | ||
233 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | 280 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] |
234 | struct Gen(u64); | 281 | struct Gen(u64); |
282 | const INVALIDATED: Gen = Gen(!0); | ||
235 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | 283 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] |
236 | pub struct InputFingerprint(pub u64); | 284 | pub struct InputFingerprint(pub u64); |
237 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | 285 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] |
diff --git a/crates/salsa/tests/integration.rs b/crates/salsa/tests/integration.rs index 2872d3913..3cec330e6 100644 --- a/crates/salsa/tests/integration.rs +++ b/crates/salsa/tests/integration.rs | |||
@@ -1,5 +1,6 @@ | |||
1 | extern crate salsa; | 1 | extern crate salsa; |
2 | use std::{ | 2 | use std::{ |
3 | iter::once, | ||
3 | sync::Arc, | 4 | sync::Arc, |
4 | collections::hash_map::{HashMap, DefaultHasher}, | 5 | collections::hash_map::{HashMap, DefaultHasher}, |
5 | any::Any, | 6 | any::Any, |
@@ -113,30 +114,45 @@ fn test_number_of_lines() { | |||
113 | assert_eq!(trace.len(), 0); | 114 | assert_eq!(trace.len(), 0); |
114 | 115 | ||
115 | state.insert(1, "hello\nworld".to_string()); | 116 | state.insert(1, "hello\nworld".to_string()); |
116 | let db = db.with_ground_data(state.clone()); | 117 | let mut inv = salsa::Invalidations::new(); |
118 | inv.invalidate(GET_TEXT, once(i_print(&1u32))); | ||
119 | inv.invalidate(GET_FILES, once(i_print(&()))); | ||
120 | let db = db.with_ground_data(state.clone(), inv); | ||
117 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | 121 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); |
118 | assert_eq!(*newlines, 2); | 122 | assert_eq!(*newlines, 2); |
119 | assert_eq!(trace.len(), 4); | 123 | assert_eq!(trace.len(), 4); |
120 | 124 | ||
121 | state.insert(2, "spam\neggs".to_string()); | 125 | state.insert(2, "spam\neggs".to_string()); |
122 | let db = db.with_ground_data(state.clone()); | 126 | let mut inv = salsa::Invalidations::new(); |
127 | inv.invalidate(GET_TEXT, once(i_print(&2u32))); | ||
128 | inv.invalidate(GET_FILES, once(i_print(&()))); | ||
129 | let db = db.with_ground_data(state.clone(), inv); | ||
123 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | 130 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); |
124 | assert_eq!(*newlines, 4); | 131 | assert_eq!(*newlines, 4); |
125 | assert_eq!(trace.len(), 5); | 132 | assert_eq!(trace.len(), 4); |
126 | 133 | ||
134 | let mut invs = vec![]; | ||
127 | for i in 0..10 { | 135 | for i in 0..10 { |
128 | state.insert(i + 10, "spam".to_string()); | 136 | let id = i + 10; |
137 | invs.push(i_print(&id)); | ||
138 | state.insert(id, "spam".to_string()); | ||
129 | } | 139 | } |
130 | let db = db.with_ground_data(state.clone()); | 140 | let mut inv = salsa::Invalidations::new(); |
141 | inv.invalidate(GET_TEXT, invs.into_iter()); | ||
142 | inv.invalidate(GET_FILES, once(i_print(&()))); | ||
143 | let db = db.with_ground_data(state.clone(), inv); | ||
131 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | 144 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); |
132 | assert_eq!(*newlines, 14); | 145 | assert_eq!(*newlines, 14); |
133 | assert_eq!(trace.len(), 24); | 146 | assert_eq!(trace.len(), 22); |
134 | 147 | ||
135 | state.insert(15, String::new()); | 148 | state.insert(15, String::new()); |
136 | let db = db.with_ground_data(state.clone()); | 149 | let mut inv = salsa::Invalidations::new(); |
150 | inv.invalidate(GET_TEXT, once(i_print(&15u32))); | ||
151 | inv.invalidate(GET_FILES, once(i_print(&()))); | ||
152 | let db = db.with_ground_data(state.clone(), inv); | ||
137 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); | 153 | let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ()); |
138 | assert_eq!(*newlines, 13); | 154 | assert_eq!(*newlines, 13); |
139 | assert_eq!(trace.len(), 15); | 155 | assert_eq!(trace.len(), 4); |
140 | } | 156 | } |
141 | 157 | ||
142 | fn o_print<T: Hash>(x: &T) -> salsa::OutputFingerprint { | 158 | fn o_print<T: Hash>(x: &T) -> salsa::OutputFingerprint { |