aboutsummaryrefslogtreecommitdiff
path: root/crates/salsa/tests
diff options
context:
space:
mode:
Diffstat (limited to 'crates/salsa/tests')
-rw-r--r--crates/salsa/tests/integration.rs170
1 files changed, 170 insertions, 0 deletions
diff --git a/crates/salsa/tests/integration.rs b/crates/salsa/tests/integration.rs
new file mode 100644
index 000000000..aed9219be
--- /dev/null
+++ b/crates/salsa/tests/integration.rs
@@ -0,0 +1,170 @@
1extern crate salsa;
2use std::{
3 iter::once,
4 sync::Arc,
5 collections::hash_map::{HashMap, DefaultHasher},
6 any::Any,
7 hash::{Hash, Hasher},
8};
9
10type State = HashMap<u32, String>;
11type Data = Arc<Any + Send + Sync + 'static>;
12const GET_TEXT: salsa::QueryTypeId = salsa::QueryTypeId(1);
13const GET_FILES: salsa::QueryTypeId = salsa::QueryTypeId(2);
14const FILE_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(3);
15const TOTAL_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(4);
16
17fn mk_ground_query<T, R>(
18 state: &State,
19 params: &Data,
20 f: fn(&State, &T) -> R,
21) -> (Data, salsa::OutputFingerprint)
22where
23 T: 'static,
24 R: Hash + Send + Sync + 'static,
25{
26 let params = params.downcast_ref().unwrap();
27 let result = f(state, params);
28 let fingerprint = o_print(&result);
29 (Arc::new(result), fingerprint)
30}
31
32fn get<T, R>(db: &salsa::Db<State, Data>, query_type: salsa::QueryTypeId, param: T) -> (Arc<R>, Vec<salsa::QueryTypeId>)
33where
34 T: Hash + Send + Sync + 'static,
35 R: Send + Sync + 'static,
36{
37 let i_print = i_print(&param);
38 let param = Arc::new(param);
39 let (res, trace) = db.get(salsa::QueryId(query_type, i_print), param);
40 (res.downcast().unwrap(), trace)
41}
42
43struct QueryCtx<'a>(&'a salsa::QueryCtx<State, Data>);
44
45impl<'a> QueryCtx<'a> {
46 fn get_text(&self, id: u32) -> Arc<String> {
47 let i_print = i_print(&id);
48 let text = self.0.get(salsa::QueryId(GET_TEXT, i_print), Arc::new(id));
49 text.downcast().unwrap()
50 }
51 fn get_files(&self) -> Arc<Vec<u32>> {
52 let i_print = i_print(&());
53 let files = self.0.get(salsa::QueryId(GET_FILES, i_print), Arc::new(()));
54 let res = files.downcast().unwrap();
55 res
56 }
57 fn get_n_lines(&self, id: u32) -> usize {
58 let i_print = i_print(&id);
59 let n_lines = self.0.get(salsa::QueryId(FILE_NEWLINES, i_print), Arc::new(id));
60 *n_lines.downcast().unwrap()
61 }
62}
63
64fn mk_query<T, R>(
65 query_ctx: &salsa::QueryCtx<State, Data>,
66 params: &Data,
67 f: fn(QueryCtx, &T) -> R,
68) -> (Data, salsa::OutputFingerprint)
69where
70 T: 'static,
71 R: Hash + Send + Sync + 'static,
72{
73 let params: &T = params.downcast_ref().unwrap();
74 let query_ctx = QueryCtx(query_ctx);
75 let result = f(query_ctx, params);
76 let fingerprint = o_print(&result);
77 (Arc::new(result), fingerprint)
78}
79
80fn mk_queries() -> salsa::QueryConfig<State, Data> {
81 salsa::QueryConfig::<State, Data>::new()
82 .with_ground_query(GET_TEXT, Box::new(|state, id| {
83 mk_ground_query::<u32, String>(state, id, |state, id| state[id].clone())
84 }))
85 .with_ground_query(GET_FILES, Box::new(|state, id| {
86 mk_ground_query::<(), Vec<u32>>(state, id, |state, &()| state.keys().cloned().collect())
87 }))
88 .with_query(FILE_NEWLINES, Box::new(|query_ctx, id| {
89 mk_query(query_ctx, id, |query_ctx, &id| {
90 let text = query_ctx.get_text(id);
91 text.lines().count()
92 })
93 }))
94 .with_query(TOTAL_NEWLINES, Box::new(|query_ctx, id| {
95 mk_query(query_ctx, id, |query_ctx, &()| {
96 let mut total = 0;
97 for &id in query_ctx.get_files().iter() {
98 total += query_ctx.get_n_lines(id)
99 }
100 total
101 })
102 }))
103}
104
105#[test]
106fn test_number_of_lines() {
107 let mut state = State::new();
108 let db = salsa::Db::new(mk_queries(), state.clone());
109 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
110 assert_eq!(*newlines, 0);
111 assert_eq!(trace.len(), 2);
112 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
113 assert_eq!(*newlines, 0);
114 assert_eq!(trace.len(), 0);
115
116 state.insert(1, "hello\nworld".to_string());
117 let mut inv = salsa::Invalidations::new();
118 inv.invalidate(GET_TEXT, once(i_print(&1u32)));
119 inv.invalidate(GET_FILES, once(i_print(&())));
120 let db = db.with_ground_data(state.clone(), inv);
121 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
122 assert_eq!(*newlines, 2);
123 assert_eq!(trace.len(), 4);
124
125 state.insert(2, "spam\neggs".to_string());
126 let mut inv = salsa::Invalidations::new();
127 inv.invalidate(GET_TEXT, once(i_print(&2u32)));
128 inv.invalidate(GET_FILES, once(i_print(&())));
129 let db = db.with_ground_data(state.clone(), inv);
130 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
131 assert_eq!(*newlines, 4);
132 assert_eq!(trace.len(), 4);
133
134 let mut invs = vec![];
135 for i in 0..10 {
136 let id = i + 10;
137 invs.push(i_print(&id));
138 state.insert(id, "spam".to_string());
139 }
140 let mut inv = salsa::Invalidations::new();
141 inv.invalidate(GET_TEXT, invs.into_iter());
142 inv.invalidate(GET_FILES, once(i_print(&())));
143 let db = db.with_ground_data(state.clone(), inv);
144 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
145 assert_eq!(*newlines, 14);
146 assert_eq!(trace.len(), 22);
147
148 state.insert(15, String::new());
149 let mut inv = salsa::Invalidations::new();
150 inv.invalidate(GET_TEXT, once(i_print(&15u32)));
151 inv.invalidate(GET_FILES, once(i_print(&())));
152 let db = db.with_ground_data(state.clone(), inv);
153 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
154 assert_eq!(*newlines, 13);
155 assert_eq!(trace.len(), 4);
156}
157
158fn o_print<T: Hash>(x: &T) -> salsa::OutputFingerprint {
159 let mut hasher = DefaultHasher::new();
160 x.hash(&mut hasher);
161 let hash = hasher.finish();
162 salsa::OutputFingerprint(hash)
163}
164
165fn i_print<T: Hash>(x: &T) -> salsa::InputFingerprint {
166 let mut hasher = DefaultHasher::new();
167 x.hash(&mut hasher);
168 let hash = hasher.finish();
169 salsa::InputFingerprint(hash)
170}