paste_impl/
lib.rs

1extern crate proc_macro;
2
3mod enum_hack;
4mod error;
5
6use crate::error::{Error, Result};
7use proc_macro::{
8    token_stream, Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree,
9};
10use proc_macro_hack::proc_macro_hack;
11use std::iter::{self, FromIterator, Peekable};
12use std::panic;
13
14#[proc_macro]
15pub fn item(input: TokenStream) -> TokenStream {
16    expand_paste(input)
17}
18
19#[proc_macro]
20pub fn item_with_macros(input: TokenStream) -> TokenStream {
21    enum_hack::wrap(expand_paste(input))
22}
23
24#[proc_macro_hack]
25pub fn expr(input: TokenStream) -> TokenStream {
26    TokenStream::from(TokenTree::Group(Group::new(
27        Delimiter::Brace,
28        expand_paste(input),
29    )))
30}
31
32#[doc(hidden)]
33#[proc_macro_derive(EnumHack)]
34pub fn enum_hack(input: TokenStream) -> TokenStream {
35    enum_hack::extract(input)
36}
37
38fn expand_paste(input: TokenStream) -> TokenStream {
39    let mut contains_paste = false;
40    match expand(input, &mut contains_paste) {
41        Ok(expanded) => expanded,
42        Err(err) => err.to_compile_error(),
43    }
44}
45
46fn expand(input: TokenStream, contains_paste: &mut bool) -> Result<TokenStream> {
47    let mut expanded = TokenStream::new();
48    let (mut prev_colon, mut colon) = (false, false);
49    let mut prev_none_group = None::<Group>;
50    let mut tokens = input.into_iter().peekable();
51    loop {
52        let token = tokens.next();
53        if let Some(group) = prev_none_group.take() {
54            if match (&token, tokens.peek()) {
55                (Some(TokenTree::Punct(fst)), Some(TokenTree::Punct(snd))) => {
56                    fst.as_char() == ':' && snd.as_char() == ':' && fst.spacing() == Spacing::Joint
57                }
58                _ => false,
59            } {
60                expanded.extend(group.stream());
61                *contains_paste = true;
62            } else {
63                expanded.extend(iter::once(TokenTree::Group(group)));
64            }
65        }
66        match token {
67            Some(TokenTree::Group(group)) => {
68                let delimiter = group.delimiter();
69                let content = group.stream();
70                let span = group.span();
71                if delimiter == Delimiter::Bracket && is_paste_operation(&content) {
72                    let segments = parse_bracket_as_segments(content, span)?;
73                    let pasted = paste_segments(span, &segments)?;
74                    expanded.extend(pasted);
75                    *contains_paste = true;
76                } else if is_none_delimited_flat_group(delimiter, &content) {
77                    expanded.extend(content);
78                    *contains_paste = true;
79                } else {
80                    let mut group_contains_paste = false;
81                    let nested = expand(content, &mut group_contains_paste)?;
82                    let group = if group_contains_paste {
83                        let mut group = Group::new(delimiter, nested);
84                        group.set_span(span);
85                        *contains_paste = true;
86                        group
87                    } else {
88                        group.clone()
89                    };
90                    if delimiter != Delimiter::None {
91                        expanded.extend(iter::once(TokenTree::Group(group)));
92                    } else if prev_colon {
93                        expanded.extend(group.stream());
94                        *contains_paste = true;
95                    } else {
96                        prev_none_group = Some(group);
97                    }
98                }
99                prev_colon = false;
100                colon = false;
101            }
102            Some(other) => {
103                match &other {
104                    TokenTree::Punct(punct) if punct.as_char() == ':' => {
105                        prev_colon = colon;
106                        colon = punct.spacing() == Spacing::Joint;
107                    }
108                    _ => {
109                        prev_colon = false;
110                        colon = false;
111                    }
112                }
113                expanded.extend(iter::once(other));
114            }
115            None => return Ok(expanded),
116        }
117    }
118}
119
120// https://github.com/dtolnay/paste/issues/26
121fn is_none_delimited_flat_group(delimiter: Delimiter, input: &TokenStream) -> bool {
122    if delimiter != Delimiter::None {
123        return false;
124    }
125
126    #[derive(PartialEq)]
127    enum State {
128        Init,
129        Ident,
130        Literal,
131        Apostrophe,
132        Lifetime,
133        Colon1,
134        Colon2,
135    }
136
137    let mut state = State::Init;
138    for tt in input.clone() {
139        state = match (state, &tt) {
140            (State::Init, TokenTree::Ident(_)) => State::Ident,
141            (State::Init, TokenTree::Literal(_)) => State::Literal,
142            (State::Init, TokenTree::Punct(punct)) if punct.as_char() == '\'' => State::Apostrophe,
143            (State::Apostrophe, TokenTree::Ident(_)) => State::Lifetime,
144            (State::Ident, TokenTree::Punct(punct))
145                if punct.as_char() == ':' && punct.spacing() == Spacing::Joint =>
146            {
147                State::Colon1
148            }
149            (State::Colon1, TokenTree::Punct(punct))
150                if punct.as_char() == ':' && punct.spacing() == Spacing::Alone =>
151            {
152                State::Colon2
153            }
154            (State::Colon2, TokenTree::Ident(_)) => State::Ident,
155            _ => return false,
156        };
157    }
158
159    state == State::Ident || state == State::Literal || state == State::Lifetime
160}
161
162struct LitStr {
163    value: String,
164    span: Span,
165}
166
167struct Colon {
168    span: Span,
169}
170
171enum Segment {
172    String(String),
173    Apostrophe(Span),
174    Env(LitStr),
175    Modifier(Colon, Ident),
176}
177
178fn is_paste_operation(input: &TokenStream) -> bool {
179    let mut tokens = input.clone().into_iter();
180
181    match &tokens.next() {
182        Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
183        _ => return false,
184    }
185
186    let mut has_token = false;
187    loop {
188        match &tokens.next() {
189            Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {
190                return has_token && tokens.next().is_none();
191            }
192            Some(_) => has_token = true,
193            None => return false,
194        }
195    }
196}
197
198fn parse_bracket_as_segments(input: TokenStream, scope: Span) -> Result<Vec<Segment>> {
199    let mut tokens = input.into_iter().peekable();
200
201    match &tokens.next() {
202        Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
203        Some(wrong) => return Err(Error::new(wrong.span(), "expected `<`")),
204        None => return Err(Error::new(scope, "expected `[< ... >]`")),
205    }
206
207    let segments = parse_segments(&mut tokens, scope)?;
208
209    match &tokens.next() {
210        Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {}
211        Some(wrong) => return Err(Error::new(wrong.span(), "expected `>`")),
212        None => return Err(Error::new(scope, "expected `[< ... >]`")),
213    }
214
215    match tokens.next() {
216        Some(unexpected) => Err(Error::new(
217            unexpected.span(),
218            "unexpected input, expected `[< ... >]`",
219        )),
220        None => Ok(segments),
221    }
222}
223
224fn parse_segments(
225    tokens: &mut Peekable<token_stream::IntoIter>,
226    scope: Span,
227) -> Result<Vec<Segment>> {
228    let mut segments = Vec::new();
229    while match tokens.peek() {
230        None => false,
231        Some(TokenTree::Punct(punct)) => punct.as_char() != '>',
232        Some(_) => true,
233    } {
234        match tokens.next().unwrap() {
235            TokenTree::Ident(ident) => {
236                let mut fragment = ident.to_string();
237                if fragment.starts_with("r#") {
238                    fragment = fragment.split_off(2);
239                }
240                if fragment == "env"
241                    && match tokens.peek() {
242                        Some(TokenTree::Punct(punct)) => punct.as_char() == '!',
243                        _ => false,
244                    }
245                {
246                    tokens.next().unwrap(); // `!`
247                    let expect_group = tokens.next();
248                    let parenthesized = match &expect_group {
249                        Some(TokenTree::Group(group))
250                            if group.delimiter() == Delimiter::Parenthesis =>
251                        {
252                            group
253                        }
254                        Some(wrong) => return Err(Error::new(wrong.span(), "expected `(`")),
255                        None => return Err(Error::new(scope, "expected `(` after `env!`")),
256                    };
257                    let mut inner = parenthesized.stream().into_iter();
258                    let lit = match inner.next() {
259                        Some(TokenTree::Literal(lit)) => lit,
260                        Some(wrong) => {
261                            return Err(Error::new(wrong.span(), "expected string literal"))
262                        }
263                        None => {
264                            return Err(Error::new2(
265                                ident.span(),
266                                parenthesized.span(),
267                                "expected string literal as argument to env! macro",
268                            ))
269                        }
270                    };
271                    let lit_string = lit.to_string();
272                    if lit_string.starts_with('"')
273                        && lit_string.ends_with('"')
274                        && lit_string.len() >= 2
275                    {
276                        // TODO: maybe handle escape sequences in the string if
277                        // someone has a use case.
278                        segments.push(Segment::Env(LitStr {
279                            value: lit_string[1..lit_string.len() - 1].to_owned(),
280                            span: lit.span(),
281                        }));
282                    } else {
283                        return Err(Error::new(lit.span(), "expected string literal"));
284                    }
285                    if let Some(unexpected) = inner.next() {
286                        return Err(Error::new(
287                            unexpected.span(),
288                            "unexpected token in env! macro",
289                        ));
290                    }
291                } else {
292                    segments.push(Segment::String(fragment));
293                }
294            }
295            TokenTree::Literal(lit) => {
296                let mut lit_string = lit.to_string();
297                if lit_string.contains(&['#', '\\', '.', '+'][..]) {
298                    return Err(Error::new(lit.span(), "unsupported literal"));
299                }
300                lit_string = lit_string
301                    .replace('"', "")
302                    .replace('\'', "")
303                    .replace('-', "_");
304                segments.push(Segment::String(lit_string));
305            }
306            TokenTree::Punct(punct) => match punct.as_char() {
307                '_' => segments.push(Segment::String("_".to_owned())),
308                '\'' => segments.push(Segment::Apostrophe(punct.span())),
309                ':' => {
310                    let colon = Colon { span: punct.span() };
311                    let ident = match tokens.next() {
312                        Some(TokenTree::Ident(ident)) => ident,
313                        wrong => {
314                            let span = wrong.as_ref().map_or(scope, TokenTree::span);
315                            return Err(Error::new(span, "expected identifier after `:`"));
316                        }
317                    };
318                    segments.push(Segment::Modifier(colon, ident));
319                }
320                _ => return Err(Error::new(punct.span(), "unexpected punct")),
321            },
322            TokenTree::Group(group) => {
323                if group.delimiter() == Delimiter::None {
324                    let mut inner = group.stream().into_iter().peekable();
325                    let nested = parse_segments(&mut inner, group.span())?;
326                    if let Some(unexpected) = inner.next() {
327                        return Err(Error::new(unexpected.span(), "unexpected token"));
328                    }
329                    segments.extend(nested);
330                } else {
331                    return Err(Error::new(group.span(), "unexpected token"));
332                }
333            }
334        }
335    }
336    Ok(segments)
337}
338
339fn paste_segments(span: Span, segments: &[Segment]) -> Result<TokenStream> {
340    let mut evaluated = Vec::new();
341    let mut is_lifetime = false;
342
343    for segment in segments {
344        match segment {
345            Segment::String(segment) => {
346                evaluated.push(segment.clone());
347            }
348            Segment::Apostrophe(span) => {
349                if is_lifetime {
350                    return Err(Error::new(*span, "unexpected lifetime"));
351                }
352                is_lifetime = true;
353            }
354            Segment::Env(var) => {
355                let resolved = match std::env::var(&var.value) {
356                    Ok(resolved) => resolved,
357                    Err(_) => {
358                        return Err(Error::new(
359                            var.span,
360                            &format!("no such env var: {:?}", var.value),
361                        ));
362                    }
363                };
364                let resolved = resolved.replace('-', "_");
365                evaluated.push(resolved);
366            }
367            Segment::Modifier(colon, ident) => {
368                let last = match evaluated.pop() {
369                    Some(last) => last,
370                    None => {
371                        return Err(Error::new2(colon.span, ident.span(), "unexpected modifier"))
372                    }
373                };
374                match ident.to_string().as_str() {
375                    "lower" => {
376                        evaluated.push(last.to_lowercase());
377                    }
378                    "upper" => {
379                        evaluated.push(last.to_uppercase());
380                    }
381                    "snake" => {
382                        let mut acc = String::new();
383                        let mut prev = '_';
384                        for ch in last.chars() {
385                            if ch.is_uppercase() && prev != '_' {
386                                acc.push('_');
387                            }
388                            acc.push(ch);
389                            prev = ch;
390                        }
391                        evaluated.push(acc.to_lowercase());
392                    }
393                    "camel" => {
394                        let mut acc = String::new();
395                        let mut prev = '_';
396                        for ch in last.chars() {
397                            if ch != '_' {
398                                if prev == '_' {
399                                    for chu in ch.to_uppercase() {
400                                        acc.push(chu);
401                                    }
402                                } else if prev.is_uppercase() {
403                                    for chl in ch.to_lowercase() {
404                                        acc.push(chl);
405                                    }
406                                } else {
407                                    acc.push(ch);
408                                }
409                            }
410                            prev = ch;
411                        }
412                        evaluated.push(acc);
413                    }
414                    _ => {
415                        return Err(Error::new2(
416                            colon.span,
417                            ident.span(),
418                            "unsupported modifier",
419                        ));
420                    }
421                }
422            }
423        }
424    }
425
426    let pasted = evaluated.into_iter().collect::<String>();
427    let ident = match panic::catch_unwind(|| Ident::new(&pasted, span)) {
428        Ok(ident) => TokenTree::Ident(ident),
429        Err(_) => {
430            return Err(Error::new(
431                span,
432                &format!("`{:?}` is not a valid identifier", pasted),
433            ));
434        }
435    };
436    let tokens = if is_lifetime {
437        let apostrophe = TokenTree::Punct(Punct::new('\'', Spacing::Joint));
438        vec![apostrophe, ident]
439    } else {
440        vec![ident]
441    };
442    Ok(TokenStream::from_iter(tokens))
443}