From 4e49b2f73144460cde5ada8140964d96166f41fd Mon Sep 17 00:00:00 2001 From: Roland Ruckerbauer Date: Tue, 13 Oct 2020 20:48:08 +0200 Subject: Implement binary operator overloading type inference --- crates/hir_ty/src/infer.rs | 24 ++++++++++- crates/hir_ty/src/infer/expr.rs | 15 +++++-- crates/hir_ty/src/tests/simple.rs | 86 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+), 5 deletions(-) diff --git a/crates/hir_ty/src/infer.rs b/crates/hir_ty/src/infer.rs index 9a7785c76..644ebd42d 100644 --- a/crates/hir_ty/src/infer.rs +++ b/crates/hir_ty/src/infer.rs @@ -22,7 +22,7 @@ use arena::map::ArenaMap; use hir_def::{ body::Body, data::{ConstData, FunctionData, StaticData}, - expr::{BindingAnnotation, ExprId, PatId}, + expr::{ArithOp, BinaryOp, BindingAnnotation, ExprId, PatId}, lang_item::LangItemTarget, path::{path, Path}, resolver::{HasResolver, Resolver, TypeNs}, @@ -586,6 +586,28 @@ impl<'a> InferenceContext<'a> { self.db.trait_data(trait_).associated_type_by_name(&name![Output]) } + fn resolve_binary_op_output(&self, bop: &BinaryOp) -> Option { + let lang_item = match bop { + BinaryOp::ArithOp(aop) => match aop { + ArithOp::Add => "add", + ArithOp::Sub => "sub", + ArithOp::Mul => "mul", + ArithOp::Div => "div", + ArithOp::Shl => "shl", + ArithOp::Shr => "shr", + ArithOp::Rem => "rem", + ArithOp::BitXor => "bitxor", + ArithOp::BitOr => "bitor", + ArithOp::BitAnd => "bitand", + }, + _ => return None, + }; + + let trait_ = self.resolve_lang_item(lang_item)?.as_trait(); + + self.db.trait_data(trait_?).associated_type_by_name(&name![Output]) + } + fn resolve_boxed_box(&self) -> Option { let struct_ = self.resolve_lang_item("owned_box")?.as_struct()?; Some(struct_.into()) diff --git a/crates/hir_ty/src/infer/expr.rs b/crates/hir_ty/src/infer/expr.rs index 0a141b9cb..8cc0d56d3 100644 --- a/crates/hir_ty/src/infer/expr.rs +++ b/crates/hir_ty/src/infer/expr.rs @@ -531,13 +531,20 @@ impl<'a> InferenceContext<'a> { _ => Expectation::none(), }; let lhs_ty = self.infer_expr(*lhs, &lhs_expectation); - // FIXME: find implementation of trait corresponding to operation - // symbol and resolve associated `Output` type let rhs_expectation = op::binary_op_rhs_expectation(*op, lhs_ty.clone()); let rhs_ty = self.infer_expr(*rhs, &Expectation::has_type(rhs_expectation)); - // FIXME: similar as above, return ty is often associated trait type - op::binary_op_return_ty(*op, lhs_ty, rhs_ty) + let ret = op::binary_op_return_ty(*op, lhs_ty.clone(), rhs_ty.clone()); + + if ret == Ty::Unknown { + self.resolve_associated_type_with_params( + lhs_ty, + self.resolve_binary_op_output(op), + &[rhs_ty], + ) + } else { + ret + } } _ => Ty::Unknown, }, diff --git a/crates/hir_ty/src/tests/simple.rs b/crates/hir_ty/src/tests/simple.rs index 5b07948f3..a3ae304a1 100644 --- a/crates/hir_ty/src/tests/simple.rs +++ b/crates/hir_ty/src/tests/simple.rs @@ -2225,3 +2225,89 @@ fn generic_default_depending_on_other_type_arg_forward() { "#]], ); } + +#[test] +fn infer_operator_overload() { + check_infer( + r#" + struct V2([f32; 2]); + + #[lang = "add"] + pub trait Add { + /// The resulting type after applying the `+` operator. + type Output; + + /// Performs the `+` operation. + #[must_use] + fn add(self, rhs: Rhs) -> Self::Output; + } + + impl Add for V2 { + type Output = V2; + + fn add(self, rhs: V2) -> V2 { + let x = self.0[0] + rhs.0[0]; + let y = self.0[1] + rhs.0[1]; + V2([x, y]) + } + } + + fn test() { + let va = V2([0.0, 1.0]); + let vb = V2([0.0, 1.0]); + + let r = va + vb; + } + + "#, + expect![[r#" + 207..211 'self': Self + 213..216 'rhs': Rhs + 299..303 'self': V2 + 305..308 'rhs': V2 + 320..422 '{ ... }': V2 + 334..335 'x': f32 + 338..342 'self': V2 + 338..344 'self.0': [f32; _] + 338..347 'self.0[0]': {unknown} + 338..358 'self.0...s.0[0]': f32 + 345..346 '0': i32 + 350..353 'rhs': V2 + 350..355 'rhs.0': [f32; _] + 350..358 'rhs.0[0]': {unknown} + 356..357 '0': i32 + 372..373 'y': f32 + 376..380 'self': V2 + 376..382 'self.0': [f32; _] + 376..385 'self.0[1]': {unknown} + 376..396 'self.0...s.0[1]': f32 + 383..384 '1': i32 + 388..391 'rhs': V2 + 388..393 'rhs.0': [f32; _] + 388..396 'rhs.0[1]': {unknown} + 394..395 '1': i32 + 406..408 'V2': V2([f32; _]) -> V2 + 406..416 'V2([x, y])': V2 + 409..415 '[x, y]': [f32; _] + 410..411 'x': f32 + 413..414 'y': f32 + 436..519 '{ ... vb; }': () + 446..448 'va': V2 + 451..453 'V2': V2([f32; _]) -> V2 + 451..465 'V2([0.0, 1.0])': V2 + 454..464 '[0.0, 1.0]': [f32; _] + 455..458 '0.0': f32 + 460..463 '1.0': f32 + 475..477 'vb': V2 + 480..482 'V2': V2([f32; _]) -> V2 + 480..494 'V2([0.0, 1.0])': V2 + 483..493 '[0.0, 1.0]': [f32; _] + 484..487 '0.0': f32 + 489..492 '1.0': f32 + 505..506 'r': V2 + 509..511 'va': V2 + 509..516 'va + vb': V2 + 514..516 'vb': V2 + "#]], + ); +} -- cgit v1.2.3