1extern crate proc_macro;
2
3mod enum_hack;
4mod error;
5
6use crate::error::{Error, Result};
7use proc_macro::{
8 token_stream, Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree,
9};
10use proc_macro_hack::proc_macro_hack;
11use std::iter::{self, FromIterator, Peekable};
12use std::panic;
13
14#[proc_macro]
15pub fn item(input: TokenStream) -> TokenStream {
16 expand_paste(input)
17}
18
19#[proc_macro]
20pub fn item_with_macros(input: TokenStream) -> TokenStream {
21 enum_hack::wrap(expand_paste(input))
22}
23
24#[proc_macro_hack]
25pub fn expr(input: TokenStream) -> TokenStream {
26 TokenStream::from(TokenTree::Group(Group::new(
27 Delimiter::Brace,
28 expand_paste(input),
29 )))
30}
31
32#[doc(hidden)]
33#[proc_macro_derive(EnumHack)]
34pub fn enum_hack(input: TokenStream) -> TokenStream {
35 enum_hack::extract(input)
36}
37
38fn expand_paste(input: TokenStream) -> TokenStream {
39 let mut contains_paste = false;
40 match expand(input, &mut contains_paste) {
41 Ok(expanded) => expanded,
42 Err(err) => err.to_compile_error(),
43 }
44}
45
46fn expand(input: TokenStream, contains_paste: &mut bool) -> Result<TokenStream> {
47 let mut expanded = TokenStream::new();
48 let (mut prev_colon, mut colon) = (false, false);
49 let mut prev_none_group = None::<Group>;
50 let mut tokens = input.into_iter().peekable();
51 loop {
52 let token = tokens.next();
53 if let Some(group) = prev_none_group.take() {
54 if match (&token, tokens.peek()) {
55 (Some(TokenTree::Punct(fst)), Some(TokenTree::Punct(snd))) => {
56 fst.as_char() == ':' && snd.as_char() == ':' && fst.spacing() == Spacing::Joint
57 }
58 _ => false,
59 } {
60 expanded.extend(group.stream());
61 *contains_paste = true;
62 } else {
63 expanded.extend(iter::once(TokenTree::Group(group)));
64 }
65 }
66 match token {
67 Some(TokenTree::Group(group)) => {
68 let delimiter = group.delimiter();
69 let content = group.stream();
70 let span = group.span();
71 if delimiter == Delimiter::Bracket && is_paste_operation(&content) {
72 let segments = parse_bracket_as_segments(content, span)?;
73 let pasted = paste_segments(span, &segments)?;
74 expanded.extend(pasted);
75 *contains_paste = true;
76 } else if is_none_delimited_flat_group(delimiter, &content) {
77 expanded.extend(content);
78 *contains_paste = true;
79 } else {
80 let mut group_contains_paste = false;
81 let nested = expand(content, &mut group_contains_paste)?;
82 let group = if group_contains_paste {
83 let mut group = Group::new(delimiter, nested);
84 group.set_span(span);
85 *contains_paste = true;
86 group
87 } else {
88 group.clone()
89 };
90 if delimiter != Delimiter::None {
91 expanded.extend(iter::once(TokenTree::Group(group)));
92 } else if prev_colon {
93 expanded.extend(group.stream());
94 *contains_paste = true;
95 } else {
96 prev_none_group = Some(group);
97 }
98 }
99 prev_colon = false;
100 colon = false;
101 }
102 Some(other) => {
103 match &other {
104 TokenTree::Punct(punct) if punct.as_char() == ':' => {
105 prev_colon = colon;
106 colon = punct.spacing() == Spacing::Joint;
107 }
108 _ => {
109 prev_colon = false;
110 colon = false;
111 }
112 }
113 expanded.extend(iter::once(other));
114 }
115 None => return Ok(expanded),
116 }
117 }
118}
119
120fn is_none_delimited_flat_group(delimiter: Delimiter, input: &TokenStream) -> bool {
122 if delimiter != Delimiter::None {
123 return false;
124 }
125
126 #[derive(PartialEq)]
127 enum State {
128 Init,
129 Ident,
130 Literal,
131 Apostrophe,
132 Lifetime,
133 Colon1,
134 Colon2,
135 }
136
137 let mut state = State::Init;
138 for tt in input.clone() {
139 state = match (state, &tt) {
140 (State::Init, TokenTree::Ident(_)) => State::Ident,
141 (State::Init, TokenTree::Literal(_)) => State::Literal,
142 (State::Init, TokenTree::Punct(punct)) if punct.as_char() == '\'' => State::Apostrophe,
143 (State::Apostrophe, TokenTree::Ident(_)) => State::Lifetime,
144 (State::Ident, TokenTree::Punct(punct))
145 if punct.as_char() == ':' && punct.spacing() == Spacing::Joint =>
146 {
147 State::Colon1
148 }
149 (State::Colon1, TokenTree::Punct(punct))
150 if punct.as_char() == ':' && punct.spacing() == Spacing::Alone =>
151 {
152 State::Colon2
153 }
154 (State::Colon2, TokenTree::Ident(_)) => State::Ident,
155 _ => return false,
156 };
157 }
158
159 state == State::Ident || state == State::Literal || state == State::Lifetime
160}
161
162struct LitStr {
163 value: String,
164 span: Span,
165}
166
167struct Colon {
168 span: Span,
169}
170
171enum Segment {
172 String(String),
173 Apostrophe(Span),
174 Env(LitStr),
175 Modifier(Colon, Ident),
176}
177
178fn is_paste_operation(input: &TokenStream) -> bool {
179 let mut tokens = input.clone().into_iter();
180
181 match &tokens.next() {
182 Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
183 _ => return false,
184 }
185
186 let mut has_token = false;
187 loop {
188 match &tokens.next() {
189 Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {
190 return has_token && tokens.next().is_none();
191 }
192 Some(_) => has_token = true,
193 None => return false,
194 }
195 }
196}
197
198fn parse_bracket_as_segments(input: TokenStream, scope: Span) -> Result<Vec<Segment>> {
199 let mut tokens = input.into_iter().peekable();
200
201 match &tokens.next() {
202 Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
203 Some(wrong) => return Err(Error::new(wrong.span(), "expected `<`")),
204 None => return Err(Error::new(scope, "expected `[< ... >]`")),
205 }
206
207 let segments = parse_segments(&mut tokens, scope)?;
208
209 match &tokens.next() {
210 Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {}
211 Some(wrong) => return Err(Error::new(wrong.span(), "expected `>`")),
212 None => return Err(Error::new(scope, "expected `[< ... >]`")),
213 }
214
215 match tokens.next() {
216 Some(unexpected) => Err(Error::new(
217 unexpected.span(),
218 "unexpected input, expected `[< ... >]`",
219 )),
220 None => Ok(segments),
221 }
222}
223
224fn parse_segments(
225 tokens: &mut Peekable<token_stream::IntoIter>,
226 scope: Span,
227) -> Result<Vec<Segment>> {
228 let mut segments = Vec::new();
229 while match tokens.peek() {
230 None => false,
231 Some(TokenTree::Punct(punct)) => punct.as_char() != '>',
232 Some(_) => true,
233 } {
234 match tokens.next().unwrap() {
235 TokenTree::Ident(ident) => {
236 let mut fragment = ident.to_string();
237 if fragment.starts_with("r#") {
238 fragment = fragment.split_off(2);
239 }
240 if fragment == "env"
241 && match tokens.peek() {
242 Some(TokenTree::Punct(punct)) => punct.as_char() == '!',
243 _ => false,
244 }
245 {
246 tokens.next().unwrap(); let expect_group = tokens.next();
248 let parenthesized = match &expect_group {
249 Some(TokenTree::Group(group))
250 if group.delimiter() == Delimiter::Parenthesis =>
251 {
252 group
253 }
254 Some(wrong) => return Err(Error::new(wrong.span(), "expected `(`")),
255 None => return Err(Error::new(scope, "expected `(` after `env!`")),
256 };
257 let mut inner = parenthesized.stream().into_iter();
258 let lit = match inner.next() {
259 Some(TokenTree::Literal(lit)) => lit,
260 Some(wrong) => {
261 return Err(Error::new(wrong.span(), "expected string literal"))
262 }
263 None => {
264 return Err(Error::new2(
265 ident.span(),
266 parenthesized.span(),
267 "expected string literal as argument to env! macro",
268 ))
269 }
270 };
271 let lit_string = lit.to_string();
272 if lit_string.starts_with('"')
273 && lit_string.ends_with('"')
274 && lit_string.len() >= 2
275 {
276 segments.push(Segment::Env(LitStr {
279 value: lit_string[1..lit_string.len() - 1].to_owned(),
280 span: lit.span(),
281 }));
282 } else {
283 return Err(Error::new(lit.span(), "expected string literal"));
284 }
285 if let Some(unexpected) = inner.next() {
286 return Err(Error::new(
287 unexpected.span(),
288 "unexpected token in env! macro",
289 ));
290 }
291 } else {
292 segments.push(Segment::String(fragment));
293 }
294 }
295 TokenTree::Literal(lit) => {
296 let mut lit_string = lit.to_string();
297 if lit_string.contains(&['#', '\\', '.', '+'][..]) {
298 return Err(Error::new(lit.span(), "unsupported literal"));
299 }
300 lit_string = lit_string
301 .replace('"', "")
302 .replace('\'', "")
303 .replace('-', "_");
304 segments.push(Segment::String(lit_string));
305 }
306 TokenTree::Punct(punct) => match punct.as_char() {
307 '_' => segments.push(Segment::String("_".to_owned())),
308 '\'' => segments.push(Segment::Apostrophe(punct.span())),
309 ':' => {
310 let colon = Colon { span: punct.span() };
311 let ident = match tokens.next() {
312 Some(TokenTree::Ident(ident)) => ident,
313 wrong => {
314 let span = wrong.as_ref().map_or(scope, TokenTree::span);
315 return Err(Error::new(span, "expected identifier after `:`"));
316 }
317 };
318 segments.push(Segment::Modifier(colon, ident));
319 }
320 _ => return Err(Error::new(punct.span(), "unexpected punct")),
321 },
322 TokenTree::Group(group) => {
323 if group.delimiter() == Delimiter::None {
324 let mut inner = group.stream().into_iter().peekable();
325 let nested = parse_segments(&mut inner, group.span())?;
326 if let Some(unexpected) = inner.next() {
327 return Err(Error::new(unexpected.span(), "unexpected token"));
328 }
329 segments.extend(nested);
330 } else {
331 return Err(Error::new(group.span(), "unexpected token"));
332 }
333 }
334 }
335 }
336 Ok(segments)
337}
338
339fn paste_segments(span: Span, segments: &[Segment]) -> Result<TokenStream> {
340 let mut evaluated = Vec::new();
341 let mut is_lifetime = false;
342
343 for segment in segments {
344 match segment {
345 Segment::String(segment) => {
346 evaluated.push(segment.clone());
347 }
348 Segment::Apostrophe(span) => {
349 if is_lifetime {
350 return Err(Error::new(*span, "unexpected lifetime"));
351 }
352 is_lifetime = true;
353 }
354 Segment::Env(var) => {
355 let resolved = match std::env::var(&var.value) {
356 Ok(resolved) => resolved,
357 Err(_) => {
358 return Err(Error::new(
359 var.span,
360 &format!("no such env var: {:?}", var.value),
361 ));
362 }
363 };
364 let resolved = resolved.replace('-', "_");
365 evaluated.push(resolved);
366 }
367 Segment::Modifier(colon, ident) => {
368 let last = match evaluated.pop() {
369 Some(last) => last,
370 None => {
371 return Err(Error::new2(colon.span, ident.span(), "unexpected modifier"))
372 }
373 };
374 match ident.to_string().as_str() {
375 "lower" => {
376 evaluated.push(last.to_lowercase());
377 }
378 "upper" => {
379 evaluated.push(last.to_uppercase());
380 }
381 "snake" => {
382 let mut acc = String::new();
383 let mut prev = '_';
384 for ch in last.chars() {
385 if ch.is_uppercase() && prev != '_' {
386 acc.push('_');
387 }
388 acc.push(ch);
389 prev = ch;
390 }
391 evaluated.push(acc.to_lowercase());
392 }
393 "camel" => {
394 let mut acc = String::new();
395 let mut prev = '_';
396 for ch in last.chars() {
397 if ch != '_' {
398 if prev == '_' {
399 for chu in ch.to_uppercase() {
400 acc.push(chu);
401 }
402 } else if prev.is_uppercase() {
403 for chl in ch.to_lowercase() {
404 acc.push(chl);
405 }
406 } else {
407 acc.push(ch);
408 }
409 }
410 prev = ch;
411 }
412 evaluated.push(acc);
413 }
414 _ => {
415 return Err(Error::new2(
416 colon.span,
417 ident.span(),
418 "unsupported modifier",
419 ));
420 }
421 }
422 }
423 }
424 }
425
426 let pasted = evaluated.into_iter().collect::<String>();
427 let ident = match panic::catch_unwind(|| Ident::new(&pasted, span)) {
428 Ok(ident) => TokenTree::Ident(ident),
429 Err(_) => {
430 return Err(Error::new(
431 span,
432 &format!("`{:?}` is not a valid identifier", pasted),
433 ));
434 }
435 };
436 let tokens = if is_lifetime {
437 let apostrophe = TokenTree::Punct(Punct::new('\'', Spacing::Joint));
438 vec![apostrophe, ident]
439 } else {
440 vec![ident]
441 };
442 Ok(TokenStream::from_iter(tokens))
443}