thiserror_impl/
fmt.rs

1use crate::ast::{ContainerKind, Field};
2use crate::attr::{Display, Trait};
3use crate::private;
4use crate::scan_expr::scan_expr;
5use crate::unraw::{IdentUnraw, MemberUnraw};
6use proc_macro2::{Delimiter, TokenStream, TokenTree};
7use quote::{format_ident, quote, quote_spanned, ToTokens as _};
8use std::collections::{BTreeSet, HashMap};
9use std::iter;
10use syn::ext::IdentExt;
11use syn::parse::discouraged::Speculative;
12use syn::parse::{Error, ParseStream, Parser, Result};
13use syn::{Expr, Ident, Index, LitStr, Token};
14
15impl Display<'_> {
16    pub fn expand_shorthand(&mut self, fields: &[Field], container: ContainerKind) -> Result<()> {
17        let raw_args = self.args.clone();
18        let FmtArguments {
19            named: user_named_args,
20            first_unnamed,
21        } = explicit_named_args.parse2(raw_args).unwrap();
22
23        let mut member_index = HashMap::new();
24        let mut extra_positional_arguments_allowed = true;
25        for (i, field) in fields.iter().enumerate() {
26            member_index.insert(&field.member, i);
27            extra_positional_arguments_allowed &= matches!(&field.member, MemberUnraw::Named(_));
28        }
29
30        let span = self.fmt.span();
31        let fmt = self.fmt.value();
32        let mut read = fmt.as_str();
33        let mut out = String::new();
34        let mut has_bonus_display = false;
35        let mut infinite_recursive = false;
36        let mut implied_bounds = BTreeSet::new();
37        let mut bindings = Vec::new();
38        let mut macro_named_args = BTreeSet::new();
39
40        self.requires_fmt_machinery = self.requires_fmt_machinery || fmt.contains('}');
41
42        while let Some(brace) = read.find('{') {
43            self.requires_fmt_machinery = true;
44            out += &read[..brace + 1];
45            read = &read[brace + 1..];
46            if read.starts_with('{') {
47                out.push('{');
48                read = &read[1..];
49                continue;
50            }
51            let next = match read.chars().next() {
52                Some(next) => next,
53                None => return Ok(()),
54            };
55            let member = match next {
56                '0'..='9' => {
57                    let int = take_int(&mut read);
58                    if !extra_positional_arguments_allowed {
59                        if let Some(first_unnamed) = &first_unnamed {
60                            let msg = format!("ambiguous reference to positional arguments by number in a {container}; change this to a named argument");
61                            return Err(Error::new_spanned(first_unnamed, msg));
62                        }
63                    }
64                    match int.parse::<u32>() {
65                        Ok(index) => MemberUnraw::Unnamed(Index { index, span }),
66                        Err(_) => return Ok(()),
67                    }
68                }
69                'a'..='z' | 'A'..='Z' | '_' => {
70                    if read.starts_with("r#") {
71                        continue;
72                    }
73                    let repr = take_ident(&mut read);
74                    if repr == "_" {
75                        // Invalid. Let rustc produce the diagnostic.
76                        out += repr;
77                        continue;
78                    }
79                    let ident = IdentUnraw::new(Ident::new(repr, span));
80                    if user_named_args.contains(&ident) {
81                        // Refers to a named argument written by the user, not to field.
82                        out += repr;
83                        continue;
84                    }
85                    MemberUnraw::Named(ident)
86                }
87                _ => continue,
88            };
89            let end_spec = match read.find('}') {
90                Some(end_spec) => end_spec,
91                None => return Ok(()),
92            };
93            let mut bonus_display = false;
94            let bound = match read[..end_spec].chars().next_back() {
95                Some('?') => Trait::Debug,
96                Some('o') => Trait::Octal,
97                Some('x') => Trait::LowerHex,
98                Some('X') => Trait::UpperHex,
99                Some('p') => Trait::Pointer,
100                Some('b') => Trait::Binary,
101                Some('e') => Trait::LowerExp,
102                Some('E') => Trait::UpperExp,
103                Some(_) => Trait::Display,
104                None => {
105                    bonus_display = true;
106                    has_bonus_display = true;
107                    Trait::Display
108                }
109            };
110            infinite_recursive |= member == *"self" && bound == Trait::Display;
111            let field = match member_index.get(&member) {
112                Some(&field) => field,
113                None => {
114                    out += &member.to_string();
115                    continue;
116                }
117            };
118            implied_bounds.insert((field, bound));
119            let formatvar_prefix = if bonus_display {
120                "__display"
121            } else if bound == Trait::Pointer {
122                "__pointer"
123            } else {
124                "__field"
125            };
126            let mut formatvar = IdentUnraw::new(match &member {
127                MemberUnraw::Unnamed(index) => format_ident!("{}{}", formatvar_prefix, index),
128                MemberUnraw::Named(ident) => {
129                    format_ident!("{}_{}", formatvar_prefix, ident.to_string())
130                }
131            });
132            while user_named_args.contains(&formatvar) {
133                formatvar = IdentUnraw::new(format_ident!("_{}", formatvar.to_string()));
134            }
135            formatvar.set_span(span);
136            out += &formatvar.to_string();
137            if !macro_named_args.insert(formatvar.clone()) {
138                // Already added to bindings by a previous use.
139                continue;
140            }
141            let mut binding_value = match &member {
142                MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
143                MemberUnraw::Named(ident) => ident.to_local(),
144            };
145            binding_value.set_span(span.resolved_at(fields[field].member.span()));
146            let wrapped_binding_value = if bonus_display {
147                quote_spanned!(span=> #binding_value.as_display())
148            } else if bound == Trait::Pointer {
149                quote!(::thiserror::#private::Var(#binding_value))
150            } else {
151                binding_value.into_token_stream()
152            };
153            bindings.push((formatvar.to_local(), wrapped_binding_value));
154        }
155
156        out += read;
157        self.fmt = LitStr::new(&out, self.fmt.span());
158        self.has_bonus_display = has_bonus_display;
159        self.infinite_recursive = infinite_recursive;
160        self.implied_bounds = implied_bounds;
161        self.bindings = bindings;
162        Ok(())
163    }
164}
165
166struct FmtArguments {
167    named: BTreeSet<IdentUnraw>,
168    first_unnamed: Option<TokenStream>,
169}
170
171#[allow(clippy::unnecessary_wraps)]
172fn explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
173    let ahead = input.fork();
174    if let Ok(set) = try_explicit_named_args(&ahead) {
175        input.advance_to(&ahead);
176        return Ok(set);
177    }
178
179    let ahead = input.fork();
180    if let Ok(set) = fallback_explicit_named_args(&ahead) {
181        input.advance_to(&ahead);
182        return Ok(set);
183    }
184
185    input.parse::<TokenStream>().unwrap();
186    Ok(FmtArguments {
187        named: BTreeSet::new(),
188        first_unnamed: None,
189    })
190}
191
192fn try_explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
193    let mut syn_full = None;
194    let mut args = FmtArguments {
195        named: BTreeSet::new(),
196        first_unnamed: None,
197    };
198
199    while !input.is_empty() {
200        input.parse::<Token![,]>()?;
201        if input.is_empty() {
202            break;
203        }
204
205        let mut begin_unnamed = None;
206        if input.peek(Ident::peek_any) && input.peek2(Token![=]) && !input.peek2(Token![==]) {
207            let ident: IdentUnraw = input.parse()?;
208            input.parse::<Token![=]>()?;
209            args.named.insert(ident);
210        } else {
211            begin_unnamed = Some(input.fork());
212        }
213
214        let ahead = input.fork();
215        if *syn_full.get_or_insert_with(is_syn_full) && ahead.parse::<Expr>().is_ok() {
216            input.advance_to(&ahead);
217        } else {
218            scan_expr(input)?;
219        }
220
221        if let Some(begin_unnamed) = begin_unnamed {
222            if args.first_unnamed.is_none() {
223                args.first_unnamed = Some(between(&begin_unnamed, input));
224            }
225        }
226    }
227
228    Ok(args)
229}
230
231fn fallback_explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
232    let mut args = FmtArguments {
233        named: BTreeSet::new(),
234        first_unnamed: None,
235    };
236
237    while !input.is_empty() {
238        if input.peek(Token![,])
239            && input.peek2(Ident::peek_any)
240            && input.peek3(Token![=])
241            && !input.peek3(Token![==])
242        {
243            input.parse::<Token![,]>()?;
244            let ident: IdentUnraw = input.parse()?;
245            input.parse::<Token![=]>()?;
246            args.named.insert(ident);
247        } else {
248            input.parse::<TokenTree>()?;
249        }
250    }
251
252    Ok(args)
253}
254
255fn is_syn_full() -> bool {
256    // Expr::Block contains syn::Block which contains Vec<syn::Stmt>. In the
257    // current version of Syn, syn::Stmt is exhaustive and could only plausibly
258    // represent `trait Trait {}` in Stmt::Item which contains syn::Item. Most
259    // of the point of syn's non-"full" mode is to avoid compiling Item and the
260    // entire expansive syntax tree it comprises. So the following expression
261    // being parsed to Expr::Block is a reliable indication that "full" is
262    // enabled.
263    let test = quote!({
264        trait Trait {}
265    });
266    match syn::parse2(test) {
267        Ok(Expr::Verbatim(_)) | Err(_) => false,
268        Ok(Expr::Block(_)) => true,
269        Ok(_) => unreachable!(),
270    }
271}
272
273fn take_int<'a>(read: &mut &'a str) -> &'a str {
274    let mut int_len = 0;
275    for ch in read.chars() {
276        match ch {
277            '0'..='9' => int_len += 1,
278            _ => break,
279        }
280    }
281    let (int, rest) = read.split_at(int_len);
282    *read = rest;
283    int
284}
285
286fn take_ident<'a>(read: &mut &'a str) -> &'a str {
287    let mut ident_len = 0;
288    for ch in read.chars() {
289        match ch {
290            'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => ident_len += 1,
291            _ => break,
292        }
293    }
294    let (ident, rest) = read.split_at(ident_len);
295    *read = rest;
296    ident
297}
298
299fn between<'a>(begin: ParseStream<'a>, end: ParseStream<'a>) -> TokenStream {
300    let end = end.cursor();
301    let mut cursor = begin.cursor();
302    let mut tokens = TokenStream::new();
303
304    while cursor < end {
305        let (tt, next) = cursor.token_tree().unwrap();
306
307        if end < next {
308            if let Some((inside, _span, _after)) = cursor.group(Delimiter::None) {
309                cursor = inside;
310                continue;
311            }
312            if tokens.is_empty() {
313                tokens.extend(iter::once(tt));
314            }
315            break;
316        }
317
318        tokens.extend(iter::once(tt));
319        cursor = next;
320    }
321
322    tokens
323}