diff options
Diffstat (limited to 'crates/hir_ty/src/infer/unify.rs')
-rw-r--r-- | crates/hir_ty/src/infer/unify.rs | 474 |
1 files changed, 474 insertions, 0 deletions
diff --git a/crates/hir_ty/src/infer/unify.rs b/crates/hir_ty/src/infer/unify.rs new file mode 100644 index 000000000..2e895d911 --- /dev/null +++ b/crates/hir_ty/src/infer/unify.rs | |||
@@ -0,0 +1,474 @@ | |||
1 | //! Unification and canonicalization logic. | ||
2 | |||
3 | use std::borrow::Cow; | ||
4 | |||
5 | use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue}; | ||
6 | |||
7 | use test_utils::mark; | ||
8 | |||
9 | use super::{InferenceContext, Obligation}; | ||
10 | use crate::{ | ||
11 | BoundVar, Canonical, DebruijnIndex, GenericPredicate, InEnvironment, InferTy, Substs, Ty, | ||
12 | TyKind, TypeCtor, TypeWalk, | ||
13 | }; | ||
14 | |||
15 | impl<'a> InferenceContext<'a> { | ||
16 | pub(super) fn canonicalizer<'b>(&'b mut self) -> Canonicalizer<'a, 'b> | ||
17 | where | ||
18 | 'a: 'b, | ||
19 | { | ||
20 | Canonicalizer { ctx: self, free_vars: Vec::new(), var_stack: Vec::new() } | ||
21 | } | ||
22 | } | ||
23 | |||
24 | pub(super) struct Canonicalizer<'a, 'b> | ||
25 | where | ||
26 | 'a: 'b, | ||
27 | { | ||
28 | ctx: &'b mut InferenceContext<'a>, | ||
29 | free_vars: Vec<InferTy>, | ||
30 | /// A stack of type variables that is used to detect recursive types (which | ||
31 | /// are an error, but we need to protect against them to avoid stack | ||
32 | /// overflows). | ||
33 | var_stack: Vec<TypeVarId>, | ||
34 | } | ||
35 | |||
36 | #[derive(Debug)] | ||
37 | pub(super) struct Canonicalized<T> { | ||
38 | pub value: Canonical<T>, | ||
39 | free_vars: Vec<InferTy>, | ||
40 | } | ||
41 | |||
42 | impl<'a, 'b> Canonicalizer<'a, 'b> | ||
43 | where | ||
44 | 'a: 'b, | ||
45 | { | ||
46 | fn add(&mut self, free_var: InferTy) -> usize { | ||
47 | self.free_vars.iter().position(|&v| v == free_var).unwrap_or_else(|| { | ||
48 | let next_index = self.free_vars.len(); | ||
49 | self.free_vars.push(free_var); | ||
50 | next_index | ||
51 | }) | ||
52 | } | ||
53 | |||
54 | fn do_canonicalize<T: TypeWalk>(&mut self, t: T, binders: DebruijnIndex) -> T { | ||
55 | t.fold_binders( | ||
56 | &mut |ty, binders| match ty { | ||
57 | Ty::Infer(tv) => { | ||
58 | let inner = tv.to_inner(); | ||
59 | if self.var_stack.contains(&inner) { | ||
60 | // recursive type | ||
61 | return tv.fallback_value(); | ||
62 | } | ||
63 | if let Some(known_ty) = | ||
64 | self.ctx.table.var_unification_table.inlined_probe_value(inner).known() | ||
65 | { | ||
66 | self.var_stack.push(inner); | ||
67 | let result = self.do_canonicalize(known_ty.clone(), binders); | ||
68 | self.var_stack.pop(); | ||
69 | result | ||
70 | } else { | ||
71 | let root = self.ctx.table.var_unification_table.find(inner); | ||
72 | let free_var = match tv { | ||
73 | InferTy::TypeVar(_) => InferTy::TypeVar(root), | ||
74 | InferTy::IntVar(_) => InferTy::IntVar(root), | ||
75 | InferTy::FloatVar(_) => InferTy::FloatVar(root), | ||
76 | InferTy::MaybeNeverTypeVar(_) => InferTy::MaybeNeverTypeVar(root), | ||
77 | }; | ||
78 | let position = self.add(free_var); | ||
79 | Ty::Bound(BoundVar::new(binders, position)) | ||
80 | } | ||
81 | } | ||
82 | _ => ty, | ||
83 | }, | ||
84 | binders, | ||
85 | ) | ||
86 | } | ||
87 | |||
88 | fn into_canonicalized<T>(self, result: T) -> Canonicalized<T> { | ||
89 | let kinds = self | ||
90 | .free_vars | ||
91 | .iter() | ||
92 | .map(|v| match v { | ||
93 | // mapping MaybeNeverTypeVar to the same kind as general ones | ||
94 | // should be fine, because as opposed to int or float type vars, | ||
95 | // they don't restrict what kind of type can go into them, they | ||
96 | // just affect fallback. | ||
97 | InferTy::TypeVar(_) | InferTy::MaybeNeverTypeVar(_) => TyKind::General, | ||
98 | InferTy::IntVar(_) => TyKind::Integer, | ||
99 | InferTy::FloatVar(_) => TyKind::Float, | ||
100 | }) | ||
101 | .collect(); | ||
102 | Canonicalized { value: Canonical { value: result, kinds }, free_vars: self.free_vars } | ||
103 | } | ||
104 | |||
105 | pub(crate) fn canonicalize_ty(mut self, ty: Ty) -> Canonicalized<Ty> { | ||
106 | let result = self.do_canonicalize(ty, DebruijnIndex::INNERMOST); | ||
107 | self.into_canonicalized(result) | ||
108 | } | ||
109 | |||
110 | pub(crate) fn canonicalize_obligation( | ||
111 | mut self, | ||
112 | obligation: InEnvironment<Obligation>, | ||
113 | ) -> Canonicalized<InEnvironment<Obligation>> { | ||
114 | let result = match obligation.value { | ||
115 | Obligation::Trait(tr) => { | ||
116 | Obligation::Trait(self.do_canonicalize(tr, DebruijnIndex::INNERMOST)) | ||
117 | } | ||
118 | Obligation::Projection(pr) => { | ||
119 | Obligation::Projection(self.do_canonicalize(pr, DebruijnIndex::INNERMOST)) | ||
120 | } | ||
121 | }; | ||
122 | self.into_canonicalized(InEnvironment { | ||
123 | value: result, | ||
124 | environment: obligation.environment, | ||
125 | }) | ||
126 | } | ||
127 | } | ||
128 | |||
129 | impl<T> Canonicalized<T> { | ||
130 | pub fn decanonicalize_ty(&self, mut ty: Ty) -> Ty { | ||
131 | ty.walk_mut_binders( | ||
132 | &mut |ty, binders| { | ||
133 | if let &mut Ty::Bound(bound) = ty { | ||
134 | if bound.debruijn >= binders { | ||
135 | *ty = Ty::Infer(self.free_vars[bound.index]); | ||
136 | } | ||
137 | } | ||
138 | }, | ||
139 | DebruijnIndex::INNERMOST, | ||
140 | ); | ||
141 | ty | ||
142 | } | ||
143 | |||
144 | pub fn apply_solution(&self, ctx: &mut InferenceContext<'_>, solution: Canonical<Substs>) { | ||
145 | // the solution may contain new variables, which we need to convert to new inference vars | ||
146 | let new_vars = Substs( | ||
147 | solution | ||
148 | .kinds | ||
149 | .iter() | ||
150 | .map(|k| match k { | ||
151 | TyKind::General => ctx.table.new_type_var(), | ||
152 | TyKind::Integer => ctx.table.new_integer_var(), | ||
153 | TyKind::Float => ctx.table.new_float_var(), | ||
154 | }) | ||
155 | .collect(), | ||
156 | ); | ||
157 | for (i, ty) in solution.value.into_iter().enumerate() { | ||
158 | let var = self.free_vars[i]; | ||
159 | // eagerly replace projections in the type; we may be getting types | ||
160 | // e.g. from where clauses where this hasn't happened yet | ||
161 | let ty = ctx.normalize_associated_types_in(ty.clone().subst_bound_vars(&new_vars)); | ||
162 | ctx.table.unify(&Ty::Infer(var), &ty); | ||
163 | } | ||
164 | } | ||
165 | } | ||
166 | |||
167 | pub fn unify(tys: &Canonical<(Ty, Ty)>) -> Option<Substs> { | ||
168 | let mut table = InferenceTable::new(); | ||
169 | let vars = Substs( | ||
170 | tys.kinds | ||
171 | .iter() | ||
172 | // we always use type vars here because we want everything to | ||
173 | // fallback to Unknown in the end (kind of hacky, as below) | ||
174 | .map(|_| table.new_type_var()) | ||
175 | .collect(), | ||
176 | ); | ||
177 | let ty1_with_vars = tys.value.0.clone().subst_bound_vars(&vars); | ||
178 | let ty2_with_vars = tys.value.1.clone().subst_bound_vars(&vars); | ||
179 | if !table.unify(&ty1_with_vars, &ty2_with_vars) { | ||
180 | return None; | ||
181 | } | ||
182 | // default any type vars that weren't unified back to their original bound vars | ||
183 | // (kind of hacky) | ||
184 | for (i, var) in vars.iter().enumerate() { | ||
185 | if &*table.resolve_ty_shallow(var) == var { | ||
186 | table.unify(var, &Ty::Bound(BoundVar::new(DebruijnIndex::INNERMOST, i))); | ||
187 | } | ||
188 | } | ||
189 | Some( | ||
190 | Substs::builder(tys.kinds.len()) | ||
191 | .fill(vars.iter().map(|v| table.resolve_ty_completely(v.clone()))) | ||
192 | .build(), | ||
193 | ) | ||
194 | } | ||
195 | |||
196 | #[derive(Clone, Debug)] | ||
197 | pub(crate) struct InferenceTable { | ||
198 | pub(super) var_unification_table: InPlaceUnificationTable<TypeVarId>, | ||
199 | } | ||
200 | |||
201 | impl InferenceTable { | ||
202 | pub fn new() -> Self { | ||
203 | InferenceTable { var_unification_table: InPlaceUnificationTable::new() } | ||
204 | } | ||
205 | |||
206 | pub fn new_type_var(&mut self) -> Ty { | ||
207 | Ty::Infer(InferTy::TypeVar(self.var_unification_table.new_key(TypeVarValue::Unknown))) | ||
208 | } | ||
209 | |||
210 | pub fn new_integer_var(&mut self) -> Ty { | ||
211 | Ty::Infer(InferTy::IntVar(self.var_unification_table.new_key(TypeVarValue::Unknown))) | ||
212 | } | ||
213 | |||
214 | pub fn new_float_var(&mut self) -> Ty { | ||
215 | Ty::Infer(InferTy::FloatVar(self.var_unification_table.new_key(TypeVarValue::Unknown))) | ||
216 | } | ||
217 | |||
218 | pub fn new_maybe_never_type_var(&mut self) -> Ty { | ||
219 | Ty::Infer(InferTy::MaybeNeverTypeVar( | ||
220 | self.var_unification_table.new_key(TypeVarValue::Unknown), | ||
221 | )) | ||
222 | } | ||
223 | |||
224 | pub fn resolve_ty_completely(&mut self, ty: Ty) -> Ty { | ||
225 | self.resolve_ty_completely_inner(&mut Vec::new(), ty) | ||
226 | } | ||
227 | |||
228 | pub fn resolve_ty_as_possible(&mut self, ty: Ty) -> Ty { | ||
229 | self.resolve_ty_as_possible_inner(&mut Vec::new(), ty) | ||
230 | } | ||
231 | |||
232 | pub fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool { | ||
233 | self.unify_inner(ty1, ty2, 0) | ||
234 | } | ||
235 | |||
236 | pub fn unify_substs(&mut self, substs1: &Substs, substs2: &Substs, depth: usize) -> bool { | ||
237 | substs1.0.iter().zip(substs2.0.iter()).all(|(t1, t2)| self.unify_inner(t1, t2, depth)) | ||
238 | } | ||
239 | |||
240 | fn unify_inner(&mut self, ty1: &Ty, ty2: &Ty, depth: usize) -> bool { | ||
241 | if depth > 1000 { | ||
242 | // prevent stackoverflows | ||
243 | panic!("infinite recursion in unification"); | ||
244 | } | ||
245 | if ty1 == ty2 { | ||
246 | return true; | ||
247 | } | ||
248 | // try to resolve type vars first | ||
249 | let ty1 = self.resolve_ty_shallow(ty1); | ||
250 | let ty2 = self.resolve_ty_shallow(ty2); | ||
251 | match (&*ty1, &*ty2) { | ||
252 | (Ty::Apply(a_ty1), Ty::Apply(a_ty2)) if a_ty1.ctor == a_ty2.ctor => { | ||
253 | self.unify_substs(&a_ty1.parameters, &a_ty2.parameters, depth + 1) | ||
254 | } | ||
255 | |||
256 | _ => self.unify_inner_trivial(&ty1, &ty2, depth), | ||
257 | } | ||
258 | } | ||
259 | |||
260 | pub(super) fn unify_inner_trivial(&mut self, ty1: &Ty, ty2: &Ty, depth: usize) -> bool { | ||
261 | match (ty1, ty2) { | ||
262 | (Ty::Unknown, _) | (_, Ty::Unknown) => true, | ||
263 | |||
264 | (Ty::Placeholder(p1), Ty::Placeholder(p2)) if *p1 == *p2 => true, | ||
265 | |||
266 | (Ty::Dyn(dyn1), Ty::Dyn(dyn2)) if dyn1.len() == dyn2.len() => { | ||
267 | for (pred1, pred2) in dyn1.iter().zip(dyn2.iter()) { | ||
268 | if !self.unify_preds(pred1, pred2, depth + 1) { | ||
269 | return false; | ||
270 | } | ||
271 | } | ||
272 | true | ||
273 | } | ||
274 | |||
275 | (Ty::Infer(InferTy::TypeVar(tv1)), Ty::Infer(InferTy::TypeVar(tv2))) | ||
276 | | (Ty::Infer(InferTy::IntVar(tv1)), Ty::Infer(InferTy::IntVar(tv2))) | ||
277 | | (Ty::Infer(InferTy::FloatVar(tv1)), Ty::Infer(InferTy::FloatVar(tv2))) | ||
278 | | ( | ||
279 | Ty::Infer(InferTy::MaybeNeverTypeVar(tv1)), | ||
280 | Ty::Infer(InferTy::MaybeNeverTypeVar(tv2)), | ||
281 | ) => { | ||
282 | // both type vars are unknown since we tried to resolve them | ||
283 | self.var_unification_table.union(*tv1, *tv2); | ||
284 | true | ||
285 | } | ||
286 | |||
287 | // The order of MaybeNeverTypeVar matters here. | ||
288 | // Unifying MaybeNeverTypeVar and TypeVar will let the latter become MaybeNeverTypeVar. | ||
289 | // Unifying MaybeNeverTypeVar and other concrete type will let the former become it. | ||
290 | (Ty::Infer(InferTy::TypeVar(tv)), other) | ||
291 | | (other, Ty::Infer(InferTy::TypeVar(tv))) | ||
292 | | (Ty::Infer(InferTy::MaybeNeverTypeVar(tv)), other) | ||
293 | | (other, Ty::Infer(InferTy::MaybeNeverTypeVar(tv))) | ||
294 | | (Ty::Infer(InferTy::IntVar(tv)), other @ ty_app!(TypeCtor::Int(_))) | ||
295 | | (other @ ty_app!(TypeCtor::Int(_)), Ty::Infer(InferTy::IntVar(tv))) | ||
296 | | (Ty::Infer(InferTy::FloatVar(tv)), other @ ty_app!(TypeCtor::Float(_))) | ||
297 | | (other @ ty_app!(TypeCtor::Float(_)), Ty::Infer(InferTy::FloatVar(tv))) => { | ||
298 | // the type var is unknown since we tried to resolve it | ||
299 | self.var_unification_table.union_value(*tv, TypeVarValue::Known(other.clone())); | ||
300 | true | ||
301 | } | ||
302 | |||
303 | _ => false, | ||
304 | } | ||
305 | } | ||
306 | |||
307 | fn unify_preds( | ||
308 | &mut self, | ||
309 | pred1: &GenericPredicate, | ||
310 | pred2: &GenericPredicate, | ||
311 | depth: usize, | ||
312 | ) -> bool { | ||
313 | match (pred1, pred2) { | ||
314 | (GenericPredicate::Implemented(tr1), GenericPredicate::Implemented(tr2)) | ||
315 | if tr1.trait_ == tr2.trait_ => | ||
316 | { | ||
317 | self.unify_substs(&tr1.substs, &tr2.substs, depth + 1) | ||
318 | } | ||
319 | (GenericPredicate::Projection(proj1), GenericPredicate::Projection(proj2)) | ||
320 | if proj1.projection_ty.associated_ty == proj2.projection_ty.associated_ty => | ||
321 | { | ||
322 | self.unify_substs( | ||
323 | &proj1.projection_ty.parameters, | ||
324 | &proj2.projection_ty.parameters, | ||
325 | depth + 1, | ||
326 | ) && self.unify_inner(&proj1.ty, &proj2.ty, depth + 1) | ||
327 | } | ||
328 | _ => false, | ||
329 | } | ||
330 | } | ||
331 | |||
332 | /// If `ty` is a type variable with known type, returns that type; | ||
333 | /// otherwise, return ty. | ||
334 | pub fn resolve_ty_shallow<'b>(&mut self, ty: &'b Ty) -> Cow<'b, Ty> { | ||
335 | let mut ty = Cow::Borrowed(ty); | ||
336 | // The type variable could resolve to a int/float variable. Hence try | ||
337 | // resolving up to three times; each type of variable shouldn't occur | ||
338 | // more than once | ||
339 | for i in 0..3 { | ||
340 | if i > 0 { | ||
341 | mark::hit!(type_var_resolves_to_int_var); | ||
342 | } | ||
343 | match &*ty { | ||
344 | Ty::Infer(tv) => { | ||
345 | let inner = tv.to_inner(); | ||
346 | match self.var_unification_table.inlined_probe_value(inner).known() { | ||
347 | Some(known_ty) => { | ||
348 | // The known_ty can't be a type var itself | ||
349 | ty = Cow::Owned(known_ty.clone()); | ||
350 | } | ||
351 | _ => return ty, | ||
352 | } | ||
353 | } | ||
354 | _ => return ty, | ||
355 | } | ||
356 | } | ||
357 | log::error!("Inference variable still not resolved: {:?}", ty); | ||
358 | ty | ||
359 | } | ||
360 | |||
361 | /// Resolves the type as far as currently possible, replacing type variables | ||
362 | /// by their known types. All types returned by the infer_* functions should | ||
363 | /// be resolved as far as possible, i.e. contain no type variables with | ||
364 | /// known type. | ||
365 | fn resolve_ty_as_possible_inner(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty { | ||
366 | ty.fold(&mut |ty| match ty { | ||
367 | Ty::Infer(tv) => { | ||
368 | let inner = tv.to_inner(); | ||
369 | if tv_stack.contains(&inner) { | ||
370 | mark::hit!(type_var_cycles_resolve_as_possible); | ||
371 | // recursive type | ||
372 | return tv.fallback_value(); | ||
373 | } | ||
374 | if let Some(known_ty) = | ||
375 | self.var_unification_table.inlined_probe_value(inner).known() | ||
376 | { | ||
377 | // known_ty may contain other variables that are known by now | ||
378 | tv_stack.push(inner); | ||
379 | let result = self.resolve_ty_as_possible_inner(tv_stack, known_ty.clone()); | ||
380 | tv_stack.pop(); | ||
381 | result | ||
382 | } else { | ||
383 | ty | ||
384 | } | ||
385 | } | ||
386 | _ => ty, | ||
387 | }) | ||
388 | } | ||
389 | |||
390 | /// Resolves the type completely; type variables without known type are | ||
391 | /// replaced by Ty::Unknown. | ||
392 | fn resolve_ty_completely_inner(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty { | ||
393 | ty.fold(&mut |ty| match ty { | ||
394 | Ty::Infer(tv) => { | ||
395 | let inner = tv.to_inner(); | ||
396 | if tv_stack.contains(&inner) { | ||
397 | mark::hit!(type_var_cycles_resolve_completely); | ||
398 | // recursive type | ||
399 | return tv.fallback_value(); | ||
400 | } | ||
401 | if let Some(known_ty) = | ||
402 | self.var_unification_table.inlined_probe_value(inner).known() | ||
403 | { | ||
404 | // known_ty may contain other variables that are known by now | ||
405 | tv_stack.push(inner); | ||
406 | let result = self.resolve_ty_completely_inner(tv_stack, known_ty.clone()); | ||
407 | tv_stack.pop(); | ||
408 | result | ||
409 | } else { | ||
410 | tv.fallback_value() | ||
411 | } | ||
412 | } | ||
413 | _ => ty, | ||
414 | }) | ||
415 | } | ||
416 | } | ||
417 | |||
418 | /// The ID of a type variable. | ||
419 | #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] | ||
420 | pub struct TypeVarId(pub(super) u32); | ||
421 | |||
422 | impl UnifyKey for TypeVarId { | ||
423 | type Value = TypeVarValue; | ||
424 | |||
425 | fn index(&self) -> u32 { | ||
426 | self.0 | ||
427 | } | ||
428 | |||
429 | fn from_index(i: u32) -> Self { | ||
430 | TypeVarId(i) | ||
431 | } | ||
432 | |||
433 | fn tag() -> &'static str { | ||
434 | "TypeVarId" | ||
435 | } | ||
436 | } | ||
437 | |||
438 | /// The value of a type variable: either we already know the type, or we don't | ||
439 | /// know it yet. | ||
440 | #[derive(Clone, PartialEq, Eq, Debug)] | ||
441 | pub enum TypeVarValue { | ||
442 | Known(Ty), | ||
443 | Unknown, | ||
444 | } | ||
445 | |||
446 | impl TypeVarValue { | ||
447 | fn known(&self) -> Option<&Ty> { | ||
448 | match self { | ||
449 | TypeVarValue::Known(ty) => Some(ty), | ||
450 | TypeVarValue::Unknown => None, | ||
451 | } | ||
452 | } | ||
453 | } | ||
454 | |||
455 | impl UnifyValue for TypeVarValue { | ||
456 | type Error = NoError; | ||
457 | |||
458 | fn unify_values(value1: &Self, value2: &Self) -> Result<Self, NoError> { | ||
459 | match (value1, value2) { | ||
460 | // We should never equate two type variables, both of which have | ||
461 | // known types. Instead, we recursively equate those types. | ||
462 | (TypeVarValue::Known(t1), TypeVarValue::Known(t2)) => panic!( | ||
463 | "equating two type variables, both of which have known types: {:?} and {:?}", | ||
464 | t1, t2 | ||
465 | ), | ||
466 | |||
467 | // If one side is known, prefer that one. | ||
468 | (TypeVarValue::Known(..), TypeVarValue::Unknown) => Ok(value1.clone()), | ||
469 | (TypeVarValue::Unknown, TypeVarValue::Known(..)) => Ok(value2.clone()), | ||
470 | |||
471 | (TypeVarValue::Unknown, TypeVarValue::Unknown) => Ok(TypeVarValue::Unknown), | ||
472 | } | ||
473 | } | ||
474 | } | ||