aboutsummaryrefslogtreecommitdiff
path: root/crates/ra_hir_ty/src/tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/ra_hir_ty/src/tests.rs')
-rw-r--r--crates/ra_hir_ty/src/tests.rs49
1 files changed, 36 insertions, 13 deletions
diff --git a/crates/ra_hir_ty/src/tests.rs b/crates/ra_hir_ty/src/tests.rs
index c520bb375..f1b67555f 100644
--- a/crates/ra_hir_ty/src/tests.rs
+++ b/crates/ra_hir_ty/src/tests.rs
@@ -11,8 +11,8 @@ use std::fmt::Write;
11use std::sync::Arc; 11use std::sync::Arc;
12 12
13use hir_def::{ 13use hir_def::{
14 body::BodySourceMap, db::DefDatabase, nameres::CrateDefMap, AssocItemId, DefWithBodyId, 14 body::BodySourceMap, child_from_source::ChildFromSource, db::DefDatabase, nameres::CrateDefMap,
15 LocalModuleId, Lookup, ModuleDefId, 15 AssocItemId, DefWithBodyId, LocalModuleId, Lookup, ModuleDefId,
16}; 16};
17use hir_expand::InFile; 17use hir_expand::InFile;
18use insta::assert_snapshot; 18use insta::assert_snapshot;
@@ -31,18 +31,15 @@ use crate::{db::HirDatabase, display::HirDisplay, test_db::TestDB, InferenceResu
31fn type_at_pos(db: &TestDB, pos: FilePosition) -> String { 31fn type_at_pos(db: &TestDB, pos: FilePosition) -> String {
32 let file = db.parse(pos.file_id).ok().unwrap(); 32 let file = db.parse(pos.file_id).ok().unwrap();
33 let expr = algo::find_node_at_offset::<ast::Expr>(file.syntax(), pos.offset).unwrap(); 33 let expr = algo::find_node_at_offset::<ast::Expr>(file.syntax(), pos.offset).unwrap();
34 34 let fn_def = expr.syntax().ancestors().find_map(ast::FnDef::cast).unwrap();
35 let module = db.module_for_file(pos.file_id); 35 let module = db.module_for_file(pos.file_id);
36 let crate_def_map = db.crate_def_map(module.krate); 36 let func = module.child_from_source(db, InFile::new(pos.file_id.into(), fn_def)).unwrap();
37 for decl in crate_def_map[module.local_id].scope.declarations() { 37
38 if let ModuleDefId::FunctionId(func) = decl { 38 let (_body, source_map) = db.body_with_source_map(func.into());
39 let (_body, source_map) = db.body_with_source_map(func.into()); 39 if let Some(expr_id) = source_map.node_expr(InFile::new(pos.file_id.into(), &expr)) {
40 if let Some(expr_id) = source_map.node_expr(InFile::new(pos.file_id.into(), &expr)) { 40 let infer = db.infer(func.into());
41 let infer = db.infer(func.into()); 41 let ty = &infer[expr_id];
42 let ty = &infer[expr_id]; 42 return ty.display(db).to_string();
43 return ty.display(db).to_string();
44 }
45 }
46 } 43 }
47 panic!("Can't find expression") 44 panic!("Can't find expression")
48} 45}
@@ -53,6 +50,10 @@ fn type_at(content: &str) -> String {
53} 50}
54 51
55fn infer(content: &str) -> String { 52fn infer(content: &str) -> String {
53 infer_with_mismatches(content, false)
54}
55
56fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
56 let (db, file_id) = TestDB::with_single_file(content); 57 let (db, file_id) = TestDB::with_single_file(content);
57 58
58 let mut acc = String::new(); 59 let mut acc = String::new();
@@ -60,6 +61,7 @@ fn infer(content: &str) -> String {
60 let mut infer_def = |inference_result: Arc<InferenceResult>, 61 let mut infer_def = |inference_result: Arc<InferenceResult>,
61 body_source_map: Arc<BodySourceMap>| { 62 body_source_map: Arc<BodySourceMap>| {
62 let mut types = Vec::new(); 63 let mut types = Vec::new();
64 let mut mismatches = Vec::new();
63 65
64 for (pat, ty) in inference_result.type_of_pat.iter() { 66 for (pat, ty) in inference_result.type_of_pat.iter() {
65 let syntax_ptr = match body_source_map.pat_syntax(pat) { 67 let syntax_ptr = match body_source_map.pat_syntax(pat) {
@@ -79,6 +81,9 @@ fn infer(content: &str) -> String {
79 None => continue, 81 None => continue,
80 }; 82 };
81 types.push((syntax_ptr, ty)); 83 types.push((syntax_ptr, ty));
84 if let Some(mismatch) = inference_result.type_mismatch_for_expr(expr) {
85 mismatches.push((syntax_ptr, mismatch));
86 }
82 } 87 }
83 88
84 // sort ranges for consistency 89 // sort ranges for consistency
@@ -104,6 +109,24 @@ fn infer(content: &str) -> String {
104 ) 109 )
105 .unwrap(); 110 .unwrap();
106 } 111 }
112 if include_mismatches {
113 mismatches.sort_by_key(|(src_ptr, _)| {
114 (src_ptr.value.range().start(), src_ptr.value.range().end())
115 });
116 for (src_ptr, mismatch) in &mismatches {
117 let range = src_ptr.value.range();
118 let macro_prefix = if src_ptr.file_id != file_id.into() { "!" } else { "" };
119 write!(
120 acc,
121 "{}{}: expected {}, got {}\n",
122 macro_prefix,
123 range,
124 mismatch.expected.display(&db),
125 mismatch.actual.display(&db),
126 )
127 .unwrap();
128 }
129 }
107 }; 130 };
108 131
109 let module = db.module_for_file(file_id); 132 let module = db.module_for_file(file_id);