From 503cbd3f4b54f3be224d7a4221fa023f0e35d228 Mon Sep 17 00:00:00 2001 From: Edwin Cheng Date: Fri, 27 Mar 2020 04:26:34 +0800 Subject: Implement ra_proc_macro client logic --- crates/ra_proc_macro/src/lib.rs | 83 +++++++++--- crates/ra_proc_macro/src/msg.rs | 218 ++++++++++++++++++++++++++++++ crates/ra_proc_macro/src/process.rs | 202 ++++++++++++++++++++++++++++ crates/ra_proc_macro/src/rpc.rs | 260 ++++++++++++++++++++++++++++++++++++ 4 files changed, 745 insertions(+), 18 deletions(-) create mode 100644 crates/ra_proc_macro/src/msg.rs create mode 100644 crates/ra_proc_macro/src/process.rs create mode 100644 crates/ra_proc_macro/src/rpc.rs (limited to 'crates/ra_proc_macro/src') diff --git a/crates/ra_proc_macro/src/lib.rs b/crates/ra_proc_macro/src/lib.rs index 5e21dd487..a0a478dc8 100644 --- a/crates/ra_proc_macro/src/lib.rs +++ b/crates/ra_proc_macro/src/lib.rs @@ -5,55 +5,102 @@ //! is used to provide basic infrastructure for communication between two //! processes: Client (RA itself), Server (the external program) +mod rpc; +mod process; +pub mod msg; + +use process::ProcMacroProcessSrv; use ra_tt::{SmolStr, Subtree}; +use rpc::ProcMacroKind; use std::{ path::{Path, PathBuf}, sync::Arc, }; -#[derive(Debug, Clone, PartialEq, Eq)] +pub use rpc::{ExpansionResult, ExpansionTask}; + +#[derive(Debug, Clone)] pub struct ProcMacroProcessExpander { process: Arc, + dylib_path: PathBuf, name: SmolStr, } +impl Eq for ProcMacroProcessExpander {} +impl PartialEq for ProcMacroProcessExpander { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.dylib_path == other.dylib_path + && Arc::ptr_eq(&self.process, &other.process) + } +} + impl ra_tt::TokenExpander for ProcMacroProcessExpander { fn expand( &self, - _subtree: &Subtree, + subtree: &Subtree, _attr: Option<&Subtree>, ) -> Result { - // FIXME: do nothing for now - Ok(Subtree::default()) + self.process.custom_derive(&self.dylib_path, subtree, &self.name) } } -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ProcMacroProcessSrv { - path: PathBuf, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ProcMacroClient { +#[derive(Debug, Clone)] +enum ProcMacroClientKind { Process { process: Arc }, Dummy, } +#[derive(Debug, Clone)] +pub struct ProcMacroClient { + kind: ProcMacroClientKind, +} + impl ProcMacroClient { - pub fn extern_process(process_path: &Path) -> ProcMacroClient { - let process = ProcMacroProcessSrv { path: process_path.into() }; - ProcMacroClient::Process { process: Arc::new(process) } + pub fn extern_process(process_path: &Path) -> Result { + let process = ProcMacroProcessSrv::run(process_path)?; + Ok(ProcMacroClient { kind: ProcMacroClientKind::Process { process: Arc::new(process) } }) } pub fn dummy() -> ProcMacroClient { - ProcMacroClient::Dummy + ProcMacroClient { kind: ProcMacroClientKind::Dummy } } pub fn by_dylib_path( &self, - _dylib_path: &Path, + dylib_path: &Path, ) -> Vec<(SmolStr, Arc)> { - // FIXME: return empty for now - vec![] + match &self.kind { + ProcMacroClientKind::Dummy => vec![], + ProcMacroClientKind::Process { process } => { + let macros = match process.find_proc_macros(dylib_path) { + Err(err) => { + eprintln!("Fail to find proc macro. Error: {:#?}", err); + return vec![]; + } + Ok(macros) => macros, + }; + + macros + .into_iter() + .filter_map(|(name, kind)| { + // FIXME: Support custom derive only for now. + match kind { + ProcMacroKind::CustomDerive => { + let name = SmolStr::new(&name); + let expander: Arc = + Arc::new(ProcMacroProcessExpander { + process: process.clone(), + name: name.clone(), + dylib_path: dylib_path.into(), + }); + Some((name, expander)) + } + _ => None, + } + }) + .collect() + } + } } } diff --git a/crates/ra_proc_macro/src/msg.rs b/crates/ra_proc_macro/src/msg.rs new file mode 100644 index 000000000..2fb065d32 --- /dev/null +++ b/crates/ra_proc_macro/src/msg.rs @@ -0,0 +1,218 @@ +//! A simplified version of lsp base protocol for rpc + +use std::{ + fmt, + io::{self, BufRead, Write}, +}; + +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum Message { + Request(Request), + Response(Response), +} + +impl From for Message { + fn from(request: Request) -> Message { + Message::Request(request) + } +} + +impl From for Message { + fn from(response: Response) -> Message { + Message::Response(response) + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[serde(transparent)] +pub struct RequestId(IdRepr); + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[serde(untagged)] +enum IdRepr { + U64(u64), + String(String), +} + +impl From for RequestId { + fn from(id: u64) -> RequestId { + RequestId(IdRepr::U64(id)) + } +} + +impl From for RequestId { + fn from(id: String) -> RequestId { + RequestId(IdRepr::String(id)) + } +} + +impl fmt::Display for RequestId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + IdRepr::U64(it) => fmt::Display::fmt(it, f), + IdRepr::String(it) => fmt::Display::fmt(it, f), + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Request { + pub id: RequestId, + pub method: String, + pub params: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Response { + // JSON RPC allows this to be null if it was impossible + // to decode the request's id. Ignore this special case + // and just die horribly. + pub id: RequestId, + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseError { + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +#[derive(Clone, Copy, Debug)] +#[allow(unused)] +pub enum ErrorCode { + // Defined by JSON RPC + ParseError = -32700, + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603, + ServerErrorStart = -32099, + ServerErrorEnd = -32000, + ServerNotInitialized = -32002, + UnknownErrorCode = -32001, + + // Defined by protocol + ExpansionError = -32900, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Notification { + pub method: String, + pub params: serde_json::Value, +} + +impl Message { + pub fn read(r: &mut impl BufRead) -> io::Result> { + let text = match read_msg_text(r)? { + None => return Ok(None), + Some(text) => text, + }; + let msg = serde_json::from_str(&text)?; + Ok(Some(msg)) + } + pub fn write(self, w: &mut impl Write) -> io::Result<()> { + #[derive(Serialize)] + struct JsonRpc { + jsonrpc: &'static str, + #[serde(flatten)] + msg: Message, + } + let text = serde_json::to_string(&JsonRpc { jsonrpc: "2.0", msg: self })?; + write_msg_text(w, &text) + } +} + +impl Response { + pub fn new_ok(id: RequestId, result: R) -> Response { + Response { id, result: Some(serde_json::to_value(result).unwrap()), error: None } + } + pub fn new_err(id: RequestId, code: i32, message: String) -> Response { + let error = ResponseError { code, message, data: None }; + Response { id, result: None, error: Some(error) } + } +} + +impl Request { + pub fn new(id: RequestId, method: String, params: P) -> Request { + Request { id, method, params: serde_json::to_value(params).unwrap() } + } + pub fn extract(self, method: &str) -> Result<(RequestId, P), Request> { + if self.method == method { + let params = serde_json::from_value(self.params).unwrap_or_else(|err| { + panic!("Invalid request\nMethod: {}\n error: {}", method, err) + }); + Ok((self.id, params)) + } else { + Err(self) + } + } +} + +impl Notification { + pub fn new(method: String, params: impl Serialize) -> Notification { + Notification { method, params: serde_json::to_value(params).unwrap() } + } + pub fn extract(self, method: &str) -> Result { + if self.method == method { + let params = serde_json::from_value(self.params).unwrap(); + Ok(params) + } else { + Err(self) + } + } +} + +fn read_msg_text(inp: &mut impl BufRead) -> io::Result> { + fn invalid_data(error: impl Into>) -> io::Error { + io::Error::new(io::ErrorKind::InvalidData, error) + } + macro_rules! invalid_data { + ($($tt:tt)*) => (invalid_data(format!($($tt)*))) + } + + let mut size = None; + let mut buf = String::new(); + loop { + buf.clear(); + if inp.read_line(&mut buf)? == 0 { + return Ok(None); + } + if !buf.ends_with("\r\n") { + return Err(invalid_data!("malformed header: {:?}", buf)); + } + let buf = &buf[..buf.len() - 2]; + if buf.is_empty() { + break; + } + let mut parts = buf.splitn(2, ": "); + let header_name = parts.next().unwrap(); + let header_value = + parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?; + if header_name == "Content-Length" { + size = Some(header_value.parse::().map_err(invalid_data)?); + } + } + let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?; + let mut buf = buf.into_bytes(); + buf.resize(size, 0); + inp.read_exact(&mut buf)?; + let buf = String::from_utf8(buf).map_err(invalid_data)?; + log::debug!("< {}", buf); + Ok(Some(buf)) +} + +fn write_msg_text(out: &mut impl Write, msg: &str) -> io::Result<()> { + log::debug!("> {}", msg); + write!(out, "Content-Length: {}\r\n\r\n", msg.len())?; + out.write_all(msg.as_bytes())?; + out.flush()?; + Ok(()) +} diff --git a/crates/ra_proc_macro/src/process.rs b/crates/ra_proc_macro/src/process.rs new file mode 100644 index 000000000..a9095af11 --- /dev/null +++ b/crates/ra_proc_macro/src/process.rs @@ -0,0 +1,202 @@ +use crossbeam_channel::{bounded, Receiver, Sender}; +use ra_tt::Subtree; + +use crate::msg::{ErrorCode, Message, Request, Response, ResponseError}; +use crate::rpc::{ExpansionResult, ExpansionTask, ListMacrosResult, ListMacrosTask, ProcMacroKind}; + +use io::{BufRead, BufReader}; +use std::{ + io::{self, Write}, + path::{Path, PathBuf}, + process::{Child, Command, Stdio}, + thread::spawn, +}; + +#[derive(Debug, Default)] +pub(crate) struct ProcMacroProcessSrv { + inner: Option, +} + +struct Task { + req: Message, + result_tx: Sender, +} + +#[derive(Debug)] +struct Handle { + sender: Sender, +} + +struct Process { + path: PathBuf, + child: Child, +} + +impl Process { + fn run(process_path: &Path) -> Result { + let child = Command::new(process_path.clone()) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn()?; + + Ok(Process { path: process_path.into(), child }) + } + + fn restart(&mut self) -> Result<(), io::Error> { + let _ = self.child.kill(); + self.child = + Command::new(self.path.clone()).stdin(Stdio::piped()).stdout(Stdio::piped()).spawn()?; + Ok(()) + } + + fn stdio(&mut self) -> Option<(impl Write, impl BufRead)> { + let stdin = self.child.stdin.take()?; + let stdout = self.child.stdout.take()?; + let read = BufReader::new(stdout); + + Some((stdin, read)) + } +} + +impl ProcMacroProcessSrv { + pub fn run(process_path: &Path) -> Result { + let process = Process::run(process_path)?; + + let (task_tx, task_rx) = bounded(0); + + let _ = spawn(move || { + client_loop(task_rx, process); + }); + Ok(ProcMacroProcessSrv { inner: Some(Handle { sender: task_tx }) }) + } + + pub fn find_proc_macros( + &self, + dylib_path: &Path, + ) -> Result, ra_tt::ExpansionError> { + let task = ListMacrosTask { lib: dylib_path.to_path_buf() }; + + let result: ListMacrosResult = self.send_task("list_macros", task)?; + Ok(result.macros) + } + + pub fn custom_derive( + &self, + dylib_path: &Path, + subtree: &Subtree, + derive_name: &str, + ) -> Result { + let task = ExpansionTask { + macro_body: subtree.clone(), + macro_name: derive_name.to_string(), + attributes: None, + lib: dylib_path.to_path_buf(), + }; + + let result: ExpansionResult = self.send_task("custom_derive", task)?; + Ok(result.expansion) + } + + pub fn send_task<'a, T, R>(&self, method: &str, task: T) -> Result + where + T: serde::Serialize, + R: serde::de::DeserializeOwned + Default, + { + let handle = match &self.inner { + None => return Err(ra_tt::ExpansionError::Unknown("No handle is found.".to_string())), + Some(it) => it, + }; + + let msg = serde_json::to_value(task).unwrap(); + + // FIXME: use a proper request id + let id = 0; + let req = Request { id: id.into(), method: method.into(), params: msg }; + + let (result_tx, result_rx) = bounded(0); + + handle.sender.send(Task { req: req.into(), result_tx }).unwrap(); + let response = result_rx.recv().unwrap(); + + match response { + Message::Request(_) => { + return Err(ra_tt::ExpansionError::Unknown( + "Return request from ra_proc_srv".into(), + )) + } + Message::Response(res) => { + if let Some(err) = res.error { + return Err(ra_tt::ExpansionError::ExpansionError(err.message)); + } + match res.result { + None => Ok(R::default()), + Some(res) => { + let result: R = serde_json::from_value(res) + .map_err(|err| ra_tt::ExpansionError::JsonError(err.to_string()))?; + Ok(result) + } + } + } + } + } +} + +fn client_loop(task_rx: Receiver, mut process: Process) { + let (mut stdin, mut stdout) = match process.stdio() { + None => return, + Some(it) => it, + }; + + loop { + let task = match task_rx.recv() { + Ok(task) => task, + Err(_) => break, + }; + + let res = match send_message(&mut stdin, &mut stdout, task.req) { + Ok(res) => res, + Err(_err) => { + let res = Response { + id: 0.into(), + result: None, + error: Some(ResponseError { + code: ErrorCode::ServerErrorEnd as i32, + message: "Server closed".into(), + data: None, + }), + }; + if task.result_tx.send(res.into()).is_err() { + break; + } + // Restart the process + if process.restart().is_err() { + break; + } + let stdio = match process.stdio() { + None => break, + Some(it) => it, + }; + stdin = stdio.0; + stdout = stdio.1; + continue; + } + }; + + if let Some(res) = res { + if task.result_tx.send(res).is_err() { + break; + } + } + } + + let _ = process.child.kill(); +} + +fn send_message( + mut writer: &mut impl Write, + mut reader: &mut impl BufRead, + msg: Message, +) -> Result, io::Error> { + msg.write(&mut writer)?; + Ok(Message::read(&mut reader)?) +} diff --git a/crates/ra_proc_macro/src/rpc.rs b/crates/ra_proc_macro/src/rpc.rs new file mode 100644 index 000000000..e7eaf7c15 --- /dev/null +++ b/crates/ra_proc_macro/src/rpc.rs @@ -0,0 +1,260 @@ +//! Data struture serialization related stuffs for RPC + +use ra_tt::{ + Delimiter, DelimiterKind, Ident, Leaf, Literal, Punct, SmolStr, Spacing, Subtree, TokenId, + TokenTree, +}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +#[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] +pub struct ListMacrosTask { + pub lib: PathBuf, +} + +#[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] +pub enum ProcMacroKind { + CustomDerive, + FuncLike, + Attr, +} + +#[derive(Clone, Eq, PartialEq, Debug, Default, Serialize, Deserialize)] +pub struct ListMacrosResult { + pub macros: Vec<(String, ProcMacroKind)>, +} + +#[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] +pub struct ExpansionTask { + /// Argument of macro call. + /// + /// In custom derive that would be a struct or enum; in attribute-like macro - underlying + /// item; in function-like macro - the macro body. + #[serde(with = "SubtreeDef")] + pub macro_body: Subtree, + + /// Names of macros to expand. + /// + /// In custom derive those are names of derived traits (`Serialize`, `Getters`, etc.). In + /// attribute-like and functiona-like macros - single name of macro itself (`show_streams`). + pub macro_name: String, + + /// Possible attributes for the attribute-like macros. + #[serde(with = "opt_subtree_def")] + pub attributes: Option, + + pub lib: PathBuf, +} + +#[derive(Clone, Eq, PartialEq, Debug, Default, Serialize, Deserialize)] +pub struct ExpansionResult { + #[serde(with = "SubtreeDef")] + pub expansion: Subtree, +} + +#[derive(Serialize, Deserialize)] +#[serde(remote = "DelimiterKind")] +enum DelimiterKindDef { + Parenthesis, + Brace, + Bracket, +} + +#[derive(Serialize, Deserialize)] +#[serde(remote = "TokenId")] +struct TokenIdDef(u32); + +#[derive(Serialize, Deserialize)] +#[serde(remote = "Delimiter")] +struct DelimiterDef { + #[serde(with = "TokenIdDef")] + pub id: TokenId, + #[serde(with = "DelimiterKindDef")] + pub kind: DelimiterKind, +} + +#[derive(Serialize, Deserialize)] +#[serde(remote = "Subtree")] +struct SubtreeDef { + #[serde(default, with = "opt_delimiter_def")] + pub delimiter: Option, + #[serde(with = "vec_token_tree")] + pub token_trees: Vec, +} + +#[derive(Serialize, Deserialize)] +#[serde(remote = "TokenTree")] +enum TokenTreeDef { + #[serde(with = "LeafDef")] + Leaf(Leaf), + #[serde(with = "SubtreeDef")] + Subtree(Subtree), +} + +#[derive(Serialize, Deserialize)] +#[serde(remote = "Leaf")] +enum LeafDef { + #[serde(with = "LiteralDef")] + Literal(Literal), + #[serde(with = "PunctDef")] + Punct(Punct), + #[serde(with = "IdentDef")] + Ident(Ident), +} + +#[derive(Serialize, Deserialize)] +#[serde(remote = "Literal")] +struct LiteralDef { + pub text: SmolStr, + #[serde(with = "TokenIdDef")] + pub id: TokenId, +} + +#[derive(Serialize, Deserialize)] +#[serde(remote = "Punct")] +struct PunctDef { + pub char: char, + #[serde(with = "SpacingDef")] + pub spacing: Spacing, + #[serde(with = "TokenIdDef")] + pub id: TokenId, +} + +#[derive(Serialize, Deserialize)] +#[serde(remote = "Spacing")] +enum SpacingDef { + Alone, + Joint, +} + +#[derive(Serialize, Deserialize)] +#[serde(remote = "Ident")] +struct IdentDef { + pub text: SmolStr, + #[serde(with = "TokenIdDef")] + pub id: TokenId, +} + +mod opt_delimiter_def { + use super::{Delimiter, DelimiterDef}; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize(value: &Option, serializer: S) -> Result + where + S: Serializer, + { + #[derive(Serialize)] + struct Helper<'a>(#[serde(with = "DelimiterDef")] &'a Delimiter); + value.as_ref().map(Helper).serialize(serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct Helper(#[serde(with = "DelimiterDef")] Delimiter); + let helper = Option::deserialize(deserializer)?; + Ok(helper.map(|Helper(external)| external)) + } +} + +mod opt_subtree_def { + use super::{Subtree, SubtreeDef}; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize(value: &Option, serializer: S) -> Result + where + S: Serializer, + { + #[derive(Serialize)] + struct Helper<'a>(#[serde(with = "SubtreeDef")] &'a Subtree); + value.as_ref().map(Helper).serialize(serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct Helper(#[serde(with = "SubtreeDef")] Subtree); + let helper = Option::deserialize(deserializer)?; + Ok(helper.map(|Helper(external)| external)) + } +} + +mod vec_token_tree { + use super::{TokenTree, TokenTreeDef}; + use serde::{ser::SerializeSeq, Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize(value: &Vec, serializer: S) -> Result + where + S: Serializer, + { + #[derive(Serialize)] + struct Helper<'a>(#[serde(with = "TokenTreeDef")] &'a TokenTree); + + let items: Vec<_> = value.iter().map(Helper).collect(); + let mut seq = serializer.serialize_seq(Some(items.len()))?; + for element in items { + seq.serialize_element(&element)?; + } + seq.end() + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct Helper(#[serde(with = "TokenTreeDef")] TokenTree); + + let helper = Vec::deserialize(deserializer)?; + Ok(helper.into_iter().map(|Helper(external)| external).collect()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn fixture_token_tree() -> Subtree { + let mut subtree = Subtree::default(); + subtree + .token_trees + .push(TokenTree::Leaf(Ident { text: "struct".into(), id: TokenId(0) }.into())); + subtree + .token_trees + .push(TokenTree::Leaf(Ident { text: "Foo".into(), id: TokenId(1) }.into())); + subtree.token_trees.push(TokenTree::Subtree( + Subtree { + delimiter: Some(Delimiter { id: TokenId(2), kind: DelimiterKind::Brace }), + token_trees: vec![], + } + .into(), + )); + subtree + } + + #[test] + fn test_proc_macro_rpc_works() { + let tt = fixture_token_tree(); + let task = ExpansionTask { + macro_body: tt.clone(), + macro_name: Default::default(), + attributes: None, + lib: Default::default(), + }; + + let json = serde_json::to_string(&task).unwrap(); + let back: ExpansionTask = serde_json::from_str(&json).unwrap(); + + assert_eq!(task.macro_body, back.macro_body); + + let result = ExpansionResult { expansion: tt.clone() }; + let json = serde_json::to_string(&task).unwrap(); + let back: ExpansionResult = serde_json::from_str(&json).unwrap(); + + assert_eq!(result, back); + } +} -- cgit v1.2.3