aboutsummaryrefslogtreecommitdiff
path: root/crates/hir_ty/src/infer/unify.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/hir_ty/src/infer/unify.rs')
-rw-r--r--crates/hir_ty/src/infer/unify.rs184
1 files changed, 110 insertions, 74 deletions
diff --git a/crates/hir_ty/src/infer/unify.rs b/crates/hir_ty/src/infer/unify.rs
index 2852ad5bf..b481aa1b3 100644
--- a/crates/hir_ty/src/infer/unify.rs
+++ b/crates/hir_ty/src/infer/unify.rs
@@ -2,14 +2,15 @@
2 2
3use std::borrow::Cow; 3use std::borrow::Cow;
4 4
5use chalk_ir::{FloatTy, IntTy, TyVariableKind};
5use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue}; 6use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue};
6 7
7use test_utils::mark; 8use test_utils::mark;
8 9
9use super::{InferenceContext, Obligation}; 10use super::{InferenceContext, Obligation};
10use crate::{ 11use crate::{
11 BoundVar, Canonical, DebruijnIndex, GenericPredicate, InEnvironment, InferTy, Scalar, Substs, 12 BoundVar, Canonical, DebruijnIndex, GenericPredicate, InEnvironment, InferenceVar, Scalar,
12 Ty, TyKind, TypeWalk, 13 Substs, Ty, TypeWalk,
13}; 14};
14 15
15impl<'a> InferenceContext<'a> { 16impl<'a> InferenceContext<'a> {
@@ -26,7 +27,7 @@ where
26 'a: 'b, 27 'a: 'b,
27{ 28{
28 ctx: &'b mut InferenceContext<'a>, 29 ctx: &'b mut InferenceContext<'a>,
29 free_vars: Vec<InferTy>, 30 free_vars: Vec<(InferenceVar, TyVariableKind)>,
30 /// A stack of type variables that is used to detect recursive types (which 31 /// A stack of type variables that is used to detect recursive types (which
31 /// are an error, but we need to protect against them to avoid stack 32 /// are an error, but we need to protect against them to avoid stack
32 /// overflows). 33 /// overflows).
@@ -36,17 +37,14 @@ where
36#[derive(Debug)] 37#[derive(Debug)]
37pub(super) struct Canonicalized<T> { 38pub(super) struct Canonicalized<T> {
38 pub(super) value: Canonical<T>, 39 pub(super) value: Canonical<T>,
39 free_vars: Vec<InferTy>, 40 free_vars: Vec<(InferenceVar, TyVariableKind)>,
40} 41}
41 42
42impl<'a, 'b> Canonicalizer<'a, 'b> 43impl<'a, 'b> Canonicalizer<'a, 'b> {
43where 44 fn add(&mut self, free_var: InferenceVar, kind: TyVariableKind) -> usize {
44 'a: 'b, 45 self.free_vars.iter().position(|&(v, _)| v == free_var).unwrap_or_else(|| {
45{
46 fn add(&mut self, free_var: InferTy) -> usize {
47 self.free_vars.iter().position(|&v| v == free_var).unwrap_or_else(|| {
48 let next_index = self.free_vars.len(); 46 let next_index = self.free_vars.len();
49 self.free_vars.push(free_var); 47 self.free_vars.push((free_var, kind));
50 next_index 48 next_index
51 }) 49 })
52 } 50 }
@@ -54,11 +52,11 @@ where
54 fn do_canonicalize<T: TypeWalk>(&mut self, t: T, binders: DebruijnIndex) -> T { 52 fn do_canonicalize<T: TypeWalk>(&mut self, t: T, binders: DebruijnIndex) -> T {
55 t.fold_binders( 53 t.fold_binders(
56 &mut |ty, binders| match ty { 54 &mut |ty, binders| match ty {
57 Ty::Infer(tv) => { 55 Ty::InferenceVar(var, kind) => {
58 let inner = tv.to_inner(); 56 let inner = var.to_inner();
59 if self.var_stack.contains(&inner) { 57 if self.var_stack.contains(&inner) {
60 // recursive type 58 // recursive type
61 return tv.fallback_value(); 59 return self.ctx.table.type_variable_table.fallback_value(var, kind);
62 } 60 }
63 if let Some(known_ty) = 61 if let Some(known_ty) =
64 self.ctx.table.var_unification_table.inlined_probe_value(inner).known() 62 self.ctx.table.var_unification_table.inlined_probe_value(inner).known()
@@ -69,13 +67,7 @@ where
69 result 67 result
70 } else { 68 } else {
71 let root = self.ctx.table.var_unification_table.find(inner); 69 let root = self.ctx.table.var_unification_table.find(inner);
72 let free_var = match tv { 70 let position = self.add(InferenceVar::from_inner(root), kind);
73 InferTy::TypeVar(_) => InferTy::TypeVar(root),
74 InferTy::IntVar(_) => InferTy::IntVar(root),
75 InferTy::FloatVar(_) => InferTy::FloatVar(root),
76 InferTy::MaybeNeverTypeVar(_) => InferTy::MaybeNeverTypeVar(root),
77 };
78 let position = self.add(free_var);
79 Ty::Bound(BoundVar::new(binders, position)) 71 Ty::Bound(BoundVar::new(binders, position))
80 } 72 }
81 } 73 }
@@ -86,19 +78,7 @@ where
86 } 78 }
87 79
88 fn into_canonicalized<T>(self, result: T) -> Canonicalized<T> { 80 fn into_canonicalized<T>(self, result: T) -> Canonicalized<T> {
89 let kinds = self 81 let kinds = self.free_vars.iter().map(|&(_, k)| k).collect();
90 .free_vars
91 .iter()
92 .map(|v| match v {
93 // mapping MaybeNeverTypeVar to the same kind as general ones
94 // should be fine, because as opposed to int or float type vars,
95 // they don't restrict what kind of type can go into them, they
96 // just affect fallback.
97 InferTy::TypeVar(_) | InferTy::MaybeNeverTypeVar(_) => TyKind::General,
98 InferTy::IntVar(_) => TyKind::Integer,
99 InferTy::FloatVar(_) => TyKind::Float,
100 })
101 .collect();
102 Canonicalized { value: Canonical { value: result, kinds }, free_vars: self.free_vars } 82 Canonicalized { value: Canonical { value: result, kinds }, free_vars: self.free_vars }
103 } 83 }
104 84
@@ -132,7 +112,8 @@ impl<T> Canonicalized<T> {
132 &mut |ty, binders| { 112 &mut |ty, binders| {
133 if let &mut Ty::Bound(bound) = ty { 113 if let &mut Ty::Bound(bound) = ty {
134 if bound.debruijn >= binders { 114 if bound.debruijn >= binders {
135 *ty = Ty::Infer(self.free_vars[bound.index]); 115 let (v, k) = self.free_vars[bound.index];
116 *ty = Ty::InferenceVar(v, k);
136 } 117 }
137 } 118 }
138 }, 119 },
@@ -152,18 +133,18 @@ impl<T> Canonicalized<T> {
152 .kinds 133 .kinds
153 .iter() 134 .iter()
154 .map(|k| match k { 135 .map(|k| match k {
155 TyKind::General => ctx.table.new_type_var(), 136 TyVariableKind::General => ctx.table.new_type_var(),
156 TyKind::Integer => ctx.table.new_integer_var(), 137 TyVariableKind::Integer => ctx.table.new_integer_var(),
157 TyKind::Float => ctx.table.new_float_var(), 138 TyVariableKind::Float => ctx.table.new_float_var(),
158 }) 139 })
159 .collect(), 140 .collect(),
160 ); 141 );
161 for (i, ty) in solution.value.into_iter().enumerate() { 142 for (i, ty) in solution.value.into_iter().enumerate() {
162 let var = self.free_vars[i]; 143 let (v, k) = self.free_vars[i];
163 // eagerly replace projections in the type; we may be getting types 144 // eagerly replace projections in the type; we may be getting types
164 // e.g. from where clauses where this hasn't happened yet 145 // e.g. from where clauses where this hasn't happened yet
165 let ty = ctx.normalize_associated_types_in(ty.clone().subst_bound_vars(&new_vars)); 146 let ty = ctx.normalize_associated_types_in(ty.clone().subst_bound_vars(&new_vars));
166 ctx.table.unify(&Ty::Infer(var), &ty); 147 ctx.table.unify(&Ty::InferenceVar(v, k), &ty);
167 } 148 }
168 } 149 }
169} 150}
@@ -198,31 +179,73 @@ pub(crate) fn unify(tys: &Canonical<(Ty, Ty)>) -> Option<Substs> {
198} 179}
199 180
200#[derive(Clone, Debug)] 181#[derive(Clone, Debug)]
182pub(super) struct TypeVariableTable {
183 inner: Vec<TypeVariableData>,
184}
185
186impl TypeVariableTable {
187 fn push(&mut self, data: TypeVariableData) {
188 self.inner.push(data);
189 }
190
191 pub(super) fn set_diverging(&mut self, iv: InferenceVar, diverging: bool) {
192 self.inner[iv.to_inner().0 as usize].diverging = diverging;
193 }
194
195 fn is_diverging(&mut self, iv: InferenceVar) -> bool {
196 self.inner[iv.to_inner().0 as usize].diverging
197 }
198
199 fn fallback_value(&self, iv: InferenceVar, kind: TyVariableKind) -> Ty {
200 match kind {
201 _ if self.inner[iv.to_inner().0 as usize].diverging => Ty::Never,
202 TyVariableKind::General => Ty::Unknown,
203 TyVariableKind::Integer => Ty::Scalar(Scalar::Int(IntTy::I32)),
204 TyVariableKind::Float => Ty::Scalar(Scalar::Float(FloatTy::F64)),
205 }
206 }
207}
208
209#[derive(Copy, Clone, Debug)]
210pub(crate) struct TypeVariableData {
211 diverging: bool,
212}
213
214#[derive(Clone, Debug)]
201pub(crate) struct InferenceTable { 215pub(crate) struct InferenceTable {
202 pub(super) var_unification_table: InPlaceUnificationTable<TypeVarId>, 216 pub(super) var_unification_table: InPlaceUnificationTable<TypeVarId>,
217 pub(super) type_variable_table: TypeVariableTable,
203} 218}
204 219
205impl InferenceTable { 220impl InferenceTable {
206 pub(crate) fn new() -> Self { 221 pub(crate) fn new() -> Self {
207 InferenceTable { var_unification_table: InPlaceUnificationTable::new() } 222 InferenceTable {
223 var_unification_table: InPlaceUnificationTable::new(),
224 type_variable_table: TypeVariableTable { inner: Vec::new() },
225 }
226 }
227
228 fn new_var(&mut self, kind: TyVariableKind, diverging: bool) -> Ty {
229 self.type_variable_table.push(TypeVariableData { diverging });
230 let key = self.var_unification_table.new_key(TypeVarValue::Unknown);
231 assert_eq!(key.0 as usize, self.type_variable_table.inner.len() - 1);
232 Ty::InferenceVar(InferenceVar::from_inner(key), kind)
208 } 233 }
209 234
210 pub(crate) fn new_type_var(&mut self) -> Ty { 235 pub(crate) fn new_type_var(&mut self) -> Ty {
211 Ty::Infer(InferTy::TypeVar(self.var_unification_table.new_key(TypeVarValue::Unknown))) 236 self.new_var(TyVariableKind::General, false)
212 } 237 }
213 238
214 pub(crate) fn new_integer_var(&mut self) -> Ty { 239 pub(crate) fn new_integer_var(&mut self) -> Ty {
215 Ty::Infer(InferTy::IntVar(self.var_unification_table.new_key(TypeVarValue::Unknown))) 240 self.new_var(TyVariableKind::Integer, false)
216 } 241 }
217 242
218 pub(crate) fn new_float_var(&mut self) -> Ty { 243 pub(crate) fn new_float_var(&mut self) -> Ty {
219 Ty::Infer(InferTy::FloatVar(self.var_unification_table.new_key(TypeVarValue::Unknown))) 244 self.new_var(TyVariableKind::Float, false)
220 } 245 }
221 246
222 pub(crate) fn new_maybe_never_type_var(&mut self) -> Ty { 247 pub(crate) fn new_maybe_never_var(&mut self) -> Ty {
223 Ty::Infer(InferTy::MaybeNeverTypeVar( 248 self.new_var(TyVariableKind::General, true)
224 self.var_unification_table.new_key(TypeVarValue::Unknown),
225 ))
226 } 249 }
227 250
228 pub(crate) fn resolve_ty_completely(&mut self, ty: Ty) -> Ty { 251 pub(crate) fn resolve_ty_completely(&mut self, ty: Ty) -> Ty {
@@ -283,33 +306,46 @@ impl InferenceTable {
283 true 306 true
284 } 307 }
285 308
286 (Ty::Infer(InferTy::TypeVar(tv1)), Ty::Infer(InferTy::TypeVar(tv2))) 309 (
287 | (Ty::Infer(InferTy::IntVar(tv1)), Ty::Infer(InferTy::IntVar(tv2))) 310 Ty::InferenceVar(tv1, TyVariableKind::General),
288 | (Ty::Infer(InferTy::FloatVar(tv1)), Ty::Infer(InferTy::FloatVar(tv2))) 311 Ty::InferenceVar(tv2, TyVariableKind::General),
312 )
289 | ( 313 | (
290 Ty::Infer(InferTy::MaybeNeverTypeVar(tv1)), 314 Ty::InferenceVar(tv1, TyVariableKind::Integer),
291 Ty::Infer(InferTy::MaybeNeverTypeVar(tv2)), 315 Ty::InferenceVar(tv2, TyVariableKind::Integer),
292 ) => { 316 )
317 | (
318 Ty::InferenceVar(tv1, TyVariableKind::Float),
319 Ty::InferenceVar(tv2, TyVariableKind::Float),
320 ) if self.type_variable_table.is_diverging(*tv1)
321 == self.type_variable_table.is_diverging(*tv2) =>
322 {
293 // both type vars are unknown since we tried to resolve them 323 // both type vars are unknown since we tried to resolve them
294 self.var_unification_table.union(*tv1, *tv2); 324 self.var_unification_table.union(tv1.to_inner(), tv2.to_inner());
295 true 325 true
296 } 326 }
297 327
298 // The order of MaybeNeverTypeVar matters here. 328 // The order of MaybeNeverTypeVar matters here.
299 // Unifying MaybeNeverTypeVar and TypeVar will let the latter become MaybeNeverTypeVar. 329 // Unifying MaybeNeverTypeVar and TypeVar will let the latter become MaybeNeverTypeVar.
300 // Unifying MaybeNeverTypeVar and other concrete type will let the former become it. 330 // Unifying MaybeNeverTypeVar and other concrete type will let the former become it.
301 (Ty::Infer(InferTy::TypeVar(tv)), other) 331 (Ty::InferenceVar(tv, TyVariableKind::General), other)
302 | (other, Ty::Infer(InferTy::TypeVar(tv))) 332 | (other, Ty::InferenceVar(tv, TyVariableKind::General))
303 | (Ty::Infer(InferTy::MaybeNeverTypeVar(tv)), other) 333 | (Ty::InferenceVar(tv, TyVariableKind::Integer), other @ Ty::Scalar(Scalar::Int(_)))
304 | (other, Ty::Infer(InferTy::MaybeNeverTypeVar(tv))) 334 | (other @ Ty::Scalar(Scalar::Int(_)), Ty::InferenceVar(tv, TyVariableKind::Integer))
305 | (Ty::Infer(InferTy::IntVar(tv)), other @ Ty::Scalar(Scalar::Int(_))) 335 | (
306 | (other @ Ty::Scalar(Scalar::Int(_)), Ty::Infer(InferTy::IntVar(tv))) 336 Ty::InferenceVar(tv, TyVariableKind::Integer),
307 | (Ty::Infer(InferTy::IntVar(tv)), other @ Ty::Scalar(Scalar::Uint(_))) 337 other @ Ty::Scalar(Scalar::Uint(_)),
308 | (other @ Ty::Scalar(Scalar::Uint(_)), Ty::Infer(InferTy::IntVar(tv))) 338 )
309 | (Ty::Infer(InferTy::FloatVar(tv)), other @ Ty::Scalar(Scalar::Float(_))) 339 | (
310 | (other @ Ty::Scalar(Scalar::Float(_)), Ty::Infer(InferTy::FloatVar(tv))) => { 340 other @ Ty::Scalar(Scalar::Uint(_)),
341 Ty::InferenceVar(tv, TyVariableKind::Integer),
342 )
343 | (Ty::InferenceVar(tv, TyVariableKind::Float), other @ Ty::Scalar(Scalar::Float(_)))
344 | (other @ Ty::Scalar(Scalar::Float(_)), Ty::InferenceVar(tv, TyVariableKind::Float)) =>
345 {
311 // the type var is unknown since we tried to resolve it 346 // the type var is unknown since we tried to resolve it
312 self.var_unification_table.union_value(*tv, TypeVarValue::Known(other.clone())); 347 self.var_unification_table
348 .union_value(tv.to_inner(), TypeVarValue::Known(other.clone()));
313 true 349 true
314 } 350 }
315 351
@@ -354,7 +390,7 @@ impl InferenceTable {
354 mark::hit!(type_var_resolves_to_int_var); 390 mark::hit!(type_var_resolves_to_int_var);
355 } 391 }
356 match &*ty { 392 match &*ty {
357 Ty::Infer(tv) => { 393 Ty::InferenceVar(tv, _) => {
358 let inner = tv.to_inner(); 394 let inner = tv.to_inner();
359 match self.var_unification_table.inlined_probe_value(inner).known() { 395 match self.var_unification_table.inlined_probe_value(inner).known() {
360 Some(known_ty) => { 396 Some(known_ty) => {
@@ -377,12 +413,12 @@ impl InferenceTable {
377 /// known type. 413 /// known type.
378 fn resolve_ty_as_possible_inner(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty { 414 fn resolve_ty_as_possible_inner(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty {
379 ty.fold(&mut |ty| match ty { 415 ty.fold(&mut |ty| match ty {
380 Ty::Infer(tv) => { 416 Ty::InferenceVar(tv, kind) => {
381 let inner = tv.to_inner(); 417 let inner = tv.to_inner();
382 if tv_stack.contains(&inner) { 418 if tv_stack.contains(&inner) {
383 mark::hit!(type_var_cycles_resolve_as_possible); 419 mark::hit!(type_var_cycles_resolve_as_possible);
384 // recursive type 420 // recursive type
385 return tv.fallback_value(); 421 return self.type_variable_table.fallback_value(tv, kind);
386 } 422 }
387 if let Some(known_ty) = 423 if let Some(known_ty) =
388 self.var_unification_table.inlined_probe_value(inner).known() 424 self.var_unification_table.inlined_probe_value(inner).known()
@@ -404,12 +440,12 @@ impl InferenceTable {
404 /// replaced by Ty::Unknown. 440 /// replaced by Ty::Unknown.
405 fn resolve_ty_completely_inner(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty { 441 fn resolve_ty_completely_inner(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty {
406 ty.fold(&mut |ty| match ty { 442 ty.fold(&mut |ty| match ty {
407 Ty::Infer(tv) => { 443 Ty::InferenceVar(tv, kind) => {
408 let inner = tv.to_inner(); 444 let inner = tv.to_inner();
409 if tv_stack.contains(&inner) { 445 if tv_stack.contains(&inner) {
410 mark::hit!(type_var_cycles_resolve_completely); 446 mark::hit!(type_var_cycles_resolve_completely);
411 // recursive type 447 // recursive type
412 return tv.fallback_value(); 448 return self.type_variable_table.fallback_value(tv, kind);
413 } 449 }
414 if let Some(known_ty) = 450 if let Some(known_ty) =
415 self.var_unification_table.inlined_probe_value(inner).known() 451 self.var_unification_table.inlined_probe_value(inner).known()
@@ -420,7 +456,7 @@ impl InferenceTable {
420 tv_stack.pop(); 456 tv_stack.pop();
421 result 457 result
422 } else { 458 } else {
423 tv.fallback_value() 459 self.type_variable_table.fallback_value(tv, kind)
424 } 460 }
425 } 461 }
426 _ => ty, 462 _ => ty,
@@ -430,7 +466,7 @@ impl InferenceTable {
430 466
431/// The ID of a type variable. 467/// The ID of a type variable.
432#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] 468#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
433pub struct TypeVarId(pub(super) u32); 469pub(super) struct TypeVarId(pub(super) u32);
434 470
435impl UnifyKey for TypeVarId { 471impl UnifyKey for TypeVarId {
436 type Value = TypeVarValue; 472 type Value = TypeVarValue;
@@ -451,7 +487,7 @@ impl UnifyKey for TypeVarId {
451/// The value of a type variable: either we already know the type, or we don't 487/// The value of a type variable: either we already know the type, or we don't
452/// know it yet. 488/// know it yet.
453#[derive(Clone, PartialEq, Eq, Debug)] 489#[derive(Clone, PartialEq, Eq, Debug)]
454pub enum TypeVarValue { 490pub(super) enum TypeVarValue {
455 Known(Ty), 491 Known(Ty),
456 Unknown, 492 Unknown,
457} 493}