pastey/
lib.rs

1#![doc = include_str!("../README.md")]
2#![allow(
3    clippy::derive_partial_eq_without_eq,
4    clippy::doc_markdown,
5    clippy::match_same_arms,
6    clippy::module_name_repetitions,
7    clippy::needless_doctest_main,
8    clippy::too_many_lines
9)]
10
11extern crate proc_macro;
12
13mod attr;
14mod error;
15mod segment;
16
17use crate::attr::expand_attr;
18use crate::error::{Error, Result};
19use crate::segment::Segment;
20use proc_macro::{
21    Delimiter, Group, Ident, LexError, Literal, Punct, Spacing, Span, TokenStream, TokenTree,
22};
23use std::char;
24use std::iter;
25use std::panic;
26use std::str::FromStr;
27
28#[proc_macro]
29pub fn paste(input: TokenStream) -> TokenStream {
30    let mut contains_paste = false;
31    let flatten_single_interpolation = true;
32    match expand(
33        input.clone(),
34        &mut contains_paste,
35        flatten_single_interpolation,
36    ) {
37        Ok(expanded) => {
38            if contains_paste {
39                expanded
40            } else {
41                input
42            }
43        }
44        Err(err) => err.to_compile_error(),
45    }
46}
47
48#[doc(hidden)]
49#[proc_macro]
50pub fn item(input: TokenStream) -> TokenStream {
51    paste(input)
52}
53
54#[doc(hidden)]
55#[proc_macro]
56pub fn expr(input: TokenStream) -> TokenStream {
57    paste(input)
58}
59
60fn expand(
61    input: TokenStream,
62    contains_paste: &mut bool,
63    flatten_single_interpolation: bool,
64) -> Result<TokenStream> {
65    let mut expanded = TokenStream::new();
66    let mut lookbehind = Lookbehind::Other;
67    let mut prev_none_group = None::<Group>;
68    let mut tokens = input.into_iter().peekable();
69    loop {
70        let token = tokens.next();
71        if let Some(group) = prev_none_group.take() {
72            if match (&token, tokens.peek()) {
73                (Some(TokenTree::Punct(fst)), Some(TokenTree::Punct(snd))) => {
74                    fst.as_char() == ':' && snd.as_char() == ':' && fst.spacing() == Spacing::Joint
75                }
76                _ => false,
77            } {
78                expanded.extend(group.stream());
79                *contains_paste = true;
80            } else {
81                expanded.extend(iter::once(TokenTree::Group(group)));
82            }
83        }
84        match token {
85            Some(TokenTree::Group(group)) => {
86                let delimiter = group.delimiter();
87                let content = group.stream();
88                let span = group.span();
89                if delimiter == Delimiter::Bracket && is_paste_operation(&content) {
90                    let segments = parse_bracket_as_segments(content, span)?;
91                    let pasted = segment::paste(&segments)?;
92                    let tokens = pasted_to_tokens(pasted, span)?;
93                    expanded.extend(tokens);
94                    *contains_paste = true;
95                } else if flatten_single_interpolation
96                    && delimiter == Delimiter::None
97                    && is_single_interpolation_group(&content)
98                {
99                    expanded.extend(content);
100                    *contains_paste = true;
101                } else {
102                    let mut group_contains_paste = false;
103                    let is_attribute = delimiter == Delimiter::Bracket
104                        && (lookbehind == Lookbehind::Pound || lookbehind == Lookbehind::PoundBang);
105                    let mut nested = expand(
106                        content,
107                        &mut group_contains_paste,
108                        flatten_single_interpolation && !is_attribute,
109                    )?;
110                    if is_attribute {
111                        nested = expand_attr(nested, span, &mut group_contains_paste)?;
112                    }
113                    let group = if group_contains_paste {
114                        let mut group = Group::new(delimiter, nested);
115                        group.set_span(span);
116                        *contains_paste = true;
117                        group
118                    } else {
119                        group.clone()
120                    };
121                    if delimiter != Delimiter::None {
122                        expanded.extend(iter::once(TokenTree::Group(group)));
123                    } else if lookbehind == Lookbehind::DoubleColon {
124                        expanded.extend(group.stream());
125                        *contains_paste = true;
126                    } else {
127                        prev_none_group = Some(group);
128                    }
129                }
130                lookbehind = Lookbehind::Other;
131            }
132            Some(TokenTree::Punct(punct)) => {
133                lookbehind = match punct.as_char() {
134                    ':' if lookbehind == Lookbehind::JointColon => Lookbehind::DoubleColon,
135                    ':' if punct.spacing() == Spacing::Joint => Lookbehind::JointColon,
136                    '#' => Lookbehind::Pound,
137                    '!' if lookbehind == Lookbehind::Pound => Lookbehind::PoundBang,
138                    _ => Lookbehind::Other,
139                };
140                expanded.extend(iter::once(TokenTree::Punct(punct)));
141            }
142            Some(other) => {
143                lookbehind = Lookbehind::Other;
144                expanded.extend(iter::once(other));
145            }
146            None => return Ok(expanded),
147        }
148    }
149}
150
151#[derive(PartialEq)]
152enum Lookbehind {
153    JointColon,
154    DoubleColon,
155    Pound,
156    PoundBang,
157    Other,
158}
159
160// https://github.com/dtolnay/paste/issues/26
161fn is_single_interpolation_group(input: &TokenStream) -> bool {
162    #[derive(PartialEq)]
163    enum State {
164        Init,
165        Ident,
166        Literal,
167        Apostrophe,
168        Lifetime,
169        Colon1,
170        Colon2,
171    }
172
173    let mut state = State::Init;
174    for tt in input.clone() {
175        state = match (state, &tt) {
176            (State::Init, TokenTree::Ident(_)) => State::Ident,
177            (State::Init, TokenTree::Literal(_)) => State::Literal,
178            (State::Init, TokenTree::Punct(punct)) if punct.as_char() == '\'' => State::Apostrophe,
179            (State::Apostrophe, TokenTree::Ident(_)) => State::Lifetime,
180            (State::Ident, TokenTree::Punct(punct))
181                if punct.as_char() == ':' && punct.spacing() == Spacing::Joint =>
182            {
183                State::Colon1
184            }
185            (State::Colon1, TokenTree::Punct(punct))
186                if punct.as_char() == ':' && punct.spacing() == Spacing::Alone =>
187            {
188                State::Colon2
189            }
190            (State::Colon2, TokenTree::Ident(_)) => State::Ident,
191            _ => return false,
192        };
193    }
194
195    state == State::Ident || state == State::Literal || state == State::Lifetime
196}
197
198fn is_paste_operation(input: &TokenStream) -> bool {
199    let mut tokens = input.clone().into_iter();
200
201    match &tokens.next() {
202        Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
203        _ => return false,
204    }
205
206    let mut has_token = false;
207    loop {
208        match &tokens.next() {
209            Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {
210                return has_token && tokens.next().is_none();
211            }
212            Some(_) => has_token = true,
213            None => return false,
214        }
215    }
216}
217
218fn parse_bracket_as_segments(input: TokenStream, scope: Span) -> Result<Vec<Segment>> {
219    let mut tokens = input.into_iter().peekable();
220
221    match &tokens.next() {
222        Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
223        Some(wrong) => return Err(Error::new(wrong.span(), "expected `<`")),
224        None => return Err(Error::new(scope, "expected `[< ... >]`")),
225    }
226
227    let mut segments = segment::parse(&mut tokens)?;
228
229    match &tokens.next() {
230        Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {}
231        Some(wrong) => return Err(Error::new(wrong.span(), "expected `>`")),
232        None => return Err(Error::new(scope, "expected `[< ... >]`")),
233    }
234
235    if let Some(unexpected) = tokens.next() {
236        return Err(Error::new(
237            unexpected.span(),
238            "unexpected input, expected `[< ... >]`",
239        ));
240    }
241
242    for segment in &mut segments {
243        if let Segment::String(string) = segment {
244            if string.value.starts_with("'\\u{") {
245                let hex = &string.value[4..string.value.len() - 2];
246                if let Ok(unsigned) = u32::from_str_radix(hex, 16) {
247                    if let Some(ch) = char::from_u32(unsigned) {
248                        string.value.clear();
249                        string.value.push(ch);
250                        continue;
251                    }
252                }
253            }
254            if string.value.contains(&['\\', '.', '+'][..])
255                || string.value.starts_with("b'")
256                || string.value.starts_with("b\"")
257                || string.value.starts_with("br\"")
258            {
259                return Err(Error::new(string.span, "unsupported literal"));
260            }
261            let mut range = 0..string.value.len();
262            if string.value.starts_with("r\"") {
263                range.start += 2;
264                range.end -= 1;
265            } else if string.value.starts_with(&['"', '\''][..]) {
266                range.start += 1;
267                range.end -= 1;
268            }
269            string.value = string.value[range].replace('-', "_");
270        }
271    }
272
273    Ok(segments)
274}
275
276fn pasted_to_tokens(mut pasted: String, span: Span) -> Result<TokenStream> {
277    let mut raw_mode = false;
278    let mut tokens = TokenStream::new();
279
280    if pasted.starts_with(|ch: char| ch.is_ascii_digit()) {
281        let literal = match panic::catch_unwind(|| Literal::from_str(&pasted)) {
282            Ok(Ok(literal)) => TokenTree::Literal(literal),
283            Ok(Err(LexError { .. })) | Err(_) => {
284                return Err(Error::new(
285                    span,
286                    &format!("`{:?}` is not a valid literal", pasted),
287                ));
288            }
289        };
290        tokens.extend(iter::once(literal));
291        return Ok(tokens);
292    }
293
294    if pasted.starts_with('\'') {
295        let mut apostrophe = TokenTree::Punct(Punct::new('\'', Spacing::Joint));
296        apostrophe.set_span(span);
297        tokens.extend(iter::once(apostrophe));
298        pasted.remove(0);
299    }
300
301    if pasted.starts_with("r#") {
302        raw_mode = true;
303    }
304
305    let ident = match panic::catch_unwind(|| {
306        if raw_mode {
307            let mut spasted = pasted.clone();
308            spasted.remove(0);
309            spasted.remove(0);
310            Ident::new_raw(&spasted, span)
311        } else {
312            Ident::new(&pasted, span)
313        }
314    }) {
315        Ok(ident) => TokenTree::Ident(ident),
316        Err(_) => {
317            return Err(Error::new(
318                span,
319                &format!("`{:?}` is not a valid identifier", pasted),
320            ));
321        }
322    };
323
324    tokens.extend(iter::once(ident));
325    Ok(tokens)
326}