aboutsummaryrefslogtreecommitdiff
path: root/crates/ra_hir/src/ty/utils.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/ra_hir/src/ty/utils.rs')
-rw-r--r--crates/ra_hir/src/ty/utils.rs50
1 files changed, 50 insertions, 0 deletions
diff --git a/crates/ra_hir/src/ty/utils.rs b/crates/ra_hir/src/ty/utils.rs
new file mode 100644
index 000000000..345fa9430
--- /dev/null
+++ b/crates/ra_hir/src/ty/utils.rs
@@ -0,0 +1,50 @@
1use hir_def::{
2 db::DefDatabase,
3 resolver::{HasResolver, TypeNs},
4 type_ref::TypeRef,
5 TraitId,
6};
7use hir_expand::name;
8
9// FIXME: this is wrong, b/c it can't express `trait T: PartialEq<()>`.
10// We should return a `TraitREf` here.
11fn direct_super_traits(db: &impl DefDatabase, trait_: TraitId) -> Vec<TraitId> {
12 let resolver = trait_.resolver(db);
13 // returning the iterator directly doesn't easily work because of
14 // lifetime problems, but since there usually shouldn't be more than a
15 // few direct traits this should be fine (we could even use some kind of
16 // SmallVec if performance is a concern)
17 db.generic_params(trait_.into())
18 .where_predicates
19 .iter()
20 .filter_map(|pred| match &pred.type_ref {
21 TypeRef::Path(p) if p.as_ident() == Some(&name::SELF_TYPE) => pred.bound.as_path(),
22 _ => None,
23 })
24 .filter_map(|path| match resolver.resolve_path_in_type_ns_fully(db, path) {
25 Some(TypeNs::TraitId(t)) => Some(t),
26 _ => None,
27 })
28 .collect()
29}
30
31/// Returns an iterator over the whole super trait hierarchy (including the
32/// trait itself).
33pub(crate) fn all_super_traits(db: &impl DefDatabase, trait_: TraitId) -> Vec<TraitId> {
34 // we need to take care a bit here to avoid infinite loops in case of cycles
35 // (i.e. if we have `trait A: B; trait B: A;`)
36 let mut result = vec![trait_];
37 let mut i = 0;
38 while i < result.len() {
39 let t = result[i];
40 // yeah this is quadratic, but trait hierarchies should be flat
41 // enough that this doesn't matter
42 for tt in direct_super_traits(db, t) {
43 if !result.contains(&tt) {
44 result.push(tt);
45 }
46 }
47 i += 1;
48 }
49 result
50}