From 3ac605e6876056fa56098231cc2f96553faab8f0 Mon Sep 17 00:00:00 2001 From: Florian Diebold Date: Thu, 20 Dec 2018 21:56:28 +0100 Subject: Add beginnings of type infrastructure --- crates/ra_hir/src/db.rs | 6 + crates/ra_hir/src/function.rs | 6 +- crates/ra_hir/src/lib.rs | 1 + crates/ra_hir/src/mock.rs | 3 +- crates/ra_hir/src/query_definitions.rs | 8 + crates/ra_hir/src/ty.rs | 478 +++++++++++++++++++++++++++++++++ crates/ra_hir/src/ty/primitive.rs | 98 +++++++ crates/ra_hir/src/ty/tests.rs | 45 ++++ 8 files changed, 643 insertions(+), 2 deletions(-) create mode 100644 crates/ra_hir/src/ty.rs create mode 100644 crates/ra_hir/src/ty/primitive.rs create mode 100644 crates/ra_hir/src/ty/tests.rs (limited to 'crates/ra_hir') diff --git a/crates/ra_hir/src/db.rs b/crates/ra_hir/src/db.rs index 62cf9ab17..f0bff3c02 100644 --- a/crates/ra_hir/src/db.rs +++ b/crates/ra_hir/src/db.rs @@ -14,6 +14,7 @@ use crate::{ function::FnId, module::{ModuleId, ModuleTree, ModuleSource, nameres::{ItemMap, InputModuleItems}}, + ty::InferenceResult, }; salsa::query_group! { @@ -30,6 +31,11 @@ pub trait HirDatabase: SyntaxDatabase use fn query_definitions::fn_syntax; } + fn infer(fn_id: FnId) -> Arc { + type InferQuery; + use fn query_definitions::infer; + } + fn file_items(file_id: FileId) -> Arc { type SourceFileItemsQuery; use fn query_definitions::file_items; diff --git a/crates/ra_hir/src/function.rs b/crates/ra_hir/src/function.rs index 2925beb16..360e9e9a0 100644 --- a/crates/ra_hir/src/function.rs +++ b/crates/ra_hir/src/function.rs @@ -10,7 +10,7 @@ use ra_syntax::{ ast::{self, AstNode, DocCommentsOwner, NameOwner}, }; -use crate::{ DefId, HirDatabase }; +use crate::{ DefId, HirDatabase, ty::InferenceResult }; pub use self::scope::FnScopes; @@ -35,6 +35,10 @@ impl Function { let syntax = db.fn_syntax(self.fn_id); FnSignatureInfo::new(syntax.borrowed()) } + + pub fn infer(&self, db: &impl HirDatabase) -> Arc { + db.infer(self.fn_id) + } } #[derive(Debug, Clone)] diff --git a/crates/ra_hir/src/lib.rs b/crates/ra_hir/src/lib.rs index f56214b47..e84f44675 100644 --- a/crates/ra_hir/src/lib.rs +++ b/crates/ra_hir/src/lib.rs @@ -25,6 +25,7 @@ pub mod source_binder; mod krate; mod module; mod function; +mod ty; use std::ops::Index; diff --git a/crates/ra_hir/src/mock.rs b/crates/ra_hir/src/mock.rs index 9423e6571..a9fa540d5 100644 --- a/crates/ra_hir/src/mock.rs +++ b/crates/ra_hir/src/mock.rs @@ -8,7 +8,7 @@ use test_utils::{parse_fixture, CURSOR_MARKER, extract_offset}; use crate::{db, DefId, DefLoc}; -const WORKSPACE: SourceRootId = SourceRootId(0); +pub const WORKSPACE: SourceRootId = SourceRootId(0); #[derive(Debug)] pub(crate) struct MockDatabase { @@ -182,6 +182,7 @@ salsa::database_storage! { fn item_map() for db::ItemMapQuery; fn fn_syntax() for db::FnSyntaxQuery; fn submodules() for db::SubmodulesQuery; + fn infer() for db::InferQuery; } } } diff --git a/crates/ra_hir/src/query_definitions.rs b/crates/ra_hir/src/query_definitions.rs index efaeb1525..ccbfdf028 100644 --- a/crates/ra_hir/src/query_definitions.rs +++ b/crates/ra_hir/src/query_definitions.rs @@ -19,6 +19,7 @@ use crate::{ imp::Submodule, nameres::{InputModuleItems, ItemMap, Resolver}, }, + ty::{self, InferenceResult} }; /// Resolve `FnId` to the corresponding `SyntaxNode` @@ -35,6 +36,13 @@ pub(super) fn fn_scopes(db: &impl HirDatabase, fn_id: FnId) -> Arc { Arc::new(res) } +pub(super) fn infer(db: &impl HirDatabase, fn_id: FnId) -> Arc { + let syntax = db.fn_syntax(fn_id); + let scopes = db.fn_scopes(fn_id); + let res = ty::infer(db, syntax.borrowed(), scopes); + Arc::new(res) +} + pub(super) fn file_items(db: &impl HirDatabase, file_id: FileId) -> Arc { let mut res = SourceFileItems::new(file_id); let source_file = db.source_file(file_id); diff --git a/crates/ra_hir/src/ty.rs b/crates/ra_hir/src/ty.rs new file mode 100644 index 000000000..36dc5d137 --- /dev/null +++ b/crates/ra_hir/src/ty.rs @@ -0,0 +1,478 @@ +mod primitive; +#[cfg(test)] +mod tests; + +use rustc_hash::{FxHashMap, FxHashSet}; + +use std::sync::Arc; +use std::collections::HashMap; + +use ra_db::LocalSyntaxPtr; +use ra_syntax::{ + TextRange, TextUnit, + algo::visit::{visitor, Visitor}, + ast::{self, AstNode, DocCommentsOwner, NameOwner, LoopBodyOwner, ArgListOwner}, + SyntaxNodeRef +}; + +use crate::{ + FnScopes, + db::HirDatabase, + arena::{Arena, Id}, +}; + +// pub(crate) type TypeId = Id; + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub enum Ty { + /// The primitive boolean type. Written as `bool`. + Bool, + + /// The primitive character type; holds a Unicode scalar value + /// (a non-surrogate code point). Written as `char`. + Char, + + /// A primitive signed integer type. For example, `i32`. + Int(primitive::IntTy), + + /// A primitive unsigned integer type. For example, `u32`. + Uint(primitive::UintTy), + + /// A primitive floating-point type. For example, `f64`. + Float(primitive::FloatTy), + + /// Structures, enumerations and unions. + /// + /// Substs here, possibly against intuition, *may* contain `Param`s. + /// That is, even after substitution it is possible that there are type + /// variables. This happens when the `Adt` corresponds to an ADT + /// definition and not a concrete use of it. + // Adt(&'tcx AdtDef, &'tcx Substs<'tcx>), + + // Foreign(DefId), + + /// The pointee of a string slice. Written as `str`. + Str, + + /// An array with the given length. Written as `[T; n]`. + // Array(Ty<'tcx>, &'tcx ty::Const<'tcx>), + + /// The pointee of an array slice. Written as `[T]`. + Slice(TyRef), + + /// A raw pointer. Written as `*mut T` or `*const T` + // RawPtr(TypeAndMut<'tcx>), + + /// A reference; a pointer with an associated lifetime. Written as + /// `&'a mut T` or `&'a T`. + // Ref(Region<'tcx>, Ty<'tcx>, hir::Mutability), + + /// The anonymous type of a function declaration/definition. Each + /// function has a unique type, which is output (for a function + /// named `foo` returning an `i32`) as `fn() -> i32 {foo}`. + /// + /// For example the type of `bar` here: + /// + /// ```rust + /// fn foo() -> i32 { 1 } + /// let bar = foo; // bar: fn() -> i32 {foo} + /// ``` + // FnDef(DefId, &'tcx Substs<'tcx>), + + /// A pointer to a function. Written as `fn() -> i32`. + /// + /// For example the type of `bar` here: + /// + /// ```rust + /// fn foo() -> i32 { 1 } + /// let bar: fn() -> i32 = foo; + /// ``` + // FnPtr(PolyFnSig<'tcx>), + + /// A trait, defined with `trait`. + // Dynamic(Binder<&'tcx List>>, ty::Region<'tcx>), + + /// The anonymous type of a closure. Used to represent the type of + /// `|a| a`. + // Closure(DefId, ClosureSubsts<'tcx>), + + /// The anonymous type of a generator. Used to represent the type of + /// `|a| yield a`. + // Generator(DefId, GeneratorSubsts<'tcx>, hir::GeneratorMovability), + + /// A type representin the types stored inside a generator. + /// This should only appear in GeneratorInteriors. + // GeneratorWitness(Binder<&'tcx List>>), + + /// The never type `!` + Never, + + /// A tuple type. For example, `(i32, bool)`. + Tuple(Vec), + + /// The projection of an associated type. For example, + /// `>::N`. + // Projection(ProjectionTy<'tcx>), + + /// Opaque (`impl Trait`) type found in a return type. + /// The `DefId` comes either from + /// * the `impl Trait` ast::Ty node, + /// * or the `existential type` declaration + /// The substitutions are for the generics of the function in question. + /// After typeck, the concrete type can be found in the `types` map. + // Opaque(DefId, &'tcx Substs<'tcx>), + + /// A type parameter; for example, `T` in `fn f(x: T) {} + // Param(ParamTy), + + /// Bound type variable, used only when preparing a trait query. + // Bound(ty::DebruijnIndex, BoundTy), + + /// A placeholder type - universally quantified higher-ranked type. + // Placeholder(ty::PlaceholderType), + + /// A type variable used during type checking. + // Infer(InferTy), + + /// A placeholder for a type which could not be computed; this is + /// propagated to avoid useless error messages. + Unknown, +} + +type TyRef = Arc; + +impl Ty { + pub fn new(node: ast::TypeRef) -> Self { + use ra_syntax::ast::TypeRef::*; + match node { + ParenType(_inner) => Ty::Unknown, // TODO + TupleType(_inner) => Ty::Unknown, // TODO + NeverType(..) => Ty::Never, + PathType(_inner) => Ty::Unknown, // TODO + PointerType(_inner) => Ty::Unknown, // TODO + ArrayType(_inner) => Ty::Unknown, // TODO + SliceType(_inner) => Ty::Unknown, // TODO + ReferenceType(_inner) => Ty::Unknown, // TODO + PlaceholderType(_inner) => Ty::Unknown, // TODO + FnPointerType(_inner) => Ty::Unknown, // TODO + ForType(_inner) => Ty::Unknown, // TODO + ImplTraitType(_inner) => Ty::Unknown, // TODO + DynTraitType(_inner) => Ty::Unknown, // TODO + } + } + + pub fn unit() -> Self { + Ty::Tuple(Vec::new()) + } +} + +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct InferenceResult { + type_for: FxHashMap, +} + +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct InferenceContext { + scopes: Arc, + // TODO unification tables... + type_for: FxHashMap, +} + +impl InferenceContext { + fn new(scopes: Arc) -> Self { + InferenceContext { + type_for: FxHashMap::default(), + scopes + } + } + + fn write_ty(&mut self, node: SyntaxNodeRef, ty: Ty) { + self.type_for.insert(LocalSyntaxPtr::new(node), ty); + } + + fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool { + unimplemented!() + } + + fn infer_expr(&mut self, expr: ast::Expr) -> Ty { + let ty = match expr { + ast::Expr::IfExpr(e) => { + if let Some(condition) = e.condition() { + if let Some(e) = condition.expr() { + // TODO if no pat, this should be bool + self.infer_expr(e); + } + // TODO write type for pat + }; + let if_ty = if let Some(block) = e.then_branch() { + self.infer_block(block) + } else { + Ty::Unknown + }; + let else_ty = if let Some(block) = e.else_branch() { + self.infer_block(block) + } else { + Ty::Unknown + }; + if self.unify(&if_ty, &else_ty) { + // TODO actually, need to take the 'more specific' type (not unknown, never, ...) + if_ty + } else { + // TODO report diagnostic + Ty::Unknown + } + } + ast::Expr::BlockExpr(e) => { + if let Some(block) = e.block() { + self.infer_block(block) + } else { + Ty::Unknown + } + } + ast::Expr::LoopExpr(e) => { + if let Some(block) = e.loop_body() { + self.infer_block(block); + }; + // TODO never, or the type of the break param + Ty::Unknown + } + ast::Expr::WhileExpr(e) => { + if let Some(condition) = e.condition() { + if let Some(e) = condition.expr() { + // TODO if no pat, this should be bool + self.infer_expr(e); + } + // TODO write type for pat + }; + if let Some(block) = e.loop_body() { + // TODO + self.infer_block(block); + }; + // TODO always unit? + Ty::Unknown + } + ast::Expr::ForExpr(e) => { + if let Some(expr) = e.iterable() { + self.infer_expr(expr); + } + if let Some(pat) = e.pat() { + // TODO write type for pat + } + if let Some(block) = e.loop_body() { + self.infer_block(block); + } + // TODO always unit? + Ty::Unknown + } + ast::Expr::LambdaExpr(e) => { + let body_ty = if let Some(body) = e.body() { + self.infer_expr(body) + } else { + Ty::Unknown + }; + Ty::Unknown + } + ast::Expr::CallExpr(e) => { + if let Some(arg_list) = e.arg_list() { + for arg in arg_list.args() { + // TODO unify / expect argument type + self.infer_expr(arg); + } + } + Ty::Unknown + } + ast::Expr::MethodCallExpr(e) => { + if let Some(arg_list) = e.arg_list() { + for arg in arg_list.args() { + // TODO unify / expect argument type + self.infer_expr(arg); + } + } + Ty::Unknown + } + ast::Expr::MatchExpr(e) => { + let ty = if let Some(match_expr) = e.expr() { + self.infer_expr(match_expr) + } else { + Ty::Unknown + }; + if let Some(match_arm_list) = e.match_arm_list() { + for arm in match_arm_list.arms() { + // TODO type the bindings in pat + // TODO type the guard + let ty = if let Some(e) = arm.expr() { + self.infer_expr(e) + } else { + Ty::Unknown + }; + } + // TODO unify all the match arm types + Ty::Unknown + } else { + Ty::Unknown + } + } + ast::Expr::TupleExpr(e) => { + Ty::Unknown + } + ast::Expr::ArrayExpr(e) => { + Ty::Unknown + } + ast::Expr::PathExpr(e) => { + if let Some(p) = e.path() { + if p.qualifier().is_none() { + if let Some(name) = p.segment().and_then(|s| s.name_ref()) { + let s = self.scopes.resolve_local_name(name); + if let Some(scope_entry) = s { + if let Some(ty) = self.type_for.get(&scope_entry.ptr()) { + ty.clone() + } else { + // TODO introduce type variable? + Ty::Unknown + } + } else { + Ty::Unknown + } + } else { + Ty::Unknown + } + } else { + // TODO resolve path + Ty::Unknown + } + } else { + Ty::Unknown + } + } + ast::Expr::ContinueExpr(e) => { + Ty::Never + } + ast::Expr::BreakExpr(e) => { + Ty::Never + } + ast::Expr::ParenExpr(e) => { + if let Some(e) = e.expr() { + self.infer_expr(e) + } else { + Ty::Unknown + } + } + ast::Expr::Label(e) => { + Ty::Unknown + } + ast::Expr::ReturnExpr(e) => { + if let Some(e) = e.expr() { + // TODO unify with return type + self.infer_expr(e); + }; + Ty::Never + } + ast::Expr::MatchArmList(_) | ast::Expr::MatchArm(_) | ast::Expr::MatchGuard(_) => { + // Can this even occur outside of a match expression? + Ty::Unknown + } + ast::Expr::StructLit(e) => { + Ty::Unknown + } + ast::Expr::NamedFieldList(_) | ast::Expr::NamedField(_) => { + // Can this even occur outside of a struct literal? + Ty::Unknown + } + ast::Expr::IndexExpr(e) => { + Ty::Unknown + } + ast::Expr::FieldExpr(e) => { + Ty::Unknown + } + ast::Expr::TryExpr(e) => { + let inner_ty = if let Some(e) = e.expr() { + self.infer_expr(e) + } else { + Ty::Unknown + }; + Ty::Unknown + } + ast::Expr::CastExpr(e) => { + let inner_ty = if let Some(e) = e.expr() { + self.infer_expr(e) + } else { + Ty::Unknown + }; + let cast_ty = e.type_ref().map(Ty::new).unwrap_or(Ty::Unknown); + // TODO do the coercion... + cast_ty + } + ast::Expr::RefExpr(e) => { + let inner_ty = if let Some(e) = e.expr() { + self.infer_expr(e) + } else { + Ty::Unknown + }; + Ty::Unknown + } + ast::Expr::PrefixExpr(e) => { + let inner_ty = if let Some(e) = e.expr() { + self.infer_expr(e) + } else { + Ty::Unknown + }; + Ty::Unknown + } + ast::Expr::RangeExpr(e) => { + Ty::Unknown + } + ast::Expr::BinExpr(e) => { + Ty::Unknown + } + ast::Expr::Literal(e) => { + Ty::Unknown + } + }; + self.write_ty(expr.syntax(), ty.clone()); + ty + } + + fn infer_block(&mut self, node: ast::Block) -> Ty { + for stmt in node.statements() { + match stmt { + ast::Stmt::LetStmt(stmt) => { + if let Some(expr) = stmt.initializer() { + self.infer_expr(expr); + } + } + ast::Stmt::ExprStmt(expr_stmt) => { + if let Some(expr) = expr_stmt.expr() { + self.infer_expr(expr); + } + } + } + } + let ty = if let Some(expr) = node.expr() { + self.infer_expr(expr) + } else { + Ty::unit() + }; + self.write_ty(node.syntax(), ty.clone()); + ty + } +} + +pub fn infer(db: &impl HirDatabase, node: ast::FnDef, scopes: Arc) -> InferenceResult { + let mut ctx = InferenceContext::new(scopes); + + for param in node.param_list().unwrap().params() { + let pat = param.pat().unwrap(); + let type_ref = param.type_ref().unwrap(); + let ty = Ty::new(type_ref); + ctx.type_for.insert(LocalSyntaxPtr::new(pat.syntax()), ty); + } + + // TODO get Ty for node.ret_type() and pass that to infer_block as expectation + // (see Expectation in rustc_typeck) + + ctx.infer_block(node.body().unwrap()); + + // TODO 'resolve' the types: replace inference variables by their inferred results + + InferenceResult { type_for: ctx.type_for } +} diff --git a/crates/ra_hir/src/ty/primitive.rs b/crates/ra_hir/src/ty/primitive.rs new file mode 100644 index 000000000..4a5ce5a97 --- /dev/null +++ b/crates/ra_hir/src/ty/primitive.rs @@ -0,0 +1,98 @@ +use std::fmt; + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Copy)] +pub enum IntTy { + Isize, + I8, + I16, + I32, + I64, + I128, +} + +impl fmt::Debug for IntTy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} + +impl fmt::Display for IntTy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.ty_to_string()) + } +} + +impl IntTy { + pub fn ty_to_string(&self) -> &'static str { + match *self { + IntTy::Isize => "isize", + IntTy::I8 => "i8", + IntTy::I16 => "i16", + IntTy::I32 => "i32", + IntTy::I64 => "i64", + IntTy::I128 => "i128", + } + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Copy)] +pub enum UintTy { + Usize, + U8, + U16, + U32, + U64, + U128, +} + +impl UintTy { + pub fn ty_to_string(&self) -> &'static str { + match *self { + UintTy::Usize => "usize", + UintTy::U8 => "u8", + UintTy::U16 => "u16", + UintTy::U32 => "u32", + UintTy::U64 => "u64", + UintTy::U128 => "u128", + } + } +} + +impl fmt::Debug for UintTy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} + +impl fmt::Display for UintTy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.ty_to_string()) + } +} + +#[derive(Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)] +pub enum FloatTy { + F32, + F64, +} + +impl fmt::Debug for FloatTy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} + +impl fmt::Display for FloatTy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.ty_to_string()) + } +} + +impl FloatTy { + pub fn ty_to_string(self) -> &'static str { + match self { + FloatTy::F32 => "f32", + FloatTy::F64 => "f64", + } + } +} diff --git a/crates/ra_hir/src/ty/tests.rs b/crates/ra_hir/src/ty/tests.rs new file mode 100644 index 000000000..f2466dd51 --- /dev/null +++ b/crates/ra_hir/src/ty/tests.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +use salsa::Database; +use ra_db::{FilesDatabase, CrateGraph, SyntaxDatabase}; +use ra_syntax::{SmolStr, algo::visit::{visitor, Visitor}, ast::{self, AstNode}}; +use relative_path::RelativePath; + +use crate::{source_binder, mock::WORKSPACE, module::ModuleSourceNode}; + +use crate::{ + self as hir, + db::HirDatabase, + mock::MockDatabase, +}; + +fn infer_all_fns(fixture: &str) -> () { + let (db, source_root) = MockDatabase::with_files(fixture); + for &file_id in source_root.files.values() { + let source_file = db.source_file(file_id); + for fn_def in source_file.syntax().descendants().filter_map(ast::FnDef::cast) { + let func = source_binder::function_from_source(&db, file_id, fn_def).unwrap().unwrap(); + let inference_result = func.infer(&db); + for (syntax_ptr, ty) in &inference_result.type_for { + let node = syntax_ptr.resolve(&source_file); + eprintln!("{} '{}': {:?}", syntax_ptr.range(), node.text(), ty); + } + } + } +} + +#[test] +fn infer_smoke_test() { + let text = " + //- /lib.rs + fn foo(x: u32, y: !) -> i128 { + x; + y; + return 1; + \"hello\"; + 0 + } + "; + + infer_all_fns(text); +} -- cgit v1.2.3