diff options
Diffstat (limited to 'crates/ra_proc_macro/src/msg.rs')
-rw-r--r-- | crates/ra_proc_macro/src/msg.rs | 231 |
1 files changed, 53 insertions, 178 deletions
diff --git a/crates/ra_proc_macro/src/msg.rs b/crates/ra_proc_macro/src/msg.rs index 2fb065d32..aa95bcc8f 100644 --- a/crates/ra_proc_macro/src/msg.rs +++ b/crates/ra_proc_macro/src/msg.rs | |||
@@ -1,218 +1,93 @@ | |||
1 | //! A simplified version of lsp base protocol for rpc | 1 | //! Defines messages for cross-process message based on `ndjson` wire protocol |
2 | 2 | ||
3 | use std::{ | 3 | use std::{ |
4 | fmt, | 4 | convert::TryFrom, |
5 | io::{self, BufRead, Write}, | 5 | io::{self, BufRead, Write}, |
6 | }; | 6 | }; |
7 | 7 | ||
8 | use crate::{ | ||
9 | rpc::{ListMacrosResult, ListMacrosTask}, | ||
10 | ExpansionResult, ExpansionTask, | ||
11 | }; | ||
8 | use serde::{de::DeserializeOwned, Deserialize, Serialize}; | 12 | use serde::{de::DeserializeOwned, Deserialize, Serialize}; |
9 | 13 | ||
10 | #[derive(Serialize, Deserialize, Debug, Clone)] | ||
11 | #[serde(untagged)] | ||
12 | pub enum Message { | ||
13 | Request(Request), | ||
14 | Response(Response), | ||
15 | } | ||
16 | |||
17 | impl From<Request> for Message { | ||
18 | fn from(request: Request) -> Message { | ||
19 | Message::Request(request) | ||
20 | } | ||
21 | } | ||
22 | |||
23 | impl From<Response> for Message { | ||
24 | fn from(response: Response) -> Message { | ||
25 | Message::Response(response) | ||
26 | } | ||
27 | } | ||
28 | |||
29 | #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] | ||
30 | #[serde(transparent)] | ||
31 | pub struct RequestId(IdRepr); | ||
32 | |||
33 | #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] | ||
34 | #[serde(untagged)] | ||
35 | enum IdRepr { | ||
36 | U64(u64), | ||
37 | String(String), | ||
38 | } | ||
39 | |||
40 | impl From<u64> for RequestId { | ||
41 | fn from(id: u64) -> RequestId { | ||
42 | RequestId(IdRepr::U64(id)) | ||
43 | } | ||
44 | } | ||
45 | |||
46 | impl From<String> for RequestId { | ||
47 | fn from(id: String) -> RequestId { | ||
48 | RequestId(IdRepr::String(id)) | ||
49 | } | ||
50 | } | ||
51 | |||
52 | impl fmt::Display for RequestId { | ||
53 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
54 | match &self.0 { | ||
55 | IdRepr::U64(it) => fmt::Display::fmt(it, f), | ||
56 | IdRepr::String(it) => fmt::Display::fmt(it, f), | ||
57 | } | ||
58 | } | ||
59 | } | ||
60 | |||
61 | #[derive(Debug, Serialize, Deserialize, Clone)] | 14 | #[derive(Debug, Serialize, Deserialize, Clone)] |
62 | pub struct Request { | 15 | pub enum Request { |
63 | pub id: RequestId, | 16 | ListMacro(ListMacrosTask), |
64 | pub method: String, | 17 | ExpansionMacro(ExpansionTask), |
65 | pub params: serde_json::Value, | ||
66 | } | 18 | } |
67 | 19 | ||
68 | #[derive(Debug, Serialize, Deserialize, Clone)] | 20 | #[derive(Debug, Serialize, Deserialize, Clone)] |
69 | pub struct Response { | 21 | pub enum Response { |
70 | // JSON RPC allows this to be null if it was impossible | 22 | Error(ResponseError), |
71 | // to decode the request's id. Ignore this special case | 23 | ListMacro(ListMacrosResult), |
72 | // and just die horribly. | 24 | ExpansionMacro(ExpansionResult), |
73 | pub id: RequestId, | 25 | } |
74 | #[serde(skip_serializing_if = "Option::is_none")] | 26 | |
75 | pub result: Option<serde_json::Value>, | 27 | macro_rules! impl_try_from_response { |
76 | #[serde(skip_serializing_if = "Option::is_none")] | 28 | ($ty:ty, $tag:ident) => { |
77 | pub error: Option<ResponseError>, | 29 | impl TryFrom<Response> for $ty { |
30 | type Error = &'static str; | ||
31 | fn try_from(value: Response) -> Result<Self, Self::Error> { | ||
32 | match value { | ||
33 | Response::$tag(res) => Ok(res), | ||
34 | _ => Err("Fail to convert from response"), | ||
35 | } | ||
36 | } | ||
37 | } | ||
38 | }; | ||
78 | } | 39 | } |
79 | 40 | ||
41 | impl_try_from_response!(ListMacrosResult, ListMacro); | ||
42 | impl_try_from_response!(ExpansionResult, ExpansionMacro); | ||
43 | |||
80 | #[derive(Debug, Serialize, Deserialize, Clone)] | 44 | #[derive(Debug, Serialize, Deserialize, Clone)] |
81 | pub struct ResponseError { | 45 | pub struct ResponseError { |
82 | pub code: i32, | 46 | pub code: ErrorCode, |
83 | pub message: String, | 47 | pub message: String, |
84 | #[serde(skip_serializing_if = "Option::is_none")] | ||
85 | pub data: Option<serde_json::Value>, | ||
86 | } | ||
87 | |||
88 | #[derive(Clone, Copy, Debug)] | ||
89 | #[allow(unused)] | ||
90 | pub enum ErrorCode { | ||
91 | // Defined by JSON RPC | ||
92 | ParseError = -32700, | ||
93 | InvalidRequest = -32600, | ||
94 | MethodNotFound = -32601, | ||
95 | InvalidParams = -32602, | ||
96 | InternalError = -32603, | ||
97 | ServerErrorStart = -32099, | ||
98 | ServerErrorEnd = -32000, | ||
99 | ServerNotInitialized = -32002, | ||
100 | UnknownErrorCode = -32001, | ||
101 | |||
102 | // Defined by protocol | ||
103 | ExpansionError = -32900, | ||
104 | } | 48 | } |
105 | 49 | ||
106 | #[derive(Debug, Serialize, Deserialize, Clone)] | 50 | #[derive(Debug, Serialize, Deserialize, Clone)] |
107 | pub struct Notification { | 51 | pub enum ErrorCode { |
108 | pub method: String, | 52 | ServerErrorEnd, |
109 | pub params: serde_json::Value, | 53 | ExpansionError, |
110 | } | 54 | } |
111 | 55 | ||
112 | impl Message { | 56 | pub trait Message: Sized + Serialize + DeserializeOwned { |
113 | pub fn read(r: &mut impl BufRead) -> io::Result<Option<Message>> { | 57 | fn read(r: &mut impl BufRead) -> io::Result<Option<Self>> { |
114 | let text = match read_msg_text(r)? { | 58 | let text = match read_json(r)? { |
115 | None => return Ok(None), | 59 | None => return Ok(None), |
116 | Some(text) => text, | 60 | Some(text) => text, |
117 | }; | 61 | }; |
118 | let msg = serde_json::from_str(&text)?; | 62 | let msg = serde_json::from_str(&text)?; |
119 | Ok(Some(msg)) | 63 | Ok(Some(msg)) |
120 | } | 64 | } |
121 | pub fn write(self, w: &mut impl Write) -> io::Result<()> { | 65 | fn write(self, w: &mut impl Write) -> io::Result<()> { |
122 | #[derive(Serialize)] | 66 | let text = serde_json::to_string(&self)?; |
123 | struct JsonRpc { | 67 | write_json(w, &text) |
124 | jsonrpc: &'static str, | ||
125 | #[serde(flatten)] | ||
126 | msg: Message, | ||
127 | } | ||
128 | let text = serde_json::to_string(&JsonRpc { jsonrpc: "2.0", msg: self })?; | ||
129 | write_msg_text(w, &text) | ||
130 | } | 68 | } |
131 | } | 69 | } |
132 | 70 | ||
133 | impl Response { | 71 | impl Message for Request {} |
134 | pub fn new_ok<R: Serialize>(id: RequestId, result: R) -> Response { | 72 | impl Message for Response {} |
135 | Response { id, result: Some(serde_json::to_value(result).unwrap()), error: None } | ||
136 | } | ||
137 | pub fn new_err(id: RequestId, code: i32, message: String) -> Response { | ||
138 | let error = ResponseError { code, message, data: None }; | ||
139 | Response { id, result: None, error: Some(error) } | ||
140 | } | ||
141 | } | ||
142 | 73 | ||
143 | impl Request { | 74 | fn read_json(inp: &mut impl BufRead) -> io::Result<Option<String>> { |
144 | pub fn new<P: Serialize>(id: RequestId, method: String, params: P) -> Request { | ||
145 | Request { id, method, params: serde_json::to_value(params).unwrap() } | ||
146 | } | ||
147 | pub fn extract<P: DeserializeOwned>(self, method: &str) -> Result<(RequestId, P), Request> { | ||
148 | if self.method == method { | ||
149 | let params = serde_json::from_value(self.params).unwrap_or_else(|err| { | ||
150 | panic!("Invalid request\nMethod: {}\n error: {}", method, err) | ||
151 | }); | ||
152 | Ok((self.id, params)) | ||
153 | } else { | ||
154 | Err(self) | ||
155 | } | ||
156 | } | ||
157 | } | ||
158 | |||
159 | impl Notification { | ||
160 | pub fn new(method: String, params: impl Serialize) -> Notification { | ||
161 | Notification { method, params: serde_json::to_value(params).unwrap() } | ||
162 | } | ||
163 | pub fn extract<P: DeserializeOwned>(self, method: &str) -> Result<P, Notification> { | ||
164 | if self.method == method { | ||
165 | let params = serde_json::from_value(self.params).unwrap(); | ||
166 | Ok(params) | ||
167 | } else { | ||
168 | Err(self) | ||
169 | } | ||
170 | } | ||
171 | } | ||
172 | |||
173 | fn read_msg_text(inp: &mut impl BufRead) -> io::Result<Option<String>> { | ||
174 | fn invalid_data(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error { | ||
175 | io::Error::new(io::ErrorKind::InvalidData, error) | ||
176 | } | ||
177 | macro_rules! invalid_data { | ||
178 | ($($tt:tt)*) => (invalid_data(format!($($tt)*))) | ||
179 | } | ||
180 | |||
181 | let mut size = None; | ||
182 | let mut buf = String::new(); | 75 | let mut buf = String::new(); |
183 | loop { | 76 | if inp.read_line(&mut buf)? == 0 { |
184 | buf.clear(); | 77 | return Ok(None); |
185 | if inp.read_line(&mut buf)? == 0 { | 78 | } |
186 | return Ok(None); | 79 | // Remove ending '\n' |
187 | } | 80 | let buf = &buf[..buf.len() - 1]; |
188 | if !buf.ends_with("\r\n") { | 81 | if buf.is_empty() { |
189 | return Err(invalid_data!("malformed header: {:?}", buf)); | 82 | return Ok(None); |
190 | } | ||
191 | let buf = &buf[..buf.len() - 2]; | ||
192 | if buf.is_empty() { | ||
193 | break; | ||
194 | } | ||
195 | let mut parts = buf.splitn(2, ": "); | ||
196 | let header_name = parts.next().unwrap(); | ||
197 | let header_value = | ||
198 | parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?; | ||
199 | if header_name == "Content-Length" { | ||
200 | size = Some(header_value.parse::<usize>().map_err(invalid_data)?); | ||
201 | } | ||
202 | } | 83 | } |
203 | let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?; | 84 | Ok(Some(buf.to_string())) |
204 | let mut buf = buf.into_bytes(); | ||
205 | buf.resize(size, 0); | ||
206 | inp.read_exact(&mut buf)?; | ||
207 | let buf = String::from_utf8(buf).map_err(invalid_data)?; | ||
208 | log::debug!("< {}", buf); | ||
209 | Ok(Some(buf)) | ||
210 | } | 85 | } |
211 | 86 | ||
212 | fn write_msg_text(out: &mut impl Write, msg: &str) -> io::Result<()> { | 87 | fn write_json(out: &mut impl Write, msg: &str) -> io::Result<()> { |
213 | log::debug!("> {}", msg); | 88 | log::debug!("> {}", msg); |
214 | write!(out, "Content-Length: {}\r\n\r\n", msg.len())?; | ||
215 | out.write_all(msg.as_bytes())?; | 89 | out.write_all(msg.as_bytes())?; |
90 | out.write_all(b"\n")?; | ||
216 | out.flush()?; | 91 | out.flush()?; |
217 | Ok(()) | 92 | Ok(()) |
218 | } | 93 | } |