1#![doc = include_str!("../README.md")]
2#![allow(
3 clippy::derive_partial_eq_without_eq,
4 clippy::doc_markdown,
5 clippy::match_same_arms,
6 clippy::module_name_repetitions,
7 clippy::needless_doctest_main,
8 clippy::too_many_lines
9)]
10
11extern crate proc_macro;
12
13mod attr;
14mod error;
15mod segment;
16
17use crate::attr::expand_attr;
18use crate::error::{Error, Result};
19use crate::segment::Segment;
20use proc_macro::{
21 Delimiter, Group, Ident, LexError, Literal, Punct, Spacing, Span, TokenStream, TokenTree,
22};
23use std::char;
24use std::iter;
25use std::panic;
26use std::str::FromStr;
27
28#[proc_macro]
29pub fn paste(input: TokenStream) -> TokenStream {
30 let mut contains_paste = false;
31 let flatten_single_interpolation = true;
32 match expand(
33 input.clone(),
34 &mut contains_paste,
35 flatten_single_interpolation,
36 ) {
37 Ok(expanded) => {
38 if contains_paste {
39 expanded
40 } else {
41 input
42 }
43 }
44 Err(err) => err.to_compile_error(),
45 }
46}
47
48#[doc(hidden)]
49#[proc_macro]
50pub fn item(input: TokenStream) -> TokenStream {
51 paste(input)
52}
53
54#[doc(hidden)]
55#[proc_macro]
56pub fn expr(input: TokenStream) -> TokenStream {
57 paste(input)
58}
59
60fn expand(
61 input: TokenStream,
62 contains_paste: &mut bool,
63 flatten_single_interpolation: bool,
64) -> Result<TokenStream> {
65 let mut expanded = TokenStream::new();
66 let mut lookbehind = Lookbehind::Other;
67 let mut prev_none_group = None::<Group>;
68 let mut tokens = input.into_iter().peekable();
69 loop {
70 let token = tokens.next();
71 if let Some(group) = prev_none_group.take() {
72 if match (&token, tokens.peek()) {
73 (Some(TokenTree::Punct(fst)), Some(TokenTree::Punct(snd))) => {
74 fst.as_char() == ':' && snd.as_char() == ':' && fst.spacing() == Spacing::Joint
75 }
76 _ => false,
77 } {
78 expanded.extend(group.stream());
79 *contains_paste = true;
80 } else {
81 expanded.extend(iter::once(TokenTree::Group(group)));
82 }
83 }
84 match token {
85 Some(TokenTree::Group(group)) => {
86 let delimiter = group.delimiter();
87 let content = group.stream();
88 let span = group.span();
89 if delimiter == Delimiter::Bracket && is_paste_operation(&content) {
90 let segments = parse_bracket_as_segments(content, span)?;
91 let pasted = segment::paste(&segments)?;
92 let tokens = pasted_to_tokens(pasted, span)?;
93 expanded.extend(tokens);
94 *contains_paste = true;
95 } else if flatten_single_interpolation
96 && delimiter == Delimiter::None
97 && is_single_interpolation_group(&content)
98 {
99 expanded.extend(content);
100 *contains_paste = true;
101 } else {
102 let mut group_contains_paste = false;
103 let is_attribute = delimiter == Delimiter::Bracket
104 && (lookbehind == Lookbehind::Pound || lookbehind == Lookbehind::PoundBang);
105 let mut nested = expand(
106 content,
107 &mut group_contains_paste,
108 flatten_single_interpolation && !is_attribute,
109 )?;
110 if is_attribute {
111 nested = expand_attr(nested, span, &mut group_contains_paste)?;
112 }
113 let group = if group_contains_paste {
114 let mut group = Group::new(delimiter, nested);
115 group.set_span(span);
116 *contains_paste = true;
117 group
118 } else {
119 group.clone()
120 };
121 if delimiter != Delimiter::None {
122 expanded.extend(iter::once(TokenTree::Group(group)));
123 } else if lookbehind == Lookbehind::DoubleColon {
124 expanded.extend(group.stream());
125 *contains_paste = true;
126 } else {
127 prev_none_group = Some(group);
128 }
129 }
130 lookbehind = Lookbehind::Other;
131 }
132 Some(TokenTree::Punct(punct)) => {
133 lookbehind = match punct.as_char() {
134 ':' if lookbehind == Lookbehind::JointColon => Lookbehind::DoubleColon,
135 ':' if punct.spacing() == Spacing::Joint => Lookbehind::JointColon,
136 '#' => Lookbehind::Pound,
137 '!' if lookbehind == Lookbehind::Pound => Lookbehind::PoundBang,
138 _ => Lookbehind::Other,
139 };
140 expanded.extend(iter::once(TokenTree::Punct(punct)));
141 }
142 Some(other) => {
143 lookbehind = Lookbehind::Other;
144 expanded.extend(iter::once(other));
145 }
146 None => return Ok(expanded),
147 }
148 }
149}
150
151#[derive(PartialEq)]
152enum Lookbehind {
153 JointColon,
154 DoubleColon,
155 Pound,
156 PoundBang,
157 Other,
158}
159
160fn is_single_interpolation_group(input: &TokenStream) -> bool {
162 #[derive(PartialEq)]
163 enum State {
164 Init,
165 Ident,
166 Literal,
167 Apostrophe,
168 Lifetime,
169 Colon1,
170 Colon2,
171 }
172
173 let mut state = State::Init;
174 for tt in input.clone() {
175 state = match (state, &tt) {
176 (State::Init, TokenTree::Ident(_)) => State::Ident,
177 (State::Init, TokenTree::Literal(_)) => State::Literal,
178 (State::Init, TokenTree::Punct(punct)) if punct.as_char() == '\'' => State::Apostrophe,
179 (State::Apostrophe, TokenTree::Ident(_)) => State::Lifetime,
180 (State::Ident, TokenTree::Punct(punct))
181 if punct.as_char() == ':' && punct.spacing() == Spacing::Joint =>
182 {
183 State::Colon1
184 }
185 (State::Colon1, TokenTree::Punct(punct))
186 if punct.as_char() == ':' && punct.spacing() == Spacing::Alone =>
187 {
188 State::Colon2
189 }
190 (State::Colon2, TokenTree::Ident(_)) => State::Ident,
191 _ => return false,
192 };
193 }
194
195 state == State::Ident || state == State::Literal || state == State::Lifetime
196}
197
198fn is_paste_operation(input: &TokenStream) -> bool {
199 let mut tokens = input.clone().into_iter();
200
201 match &tokens.next() {
202 Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
203 _ => return false,
204 }
205
206 let mut has_token = false;
207 loop {
208 match &tokens.next() {
209 Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {
210 return has_token && tokens.next().is_none();
211 }
212 Some(_) => has_token = true,
213 None => return false,
214 }
215 }
216}
217
218fn parse_bracket_as_segments(input: TokenStream, scope: Span) -> Result<Vec<Segment>> {
219 let mut tokens = input.into_iter().peekable();
220
221 match &tokens.next() {
222 Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
223 Some(wrong) => return Err(Error::new(wrong.span(), "expected `<`")),
224 None => return Err(Error::new(scope, "expected `[< ... >]`")),
225 }
226
227 let mut segments = segment::parse(&mut tokens)?;
228
229 match &tokens.next() {
230 Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {}
231 Some(wrong) => return Err(Error::new(wrong.span(), "expected `>`")),
232 None => return Err(Error::new(scope, "expected `[< ... >]`")),
233 }
234
235 if let Some(unexpected) = tokens.next() {
236 return Err(Error::new(
237 unexpected.span(),
238 "unexpected input, expected `[< ... >]`",
239 ));
240 }
241
242 for segment in &mut segments {
243 if let Segment::String(string) = segment {
244 if string.value.starts_with("'\\u{") {
245 let hex = &string.value[4..string.value.len() - 2];
246 if let Ok(unsigned) = u32::from_str_radix(hex, 16) {
247 if let Some(ch) = char::from_u32(unsigned) {
248 string.value.clear();
249 string.value.push(ch);
250 continue;
251 }
252 }
253 }
254 if string.value.contains(&['\\', '.', '+'][..])
255 || string.value.starts_with("b'")
256 || string.value.starts_with("b\"")
257 || string.value.starts_with("br\"")
258 {
259 return Err(Error::new(string.span, "unsupported literal"));
260 }
261 let mut range = 0..string.value.len();
262 if string.value.starts_with("r\"") {
263 range.start += 2;
264 range.end -= 1;
265 } else if string.value.starts_with(&['"', '\''][..]) {
266 range.start += 1;
267 range.end -= 1;
268 }
269 string.value = string.value[range].replace('-', "_");
270 }
271 }
272
273 Ok(segments)
274}
275
276fn pasted_to_tokens(mut pasted: String, span: Span) -> Result<TokenStream> {
277 let mut raw_mode = false;
278 let mut tokens = TokenStream::new();
279
280 if pasted.starts_with(|ch: char| ch.is_ascii_digit()) {
281 let literal = match panic::catch_unwind(|| Literal::from_str(&pasted)) {
282 Ok(Ok(literal)) => TokenTree::Literal(literal),
283 Ok(Err(LexError { .. })) | Err(_) => {
284 return Err(Error::new(
285 span,
286 &format!("`{:?}` is not a valid literal", pasted),
287 ));
288 }
289 };
290 tokens.extend(iter::once(literal));
291 return Ok(tokens);
292 }
293
294 if pasted.starts_with('\'') {
295 let mut apostrophe = TokenTree::Punct(Punct::new('\'', Spacing::Joint));
296 apostrophe.set_span(span);
297 tokens.extend(iter::once(apostrophe));
298 pasted.remove(0);
299 }
300
301 if pasted.starts_with("r#") {
302 raw_mode = true;
303 }
304
305 let ident = match panic::catch_unwind(|| {
306 if raw_mode {
307 let mut spasted = pasted.clone();
308 spasted.remove(0);
309 spasted.remove(0);
310 Ident::new_raw(&spasted, span)
311 } else {
312 Ident::new(&pasted, span)
313 }
314 }) {
315 Ok(ident) => TokenTree::Ident(ident),
316 Err(_) => {
317 return Err(Error::new(
318 span,
319 &format!("`{:?}` is not a valid identifier", pasted),
320 ));
321 }
322 };
323
324 tokens.extend(iter::once(ident));
325 Ok(tokens)
326}