thiserror_impl/
expand.rs

1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::fallback;
4use crate::generics::InferredBounds;
5use crate::private;
6use crate::unraw::MemberUnraw;
7use proc_macro2::{Ident, Span, TokenStream};
8use quote::{format_ident, quote, quote_spanned, ToTokens};
9use std::collections::BTreeSet as Set;
10use syn::{DeriveInput, GenericArgument, PathArguments, Result, Token, Type};
11
12pub fn derive(input: &DeriveInput) -> TokenStream {
13    match try_expand(input) {
14        Ok(expanded) => expanded,
15        // If there are invalid attributes in the input, expand to an Error impl
16        // anyway to minimize spurious secondary errors in other code that uses
17        // this type as an Error.
18        Err(error) => fallback::expand(input, error),
19    }
20}
21
22fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
23    let input = Input::from_syn(input)?;
24    input.validate()?;
25    Ok(match input {
26        Input::Struct(input) => impl_struct(input),
27        Input::Enum(input) => impl_enum(input),
28    })
29}
30
31fn impl_struct(input: Struct) -> TokenStream {
32    let ty = call_site_ident(&input.ident);
33    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
34    let mut error_inferred_bounds = InferredBounds::new();
35
36    let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
37        let only_field = &input.fields[0];
38        if only_field.contains_generic {
39            error_inferred_bounds.insert(only_field.ty, quote!(::thiserror::#private::Error));
40        }
41        let member = &only_field.member;
42        Some(quote_spanned! {transparent_attr.span=>
43            ::thiserror::#private::Error::source(self.#member.as_dyn_error())
44        })
45    } else if let Some(source_field) = input.source_field() {
46        let source = &source_field.member;
47        if source_field.contains_generic {
48            let ty = unoptional_type(source_field.ty);
49            error_inferred_bounds.insert(ty, quote!(::thiserror::#private::Error + 'static));
50        }
51        let asref = if type_is_option(source_field.ty) {
52            Some(quote_spanned!(source.span()=> .as_ref()?))
53        } else {
54            None
55        };
56        let dyn_error = quote_spanned! {source_field.source_span()=>
57            self.#source #asref.as_dyn_error()
58        };
59        Some(quote! {
60            ::core::option::Option::Some(#dyn_error)
61        })
62    } else {
63        None
64    };
65    let source_method = source_body.map(|body| {
66        quote! {
67            fn source(&self) -> ::core::option::Option<&(dyn ::thiserror::#private::Error + 'static)> {
68                use ::thiserror::#private::AsDynError as _;
69                #body
70            }
71        }
72    });
73
74    let provide_method = input.backtrace_field().map(|backtrace_field| {
75        let request = quote!(request);
76        let backtrace = &backtrace_field.member;
77        let body = if let Some(source_field) = input.source_field() {
78            let source = &source_field.member;
79            let source_provide = if type_is_option(source_field.ty) {
80                quote_spanned! {source.span()=>
81                    if let ::core::option::Option::Some(source) = &self.#source {
82                        source.thiserror_provide(#request);
83                    }
84                }
85            } else {
86                quote_spanned! {source.span()=>
87                    self.#source.thiserror_provide(#request);
88                }
89            };
90            let self_provide = if source == backtrace {
91                None
92            } else if type_is_option(backtrace_field.ty) {
93                Some(quote! {
94                    if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
95                        #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
96                    }
97                })
98            } else {
99                Some(quote! {
100                    #request.provide_ref::<::thiserror::#private::Backtrace>(&self.#backtrace);
101                })
102            };
103            quote! {
104                use ::thiserror::#private::ThiserrorProvide as _;
105                #source_provide
106                #self_provide
107            }
108        } else if type_is_option(backtrace_field.ty) {
109            quote! {
110                if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
111                    #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
112                }
113            }
114        } else {
115            quote! {
116                #request.provide_ref::<::thiserror::#private::Backtrace>(&self.#backtrace);
117            }
118        };
119        quote! {
120            fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
121                #body
122            }
123        }
124    });
125
126    let mut display_implied_bounds = Set::new();
127    let display_body = if input.attrs.transparent.is_some() {
128        let only_field = &input.fields[0].member;
129        display_implied_bounds.insert((0, Trait::Display));
130        Some(quote! {
131            ::core::fmt::Display::fmt(&self.#only_field, __formatter)
132        })
133    } else if let Some(display) = &input.attrs.display {
134        display_implied_bounds.clone_from(&display.implied_bounds);
135        let use_as_display = use_as_display(display.has_bonus_display);
136        let pat = fields_pat(&input.fields);
137        Some(quote! {
138            #use_as_display
139            #[allow(unused_variables, deprecated)]
140            let Self #pat = self;
141            #display
142        })
143    } else {
144        None
145    };
146    let display_impl = display_body.map(|body| {
147        let mut display_inferred_bounds = InferredBounds::new();
148        for (field, bound) in display_implied_bounds {
149            let field = &input.fields[field];
150            if field.contains_generic {
151                display_inferred_bounds.insert(field.ty, bound);
152            }
153        }
154        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
155        quote! {
156            #[allow(unused_qualifications)]
157            #[automatically_derived]
158            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
159                #[allow(clippy::used_underscore_binding)]
160                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
161                    #body
162                }
163            }
164        }
165    });
166
167    let from_impl = input.from_field().map(|from_field| {
168        let span = from_field.attrs.from.unwrap().span;
169        let backtrace_field = input.distinct_backtrace_field();
170        let from = unoptional_type(from_field.ty);
171        let source_var = Ident::new("source", span);
172        let body = from_initializer(from_field, backtrace_field, &source_var);
173        let from_function = quote! {
174            fn from(#source_var: #from) -> Self {
175                #ty #body
176            }
177        };
178        let from_impl = quote_spanned! {span=>
179            #[automatically_derived]
180            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
181                #from_function
182            }
183        };
184        Some(quote! {
185            #[allow(
186                deprecated,
187                unused_qualifications,
188                clippy::elidable_lifetime_names,
189                clippy::needless_lifetimes,
190            )]
191            #from_impl
192        })
193    });
194
195    if input.generics.type_params().next().is_some() {
196        let self_token = <Token![Self]>::default();
197        error_inferred_bounds.insert(self_token, Trait::Debug);
198        error_inferred_bounds.insert(self_token, Trait::Display);
199    }
200    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
201
202    quote! {
203        #[allow(unused_qualifications)]
204        #[automatically_derived]
205        impl #impl_generics ::thiserror::#private::Error for #ty #ty_generics #error_where_clause {
206            #source_method
207            #provide_method
208        }
209        #display_impl
210        #from_impl
211    }
212}
213
214fn impl_enum(input: Enum) -> TokenStream {
215    let ty = call_site_ident(&input.ident);
216    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
217    let mut error_inferred_bounds = InferredBounds::new();
218
219    let source_method = if input.has_source() {
220        let arms = input.variants.iter().map(|variant| {
221            let ident = &variant.ident;
222            if let Some(transparent_attr) = &variant.attrs.transparent {
223                let only_field = &variant.fields[0];
224                if only_field.contains_generic {
225                    error_inferred_bounds.insert(only_field.ty, quote!(::thiserror::#private::Error));
226                }
227                let member = &only_field.member;
228                let source = quote_spanned! {transparent_attr.span=>
229                    ::thiserror::#private::Error::source(transparent.as_dyn_error())
230                };
231                quote! {
232                    #ty::#ident {#member: transparent} => #source,
233                }
234            } else if let Some(source_field) = variant.source_field() {
235                let source = &source_field.member;
236                if source_field.contains_generic {
237                    let ty = unoptional_type(source_field.ty);
238                    error_inferred_bounds.insert(ty, quote!(::thiserror::#private::Error + 'static));
239                }
240                let asref = if type_is_option(source_field.ty) {
241                    Some(quote_spanned!(source.span()=> .as_ref()?))
242                } else {
243                    None
244                };
245                let varsource = quote!(source);
246                let dyn_error = quote_spanned! {source_field.source_span()=>
247                    #varsource #asref.as_dyn_error()
248                };
249                quote! {
250                    #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
251                }
252            } else {
253                quote! {
254                    #ty::#ident {..} => ::core::option::Option::None,
255                }
256            }
257        });
258        Some(quote! {
259            fn source(&self) -> ::core::option::Option<&(dyn ::thiserror::#private::Error + 'static)> {
260                use ::thiserror::#private::AsDynError as _;
261                #[allow(deprecated)]
262                match self {
263                    #(#arms)*
264                }
265            }
266        })
267    } else {
268        None
269    };
270
271    let provide_method = if input.has_backtrace() {
272        let request = quote!(request);
273        let arms = input.variants.iter().map(|variant| {
274            let ident = &variant.ident;
275            match (variant.backtrace_field(), variant.source_field()) {
276                (Some(backtrace_field), Some(source_field))
277                    if backtrace_field.attrs.backtrace.is_none() =>
278                {
279                    let backtrace = &backtrace_field.member;
280                    let source = &source_field.member;
281                    let varsource = quote!(source);
282                    let source_provide = if type_is_option(source_field.ty) {
283                        quote_spanned! {source.span()=>
284                            if let ::core::option::Option::Some(source) = #varsource {
285                                source.thiserror_provide(#request);
286                            }
287                        }
288                    } else {
289                        quote_spanned! {source.span()=>
290                            #varsource.thiserror_provide(#request);
291                        }
292                    };
293                    let self_provide = if type_is_option(backtrace_field.ty) {
294                        quote! {
295                            if let ::core::option::Option::Some(backtrace) = backtrace {
296                                #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
297                            }
298                        }
299                    } else {
300                        quote! {
301                            #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
302                        }
303                    };
304                    quote! {
305                        #ty::#ident {
306                            #backtrace: backtrace,
307                            #source: #varsource,
308                            ..
309                        } => {
310                            use ::thiserror::#private::ThiserrorProvide as _;
311                            #source_provide
312                            #self_provide
313                        }
314                    }
315                }
316                (Some(backtrace_field), Some(source_field))
317                    if backtrace_field.member == source_field.member =>
318                {
319                    let backtrace = &backtrace_field.member;
320                    let varsource = quote!(source);
321                    let source_provide = if type_is_option(source_field.ty) {
322                        quote_spanned! {backtrace.span()=>
323                            if let ::core::option::Option::Some(source) = #varsource {
324                                source.thiserror_provide(#request);
325                            }
326                        }
327                    } else {
328                        quote_spanned! {backtrace.span()=>
329                            #varsource.thiserror_provide(#request);
330                        }
331                    };
332                    quote! {
333                        #ty::#ident {#backtrace: #varsource, ..} => {
334                            use ::thiserror::#private::ThiserrorProvide as _;
335                            #source_provide
336                        }
337                    }
338                }
339                (Some(backtrace_field), _) => {
340                    let backtrace = &backtrace_field.member;
341                    let body = if type_is_option(backtrace_field.ty) {
342                        quote! {
343                            if let ::core::option::Option::Some(backtrace) = backtrace {
344                                #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
345                            }
346                        }
347                    } else {
348                        quote! {
349                            #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
350                        }
351                    };
352                    quote! {
353                        #ty::#ident {#backtrace: backtrace, ..} => {
354                            #body
355                        }
356                    }
357                }
358                (None, _) => quote! {
359                    #ty::#ident {..} => {}
360                },
361            }
362        });
363        Some(quote! {
364            fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
365                #[allow(deprecated)]
366                match self {
367                    #(#arms)*
368                }
369            }
370        })
371    } else {
372        None
373    };
374
375    let display_impl = if input.has_display() {
376        let mut display_inferred_bounds = InferredBounds::new();
377        let has_bonus_display = input.variants.iter().any(|v| {
378            v.attrs
379                .display
380                .as_ref()
381                .map_or(false, |display| display.has_bonus_display)
382        });
383        let use_as_display = use_as_display(has_bonus_display);
384        let void_deref = if input.variants.is_empty() {
385            Some(quote!(*))
386        } else {
387            None
388        };
389        let arms = input.variants.iter().map(|variant| {
390            let mut display_implied_bounds = Set::new();
391            let display = if let Some(display) = &variant.attrs.display {
392                display_implied_bounds.clone_from(&display.implied_bounds);
393                display.to_token_stream()
394            } else if let Some(fmt) = &variant.attrs.fmt {
395                let fmt_path = &fmt.path;
396                let vars = variant.fields.iter().map(|field| match &field.member {
397                    MemberUnraw::Named(ident) => ident.to_local(),
398                    MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
399                });
400                quote!(#fmt_path(#(#vars,)* __formatter))
401            } else {
402                let only_field = match &variant.fields[0].member {
403                    MemberUnraw::Named(ident) => ident.to_local(),
404                    MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
405                };
406                display_implied_bounds.insert((0, Trait::Display));
407                quote!(::core::fmt::Display::fmt(#only_field, __formatter))
408            };
409            for (field, bound) in display_implied_bounds {
410                let field = &variant.fields[field];
411                if field.contains_generic {
412                    display_inferred_bounds.insert(field.ty, bound);
413                }
414            }
415            let ident = &variant.ident;
416            let pat = fields_pat(&variant.fields);
417            quote! {
418                #ty::#ident #pat => #display
419            }
420        });
421        let arms = arms.collect::<Vec<_>>();
422        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
423        Some(quote! {
424            #[allow(unused_qualifications)]
425            #[automatically_derived]
426            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
427                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
428                    #use_as_display
429                    #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
430                    match #void_deref self {
431                        #(#arms,)*
432                    }
433                }
434            }
435        })
436    } else {
437        None
438    };
439
440    let from_impls = input.variants.iter().filter_map(|variant| {
441        let from_field = variant.from_field()?;
442        let span = from_field.attrs.from.unwrap().span;
443        let backtrace_field = variant.distinct_backtrace_field();
444        let variant = &variant.ident;
445        let from = unoptional_type(from_field.ty);
446        let source_var = Ident::new("source", span);
447        let body = from_initializer(from_field, backtrace_field, &source_var);
448        let from_function = quote! {
449            fn from(#source_var: #from) -> Self {
450                #ty::#variant #body
451            }
452        };
453        let from_impl = quote_spanned! {span=>
454            #[automatically_derived]
455            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
456                #from_function
457            }
458        };
459        Some(quote! {
460            #[allow(
461                deprecated,
462                unused_qualifications,
463                clippy::elidable_lifetime_names,
464                clippy::needless_lifetimes,
465            )]
466            #from_impl
467        })
468    });
469
470    if input.generics.type_params().next().is_some() {
471        let self_token = <Token![Self]>::default();
472        error_inferred_bounds.insert(self_token, Trait::Debug);
473        error_inferred_bounds.insert(self_token, Trait::Display);
474    }
475    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
476
477    quote! {
478        #[allow(unused_qualifications)]
479        #[automatically_derived]
480        impl #impl_generics ::thiserror::#private::Error for #ty #ty_generics #error_where_clause {
481            #source_method
482            #provide_method
483        }
484        #display_impl
485        #(#from_impls)*
486    }
487}
488
489// Create an ident with which we can expand `impl Trait for #ident {}` on a
490// deprecated type without triggering deprecation warning on the generated impl.
491pub(crate) fn call_site_ident(ident: &Ident) -> Ident {
492    let mut ident = ident.clone();
493    ident.set_span(ident.span().resolved_at(Span::call_site()));
494    ident
495}
496
497fn fields_pat(fields: &[Field]) -> TokenStream {
498    let mut members = fields.iter().map(|field| &field.member).peekable();
499    match members.peek() {
500        Some(MemberUnraw::Named(_)) => quote!({ #(#members),* }),
501        Some(MemberUnraw::Unnamed(_)) => {
502            let vars = members.map(|member| match member {
503                MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
504                MemberUnraw::Named(_) => unreachable!(),
505            });
506            quote!((#(#vars),*))
507        }
508        None => quote!({}),
509    }
510}
511
512fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
513    if needs_as_display {
514        Some(quote! {
515            use ::thiserror::#private::AsDisplay as _;
516        })
517    } else {
518        None
519    }
520}
521
522fn from_initializer(
523    from_field: &Field,
524    backtrace_field: Option<&Field>,
525    source_var: &Ident,
526) -> TokenStream {
527    let from_member = &from_field.member;
528    let some_source = if type_is_option(from_field.ty) {
529        quote!(::core::option::Option::Some(#source_var))
530    } else {
531        quote!(#source_var)
532    };
533    let backtrace = backtrace_field.map(|backtrace_field| {
534        let backtrace_member = &backtrace_field.member;
535        if type_is_option(backtrace_field.ty) {
536            quote! {
537                #backtrace_member: ::core::option::Option::Some(::thiserror::#private::Backtrace::capture()),
538            }
539        } else {
540            quote! {
541                #backtrace_member: ::core::convert::From::from(::thiserror::#private::Backtrace::capture()),
542            }
543        }
544    });
545    quote!({
546        #from_member: #some_source,
547        #backtrace
548    })
549}
550
551fn type_is_option(ty: &Type) -> bool {
552    type_parameter_of_option(ty).is_some()
553}
554
555fn unoptional_type(ty: &Type) -> TokenStream {
556    let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
557    quote!(#unoptional)
558}
559
560fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
561    let path = match ty {
562        Type::Path(ty) => &ty.path,
563        _ => return None,
564    };
565
566    let last = path.segments.last().unwrap();
567    if last.ident != "Option" {
568        return None;
569    }
570
571    let bracketed = match &last.arguments {
572        PathArguments::AngleBracketed(bracketed) => bracketed,
573        _ => return None,
574    };
575
576    if bracketed.args.len() != 1 {
577        return None;
578    }
579
580    match &bracketed.args[0] {
581        GenericArgument::Type(arg) => Some(arg),
582        _ => None,
583    }
584}