diff options
Diffstat (limited to 'crates/test_utils/src/assert_linear.rs')
-rw-r--r-- | crates/test_utils/src/assert_linear.rs | 112 |
1 files changed, 112 insertions, 0 deletions
diff --git a/crates/test_utils/src/assert_linear.rs b/crates/test_utils/src/assert_linear.rs new file mode 100644 index 000000000..6ecc232e1 --- /dev/null +++ b/crates/test_utils/src/assert_linear.rs | |||
@@ -0,0 +1,112 @@ | |||
1 | //! Checks that a set of measurements looks like a linear function rather than | ||
2 | //! like a quadratic function. Algorithm: | ||
3 | //! | ||
4 | //! 1. Linearly scale input to be in [0; 1) | ||
5 | //! 2. Using linear regression, compute the best linear function approximating | ||
6 | //! the input. | ||
7 | //! 3. Compute RMSE and maximal absolute error. | ||
8 | //! 4. Check that errors are within tolerances and that the constant term is not | ||
9 | //! too negative. | ||
10 | //! | ||
11 | //! Ideally, we should use a proper "model selection" to directly compare | ||
12 | //! quadratic and linear models, but that sounds rather complicated: | ||
13 | //! | ||
14 | //! https://stats.stackexchange.com/questions/21844/selecting-best-model-based-on-linear-quadratic-and-cubic-fit-of-data | ||
15 | //! | ||
16 | //! We might get false positives on a VM, but never false negatives. So, if the | ||
17 | //! first round fails, we repeat the ordeal three more times and fail only if | ||
18 | //! every time there's a fault. | ||
19 | use stdx::format_to; | ||
20 | |||
21 | #[derive(Default)] | ||
22 | pub struct AssertLinear { | ||
23 | rounds: Vec<Round>, | ||
24 | } | ||
25 | |||
26 | #[derive(Default)] | ||
27 | struct Round { | ||
28 | samples: Vec<(f64, f64)>, | ||
29 | plot: String, | ||
30 | linear: bool, | ||
31 | } | ||
32 | |||
33 | impl AssertLinear { | ||
34 | pub fn next_round(&mut self) -> bool { | ||
35 | if let Some(round) = self.rounds.last_mut() { | ||
36 | round.finish(); | ||
37 | } | ||
38 | if self.rounds.iter().any(|it| it.linear) || self.rounds.len() == 4 { | ||
39 | return false; | ||
40 | } | ||
41 | self.rounds.push(Round::default()); | ||
42 | true | ||
43 | } | ||
44 | |||
45 | pub fn sample(&mut self, x: f64, y: f64) { | ||
46 | self.rounds.last_mut().unwrap().samples.push((x, y)) | ||
47 | } | ||
48 | } | ||
49 | |||
50 | impl Drop for AssertLinear { | ||
51 | fn drop(&mut self) { | ||
52 | assert!(!self.rounds.is_empty()); | ||
53 | if self.rounds.iter().all(|it| !it.linear) { | ||
54 | for round in &self.rounds { | ||
55 | eprintln!("\n{}", round.plot); | ||
56 | } | ||
57 | panic!("Doesn't look linear!") | ||
58 | } | ||
59 | } | ||
60 | } | ||
61 | |||
62 | impl Round { | ||
63 | fn finish(&mut self) { | ||
64 | let (mut xs, mut ys): (Vec<_>, Vec<_>) = self.samples.iter().copied().unzip(); | ||
65 | normalize(&mut xs); | ||
66 | normalize(&mut ys); | ||
67 | let xy = xs.iter().copied().zip(ys.iter().copied()); | ||
68 | |||
69 | // Linear regression: finding a and b to fit y = a + b*x. | ||
70 | |||
71 | let mean_x = mean(&xs); | ||
72 | let mean_y = mean(&ys); | ||
73 | |||
74 | let b = { | ||
75 | let mut num = 0.0; | ||
76 | let mut denom = 0.0; | ||
77 | for (x, y) in xy.clone() { | ||
78 | num += (x - mean_x) * (y - mean_y); | ||
79 | denom += (x - mean_x).powi(2); | ||
80 | } | ||
81 | num / denom | ||
82 | }; | ||
83 | |||
84 | let a = mean_y - b * mean_x; | ||
85 | |||
86 | self.plot = format!("y_pred = {:.3} + {:.3} * x\n\nx y y_pred\n", a, b); | ||
87 | |||
88 | let mut se = 0.0; | ||
89 | let mut max_error = 0.0f64; | ||
90 | for (x, y) in xy { | ||
91 | let y_pred = a + b * x; | ||
92 | se += (y - y_pred).powi(2); | ||
93 | max_error = max_error.max((y_pred - y).abs()); | ||
94 | |||
95 | format_to!(self.plot, "{:.3} {:.3} {:.3}\n", x, y, y_pred); | ||
96 | } | ||
97 | |||
98 | let rmse = (se / xs.len() as f64).sqrt(); | ||
99 | format_to!(self.plot, "\nrmse = {:.3} max error = {:.3}", rmse, max_error); | ||
100 | |||
101 | self.linear = rmse < 0.05 && max_error < 0.1 && a > -0.1; | ||
102 | |||
103 | fn normalize(xs: &mut Vec<f64>) { | ||
104 | let max = xs.iter().copied().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); | ||
105 | xs.iter_mut().for_each(|it| *it /= max); | ||
106 | } | ||
107 | |||
108 | fn mean(xs: &[f64]) -> f64 { | ||
109 | xs.iter().copied().sum::<f64>() / (xs.len() as f64) | ||
110 | } | ||
111 | } | ||
112 | } | ||