From 6a77ec7bbe6ddbf663dce9529d11d1bb56c5489a Mon Sep 17 00:00:00 2001 From: Aleksey Kladov Date: Thu, 13 Aug 2020 16:35:29 +0200 Subject: Rename ra_hir_ty -> hir_ty --- crates/hir_ty/src/infer/coerce.rs | 197 +++++++++ crates/hir_ty/src/infer/expr.rs | 873 ++++++++++++++++++++++++++++++++++++++ crates/hir_ty/src/infer/pat.rs | 241 +++++++++++ crates/hir_ty/src/infer/path.rs | 287 +++++++++++++ crates/hir_ty/src/infer/unify.rs | 474 +++++++++++++++++++++ 5 files changed, 2072 insertions(+) create mode 100644 crates/hir_ty/src/infer/coerce.rs create mode 100644 crates/hir_ty/src/infer/expr.rs create mode 100644 crates/hir_ty/src/infer/pat.rs create mode 100644 crates/hir_ty/src/infer/path.rs create mode 100644 crates/hir_ty/src/infer/unify.rs (limited to 'crates/hir_ty/src/infer') diff --git a/crates/hir_ty/src/infer/coerce.rs b/crates/hir_ty/src/infer/coerce.rs new file mode 100644 index 000000000..32c7c57cd --- /dev/null +++ b/crates/hir_ty/src/infer/coerce.rs @@ -0,0 +1,197 @@ +//! Coercion logic. Coercions are certain type conversions that can implicitly +//! happen in certain places, e.g. weakening `&mut` to `&` or deref coercions +//! like going from `&Vec` to `&[T]`. +//! +//! See: https://doc.rust-lang.org/nomicon/coercions.html + +use hir_def::{lang_item::LangItemTarget, type_ref::Mutability}; +use test_utils::mark; + +use crate::{autoderef, traits::Solution, Obligation, Substs, TraitRef, Ty, TypeCtor}; + +use super::{unify::TypeVarValue, InEnvironment, InferTy, InferenceContext}; + +impl<'a> InferenceContext<'a> { + /// Unify two types, but may coerce the first one to the second one + /// using "implicit coercion rules" if needed. + pub(super) fn coerce(&mut self, from_ty: &Ty, to_ty: &Ty) -> bool { + let from_ty = self.resolve_ty_shallow(from_ty).into_owned(); + let to_ty = self.resolve_ty_shallow(to_ty); + self.coerce_inner(from_ty, &to_ty) + } + + /// Merge two types from different branches, with possible coercion. + /// + /// Mostly this means trying to coerce one to the other, but + /// - if we have two function types for different functions, we need to + /// coerce both to function pointers; + /// - if we were concerned with lifetime subtyping, we'd need to look for a + /// least upper bound. + pub(super) fn coerce_merge_branch(&mut self, ty1: &Ty, ty2: &Ty) -> Ty { + if self.coerce(ty1, ty2) { + ty2.clone() + } else if self.coerce(ty2, ty1) { + ty1.clone() + } else { + if let (ty_app!(TypeCtor::FnDef(_)), ty_app!(TypeCtor::FnDef(_))) = (ty1, ty2) { + mark::hit!(coerce_fn_reification); + // Special case: two function types. Try to coerce both to + // pointers to have a chance at getting a match. See + // https://github.com/rust-lang/rust/blob/7b805396bf46dce972692a6846ce2ad8481c5f85/src/librustc_typeck/check/coercion.rs#L877-L916 + let sig1 = ty1.callable_sig(self.db).expect("FnDef without callable sig"); + let sig2 = ty2.callable_sig(self.db).expect("FnDef without callable sig"); + let ptr_ty1 = Ty::fn_ptr(sig1); + let ptr_ty2 = Ty::fn_ptr(sig2); + self.coerce_merge_branch(&ptr_ty1, &ptr_ty2) + } else { + mark::hit!(coerce_merge_fail_fallback); + ty1.clone() + } + } + } + + fn coerce_inner(&mut self, mut from_ty: Ty, to_ty: &Ty) -> bool { + match (&from_ty, to_ty) { + // Never type will make type variable to fallback to Never Type instead of Unknown. + (ty_app!(TypeCtor::Never), Ty::Infer(InferTy::TypeVar(tv))) => { + let var = self.table.new_maybe_never_type_var(); + self.table.var_unification_table.union_value(*tv, TypeVarValue::Known(var)); + return true; + } + (ty_app!(TypeCtor::Never), _) => return true, + + // Trivial cases, this should go after `never` check to + // avoid infer result type to be never + _ => { + if self.table.unify_inner_trivial(&from_ty, &to_ty, 0) { + return true; + } + } + } + + // Pointer weakening and function to pointer + match (&mut from_ty, to_ty) { + // `*mut T`, `&mut T, `&T`` -> `*const T` + // `&mut T` -> `&T` + // `&mut T` -> `*mut T` + (ty_app!(c1@TypeCtor::RawPtr(_)), ty_app!(c2@TypeCtor::RawPtr(Mutability::Shared))) + | (ty_app!(c1@TypeCtor::Ref(_)), ty_app!(c2@TypeCtor::RawPtr(Mutability::Shared))) + | (ty_app!(c1@TypeCtor::Ref(_)), ty_app!(c2@TypeCtor::Ref(Mutability::Shared))) + | (ty_app!(c1@TypeCtor::Ref(Mutability::Mut)), ty_app!(c2@TypeCtor::RawPtr(_))) => { + *c1 = *c2; + } + + // Illegal mutablity conversion + ( + ty_app!(TypeCtor::RawPtr(Mutability::Shared)), + ty_app!(TypeCtor::RawPtr(Mutability::Mut)), + ) + | ( + ty_app!(TypeCtor::Ref(Mutability::Shared)), + ty_app!(TypeCtor::Ref(Mutability::Mut)), + ) => return false, + + // `{function_type}` -> `fn()` + (ty_app!(TypeCtor::FnDef(_)), ty_app!(TypeCtor::FnPtr { .. })) => { + match from_ty.callable_sig(self.db) { + None => return false, + Some(sig) => { + from_ty = Ty::fn_ptr(sig); + } + } + } + + (ty_app!(TypeCtor::Closure { .. }, params), ty_app!(TypeCtor::FnPtr { .. })) => { + from_ty = params[0].clone(); + } + + _ => {} + } + + if let Some(ret) = self.try_coerce_unsized(&from_ty, &to_ty) { + return ret; + } + + // Auto Deref if cannot coerce + match (&from_ty, to_ty) { + // FIXME: DerefMut + (ty_app!(TypeCtor::Ref(_), st1), ty_app!(TypeCtor::Ref(_), st2)) => { + self.unify_autoderef_behind_ref(&st1[0], &st2[0]) + } + + // Otherwise, normal unify + _ => self.unify(&from_ty, to_ty), + } + } + + /// Coerce a type using `from_ty: CoerceUnsized` + /// + /// See: https://doc.rust-lang.org/nightly/std/marker/trait.CoerceUnsized.html + fn try_coerce_unsized(&mut self, from_ty: &Ty, to_ty: &Ty) -> Option { + let krate = self.resolver.krate().unwrap(); + let coerce_unsized_trait = match self.db.lang_item(krate, "coerce_unsized".into()) { + Some(LangItemTarget::TraitId(trait_)) => trait_, + _ => return None, + }; + + let generic_params = crate::utils::generics(self.db.upcast(), coerce_unsized_trait.into()); + if generic_params.len() != 2 { + // The CoerceUnsized trait should have two generic params: Self and T. + return None; + } + + let substs = Substs::build_for_generics(&generic_params) + .push(from_ty.clone()) + .push(to_ty.clone()) + .build(); + let trait_ref = TraitRef { trait_: coerce_unsized_trait, substs }; + let goal = InEnvironment::new(self.trait_env.clone(), Obligation::Trait(trait_ref)); + + let canonicalizer = self.canonicalizer(); + let canonicalized = canonicalizer.canonicalize_obligation(goal); + + let solution = self.db.trait_solve(krate, canonicalized.value.clone())?; + + match solution { + Solution::Unique(v) => { + canonicalized.apply_solution(self, v.0); + } + _ => return None, + }; + + Some(true) + } + + /// Unify `from_ty` to `to_ty` with optional auto Deref + /// + /// Note that the parameters are already stripped the outer reference. + fn unify_autoderef_behind_ref(&mut self, from_ty: &Ty, to_ty: &Ty) -> bool { + let canonicalized = self.canonicalizer().canonicalize_ty(from_ty.clone()); + let to_ty = self.resolve_ty_shallow(&to_ty); + // FIXME: Auto DerefMut + for derefed_ty in autoderef::autoderef( + self.db, + self.resolver.krate(), + InEnvironment { + value: canonicalized.value.clone(), + environment: self.trait_env.clone(), + }, + ) { + let derefed_ty = canonicalized.decanonicalize_ty(derefed_ty.value); + match (&*self.resolve_ty_shallow(&derefed_ty), &*to_ty) { + // Stop when constructor matches. + (ty_app!(from_ctor, st1), ty_app!(to_ctor, st2)) if from_ctor == to_ctor => { + // It will not recurse to `coerce`. + return self.table.unify_substs(st1, st2, 0); + } + _ => { + if self.table.unify_inner_trivial(&derefed_ty, &to_ty, 0) { + return true; + } + } + } + } + + false + } +} diff --git a/crates/hir_ty/src/infer/expr.rs b/crates/hir_ty/src/infer/expr.rs new file mode 100644 index 000000000..a2f849d02 --- /dev/null +++ b/crates/hir_ty/src/infer/expr.rs @@ -0,0 +1,873 @@ +//! Type inference for expressions. + +use std::iter::{repeat, repeat_with}; +use std::{mem, sync::Arc}; + +use hir_def::{ + builtin_type::Signedness, + expr::{Array, BinaryOp, Expr, ExprId, Literal, Statement, UnaryOp}, + path::{GenericArg, GenericArgs}, + resolver::resolver_for_expr, + AdtId, AssocContainerId, FieldId, Lookup, +}; +use hir_expand::name::{name, Name}; +use syntax::ast::RangeOp; + +use crate::{ + autoderef, method_resolution, op, + traits::{FnTrait, InEnvironment}, + utils::{generics, variant_data, Generics}, + ApplicationTy, Binders, CallableDefId, InferTy, IntTy, Mutability, Obligation, Rawness, Substs, + TraitRef, Ty, TypeCtor, +}; + +use super::{ + find_breakable, BindingMode, BreakableContext, Diverges, Expectation, InferenceContext, + InferenceDiagnostic, TypeMismatch, +}; + +impl<'a> InferenceContext<'a> { + pub(super) fn infer_expr(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty { + let ty = self.infer_expr_inner(tgt_expr, expected); + if ty.is_never() { + // Any expression that produces a value of type `!` must have diverged + self.diverges = Diverges::Always; + } + let could_unify = self.unify(&ty, &expected.ty); + if !could_unify { + self.result.type_mismatches.insert( + tgt_expr, + TypeMismatch { expected: expected.ty.clone(), actual: ty.clone() }, + ); + } + self.resolve_ty_as_possible(ty) + } + + /// Infer type of expression with possibly implicit coerce to the expected type. + /// Return the type after possible coercion. + pub(super) fn infer_expr_coerce(&mut self, expr: ExprId, expected: &Expectation) -> Ty { + let ty = self.infer_expr_inner(expr, &expected); + let ty = if !self.coerce(&ty, &expected.coercion_target()) { + self.result + .type_mismatches + .insert(expr, TypeMismatch { expected: expected.ty.clone(), actual: ty.clone() }); + // Return actual type when type mismatch. + // This is needed for diagnostic when return type mismatch. + ty + } else if expected.coercion_target() == &Ty::Unknown { + ty + } else { + expected.ty.clone() + }; + + self.resolve_ty_as_possible(ty) + } + + fn callable_sig_from_fn_trait(&mut self, ty: &Ty, num_args: usize) -> Option<(Vec, Ty)> { + let krate = self.resolver.krate()?; + let fn_once_trait = FnTrait::FnOnce.get_id(self.db, krate)?; + let output_assoc_type = + self.db.trait_data(fn_once_trait).associated_type_by_name(&name![Output])?; + let generic_params = generics(self.db.upcast(), fn_once_trait.into()); + if generic_params.len() != 2 { + return None; + } + + let mut param_builder = Substs::builder(num_args); + let mut arg_tys = vec![]; + for _ in 0..num_args { + let arg = self.table.new_type_var(); + param_builder = param_builder.push(arg.clone()); + arg_tys.push(arg); + } + let parameters = param_builder.build(); + let arg_ty = Ty::Apply(ApplicationTy { + ctor: TypeCtor::Tuple { cardinality: num_args as u16 }, + parameters, + }); + let substs = + Substs::build_for_generics(&generic_params).push(ty.clone()).push(arg_ty).build(); + + let trait_env = Arc::clone(&self.trait_env); + let implements_fn_trait = + Obligation::Trait(TraitRef { trait_: fn_once_trait, substs: substs.clone() }); + let goal = self.canonicalizer().canonicalize_obligation(InEnvironment { + value: implements_fn_trait.clone(), + environment: trait_env, + }); + if self.db.trait_solve(krate, goal.value).is_some() { + self.obligations.push(implements_fn_trait); + let output_proj_ty = + crate::ProjectionTy { associated_ty: output_assoc_type, parameters: substs }; + let return_ty = self.normalize_projection_ty(output_proj_ty); + Some((arg_tys, return_ty)) + } else { + None + } + } + + pub fn callable_sig(&mut self, ty: &Ty, num_args: usize) -> Option<(Vec, Ty)> { + match ty.callable_sig(self.db) { + Some(sig) => Some((sig.params().to_vec(), sig.ret().clone())), + None => self.callable_sig_from_fn_trait(ty, num_args), + } + } + + fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty { + let body = Arc::clone(&self.body); // avoid borrow checker problem + let ty = match &body[tgt_expr] { + Expr::Missing => Ty::Unknown, + Expr::If { condition, then_branch, else_branch } => { + // if let is desugared to match, so this is always simple if + self.infer_expr(*condition, &Expectation::has_type(Ty::simple(TypeCtor::Bool))); + + let condition_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); + let mut both_arms_diverge = Diverges::Always; + + let then_ty = self.infer_expr_inner(*then_branch, &expected); + both_arms_diverge &= mem::replace(&mut self.diverges, Diverges::Maybe); + let else_ty = match else_branch { + Some(else_branch) => self.infer_expr_inner(*else_branch, &expected), + None => Ty::unit(), + }; + both_arms_diverge &= self.diverges; + + self.diverges = condition_diverges | both_arms_diverge; + + self.coerce_merge_branch(&then_ty, &else_ty) + } + Expr::Block { statements, tail, .. } => { + // FIXME: Breakable block inference + self.infer_block(statements, *tail, expected) + } + Expr::Unsafe { body } => self.infer_expr(*body, expected), + Expr::TryBlock { body } => { + let _inner = self.infer_expr(*body, expected); + // FIXME should be std::result::Result<{inner}, _> + Ty::Unknown + } + Expr::Loop { body, label } => { + self.breakables.push(BreakableContext { + may_break: false, + break_ty: self.table.new_type_var(), + label: label.clone(), + }); + self.infer_expr(*body, &Expectation::has_type(Ty::unit())); + + let ctxt = self.breakables.pop().expect("breakable stack broken"); + if ctxt.may_break { + self.diverges = Diverges::Maybe; + } + + if ctxt.may_break { + ctxt.break_ty + } else { + Ty::simple(TypeCtor::Never) + } + } + Expr::While { condition, body, label } => { + self.breakables.push(BreakableContext { + may_break: false, + break_ty: Ty::Unknown, + label: label.clone(), + }); + // while let is desugared to a match loop, so this is always simple while + self.infer_expr(*condition, &Expectation::has_type(Ty::simple(TypeCtor::Bool))); + self.infer_expr(*body, &Expectation::has_type(Ty::unit())); + let _ctxt = self.breakables.pop().expect("breakable stack broken"); + // the body may not run, so it diverging doesn't mean we diverge + self.diverges = Diverges::Maybe; + Ty::unit() + } + Expr::For { iterable, body, pat, label } => { + let iterable_ty = self.infer_expr(*iterable, &Expectation::none()); + + self.breakables.push(BreakableContext { + may_break: false, + break_ty: Ty::Unknown, + label: label.clone(), + }); + let pat_ty = + self.resolve_associated_type(iterable_ty, self.resolve_into_iter_item()); + + self.infer_pat(*pat, &pat_ty, BindingMode::default()); + + self.infer_expr(*body, &Expectation::has_type(Ty::unit())); + let _ctxt = self.breakables.pop().expect("breakable stack broken"); + // the body may not run, so it diverging doesn't mean we diverge + self.diverges = Diverges::Maybe; + Ty::unit() + } + Expr::Lambda { body, args, ret_type, arg_types } => { + assert_eq!(args.len(), arg_types.len()); + + let mut sig_tys = Vec::new(); + + // collect explicitly written argument types + for arg_type in arg_types.iter() { + let arg_ty = if let Some(type_ref) = arg_type { + self.make_ty(type_ref) + } else { + self.table.new_type_var() + }; + sig_tys.push(arg_ty); + } + + // add return type + let ret_ty = match ret_type { + Some(type_ref) => self.make_ty(type_ref), + None => self.table.new_type_var(), + }; + sig_tys.push(ret_ty.clone()); + let sig_ty = Ty::apply( + TypeCtor::FnPtr { num_args: sig_tys.len() as u16 - 1, is_varargs: false }, + Substs(sig_tys.clone().into()), + ); + let closure_ty = + Ty::apply_one(TypeCtor::Closure { def: self.owner, expr: tgt_expr }, sig_ty); + + // Eagerly try to relate the closure type with the expected + // type, otherwise we often won't have enough information to + // infer the body. + self.coerce(&closure_ty, &expected.ty); + + // Now go through the argument patterns + for (arg_pat, arg_ty) in args.iter().zip(sig_tys) { + let resolved = self.resolve_ty_as_possible(arg_ty); + self.infer_pat(*arg_pat, &resolved, BindingMode::default()); + } + + let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); + let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty.clone()); + + self.infer_expr_coerce(*body, &Expectation::has_type(ret_ty)); + + self.diverges = prev_diverges; + self.return_ty = prev_ret_ty; + + closure_ty + } + Expr::Call { callee, args } => { + let callee_ty = self.infer_expr(*callee, &Expectation::none()); + let canonicalized = self.canonicalizer().canonicalize_ty(callee_ty.clone()); + let mut derefs = autoderef( + self.db, + self.resolver.krate(), + InEnvironment { + value: canonicalized.value.clone(), + environment: self.trait_env.clone(), + }, + ); + let (param_tys, ret_ty): (Vec, Ty) = derefs + .find_map(|callee_deref_ty| { + self.callable_sig( + &canonicalized.decanonicalize_ty(callee_deref_ty.value), + args.len(), + ) + }) + .unwrap_or((Vec::new(), Ty::Unknown)); + self.register_obligations_for_call(&callee_ty); + self.check_call_arguments(args, ¶m_tys); + self.normalize_associated_types_in(ret_ty) + } + Expr::MethodCall { receiver, args, method_name, generic_args } => self + .infer_method_call(tgt_expr, *receiver, &args, &method_name, generic_args.as_ref()), + Expr::Match { expr, arms } => { + let input_ty = self.infer_expr(*expr, &Expectation::none()); + + let mut result_ty = if arms.is_empty() { + Ty::simple(TypeCtor::Never) + } else { + self.table.new_type_var() + }; + + let matchee_diverges = self.diverges; + let mut all_arms_diverge = Diverges::Always; + + for arm in arms { + self.diverges = Diverges::Maybe; + let _pat_ty = self.infer_pat(arm.pat, &input_ty, BindingMode::default()); + if let Some(guard_expr) = arm.guard { + self.infer_expr( + guard_expr, + &Expectation::has_type(Ty::simple(TypeCtor::Bool)), + ); + } + + let arm_ty = self.infer_expr_inner(arm.expr, &expected); + all_arms_diverge &= self.diverges; + result_ty = self.coerce_merge_branch(&result_ty, &arm_ty); + } + + self.diverges = matchee_diverges | all_arms_diverge; + + result_ty + } + Expr::Path(p) => { + // FIXME this could be more efficient... + let resolver = resolver_for_expr(self.db.upcast(), self.owner, tgt_expr); + self.infer_path(&resolver, p, tgt_expr.into()).unwrap_or(Ty::Unknown) + } + Expr::Continue { .. } => Ty::simple(TypeCtor::Never), + Expr::Break { expr, label } => { + let val_ty = if let Some(expr) = expr { + self.infer_expr(*expr, &Expectation::none()) + } else { + Ty::unit() + }; + + let last_ty = + if let Some(ctxt) = find_breakable(&mut self.breakables, label.as_ref()) { + ctxt.break_ty.clone() + } else { + Ty::Unknown + }; + + let merged_type = self.coerce_merge_branch(&last_ty, &val_ty); + + if let Some(ctxt) = find_breakable(&mut self.breakables, label.as_ref()) { + ctxt.break_ty = merged_type; + ctxt.may_break = true; + } else { + self.push_diagnostic(InferenceDiagnostic::BreakOutsideOfLoop { + expr: tgt_expr, + }); + } + + Ty::simple(TypeCtor::Never) + } + Expr::Return { expr } => { + if let Some(expr) = expr { + self.infer_expr_coerce(*expr, &Expectation::has_type(self.return_ty.clone())); + } else { + let unit = Ty::unit(); + self.coerce(&unit, &self.return_ty.clone()); + } + Ty::simple(TypeCtor::Never) + } + Expr::RecordLit { path, fields, spread } => { + let (ty, def_id) = self.resolve_variant(path.as_ref()); + if let Some(variant) = def_id { + self.write_variant_resolution(tgt_expr.into(), variant); + } + + self.unify(&ty, &expected.ty); + + let substs = ty.substs().unwrap_or_else(Substs::empty); + let field_types = def_id.map(|it| self.db.field_types(it)).unwrap_or_default(); + let variant_data = def_id.map(|it| variant_data(self.db.upcast(), it)); + for (field_idx, field) in fields.iter().enumerate() { + let field_def = + variant_data.as_ref().and_then(|it| match it.field(&field.name) { + Some(local_id) => Some(FieldId { parent: def_id.unwrap(), local_id }), + None => { + self.push_diagnostic(InferenceDiagnostic::NoSuchField { + expr: tgt_expr, + field: field_idx, + }); + None + } + }); + if let Some(field_def) = field_def { + self.result.record_field_resolutions.insert(field.expr, field_def); + } + let field_ty = field_def + .map_or(Ty::Unknown, |it| field_types[it.local_id].clone().subst(&substs)); + self.infer_expr_coerce(field.expr, &Expectation::has_type(field_ty)); + } + if let Some(expr) = spread { + self.infer_expr(*expr, &Expectation::has_type(ty.clone())); + } + ty + } + Expr::Field { expr, name } => { + let receiver_ty = self.infer_expr_inner(*expr, &Expectation::none()); + let canonicalized = self.canonicalizer().canonicalize_ty(receiver_ty); + let ty = autoderef::autoderef( + self.db, + self.resolver.krate(), + InEnvironment { + value: canonicalized.value.clone(), + environment: self.trait_env.clone(), + }, + ) + .find_map(|derefed_ty| match canonicalized.decanonicalize_ty(derefed_ty.value) { + Ty::Apply(a_ty) => match a_ty.ctor { + TypeCtor::Tuple { .. } => name + .as_tuple_index() + .and_then(|idx| a_ty.parameters.0.get(idx).cloned()), + TypeCtor::Adt(AdtId::StructId(s)) => { + self.db.struct_data(s).variant_data.field(name).map(|local_id| { + let field = FieldId { parent: s.into(), local_id }; + self.write_field_resolution(tgt_expr, field); + self.db.field_types(s.into())[field.local_id] + .clone() + .subst(&a_ty.parameters) + }) + } + TypeCtor::Adt(AdtId::UnionId(u)) => { + self.db.union_data(u).variant_data.field(name).map(|local_id| { + let field = FieldId { parent: u.into(), local_id }; + self.write_field_resolution(tgt_expr, field); + self.db.field_types(u.into())[field.local_id] + .clone() + .subst(&a_ty.parameters) + }) + } + _ => None, + }, + _ => None, + }) + .unwrap_or(Ty::Unknown); + let ty = self.insert_type_vars(ty); + self.normalize_associated_types_in(ty) + } + Expr::Await { expr } => { + let inner_ty = self.infer_expr_inner(*expr, &Expectation::none()); + self.resolve_associated_type(inner_ty, self.resolve_future_future_output()) + } + Expr::Try { expr } => { + let inner_ty = self.infer_expr_inner(*expr, &Expectation::none()); + self.resolve_associated_type(inner_ty, self.resolve_ops_try_ok()) + } + Expr::Cast { expr, type_ref } => { + let _inner_ty = self.infer_expr_inner(*expr, &Expectation::none()); + let cast_ty = self.make_ty(type_ref); + // FIXME check the cast... + cast_ty + } + Expr::Ref { expr, rawness, mutability } => { + let expectation = if let Some((exp_inner, exp_rawness, exp_mutability)) = + &expected.ty.as_reference_or_ptr() + { + if *exp_mutability == Mutability::Mut && *mutability == Mutability::Shared { + // FIXME: throw type error - expected mut reference but found shared ref, + // which cannot be coerced + } + if *exp_rawness == Rawness::Ref && *rawness == Rawness::RawPtr { + // FIXME: throw type error - expected reference but found ptr, + // which cannot be coerced + } + Expectation::rvalue_hint(Ty::clone(exp_inner)) + } else { + Expectation::none() + }; + let inner_ty = self.infer_expr_inner(*expr, &expectation); + let ty = match rawness { + Rawness::RawPtr => TypeCtor::RawPtr(*mutability), + Rawness::Ref => TypeCtor::Ref(*mutability), + }; + Ty::apply_one(ty, inner_ty) + } + Expr::Box { expr } => { + let inner_ty = self.infer_expr_inner(*expr, &Expectation::none()); + if let Some(box_) = self.resolve_boxed_box() { + Ty::apply_one(TypeCtor::Adt(box_), inner_ty) + } else { + Ty::Unknown + } + } + Expr::UnaryOp { expr, op } => { + let inner_ty = self.infer_expr_inner(*expr, &Expectation::none()); + match op { + UnaryOp::Deref => match self.resolver.krate() { + Some(krate) => { + let canonicalized = self.canonicalizer().canonicalize_ty(inner_ty); + match autoderef::deref( + self.db, + krate, + InEnvironment { + value: &canonicalized.value, + environment: self.trait_env.clone(), + }, + ) { + Some(derefed_ty) => { + canonicalized.decanonicalize_ty(derefed_ty.value) + } + None => Ty::Unknown, + } + } + None => Ty::Unknown, + }, + UnaryOp::Neg => { + match &inner_ty { + // Fast path for builtins + Ty::Apply(ApplicationTy { + ctor: TypeCtor::Int(IntTy { signedness: Signedness::Signed, .. }), + .. + }) + | Ty::Apply(ApplicationTy { ctor: TypeCtor::Float(_), .. }) + | Ty::Infer(InferTy::IntVar(..)) + | Ty::Infer(InferTy::FloatVar(..)) => inner_ty, + // Otherwise we resolve via the std::ops::Neg trait + _ => self + .resolve_associated_type(inner_ty, self.resolve_ops_neg_output()), + } + } + UnaryOp::Not => { + match &inner_ty { + // Fast path for builtins + Ty::Apply(ApplicationTy { ctor: TypeCtor::Bool, .. }) + | Ty::Apply(ApplicationTy { ctor: TypeCtor::Int(_), .. }) + | Ty::Infer(InferTy::IntVar(..)) => inner_ty, + // Otherwise we resolve via the std::ops::Not trait + _ => self + .resolve_associated_type(inner_ty, self.resolve_ops_not_output()), + } + } + } + } + Expr::BinaryOp { lhs, rhs, op } => match op { + Some(op) => { + let lhs_expectation = match op { + BinaryOp::LogicOp(..) => Expectation::has_type(Ty::simple(TypeCtor::Bool)), + _ => Expectation::none(), + }; + let lhs_ty = self.infer_expr(*lhs, &lhs_expectation); + // FIXME: find implementation of trait corresponding to operation + // symbol and resolve associated `Output` type + let rhs_expectation = op::binary_op_rhs_expectation(*op, lhs_ty.clone()); + let rhs_ty = self.infer_expr(*rhs, &Expectation::has_type(rhs_expectation)); + + // FIXME: similar as above, return ty is often associated trait type + op::binary_op_return_ty(*op, lhs_ty, rhs_ty) + } + _ => Ty::Unknown, + }, + Expr::Range { lhs, rhs, range_type } => { + let lhs_ty = lhs.map(|e| self.infer_expr_inner(e, &Expectation::none())); + let rhs_expect = lhs_ty + .as_ref() + .map_or_else(Expectation::none, |ty| Expectation::has_type(ty.clone())); + let rhs_ty = rhs.map(|e| self.infer_expr(e, &rhs_expect)); + match (range_type, lhs_ty, rhs_ty) { + (RangeOp::Exclusive, None, None) => match self.resolve_range_full() { + Some(adt) => Ty::simple(TypeCtor::Adt(adt)), + None => Ty::Unknown, + }, + (RangeOp::Exclusive, None, Some(ty)) => match self.resolve_range_to() { + Some(adt) => Ty::apply_one(TypeCtor::Adt(adt), ty), + None => Ty::Unknown, + }, + (RangeOp::Inclusive, None, Some(ty)) => { + match self.resolve_range_to_inclusive() { + Some(adt) => Ty::apply_one(TypeCtor::Adt(adt), ty), + None => Ty::Unknown, + } + } + (RangeOp::Exclusive, Some(_), Some(ty)) => match self.resolve_range() { + Some(adt) => Ty::apply_one(TypeCtor::Adt(adt), ty), + None => Ty::Unknown, + }, + (RangeOp::Inclusive, Some(_), Some(ty)) => { + match self.resolve_range_inclusive() { + Some(adt) => Ty::apply_one(TypeCtor::Adt(adt), ty), + None => Ty::Unknown, + } + } + (RangeOp::Exclusive, Some(ty), None) => match self.resolve_range_from() { + Some(adt) => Ty::apply_one(TypeCtor::Adt(adt), ty), + None => Ty::Unknown, + }, + (RangeOp::Inclusive, _, None) => Ty::Unknown, + } + } + Expr::Index { base, index } => { + let base_ty = self.infer_expr_inner(*base, &Expectation::none()); + let index_ty = self.infer_expr(*index, &Expectation::none()); + + if let (Some(index_trait), Some(krate)) = + (self.resolve_ops_index(), self.resolver.krate()) + { + let canonicalized = self.canonicalizer().canonicalize_ty(base_ty); + let self_ty = method_resolution::resolve_indexing_op( + self.db, + &canonicalized.value, + self.trait_env.clone(), + krate, + index_trait, + ); + let self_ty = + self_ty.map_or(Ty::Unknown, |t| canonicalized.decanonicalize_ty(t.value)); + self.resolve_associated_type_with_params( + self_ty, + self.resolve_ops_index_output(), + &[index_ty], + ) + } else { + Ty::Unknown + } + } + Expr::Tuple { exprs } => { + let mut tys = match &expected.ty { + ty_app!(TypeCtor::Tuple { .. }, st) => st + .iter() + .cloned() + .chain(repeat_with(|| self.table.new_type_var())) + .take(exprs.len()) + .collect::>(), + _ => (0..exprs.len()).map(|_| self.table.new_type_var()).collect(), + }; + + for (expr, ty) in exprs.iter().zip(tys.iter_mut()) { + self.infer_expr_coerce(*expr, &Expectation::has_type(ty.clone())); + } + + Ty::apply(TypeCtor::Tuple { cardinality: tys.len() as u16 }, Substs(tys.into())) + } + Expr::Array(array) => { + let elem_ty = match &expected.ty { + ty_app!(TypeCtor::Array, st) | ty_app!(TypeCtor::Slice, st) => { + st.as_single().clone() + } + _ => self.table.new_type_var(), + }; + + match array { + Array::ElementList(items) => { + for expr in items.iter() { + self.infer_expr_coerce(*expr, &Expectation::has_type(elem_ty.clone())); + } + } + Array::Repeat { initializer, repeat } => { + self.infer_expr_coerce( + *initializer, + &Expectation::has_type(elem_ty.clone()), + ); + self.infer_expr( + *repeat, + &Expectation::has_type(Ty::simple(TypeCtor::Int(IntTy::usize()))), + ); + } + } + + Ty::apply_one(TypeCtor::Array, elem_ty) + } + Expr::Literal(lit) => match lit { + Literal::Bool(..) => Ty::simple(TypeCtor::Bool), + Literal::String(..) => { + Ty::apply_one(TypeCtor::Ref(Mutability::Shared), Ty::simple(TypeCtor::Str)) + } + Literal::ByteString(..) => { + let byte_type = Ty::simple(TypeCtor::Int(IntTy::u8())); + let array_type = Ty::apply_one(TypeCtor::Array, byte_type); + Ty::apply_one(TypeCtor::Ref(Mutability::Shared), array_type) + } + Literal::Char(..) => Ty::simple(TypeCtor::Char), + Literal::Int(_v, ty) => match ty { + Some(int_ty) => Ty::simple(TypeCtor::Int((*int_ty).into())), + None => self.table.new_integer_var(), + }, + Literal::Float(_v, ty) => match ty { + Some(float_ty) => Ty::simple(TypeCtor::Float((*float_ty).into())), + None => self.table.new_float_var(), + }, + }, + }; + // use a new type variable if we got Ty::Unknown here + let ty = self.insert_type_vars_shallow(ty); + let ty = self.resolve_ty_as_possible(ty); + self.write_expr_ty(tgt_expr, ty.clone()); + ty + } + + fn infer_block( + &mut self, + statements: &[Statement], + tail: Option, + expected: &Expectation, + ) -> Ty { + for stmt in statements { + match stmt { + Statement::Let { pat, type_ref, initializer } => { + let decl_ty = + type_ref.as_ref().map(|tr| self.make_ty(tr)).unwrap_or(Ty::Unknown); + + // Always use the declared type when specified + let mut ty = decl_ty.clone(); + + if let Some(expr) = initializer { + let actual_ty = + self.infer_expr_coerce(*expr, &Expectation::has_type(decl_ty.clone())); + if decl_ty == Ty::Unknown { + ty = actual_ty; + } + } + + let ty = self.resolve_ty_as_possible(ty); + self.infer_pat(*pat, &ty, BindingMode::default()); + } + Statement::Expr(expr) => { + self.infer_expr(*expr, &Expectation::none()); + } + } + } + + let ty = if let Some(expr) = tail { + self.infer_expr_coerce(expr, expected) + } else { + // Citing rustc: if there is no explicit tail expression, + // that is typically equivalent to a tail expression + // of `()` -- except if the block diverges. In that + // case, there is no value supplied from the tail + // expression (assuming there are no other breaks, + // this implies that the type of the block will be + // `!`). + if self.diverges.is_always() { + // we don't even make an attempt at coercion + self.table.new_maybe_never_type_var() + } else { + self.coerce(&Ty::unit(), expected.coercion_target()); + Ty::unit() + } + }; + ty + } + + fn infer_method_call( + &mut self, + tgt_expr: ExprId, + receiver: ExprId, + args: &[ExprId], + method_name: &Name, + generic_args: Option<&GenericArgs>, + ) -> Ty { + let receiver_ty = self.infer_expr(receiver, &Expectation::none()); + let canonicalized_receiver = self.canonicalizer().canonicalize_ty(receiver_ty.clone()); + + let traits_in_scope = self.resolver.traits_in_scope(self.db.upcast()); + + let resolved = self.resolver.krate().and_then(|krate| { + method_resolution::lookup_method( + &canonicalized_receiver.value, + self.db, + self.trait_env.clone(), + krate, + &traits_in_scope, + method_name, + ) + }); + let (derefed_receiver_ty, method_ty, def_generics) = match resolved { + Some((ty, func)) => { + let ty = canonicalized_receiver.decanonicalize_ty(ty); + self.write_method_resolution(tgt_expr, func); + (ty, self.db.value_ty(func.into()), Some(generics(self.db.upcast(), func.into()))) + } + None => (receiver_ty, Binders::new(0, Ty::Unknown), None), + }; + let substs = self.substs_for_method_call(def_generics, generic_args, &derefed_receiver_ty); + let method_ty = method_ty.subst(&substs); + let method_ty = self.insert_type_vars(method_ty); + self.register_obligations_for_call(&method_ty); + let (expected_receiver_ty, param_tys, ret_ty) = match method_ty.callable_sig(self.db) { + Some(sig) => { + if !sig.params().is_empty() { + (sig.params()[0].clone(), sig.params()[1..].to_vec(), sig.ret().clone()) + } else { + (Ty::Unknown, Vec::new(), sig.ret().clone()) + } + } + None => (Ty::Unknown, Vec::new(), Ty::Unknown), + }; + // Apply autoref so the below unification works correctly + // FIXME: return correct autorefs from lookup_method + let actual_receiver_ty = match expected_receiver_ty.as_reference() { + Some((_, mutability)) => Ty::apply_one(TypeCtor::Ref(mutability), derefed_receiver_ty), + _ => derefed_receiver_ty, + }; + self.unify(&expected_receiver_ty, &actual_receiver_ty); + + self.check_call_arguments(args, ¶m_tys); + self.normalize_associated_types_in(ret_ty) + } + + fn check_call_arguments(&mut self, args: &[ExprId], param_tys: &[Ty]) { + // Quoting https://github.com/rust-lang/rust/blob/6ef275e6c3cb1384ec78128eceeb4963ff788dca/src/librustc_typeck/check/mod.rs#L3325 -- + // We do this in a pretty awful way: first we type-check any arguments + // that are not closures, then we type-check the closures. This is so + // that we have more information about the types of arguments when we + // type-check the functions. This isn't really the right way to do this. + for &check_closures in &[false, true] { + let param_iter = param_tys.iter().cloned().chain(repeat(Ty::Unknown)); + for (&arg, param_ty) in args.iter().zip(param_iter) { + let is_closure = matches!(&self.body[arg], Expr::Lambda { .. }); + if is_closure != check_closures { + continue; + } + + let param_ty = self.normalize_associated_types_in(param_ty); + self.infer_expr_coerce(arg, &Expectation::has_type(param_ty.clone())); + } + } + } + + fn substs_for_method_call( + &mut self, + def_generics: Option, + generic_args: Option<&GenericArgs>, + receiver_ty: &Ty, + ) -> Substs { + let (parent_params, self_params, type_params, impl_trait_params) = + def_generics.as_ref().map_or((0, 0, 0, 0), |g| g.provenance_split()); + assert_eq!(self_params, 0); // method shouldn't have another Self param + let total_len = parent_params + type_params + impl_trait_params; + let mut substs = Vec::with_capacity(total_len); + // Parent arguments are unknown, except for the receiver type + if let Some(parent_generics) = def_generics.as_ref().map(|p| p.iter_parent()) { + for (_id, param) in parent_generics { + if param.provenance == hir_def::generics::TypeParamProvenance::TraitSelf { + substs.push(receiver_ty.clone()); + } else { + substs.push(Ty::Unknown); + } + } + } + // handle provided type arguments + if let Some(generic_args) = generic_args { + // if args are provided, it should be all of them, but we can't rely on that + for arg in generic_args.args.iter().take(type_params) { + match arg { + GenericArg::Type(type_ref) => { + let ty = self.make_ty(type_ref); + substs.push(ty); + } + } + } + }; + let supplied_params = substs.len(); + for _ in supplied_params..total_len { + substs.push(Ty::Unknown); + } + assert_eq!(substs.len(), total_len); + Substs(substs.into()) + } + + fn register_obligations_for_call(&mut self, callable_ty: &Ty) { + if let Ty::Apply(a_ty) = callable_ty { + if let TypeCtor::FnDef(def) = a_ty.ctor { + let generic_predicates = self.db.generic_predicates(def.into()); + for predicate in generic_predicates.iter() { + let predicate = predicate.clone().subst(&a_ty.parameters); + if let Some(obligation) = Obligation::from_predicate(predicate) { + self.obligations.push(obligation); + } + } + // add obligation for trait implementation, if this is a trait method + match def { + CallableDefId::FunctionId(f) => { + if let AssocContainerId::TraitId(trait_) = + f.lookup(self.db.upcast()).container + { + // construct a TraitDef + let substs = a_ty + .parameters + .prefix(generics(self.db.upcast(), trait_.into()).len()); + self.obligations.push(Obligation::Trait(TraitRef { trait_, substs })); + } + } + CallableDefId::StructId(_) | CallableDefId::EnumVariantId(_) => {} + } + } + } + } +} diff --git a/crates/hir_ty/src/infer/pat.rs b/crates/hir_ty/src/infer/pat.rs new file mode 100644 index 000000000..4dd4f9802 --- /dev/null +++ b/crates/hir_ty/src/infer/pat.rs @@ -0,0 +1,241 @@ +//! Type inference for patterns. + +use std::iter::repeat; +use std::sync::Arc; + +use hir_def::{ + expr::{BindingAnnotation, Expr, Literal, Pat, PatId, RecordFieldPat}, + path::Path, + type_ref::Mutability, + FieldId, +}; +use hir_expand::name::Name; +use test_utils::mark; + +use super::{BindingMode, Expectation, InferenceContext}; +use crate::{utils::variant_data, Substs, Ty, TypeCtor}; + +impl<'a> InferenceContext<'a> { + fn infer_tuple_struct_pat( + &mut self, + path: Option<&Path>, + subpats: &[PatId], + expected: &Ty, + default_bm: BindingMode, + id: PatId, + ) -> Ty { + let (ty, def) = self.resolve_variant(path); + let var_data = def.map(|it| variant_data(self.db.upcast(), it)); + if let Some(variant) = def { + self.write_variant_resolution(id.into(), variant); + } + self.unify(&ty, expected); + + let substs = ty.substs().unwrap_or_else(Substs::empty); + + let field_tys = def.map(|it| self.db.field_types(it)).unwrap_or_default(); + + for (i, &subpat) in subpats.iter().enumerate() { + let expected_ty = var_data + .as_ref() + .and_then(|d| d.field(&Name::new_tuple_field(i))) + .map_or(Ty::Unknown, |field| field_tys[field].clone().subst(&substs)); + let expected_ty = self.normalize_associated_types_in(expected_ty); + self.infer_pat(subpat, &expected_ty, default_bm); + } + + ty + } + + fn infer_record_pat( + &mut self, + path: Option<&Path>, + subpats: &[RecordFieldPat], + expected: &Ty, + default_bm: BindingMode, + id: PatId, + ) -> Ty { + let (ty, def) = self.resolve_variant(path); + let var_data = def.map(|it| variant_data(self.db.upcast(), it)); + if let Some(variant) = def { + self.write_variant_resolution(id.into(), variant); + } + + self.unify(&ty, expected); + + let substs = ty.substs().unwrap_or_else(Substs::empty); + + let field_tys = def.map(|it| self.db.field_types(it)).unwrap_or_default(); + for subpat in subpats { + let matching_field = var_data.as_ref().and_then(|it| it.field(&subpat.name)); + if let Some(local_id) = matching_field { + let field_def = FieldId { parent: def.unwrap(), local_id }; + self.result.record_field_pat_resolutions.insert(subpat.pat, field_def); + } + + let expected_ty = + matching_field.map_or(Ty::Unknown, |field| field_tys[field].clone().subst(&substs)); + let expected_ty = self.normalize_associated_types_in(expected_ty); + self.infer_pat(subpat.pat, &expected_ty, default_bm); + } + + ty + } + + pub(super) fn infer_pat( + &mut self, + pat: PatId, + mut expected: &Ty, + mut default_bm: BindingMode, + ) -> Ty { + let body = Arc::clone(&self.body); // avoid borrow checker problem + + if is_non_ref_pat(&body, pat) { + while let Some((inner, mutability)) = expected.as_reference() { + expected = inner; + default_bm = match default_bm { + BindingMode::Move => BindingMode::Ref(mutability), + BindingMode::Ref(Mutability::Shared) => BindingMode::Ref(Mutability::Shared), + BindingMode::Ref(Mutability::Mut) => BindingMode::Ref(mutability), + } + } + } else if let Pat::Ref { .. } = &body[pat] { + mark::hit!(match_ergonomics_ref); + // When you encounter a `&pat` pattern, reset to Move. + // This is so that `w` is by value: `let (_, &w) = &(1, &2);` + default_bm = BindingMode::Move; + } + + // Lose mutability. + let default_bm = default_bm; + let expected = expected; + + let ty = match &body[pat] { + Pat::Tuple { ref args, .. } => { + let expectations = match expected.as_tuple() { + Some(parameters) => &*parameters.0, + _ => &[], + }; + let expectations_iter = expectations.iter().chain(repeat(&Ty::Unknown)); + + let inner_tys = args + .iter() + .zip(expectations_iter) + .map(|(&pat, ty)| self.infer_pat(pat, ty, default_bm)) + .collect(); + + Ty::apply(TypeCtor::Tuple { cardinality: args.len() as u16 }, Substs(inner_tys)) + } + Pat::Or(ref pats) => { + if let Some((first_pat, rest)) = pats.split_first() { + let ty = self.infer_pat(*first_pat, expected, default_bm); + for pat in rest { + self.infer_pat(*pat, expected, default_bm); + } + ty + } else { + Ty::Unknown + } + } + Pat::Ref { pat, mutability } => { + let expectation = match expected.as_reference() { + Some((inner_ty, exp_mut)) => { + if *mutability != exp_mut { + // FIXME: emit type error? + } + inner_ty + } + _ => &Ty::Unknown, + }; + let subty = self.infer_pat(*pat, expectation, default_bm); + Ty::apply_one(TypeCtor::Ref(*mutability), subty) + } + Pat::TupleStruct { path: p, args: subpats, .. } => { + self.infer_tuple_struct_pat(p.as_ref(), subpats, expected, default_bm, pat) + } + Pat::Record { path: p, args: fields, ellipsis: _ } => { + self.infer_record_pat(p.as_ref(), fields, expected, default_bm, pat) + } + Pat::Path(path) => { + // FIXME use correct resolver for the surrounding expression + let resolver = self.resolver.clone(); + self.infer_path(&resolver, &path, pat.into()).unwrap_or(Ty::Unknown) + } + Pat::Bind { mode, name: _, subpat } => { + let mode = if mode == &BindingAnnotation::Unannotated { + default_bm + } else { + BindingMode::convert(*mode) + }; + let inner_ty = if let Some(subpat) = subpat { + self.infer_pat(*subpat, expected, default_bm) + } else { + expected.clone() + }; + let inner_ty = self.insert_type_vars_shallow(inner_ty); + + let bound_ty = match mode { + BindingMode::Ref(mutability) => { + Ty::apply_one(TypeCtor::Ref(mutability), inner_ty.clone()) + } + BindingMode::Move => inner_ty.clone(), + }; + let bound_ty = self.resolve_ty_as_possible(bound_ty); + self.write_pat_ty(pat, bound_ty); + return inner_ty; + } + Pat::Slice { prefix, slice, suffix } => { + let (container_ty, elem_ty) = match &expected { + ty_app!(TypeCtor::Array, st) => (TypeCtor::Array, st.as_single().clone()), + ty_app!(TypeCtor::Slice, st) => (TypeCtor::Slice, st.as_single().clone()), + _ => (TypeCtor::Slice, Ty::Unknown), + }; + + for pat_id in prefix.iter().chain(suffix) { + self.infer_pat(*pat_id, &elem_ty, default_bm); + } + + let pat_ty = Ty::apply_one(container_ty, elem_ty); + if let Some(slice_pat_id) = slice { + self.infer_pat(*slice_pat_id, &pat_ty, default_bm); + } + + pat_ty + } + Pat::Wild => expected.clone(), + Pat::Range { start, end } => { + let start_ty = self.infer_expr(*start, &Expectation::has_type(expected.clone())); + let end_ty = self.infer_expr(*end, &Expectation::has_type(start_ty)); + end_ty + } + Pat::Lit(expr) => self.infer_expr(*expr, &Expectation::has_type(expected.clone())), + Pat::Missing => Ty::Unknown, + }; + // use a new type variable if we got Ty::Unknown here + let ty = self.insert_type_vars_shallow(ty); + if !self.unify(&ty, expected) { + // FIXME record mismatch, we need to change the type of self.type_mismatches for that + } + let ty = self.resolve_ty_as_possible(ty); + self.write_pat_ty(pat, ty.clone()); + ty + } +} + +fn is_non_ref_pat(body: &hir_def::body::Body, pat: PatId) -> bool { + match &body[pat] { + Pat::Tuple { .. } + | Pat::TupleStruct { .. } + | Pat::Record { .. } + | Pat::Range { .. } + | Pat::Slice { .. } => true, + Pat::Or(pats) => pats.iter().all(|p| is_non_ref_pat(body, *p)), + // FIXME: Path/Lit might actually evaluate to ref, but inference is unimplemented. + Pat::Path(..) => true, + Pat::Lit(expr) => match body[*expr] { + Expr::Literal(Literal::String(..)) => false, + _ => true, + }, + Pat::Wild | Pat::Bind { .. } | Pat::Ref { .. } | Pat::Missing => false, + } +} diff --git a/crates/hir_ty/src/infer/path.rs b/crates/hir_ty/src/infer/path.rs new file mode 100644 index 000000000..80d7ed10e --- /dev/null +++ b/crates/hir_ty/src/infer/path.rs @@ -0,0 +1,287 @@ +//! Path expression resolution. + +use std::iter; + +use hir_def::{ + path::{Path, PathSegment}, + resolver::{ResolveValueResult, Resolver, TypeNs, ValueNs}, + AdtId, AssocContainerId, AssocItemId, EnumVariantId, Lookup, +}; +use hir_expand::name::Name; + +use crate::{method_resolution, Substs, Ty, ValueTyDefId}; + +use super::{ExprOrPatId, InferenceContext, TraitRef}; + +impl<'a> InferenceContext<'a> { + pub(super) fn infer_path( + &mut self, + resolver: &Resolver, + path: &Path, + id: ExprOrPatId, + ) -> Option { + let ty = self.resolve_value_path(resolver, path, id)?; + let ty = self.insert_type_vars(ty); + let ty = self.normalize_associated_types_in(ty); + Some(ty) + } + + fn resolve_value_path( + &mut self, + resolver: &Resolver, + path: &Path, + id: ExprOrPatId, + ) -> Option { + let (value, self_subst) = if let Some(type_ref) = path.type_anchor() { + if path.segments().is_empty() { + // This can't actually happen syntax-wise + return None; + } + let ty = self.make_ty(type_ref); + let remaining_segments_for_ty = path.segments().take(path.segments().len() - 1); + let ctx = crate::lower::TyLoweringContext::new(self.db, &resolver); + let (ty, _) = Ty::from_type_relative_path(&ctx, ty, None, remaining_segments_for_ty); + self.resolve_ty_assoc_item( + ty, + &path.segments().last().expect("path had at least one segment").name, + id, + )? + } else { + let value_or_partial = + resolver.resolve_path_in_value_ns(self.db.upcast(), path.mod_path())?; + + match value_or_partial { + ResolveValueResult::ValueNs(it) => (it, None), + ResolveValueResult::Partial(def, remaining_index) => { + self.resolve_assoc_item(def, path, remaining_index, id)? + } + } + }; + + let typable: ValueTyDefId = match value { + ValueNs::LocalBinding(pat) => { + let ty = self.result.type_of_pat.get(pat)?.clone(); + let ty = self.resolve_ty_as_possible(ty); + return Some(ty); + } + ValueNs::FunctionId(it) => it.into(), + ValueNs::ConstId(it) => it.into(), + ValueNs::StaticId(it) => it.into(), + ValueNs::StructId(it) => { + self.write_variant_resolution(id, it.into()); + + it.into() + } + ValueNs::EnumVariantId(it) => { + self.write_variant_resolution(id, it.into()); + + it.into() + } + ValueNs::ImplSelf(impl_id) => { + let generics = crate::utils::generics(self.db.upcast(), impl_id.into()); + let substs = Substs::type_params_for_generics(&generics); + let ty = self.db.impl_self_ty(impl_id).subst(&substs); + if let Some((AdtId::StructId(struct_id), substs)) = ty.as_adt() { + let ty = self.db.value_ty(struct_id.into()).subst(&substs); + return Some(ty); + } else { + // FIXME: diagnostic, invalid Self reference + return None; + } + } + }; + + let ty = self.db.value_ty(typable); + // self_subst is just for the parent + let parent_substs = self_subst.unwrap_or_else(Substs::empty); + let ctx = crate::lower::TyLoweringContext::new(self.db, &self.resolver); + let substs = Ty::substs_from_path(&ctx, path, typable, true); + let full_substs = Substs::builder(substs.len()) + .use_parent_substs(&parent_substs) + .fill(substs.0[parent_substs.len()..].iter().cloned()) + .build(); + let ty = ty.subst(&full_substs); + Some(ty) + } + + fn resolve_assoc_item( + &mut self, + def: TypeNs, + path: &Path, + remaining_index: usize, + id: ExprOrPatId, + ) -> Option<(ValueNs, Option)> { + assert!(remaining_index < path.segments().len()); + // there may be more intermediate segments between the resolved one and + // the end. Only the last segment needs to be resolved to a value; from + // the segments before that, we need to get either a type or a trait ref. + + let resolved_segment = path.segments().get(remaining_index - 1).unwrap(); + let remaining_segments = path.segments().skip(remaining_index); + let is_before_last = remaining_segments.len() == 1; + + match (def, is_before_last) { + (TypeNs::TraitId(trait_), true) => { + let segment = + remaining_segments.last().expect("there should be at least one segment here"); + let ctx = crate::lower::TyLoweringContext::new(self.db, &self.resolver); + let trait_ref = TraitRef::from_resolved_path(&ctx, trait_, resolved_segment, None); + self.resolve_trait_assoc_item(trait_ref, segment, id) + } + (def, _) => { + // Either we already have a type (e.g. `Vec::new`), or we have a + // trait but it's not the last segment, so the next segment + // should resolve to an associated type of that trait (e.g. `::Item::default`) + let remaining_segments_for_ty = + remaining_segments.take(remaining_segments.len() - 1); + let ctx = crate::lower::TyLoweringContext::new(self.db, &self.resolver); + let (ty, _) = Ty::from_partly_resolved_hir_path( + &ctx, + def, + resolved_segment, + remaining_segments_for_ty, + true, + ); + if let Ty::Unknown = ty { + return None; + } + + let ty = self.insert_type_vars(ty); + let ty = self.normalize_associated_types_in(ty); + + let segment = + remaining_segments.last().expect("there should be at least one segment here"); + + self.resolve_ty_assoc_item(ty, &segment.name, id) + } + } + } + + fn resolve_trait_assoc_item( + &mut self, + trait_ref: TraitRef, + segment: PathSegment<'_>, + id: ExprOrPatId, + ) -> Option<(ValueNs, Option)> { + let trait_ = trait_ref.trait_; + let item = + self.db.trait_data(trait_).items.iter().map(|(_name, id)| (*id)).find_map(|item| { + match item { + AssocItemId::FunctionId(func) => { + if segment.name == &self.db.function_data(func).name { + Some(AssocItemId::FunctionId(func)) + } else { + None + } + } + + AssocItemId::ConstId(konst) => { + if self + .db + .const_data(konst) + .name + .as_ref() + .map_or(false, |n| n == segment.name) + { + Some(AssocItemId::ConstId(konst)) + } else { + None + } + } + AssocItemId::TypeAliasId(_) => None, + } + })?; + let def = match item { + AssocItemId::FunctionId(f) => ValueNs::FunctionId(f), + AssocItemId::ConstId(c) => ValueNs::ConstId(c), + AssocItemId::TypeAliasId(_) => unreachable!(), + }; + + self.write_assoc_resolution(id, item); + Some((def, Some(trait_ref.substs))) + } + + fn resolve_ty_assoc_item( + &mut self, + ty: Ty, + name: &Name, + id: ExprOrPatId, + ) -> Option<(ValueNs, Option)> { + if let Ty::Unknown = ty { + return None; + } + + if let Some(result) = self.resolve_enum_variant_on_ty(&ty, name, id) { + return Some(result); + } + + let canonical_ty = self.canonicalizer().canonicalize_ty(ty.clone()); + let krate = self.resolver.krate()?; + let traits_in_scope = self.resolver.traits_in_scope(self.db.upcast()); + + method_resolution::iterate_method_candidates( + &canonical_ty.value, + self.db, + self.trait_env.clone(), + krate, + &traits_in_scope, + Some(name), + method_resolution::LookupMode::Path, + move |_ty, item| { + let (def, container) = match item { + AssocItemId::FunctionId(f) => { + (ValueNs::FunctionId(f), f.lookup(self.db.upcast()).container) + } + AssocItemId::ConstId(c) => { + (ValueNs::ConstId(c), c.lookup(self.db.upcast()).container) + } + AssocItemId::TypeAliasId(_) => unreachable!(), + }; + let substs = match container { + AssocContainerId::ImplId(impl_id) => { + let impl_substs = Substs::build_for_def(self.db, impl_id) + .fill(iter::repeat_with(|| self.table.new_type_var())) + .build(); + let impl_self_ty = self.db.impl_self_ty(impl_id).subst(&impl_substs); + self.unify(&impl_self_ty, &ty); + Some(impl_substs) + } + AssocContainerId::TraitId(trait_) => { + // we're picking this method + let trait_substs = Substs::build_for_def(self.db, trait_) + .push(ty.clone()) + .fill(std::iter::repeat_with(|| self.table.new_type_var())) + .build(); + self.obligations.push(super::Obligation::Trait(TraitRef { + trait_, + substs: trait_substs.clone(), + })); + Some(trait_substs) + } + AssocContainerId::ContainerId(_) => None, + }; + + self.write_assoc_resolution(id, item); + Some((def, substs)) + }, + ) + } + + fn resolve_enum_variant_on_ty( + &mut self, + ty: &Ty, + name: &Name, + id: ExprOrPatId, + ) -> Option<(ValueNs, Option)> { + let (enum_id, subst) = match ty.as_adt() { + Some((AdtId::EnumId(e), subst)) => (e, subst), + _ => return None, + }; + let enum_data = self.db.enum_data(enum_id); + let local_id = enum_data.variant(name)?; + let variant = EnumVariantId { parent: enum_id, local_id }; + self.write_variant_resolution(id, variant.into()); + Some((ValueNs::EnumVariantId(variant), Some(subst.clone()))) + } +} diff --git a/crates/hir_ty/src/infer/unify.rs b/crates/hir_ty/src/infer/unify.rs new file mode 100644 index 000000000..2e895d911 --- /dev/null +++ b/crates/hir_ty/src/infer/unify.rs @@ -0,0 +1,474 @@ +//! Unification and canonicalization logic. + +use std::borrow::Cow; + +use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue}; + +use test_utils::mark; + +use super::{InferenceContext, Obligation}; +use crate::{ + BoundVar, Canonical, DebruijnIndex, GenericPredicate, InEnvironment, InferTy, Substs, Ty, + TyKind, TypeCtor, TypeWalk, +}; + +impl<'a> InferenceContext<'a> { + pub(super) fn canonicalizer<'b>(&'b mut self) -> Canonicalizer<'a, 'b> + where + 'a: 'b, + { + Canonicalizer { ctx: self, free_vars: Vec::new(), var_stack: Vec::new() } + } +} + +pub(super) struct Canonicalizer<'a, 'b> +where + 'a: 'b, +{ + ctx: &'b mut InferenceContext<'a>, + free_vars: Vec, + /// A stack of type variables that is used to detect recursive types (which + /// are an error, but we need to protect against them to avoid stack + /// overflows). + var_stack: Vec, +} + +#[derive(Debug)] +pub(super) struct Canonicalized { + pub value: Canonical, + free_vars: Vec, +} + +impl<'a, 'b> Canonicalizer<'a, 'b> +where + 'a: 'b, +{ + fn add(&mut self, free_var: InferTy) -> usize { + self.free_vars.iter().position(|&v| v == free_var).unwrap_or_else(|| { + let next_index = self.free_vars.len(); + self.free_vars.push(free_var); + next_index + }) + } + + fn do_canonicalize(&mut self, t: T, binders: DebruijnIndex) -> T { + t.fold_binders( + &mut |ty, binders| match ty { + Ty::Infer(tv) => { + let inner = tv.to_inner(); + if self.var_stack.contains(&inner) { + // recursive type + return tv.fallback_value(); + } + if let Some(known_ty) = + self.ctx.table.var_unification_table.inlined_probe_value(inner).known() + { + self.var_stack.push(inner); + let result = self.do_canonicalize(known_ty.clone(), binders); + self.var_stack.pop(); + result + } else { + let root = self.ctx.table.var_unification_table.find(inner); + let free_var = match tv { + InferTy::TypeVar(_) => InferTy::TypeVar(root), + InferTy::IntVar(_) => InferTy::IntVar(root), + InferTy::FloatVar(_) => InferTy::FloatVar(root), + InferTy::MaybeNeverTypeVar(_) => InferTy::MaybeNeverTypeVar(root), + }; + let position = self.add(free_var); + Ty::Bound(BoundVar::new(binders, position)) + } + } + _ => ty, + }, + binders, + ) + } + + fn into_canonicalized(self, result: T) -> Canonicalized { + let kinds = self + .free_vars + .iter() + .map(|v| match v { + // mapping MaybeNeverTypeVar to the same kind as general ones + // should be fine, because as opposed to int or float type vars, + // they don't restrict what kind of type can go into them, they + // just affect fallback. + InferTy::TypeVar(_) | InferTy::MaybeNeverTypeVar(_) => TyKind::General, + InferTy::IntVar(_) => TyKind::Integer, + InferTy::FloatVar(_) => TyKind::Float, + }) + .collect(); + Canonicalized { value: Canonical { value: result, kinds }, free_vars: self.free_vars } + } + + pub(crate) fn canonicalize_ty(mut self, ty: Ty) -> Canonicalized { + let result = self.do_canonicalize(ty, DebruijnIndex::INNERMOST); + self.into_canonicalized(result) + } + + pub(crate) fn canonicalize_obligation( + mut self, + obligation: InEnvironment, + ) -> Canonicalized> { + let result = match obligation.value { + Obligation::Trait(tr) => { + Obligation::Trait(self.do_canonicalize(tr, DebruijnIndex::INNERMOST)) + } + Obligation::Projection(pr) => { + Obligation::Projection(self.do_canonicalize(pr, DebruijnIndex::INNERMOST)) + } + }; + self.into_canonicalized(InEnvironment { + value: result, + environment: obligation.environment, + }) + } +} + +impl Canonicalized { + pub fn decanonicalize_ty(&self, mut ty: Ty) -> Ty { + ty.walk_mut_binders( + &mut |ty, binders| { + if let &mut Ty::Bound(bound) = ty { + if bound.debruijn >= binders { + *ty = Ty::Infer(self.free_vars[bound.index]); + } + } + }, + DebruijnIndex::INNERMOST, + ); + ty + } + + pub fn apply_solution(&self, ctx: &mut InferenceContext<'_>, solution: Canonical) { + // the solution may contain new variables, which we need to convert to new inference vars + let new_vars = Substs( + solution + .kinds + .iter() + .map(|k| match k { + TyKind::General => ctx.table.new_type_var(), + TyKind::Integer => ctx.table.new_integer_var(), + TyKind::Float => ctx.table.new_float_var(), + }) + .collect(), + ); + for (i, ty) in solution.value.into_iter().enumerate() { + let var = self.free_vars[i]; + // eagerly replace projections in the type; we may be getting types + // e.g. from where clauses where this hasn't happened yet + let ty = ctx.normalize_associated_types_in(ty.clone().subst_bound_vars(&new_vars)); + ctx.table.unify(&Ty::Infer(var), &ty); + } + } +} + +pub fn unify(tys: &Canonical<(Ty, Ty)>) -> Option { + let mut table = InferenceTable::new(); + let vars = Substs( + tys.kinds + .iter() + // we always use type vars here because we want everything to + // fallback to Unknown in the end (kind of hacky, as below) + .map(|_| table.new_type_var()) + .collect(), + ); + let ty1_with_vars = tys.value.0.clone().subst_bound_vars(&vars); + let ty2_with_vars = tys.value.1.clone().subst_bound_vars(&vars); + if !table.unify(&ty1_with_vars, &ty2_with_vars) { + return None; + } + // default any type vars that weren't unified back to their original bound vars + // (kind of hacky) + for (i, var) in vars.iter().enumerate() { + if &*table.resolve_ty_shallow(var) == var { + table.unify(var, &Ty::Bound(BoundVar::new(DebruijnIndex::INNERMOST, i))); + } + } + Some( + Substs::builder(tys.kinds.len()) + .fill(vars.iter().map(|v| table.resolve_ty_completely(v.clone()))) + .build(), + ) +} + +#[derive(Clone, Debug)] +pub(crate) struct InferenceTable { + pub(super) var_unification_table: InPlaceUnificationTable, +} + +impl InferenceTable { + pub fn new() -> Self { + InferenceTable { var_unification_table: InPlaceUnificationTable::new() } + } + + pub fn new_type_var(&mut self) -> Ty { + Ty::Infer(InferTy::TypeVar(self.var_unification_table.new_key(TypeVarValue::Unknown))) + } + + pub fn new_integer_var(&mut self) -> Ty { + Ty::Infer(InferTy::IntVar(self.var_unification_table.new_key(TypeVarValue::Unknown))) + } + + pub fn new_float_var(&mut self) -> Ty { + Ty::Infer(InferTy::FloatVar(self.var_unification_table.new_key(TypeVarValue::Unknown))) + } + + pub fn new_maybe_never_type_var(&mut self) -> Ty { + Ty::Infer(InferTy::MaybeNeverTypeVar( + self.var_unification_table.new_key(TypeVarValue::Unknown), + )) + } + + pub fn resolve_ty_completely(&mut self, ty: Ty) -> Ty { + self.resolve_ty_completely_inner(&mut Vec::new(), ty) + } + + pub fn resolve_ty_as_possible(&mut self, ty: Ty) -> Ty { + self.resolve_ty_as_possible_inner(&mut Vec::new(), ty) + } + + pub fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool { + self.unify_inner(ty1, ty2, 0) + } + + pub fn unify_substs(&mut self, substs1: &Substs, substs2: &Substs, depth: usize) -> bool { + substs1.0.iter().zip(substs2.0.iter()).all(|(t1, t2)| self.unify_inner(t1, t2, depth)) + } + + fn unify_inner(&mut self, ty1: &Ty, ty2: &Ty, depth: usize) -> bool { + if depth > 1000 { + // prevent stackoverflows + panic!("infinite recursion in unification"); + } + if ty1 == ty2 { + return true; + } + // try to resolve type vars first + let ty1 = self.resolve_ty_shallow(ty1); + let ty2 = self.resolve_ty_shallow(ty2); + match (&*ty1, &*ty2) { + (Ty::Apply(a_ty1), Ty::Apply(a_ty2)) if a_ty1.ctor == a_ty2.ctor => { + self.unify_substs(&a_ty1.parameters, &a_ty2.parameters, depth + 1) + } + + _ => self.unify_inner_trivial(&ty1, &ty2, depth), + } + } + + pub(super) fn unify_inner_trivial(&mut self, ty1: &Ty, ty2: &Ty, depth: usize) -> bool { + match (ty1, ty2) { + (Ty::Unknown, _) | (_, Ty::Unknown) => true, + + (Ty::Placeholder(p1), Ty::Placeholder(p2)) if *p1 == *p2 => true, + + (Ty::Dyn(dyn1), Ty::Dyn(dyn2)) if dyn1.len() == dyn2.len() => { + for (pred1, pred2) in dyn1.iter().zip(dyn2.iter()) { + if !self.unify_preds(pred1, pred2, depth + 1) { + return false; + } + } + true + } + + (Ty::Infer(InferTy::TypeVar(tv1)), Ty::Infer(InferTy::TypeVar(tv2))) + | (Ty::Infer(InferTy::IntVar(tv1)), Ty::Infer(InferTy::IntVar(tv2))) + | (Ty::Infer(InferTy::FloatVar(tv1)), Ty::Infer(InferTy::FloatVar(tv2))) + | ( + Ty::Infer(InferTy::MaybeNeverTypeVar(tv1)), + Ty::Infer(InferTy::MaybeNeverTypeVar(tv2)), + ) => { + // both type vars are unknown since we tried to resolve them + self.var_unification_table.union(*tv1, *tv2); + true + } + + // The order of MaybeNeverTypeVar matters here. + // Unifying MaybeNeverTypeVar and TypeVar will let the latter become MaybeNeverTypeVar. + // Unifying MaybeNeverTypeVar and other concrete type will let the former become it. + (Ty::Infer(InferTy::TypeVar(tv)), other) + | (other, Ty::Infer(InferTy::TypeVar(tv))) + | (Ty::Infer(InferTy::MaybeNeverTypeVar(tv)), other) + | (other, Ty::Infer(InferTy::MaybeNeverTypeVar(tv))) + | (Ty::Infer(InferTy::IntVar(tv)), other @ ty_app!(TypeCtor::Int(_))) + | (other @ ty_app!(TypeCtor::Int(_)), Ty::Infer(InferTy::IntVar(tv))) + | (Ty::Infer(InferTy::FloatVar(tv)), other @ ty_app!(TypeCtor::Float(_))) + | (other @ ty_app!(TypeCtor::Float(_)), Ty::Infer(InferTy::FloatVar(tv))) => { + // the type var is unknown since we tried to resolve it + self.var_unification_table.union_value(*tv, TypeVarValue::Known(other.clone())); + true + } + + _ => false, + } + } + + fn unify_preds( + &mut self, + pred1: &GenericPredicate, + pred2: &GenericPredicate, + depth: usize, + ) -> bool { + match (pred1, pred2) { + (GenericPredicate::Implemented(tr1), GenericPredicate::Implemented(tr2)) + if tr1.trait_ == tr2.trait_ => + { + self.unify_substs(&tr1.substs, &tr2.substs, depth + 1) + } + (GenericPredicate::Projection(proj1), GenericPredicate::Projection(proj2)) + if proj1.projection_ty.associated_ty == proj2.projection_ty.associated_ty => + { + self.unify_substs( + &proj1.projection_ty.parameters, + &proj2.projection_ty.parameters, + depth + 1, + ) && self.unify_inner(&proj1.ty, &proj2.ty, depth + 1) + } + _ => false, + } + } + + /// If `ty` is a type variable with known type, returns that type; + /// otherwise, return ty. + pub fn resolve_ty_shallow<'b>(&mut self, ty: &'b Ty) -> Cow<'b, Ty> { + let mut ty = Cow::Borrowed(ty); + // The type variable could resolve to a int/float variable. Hence try + // resolving up to three times; each type of variable shouldn't occur + // more than once + for i in 0..3 { + if i > 0 { + mark::hit!(type_var_resolves_to_int_var); + } + match &*ty { + Ty::Infer(tv) => { + let inner = tv.to_inner(); + match self.var_unification_table.inlined_probe_value(inner).known() { + Some(known_ty) => { + // The known_ty can't be a type var itself + ty = Cow::Owned(known_ty.clone()); + } + _ => return ty, + } + } + _ => return ty, + } + } + log::error!("Inference variable still not resolved: {:?}", ty); + ty + } + + /// Resolves the type as far as currently possible, replacing type variables + /// by their known types. All types returned by the infer_* functions should + /// be resolved as far as possible, i.e. contain no type variables with + /// known type. + fn resolve_ty_as_possible_inner(&mut self, tv_stack: &mut Vec, ty: Ty) -> Ty { + ty.fold(&mut |ty| match ty { + Ty::Infer(tv) => { + let inner = tv.to_inner(); + if tv_stack.contains(&inner) { + mark::hit!(type_var_cycles_resolve_as_possible); + // recursive type + return tv.fallback_value(); + } + if let Some(known_ty) = + self.var_unification_table.inlined_probe_value(inner).known() + { + // known_ty may contain other variables that are known by now + tv_stack.push(inner); + let result = self.resolve_ty_as_possible_inner(tv_stack, known_ty.clone()); + tv_stack.pop(); + result + } else { + ty + } + } + _ => ty, + }) + } + + /// Resolves the type completely; type variables without known type are + /// replaced by Ty::Unknown. + fn resolve_ty_completely_inner(&mut self, tv_stack: &mut Vec, ty: Ty) -> Ty { + ty.fold(&mut |ty| match ty { + Ty::Infer(tv) => { + let inner = tv.to_inner(); + if tv_stack.contains(&inner) { + mark::hit!(type_var_cycles_resolve_completely); + // recursive type + return tv.fallback_value(); + } + if let Some(known_ty) = + self.var_unification_table.inlined_probe_value(inner).known() + { + // known_ty may contain other variables that are known by now + tv_stack.push(inner); + let result = self.resolve_ty_completely_inner(tv_stack, known_ty.clone()); + tv_stack.pop(); + result + } else { + tv.fallback_value() + } + } + _ => ty, + }) + } +} + +/// The ID of a type variable. +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +pub struct TypeVarId(pub(super) u32); + +impl UnifyKey for TypeVarId { + type Value = TypeVarValue; + + fn index(&self) -> u32 { + self.0 + } + + fn from_index(i: u32) -> Self { + TypeVarId(i) + } + + fn tag() -> &'static str { + "TypeVarId" + } +} + +/// The value of a type variable: either we already know the type, or we don't +/// know it yet. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum TypeVarValue { + Known(Ty), + Unknown, +} + +impl TypeVarValue { + fn known(&self) -> Option<&Ty> { + match self { + TypeVarValue::Known(ty) => Some(ty), + TypeVarValue::Unknown => None, + } + } +} + +impl UnifyValue for TypeVarValue { + type Error = NoError; + + fn unify_values(value1: &Self, value2: &Self) -> Result { + match (value1, value2) { + // We should never equate two type variables, both of which have + // known types. Instead, we recursively equate those types. + (TypeVarValue::Known(t1), TypeVarValue::Known(t2)) => panic!( + "equating two type variables, both of which have known types: {:?} and {:?}", + t1, t2 + ), + + // If one side is known, prefer that one. + (TypeVarValue::Known(..), TypeVarValue::Unknown) => Ok(value1.clone()), + (TypeVarValue::Unknown, TypeVarValue::Known(..)) => Ok(value2.clone()), + + (TypeVarValue::Unknown, TypeVarValue::Unknown) => Ok(TypeVarValue::Unknown), + } + } +} -- cgit v1.2.3