aboutsummaryrefslogtreecommitdiff
path: root/crates/salsa/tests/integration.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/salsa/tests/integration.rs')
-rw-r--r--crates/salsa/tests/integration.rs153
1 files changed, 153 insertions, 0 deletions
diff --git a/crates/salsa/tests/integration.rs b/crates/salsa/tests/integration.rs
new file mode 100644
index 000000000..7241eca38
--- /dev/null
+++ b/crates/salsa/tests/integration.rs
@@ -0,0 +1,153 @@
1extern crate salsa;
2use std::{
3 sync::Arc,
4 collections::hash_map::{HashMap, DefaultHasher},
5 any::Any,
6 hash::{Hash, Hasher},
7};
8
9type State = HashMap<u32, String>;
10const GET_TEXT: salsa::QueryTypeId = salsa::QueryTypeId(1);
11const GET_FILES: salsa::QueryTypeId = salsa::QueryTypeId(2);
12const FILE_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(3);
13const TOTAL_NEWLINES: salsa::QueryTypeId = salsa::QueryTypeId(4);
14
15fn mk_ground_query<T, R>(
16 state: &State,
17 params: &(Any + Send + Sync + 'static),
18 f: fn(&State, &T) -> R,
19) -> (Box<Any + Send + Sync + 'static>, salsa::OutputFingerprint)
20where
21 T: 'static,
22 R: Hash + Send + Sync + 'static,
23{
24 let params = params.downcast_ref().unwrap();
25 let result = f(state, params);
26 let fingerprint = o_print(&result);
27 (Box::new(result), fingerprint)
28}
29
30fn get<T, R>(db: &salsa::Db<State>, query_type: salsa::QueryTypeId, param: T) -> (Arc<R>, Vec<salsa::QueryTypeId>)
31where
32 T: Hash + Send + Sync + 'static,
33 R: Send + Sync + 'static,
34{
35 let i_print = i_print(&param);
36 let param = Box::new(param);
37 let (res, trace) = db.get(salsa::QueryId(query_type, i_print), param);
38 (res.downcast().unwrap(), trace)
39}
40
41struct QueryCtx<'a>(&'a salsa::QueryCtx<State>);
42
43impl<'a> QueryCtx<'a> {
44 fn get_text(&self, id: u32) -> Arc<String> {
45 let i_print = i_print(&id);
46 let text = self.0.get(salsa::QueryId(GET_TEXT, i_print), Arc::new(id));
47 text.downcast().unwrap()
48 }
49 fn get_files(&self) -> Arc<Vec<u32>> {
50 let i_print = i_print(&());
51 let files = self.0.get(salsa::QueryId(GET_FILES, i_print), Arc::new(()));
52 let res = files.downcast().unwrap();
53 res
54 }
55 fn get_n_lines(&self, id: u32) -> usize {
56 let i_print = i_print(&id);
57 let n_lines = self.0.get(salsa::QueryId(FILE_NEWLINES, i_print), Arc::new(id));
58 *n_lines.downcast().unwrap()
59 }
60}
61
62fn mk_query<T, R>(
63 query_ctx: &salsa::QueryCtx<State>,
64 params: &(Any + Send + Sync + 'static),
65 f: fn(QueryCtx, &T) -> R,
66) -> (Box<Any + Send + Sync + 'static>, salsa::OutputFingerprint)
67where
68 T: 'static,
69 R: Hash + Send + Sync + 'static,
70{
71 let params: &T = params.downcast_ref().unwrap();
72 let query_ctx = QueryCtx(query_ctx);
73 let result = f(query_ctx, params);
74 let fingerprint = o_print(&result);
75 (Box::new(result), fingerprint)
76}
77
78fn mk_queries() -> salsa::QueryConfig<State> {
79 salsa::QueryConfig::<State>::new()
80 .with_ground_query(GET_TEXT, |state, id| {
81 mk_ground_query::<u32, String>(state, id, |state, id| state[id].clone())
82 })
83 .with_ground_query(GET_FILES, |state, id| {
84 mk_ground_query::<(), Vec<u32>>(state, id, |state, &()| state.keys().cloned().collect())
85 })
86 .with_query(FILE_NEWLINES, |query_ctx, id| {
87 mk_query(query_ctx, id, |query_ctx, &id| {
88 let text = query_ctx.get_text(id);
89 text.lines().count()
90 })
91 })
92 .with_query(TOTAL_NEWLINES, |query_ctx, id| {
93 mk_query(query_ctx, id, |query_ctx, &()| {
94 let mut total = 0;
95 for &id in query_ctx.get_files().iter() {
96 total += query_ctx.get_n_lines(id)
97 }
98 total
99 })
100 })
101}
102
103#[test]
104fn test_number_of_lines() {
105 let mut state = State::new();
106 let db = salsa::Db::new(mk_queries(), state.clone());
107 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
108 assert_eq!(*newlines, 0);
109 assert_eq!(trace.len(), 2);
110 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
111 assert_eq!(*newlines, 0);
112 assert_eq!(trace.len(), 0);
113
114 state.insert(1, "hello\nworld".to_string());
115 let db = db.with_ground_data(state.clone());
116 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
117 assert_eq!(*newlines, 2);
118 assert_eq!(trace.len(), 4);
119
120 state.insert(2, "spam\neggs".to_string());
121 let db = db.with_ground_data(state.clone());
122 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
123 assert_eq!(*newlines, 4);
124 assert_eq!(trace.len(), 5);
125
126 for i in 0..10 {
127 state.insert(i + 10, "spam".to_string());
128 }
129 let db = db.with_ground_data(state.clone());
130 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
131 assert_eq!(*newlines, 14);
132 assert_eq!(trace.len(), 24);
133
134 state.insert(15, String::new());
135 let db = db.with_ground_data(state.clone());
136 let (newlines, trace) = get::<(), usize>(&db, TOTAL_NEWLINES, ());
137 assert_eq!(*newlines, 13);
138 assert_eq!(trace.len(), 15);
139}
140
141fn o_print<T: Hash>(x: &T) -> salsa::OutputFingerprint {
142 let mut hasher = DefaultHasher::new();
143 x.hash(&mut hasher);
144 let hash = hasher.finish();
145 salsa::OutputFingerprint(hash)
146}
147
148fn i_print<T: Hash>(x: &T) -> salsa::InputFingerprint {
149 let mut hasher = DefaultHasher::new();
150 x.hash(&mut hasher);
151 let hash = hasher.finish();
152 salsa::InputFingerprint(hash)
153}