aboutsummaryrefslogtreecommitdiff
path: root/crates/hir_ty/src/infer/unify.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/hir_ty/src/infer/unify.rs')
-rw-r--r--crates/hir_ty/src/infer/unify.rs474
1 files changed, 474 insertions, 0 deletions
diff --git a/crates/hir_ty/src/infer/unify.rs b/crates/hir_ty/src/infer/unify.rs
new file mode 100644
index 000000000..2e895d911
--- /dev/null
+++ b/crates/hir_ty/src/infer/unify.rs
@@ -0,0 +1,474 @@
1//! Unification and canonicalization logic.
2
3use std::borrow::Cow;
4
5use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue};
6
7use test_utils::mark;
8
9use super::{InferenceContext, Obligation};
10use crate::{
11 BoundVar, Canonical, DebruijnIndex, GenericPredicate, InEnvironment, InferTy, Substs, Ty,
12 TyKind, TypeCtor, TypeWalk,
13};
14
15impl<'a> InferenceContext<'a> {
16 pub(super) fn canonicalizer<'b>(&'b mut self) -> Canonicalizer<'a, 'b>
17 where
18 'a: 'b,
19 {
20 Canonicalizer { ctx: self, free_vars: Vec::new(), var_stack: Vec::new() }
21 }
22}
23
24pub(super) struct Canonicalizer<'a, 'b>
25where
26 'a: 'b,
27{
28 ctx: &'b mut InferenceContext<'a>,
29 free_vars: Vec<InferTy>,
30 /// A stack of type variables that is used to detect recursive types (which
31 /// are an error, but we need to protect against them to avoid stack
32 /// overflows).
33 var_stack: Vec<TypeVarId>,
34}
35
36#[derive(Debug)]
37pub(super) struct Canonicalized<T> {
38 pub value: Canonical<T>,
39 free_vars: Vec<InferTy>,
40}
41
42impl<'a, 'b> Canonicalizer<'a, 'b>
43where
44 'a: 'b,
45{
46 fn add(&mut self, free_var: InferTy) -> usize {
47 self.free_vars.iter().position(|&v| v == free_var).unwrap_or_else(|| {
48 let next_index = self.free_vars.len();
49 self.free_vars.push(free_var);
50 next_index
51 })
52 }
53
54 fn do_canonicalize<T: TypeWalk>(&mut self, t: T, binders: DebruijnIndex) -> T {
55 t.fold_binders(
56 &mut |ty, binders| match ty {
57 Ty::Infer(tv) => {
58 let inner = tv.to_inner();
59 if self.var_stack.contains(&inner) {
60 // recursive type
61 return tv.fallback_value();
62 }
63 if let Some(known_ty) =
64 self.ctx.table.var_unification_table.inlined_probe_value(inner).known()
65 {
66 self.var_stack.push(inner);
67 let result = self.do_canonicalize(known_ty.clone(), binders);
68 self.var_stack.pop();
69 result
70 } else {
71 let root = self.ctx.table.var_unification_table.find(inner);
72 let free_var = match tv {
73 InferTy::TypeVar(_) => InferTy::TypeVar(root),
74 InferTy::IntVar(_) => InferTy::IntVar(root),
75 InferTy::FloatVar(_) => InferTy::FloatVar(root),
76 InferTy::MaybeNeverTypeVar(_) => InferTy::MaybeNeverTypeVar(root),
77 };
78 let position = self.add(free_var);
79 Ty::Bound(BoundVar::new(binders, position))
80 }
81 }
82 _ => ty,
83 },
84 binders,
85 )
86 }
87
88 fn into_canonicalized<T>(self, result: T) -> Canonicalized<T> {
89 let kinds = self
90 .free_vars
91 .iter()
92 .map(|v| match v {
93 // mapping MaybeNeverTypeVar to the same kind as general ones
94 // should be fine, because as opposed to int or float type vars,
95 // they don't restrict what kind of type can go into them, they
96 // just affect fallback.
97 InferTy::TypeVar(_) | InferTy::MaybeNeverTypeVar(_) => TyKind::General,
98 InferTy::IntVar(_) => TyKind::Integer,
99 InferTy::FloatVar(_) => TyKind::Float,
100 })
101 .collect();
102 Canonicalized { value: Canonical { value: result, kinds }, free_vars: self.free_vars }
103 }
104
105 pub(crate) fn canonicalize_ty(mut self, ty: Ty) -> Canonicalized<Ty> {
106 let result = self.do_canonicalize(ty, DebruijnIndex::INNERMOST);
107 self.into_canonicalized(result)
108 }
109
110 pub(crate) fn canonicalize_obligation(
111 mut self,
112 obligation: InEnvironment<Obligation>,
113 ) -> Canonicalized<InEnvironment<Obligation>> {
114 let result = match obligation.value {
115 Obligation::Trait(tr) => {
116 Obligation::Trait(self.do_canonicalize(tr, DebruijnIndex::INNERMOST))
117 }
118 Obligation::Projection(pr) => {
119 Obligation::Projection(self.do_canonicalize(pr, DebruijnIndex::INNERMOST))
120 }
121 };
122 self.into_canonicalized(InEnvironment {
123 value: result,
124 environment: obligation.environment,
125 })
126 }
127}
128
129impl<T> Canonicalized<T> {
130 pub fn decanonicalize_ty(&self, mut ty: Ty) -> Ty {
131 ty.walk_mut_binders(
132 &mut |ty, binders| {
133 if let &mut Ty::Bound(bound) = ty {
134 if bound.debruijn >= binders {
135 *ty = Ty::Infer(self.free_vars[bound.index]);
136 }
137 }
138 },
139 DebruijnIndex::INNERMOST,
140 );
141 ty
142 }
143
144 pub fn apply_solution(&self, ctx: &mut InferenceContext<'_>, solution: Canonical<Substs>) {
145 // the solution may contain new variables, which we need to convert to new inference vars
146 let new_vars = Substs(
147 solution
148 .kinds
149 .iter()
150 .map(|k| match k {
151 TyKind::General => ctx.table.new_type_var(),
152 TyKind::Integer => ctx.table.new_integer_var(),
153 TyKind::Float => ctx.table.new_float_var(),
154 })
155 .collect(),
156 );
157 for (i, ty) in solution.value.into_iter().enumerate() {
158 let var = self.free_vars[i];
159 // eagerly replace projections in the type; we may be getting types
160 // e.g. from where clauses where this hasn't happened yet
161 let ty = ctx.normalize_associated_types_in(ty.clone().subst_bound_vars(&new_vars));
162 ctx.table.unify(&Ty::Infer(var), &ty);
163 }
164 }
165}
166
167pub fn unify(tys: &Canonical<(Ty, Ty)>) -> Option<Substs> {
168 let mut table = InferenceTable::new();
169 let vars = Substs(
170 tys.kinds
171 .iter()
172 // we always use type vars here because we want everything to
173 // fallback to Unknown in the end (kind of hacky, as below)
174 .map(|_| table.new_type_var())
175 .collect(),
176 );
177 let ty1_with_vars = tys.value.0.clone().subst_bound_vars(&vars);
178 let ty2_with_vars = tys.value.1.clone().subst_bound_vars(&vars);
179 if !table.unify(&ty1_with_vars, &ty2_with_vars) {
180 return None;
181 }
182 // default any type vars that weren't unified back to their original bound vars
183 // (kind of hacky)
184 for (i, var) in vars.iter().enumerate() {
185 if &*table.resolve_ty_shallow(var) == var {
186 table.unify(var, &Ty::Bound(BoundVar::new(DebruijnIndex::INNERMOST, i)));
187 }
188 }
189 Some(
190 Substs::builder(tys.kinds.len())
191 .fill(vars.iter().map(|v| table.resolve_ty_completely(v.clone())))
192 .build(),
193 )
194}
195
196#[derive(Clone, Debug)]
197pub(crate) struct InferenceTable {
198 pub(super) var_unification_table: InPlaceUnificationTable<TypeVarId>,
199}
200
201impl InferenceTable {
202 pub fn new() -> Self {
203 InferenceTable { var_unification_table: InPlaceUnificationTable::new() }
204 }
205
206 pub fn new_type_var(&mut self) -> Ty {
207 Ty::Infer(InferTy::TypeVar(self.var_unification_table.new_key(TypeVarValue::Unknown)))
208 }
209
210 pub fn new_integer_var(&mut self) -> Ty {
211 Ty::Infer(InferTy::IntVar(self.var_unification_table.new_key(TypeVarValue::Unknown)))
212 }
213
214 pub fn new_float_var(&mut self) -> Ty {
215 Ty::Infer(InferTy::FloatVar(self.var_unification_table.new_key(TypeVarValue::Unknown)))
216 }
217
218 pub fn new_maybe_never_type_var(&mut self) -> Ty {
219 Ty::Infer(InferTy::MaybeNeverTypeVar(
220 self.var_unification_table.new_key(TypeVarValue::Unknown),
221 ))
222 }
223
224 pub fn resolve_ty_completely(&mut self, ty: Ty) -> Ty {
225 self.resolve_ty_completely_inner(&mut Vec::new(), ty)
226 }
227
228 pub fn resolve_ty_as_possible(&mut self, ty: Ty) -> Ty {
229 self.resolve_ty_as_possible_inner(&mut Vec::new(), ty)
230 }
231
232 pub fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
233 self.unify_inner(ty1, ty2, 0)
234 }
235
236 pub fn unify_substs(&mut self, substs1: &Substs, substs2: &Substs, depth: usize) -> bool {
237 substs1.0.iter().zip(substs2.0.iter()).all(|(t1, t2)| self.unify_inner(t1, t2, depth))
238 }
239
240 fn unify_inner(&mut self, ty1: &Ty, ty2: &Ty, depth: usize) -> bool {
241 if depth > 1000 {
242 // prevent stackoverflows
243 panic!("infinite recursion in unification");
244 }
245 if ty1 == ty2 {
246 return true;
247 }
248 // try to resolve type vars first
249 let ty1 = self.resolve_ty_shallow(ty1);
250 let ty2 = self.resolve_ty_shallow(ty2);
251 match (&*ty1, &*ty2) {
252 (Ty::Apply(a_ty1), Ty::Apply(a_ty2)) if a_ty1.ctor == a_ty2.ctor => {
253 self.unify_substs(&a_ty1.parameters, &a_ty2.parameters, depth + 1)
254 }
255
256 _ => self.unify_inner_trivial(&ty1, &ty2, depth),
257 }
258 }
259
260 pub(super) fn unify_inner_trivial(&mut self, ty1: &Ty, ty2: &Ty, depth: usize) -> bool {
261 match (ty1, ty2) {
262 (Ty::Unknown, _) | (_, Ty::Unknown) => true,
263
264 (Ty::Placeholder(p1), Ty::Placeholder(p2)) if *p1 == *p2 => true,
265
266 (Ty::Dyn(dyn1), Ty::Dyn(dyn2)) if dyn1.len() == dyn2.len() => {
267 for (pred1, pred2) in dyn1.iter().zip(dyn2.iter()) {
268 if !self.unify_preds(pred1, pred2, depth + 1) {
269 return false;
270 }
271 }
272 true
273 }
274
275 (Ty::Infer(InferTy::TypeVar(tv1)), Ty::Infer(InferTy::TypeVar(tv2)))
276 | (Ty::Infer(InferTy::IntVar(tv1)), Ty::Infer(InferTy::IntVar(tv2)))
277 | (Ty::Infer(InferTy::FloatVar(tv1)), Ty::Infer(InferTy::FloatVar(tv2)))
278 | (
279 Ty::Infer(InferTy::MaybeNeverTypeVar(tv1)),
280 Ty::Infer(InferTy::MaybeNeverTypeVar(tv2)),
281 ) => {
282 // both type vars are unknown since we tried to resolve them
283 self.var_unification_table.union(*tv1, *tv2);
284 true
285 }
286
287 // The order of MaybeNeverTypeVar matters here.
288 // Unifying MaybeNeverTypeVar and TypeVar will let the latter become MaybeNeverTypeVar.
289 // Unifying MaybeNeverTypeVar and other concrete type will let the former become it.
290 (Ty::Infer(InferTy::TypeVar(tv)), other)
291 | (other, Ty::Infer(InferTy::TypeVar(tv)))
292 | (Ty::Infer(InferTy::MaybeNeverTypeVar(tv)), other)
293 | (other, Ty::Infer(InferTy::MaybeNeverTypeVar(tv)))
294 | (Ty::Infer(InferTy::IntVar(tv)), other @ ty_app!(TypeCtor::Int(_)))
295 | (other @ ty_app!(TypeCtor::Int(_)), Ty::Infer(InferTy::IntVar(tv)))
296 | (Ty::Infer(InferTy::FloatVar(tv)), other @ ty_app!(TypeCtor::Float(_)))
297 | (other @ ty_app!(TypeCtor::Float(_)), Ty::Infer(InferTy::FloatVar(tv))) => {
298 // the type var is unknown since we tried to resolve it
299 self.var_unification_table.union_value(*tv, TypeVarValue::Known(other.clone()));
300 true
301 }
302
303 _ => false,
304 }
305 }
306
307 fn unify_preds(
308 &mut self,
309 pred1: &GenericPredicate,
310 pred2: &GenericPredicate,
311 depth: usize,
312 ) -> bool {
313 match (pred1, pred2) {
314 (GenericPredicate::Implemented(tr1), GenericPredicate::Implemented(tr2))
315 if tr1.trait_ == tr2.trait_ =>
316 {
317 self.unify_substs(&tr1.substs, &tr2.substs, depth + 1)
318 }
319 (GenericPredicate::Projection(proj1), GenericPredicate::Projection(proj2))
320 if proj1.projection_ty.associated_ty == proj2.projection_ty.associated_ty =>
321 {
322 self.unify_substs(
323 &proj1.projection_ty.parameters,
324 &proj2.projection_ty.parameters,
325 depth + 1,
326 ) && self.unify_inner(&proj1.ty, &proj2.ty, depth + 1)
327 }
328 _ => false,
329 }
330 }
331
332 /// If `ty` is a type variable with known type, returns that type;
333 /// otherwise, return ty.
334 pub fn resolve_ty_shallow<'b>(&mut self, ty: &'b Ty) -> Cow<'b, Ty> {
335 let mut ty = Cow::Borrowed(ty);
336 // The type variable could resolve to a int/float variable. Hence try
337 // resolving up to three times; each type of variable shouldn't occur
338 // more than once
339 for i in 0..3 {
340 if i > 0 {
341 mark::hit!(type_var_resolves_to_int_var);
342 }
343 match &*ty {
344 Ty::Infer(tv) => {
345 let inner = tv.to_inner();
346 match self.var_unification_table.inlined_probe_value(inner).known() {
347 Some(known_ty) => {
348 // The known_ty can't be a type var itself
349 ty = Cow::Owned(known_ty.clone());
350 }
351 _ => return ty,
352 }
353 }
354 _ => return ty,
355 }
356 }
357 log::error!("Inference variable still not resolved: {:?}", ty);
358 ty
359 }
360
361 /// Resolves the type as far as currently possible, replacing type variables
362 /// by their known types. All types returned by the infer_* functions should
363 /// be resolved as far as possible, i.e. contain no type variables with
364 /// known type.
365 fn resolve_ty_as_possible_inner(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty {
366 ty.fold(&mut |ty| match ty {
367 Ty::Infer(tv) => {
368 let inner = tv.to_inner();
369 if tv_stack.contains(&inner) {
370 mark::hit!(type_var_cycles_resolve_as_possible);
371 // recursive type
372 return tv.fallback_value();
373 }
374 if let Some(known_ty) =
375 self.var_unification_table.inlined_probe_value(inner).known()
376 {
377 // known_ty may contain other variables that are known by now
378 tv_stack.push(inner);
379 let result = self.resolve_ty_as_possible_inner(tv_stack, known_ty.clone());
380 tv_stack.pop();
381 result
382 } else {
383 ty
384 }
385 }
386 _ => ty,
387 })
388 }
389
390 /// Resolves the type completely; type variables without known type are
391 /// replaced by Ty::Unknown.
392 fn resolve_ty_completely_inner(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty {
393 ty.fold(&mut |ty| match ty {
394 Ty::Infer(tv) => {
395 let inner = tv.to_inner();
396 if tv_stack.contains(&inner) {
397 mark::hit!(type_var_cycles_resolve_completely);
398 // recursive type
399 return tv.fallback_value();
400 }
401 if let Some(known_ty) =
402 self.var_unification_table.inlined_probe_value(inner).known()
403 {
404 // known_ty may contain other variables that are known by now
405 tv_stack.push(inner);
406 let result = self.resolve_ty_completely_inner(tv_stack, known_ty.clone());
407 tv_stack.pop();
408 result
409 } else {
410 tv.fallback_value()
411 }
412 }
413 _ => ty,
414 })
415 }
416}
417
418/// The ID of a type variable.
419#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
420pub struct TypeVarId(pub(super) u32);
421
422impl UnifyKey for TypeVarId {
423 type Value = TypeVarValue;
424
425 fn index(&self) -> u32 {
426 self.0
427 }
428
429 fn from_index(i: u32) -> Self {
430 TypeVarId(i)
431 }
432
433 fn tag() -> &'static str {
434 "TypeVarId"
435 }
436}
437
438/// The value of a type variable: either we already know the type, or we don't
439/// know it yet.
440#[derive(Clone, PartialEq, Eq, Debug)]
441pub enum TypeVarValue {
442 Known(Ty),
443 Unknown,
444}
445
446impl TypeVarValue {
447 fn known(&self) -> Option<&Ty> {
448 match self {
449 TypeVarValue::Known(ty) => Some(ty),
450 TypeVarValue::Unknown => None,
451 }
452 }
453}
454
455impl UnifyValue for TypeVarValue {
456 type Error = NoError;
457
458 fn unify_values(value1: &Self, value2: &Self) -> Result<Self, NoError> {
459 match (value1, value2) {
460 // We should never equate two type variables, both of which have
461 // known types. Instead, we recursively equate those types.
462 (TypeVarValue::Known(t1), TypeVarValue::Known(t2)) => panic!(
463 "equating two type variables, both of which have known types: {:?} and {:?}",
464 t1, t2
465 ),
466
467 // If one side is known, prefer that one.
468 (TypeVarValue::Known(..), TypeVarValue::Unknown) => Ok(value1.clone()),
469 (TypeVarValue::Unknown, TypeVarValue::Known(..)) => Ok(value2.clone()),
470
471 (TypeVarValue::Unknown, TypeVarValue::Unknown) => Ok(TypeVarValue::Unknown),
472 }
473 }
474}