1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::fallback;
4use crate::generics::InferredBounds;
5use crate::private;
6use crate::unraw::MemberUnraw;
7use proc_macro2::{Ident, Span, TokenStream};
8use quote::{format_ident, quote, quote_spanned, ToTokens};
9use std::collections::BTreeSet as Set;
10use syn::{DeriveInput, GenericArgument, PathArguments, Result, Token, Type};
11
12pub fn derive(input: &DeriveInput) -> TokenStream {
13 match try_expand(input) {
14 Ok(expanded) => expanded,
15 Err(error) => fallback::expand(input, error),
19 }
20}
21
22fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
23 let input = Input::from_syn(input)?;
24 input.validate()?;
25 Ok(match input {
26 Input::Struct(input) => impl_struct(input),
27 Input::Enum(input) => impl_enum(input),
28 })
29}
30
31fn impl_struct(input: Struct) -> TokenStream {
32 let ty = call_site_ident(&input.ident);
33 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
34 let mut error_inferred_bounds = InferredBounds::new();
35
36 let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
37 let only_field = &input.fields[0];
38 if only_field.contains_generic {
39 error_inferred_bounds.insert(only_field.ty, quote!(::thiserror::#private::Error));
40 }
41 let member = &only_field.member;
42 Some(quote_spanned! {transparent_attr.span=>
43 ::thiserror::#private::Error::source(self.#member.as_dyn_error())
44 })
45 } else if let Some(source_field) = input.source_field() {
46 let source = &source_field.member;
47 if source_field.contains_generic {
48 let ty = unoptional_type(source_field.ty);
49 error_inferred_bounds.insert(ty, quote!(::thiserror::#private::Error + 'static));
50 }
51 let asref = if type_is_option(source_field.ty) {
52 Some(quote_spanned!(source.span()=> .as_ref()?))
53 } else {
54 None
55 };
56 let dyn_error = quote_spanned! {source_field.source_span()=>
57 self.#source #asref.as_dyn_error()
58 };
59 Some(quote! {
60 ::core::option::Option::Some(#dyn_error)
61 })
62 } else {
63 None
64 };
65 let source_method = source_body.map(|body| {
66 quote! {
67 fn source(&self) -> ::core::option::Option<&(dyn ::thiserror::#private::Error + 'static)> {
68 use ::thiserror::#private::AsDynError as _;
69 #body
70 }
71 }
72 });
73
74 let provide_method = input.backtrace_field().map(|backtrace_field| {
75 let request = quote!(request);
76 let backtrace = &backtrace_field.member;
77 let body = if let Some(source_field) = input.source_field() {
78 let source = &source_field.member;
79 let source_provide = if type_is_option(source_field.ty) {
80 quote_spanned! {source.span()=>
81 if let ::core::option::Option::Some(source) = &self.#source {
82 source.thiserror_provide(#request);
83 }
84 }
85 } else {
86 quote_spanned! {source.span()=>
87 self.#source.thiserror_provide(#request);
88 }
89 };
90 let self_provide = if source == backtrace {
91 None
92 } else if type_is_option(backtrace_field.ty) {
93 Some(quote! {
94 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
95 #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
96 }
97 })
98 } else {
99 Some(quote! {
100 #request.provide_ref::<::thiserror::#private::Backtrace>(&self.#backtrace);
101 })
102 };
103 quote! {
104 use ::thiserror::#private::ThiserrorProvide as _;
105 #source_provide
106 #self_provide
107 }
108 } else if type_is_option(backtrace_field.ty) {
109 quote! {
110 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
111 #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
112 }
113 }
114 } else {
115 quote! {
116 #request.provide_ref::<::thiserror::#private::Backtrace>(&self.#backtrace);
117 }
118 };
119 quote! {
120 fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
121 #body
122 }
123 }
124 });
125
126 let mut display_implied_bounds = Set::new();
127 let display_body = if input.attrs.transparent.is_some() {
128 let only_field = &input.fields[0].member;
129 display_implied_bounds.insert((0, Trait::Display));
130 Some(quote! {
131 ::core::fmt::Display::fmt(&self.#only_field, __formatter)
132 })
133 } else if let Some(display) = &input.attrs.display {
134 display_implied_bounds.clone_from(&display.implied_bounds);
135 let use_as_display = use_as_display(display.has_bonus_display);
136 let pat = fields_pat(&input.fields);
137 Some(quote! {
138 #use_as_display
139 #[allow(unused_variables, deprecated)]
140 let Self #pat = self;
141 #display
142 })
143 } else {
144 None
145 };
146 let display_impl = display_body.map(|body| {
147 let mut display_inferred_bounds = InferredBounds::new();
148 for (field, bound) in display_implied_bounds {
149 let field = &input.fields[field];
150 if field.contains_generic {
151 display_inferred_bounds.insert(field.ty, bound);
152 }
153 }
154 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
155 quote! {
156 #[allow(unused_qualifications)]
157 #[automatically_derived]
158 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
159 #[allow(clippy::used_underscore_binding)]
160 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
161 #body
162 }
163 }
164 }
165 });
166
167 let from_impl = input.from_field().map(|from_field| {
168 let span = from_field.attrs.from.unwrap().span;
169 let backtrace_field = input.distinct_backtrace_field();
170 let from = unoptional_type(from_field.ty);
171 let source_var = Ident::new("source", span);
172 let body = from_initializer(from_field, backtrace_field, &source_var);
173 let from_function = quote! {
174 fn from(#source_var: #from) -> Self {
175 #ty #body
176 }
177 };
178 let from_impl = quote_spanned! {span=>
179 #[automatically_derived]
180 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
181 #from_function
182 }
183 };
184 Some(quote! {
185 #[allow(
186 deprecated,
187 unused_qualifications,
188 clippy::elidable_lifetime_names,
189 clippy::needless_lifetimes,
190 )]
191 #from_impl
192 })
193 });
194
195 if input.generics.type_params().next().is_some() {
196 let self_token = <Token![Self]>::default();
197 error_inferred_bounds.insert(self_token, Trait::Debug);
198 error_inferred_bounds.insert(self_token, Trait::Display);
199 }
200 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
201
202 quote! {
203 #[allow(unused_qualifications)]
204 #[automatically_derived]
205 impl #impl_generics ::thiserror::#private::Error for #ty #ty_generics #error_where_clause {
206 #source_method
207 #provide_method
208 }
209 #display_impl
210 #from_impl
211 }
212}
213
214fn impl_enum(input: Enum) -> TokenStream {
215 let ty = call_site_ident(&input.ident);
216 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
217 let mut error_inferred_bounds = InferredBounds::new();
218
219 let source_method = if input.has_source() {
220 let arms = input.variants.iter().map(|variant| {
221 let ident = &variant.ident;
222 if let Some(transparent_attr) = &variant.attrs.transparent {
223 let only_field = &variant.fields[0];
224 if only_field.contains_generic {
225 error_inferred_bounds.insert(only_field.ty, quote!(::thiserror::#private::Error));
226 }
227 let member = &only_field.member;
228 let source = quote_spanned! {transparent_attr.span=>
229 ::thiserror::#private::Error::source(transparent.as_dyn_error())
230 };
231 quote! {
232 #ty::#ident {#member: transparent} => #source,
233 }
234 } else if let Some(source_field) = variant.source_field() {
235 let source = &source_field.member;
236 if source_field.contains_generic {
237 let ty = unoptional_type(source_field.ty);
238 error_inferred_bounds.insert(ty, quote!(::thiserror::#private::Error + 'static));
239 }
240 let asref = if type_is_option(source_field.ty) {
241 Some(quote_spanned!(source.span()=> .as_ref()?))
242 } else {
243 None
244 };
245 let varsource = quote!(source);
246 let dyn_error = quote_spanned! {source_field.source_span()=>
247 #varsource #asref.as_dyn_error()
248 };
249 quote! {
250 #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
251 }
252 } else {
253 quote! {
254 #ty::#ident {..} => ::core::option::Option::None,
255 }
256 }
257 });
258 Some(quote! {
259 fn source(&self) -> ::core::option::Option<&(dyn ::thiserror::#private::Error + 'static)> {
260 use ::thiserror::#private::AsDynError as _;
261 #[allow(deprecated)]
262 match self {
263 #(#arms)*
264 }
265 }
266 })
267 } else {
268 None
269 };
270
271 let provide_method = if input.has_backtrace() {
272 let request = quote!(request);
273 let arms = input.variants.iter().map(|variant| {
274 let ident = &variant.ident;
275 match (variant.backtrace_field(), variant.source_field()) {
276 (Some(backtrace_field), Some(source_field))
277 if backtrace_field.attrs.backtrace.is_none() =>
278 {
279 let backtrace = &backtrace_field.member;
280 let source = &source_field.member;
281 let varsource = quote!(source);
282 let source_provide = if type_is_option(source_field.ty) {
283 quote_spanned! {source.span()=>
284 if let ::core::option::Option::Some(source) = #varsource {
285 source.thiserror_provide(#request);
286 }
287 }
288 } else {
289 quote_spanned! {source.span()=>
290 #varsource.thiserror_provide(#request);
291 }
292 };
293 let self_provide = if type_is_option(backtrace_field.ty) {
294 quote! {
295 if let ::core::option::Option::Some(backtrace) = backtrace {
296 #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
297 }
298 }
299 } else {
300 quote! {
301 #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
302 }
303 };
304 quote! {
305 #ty::#ident {
306 #backtrace: backtrace,
307 #source: #varsource,
308 ..
309 } => {
310 use ::thiserror::#private::ThiserrorProvide as _;
311 #source_provide
312 #self_provide
313 }
314 }
315 }
316 (Some(backtrace_field), Some(source_field))
317 if backtrace_field.member == source_field.member =>
318 {
319 let backtrace = &backtrace_field.member;
320 let varsource = quote!(source);
321 let source_provide = if type_is_option(source_field.ty) {
322 quote_spanned! {backtrace.span()=>
323 if let ::core::option::Option::Some(source) = #varsource {
324 source.thiserror_provide(#request);
325 }
326 }
327 } else {
328 quote_spanned! {backtrace.span()=>
329 #varsource.thiserror_provide(#request);
330 }
331 };
332 quote! {
333 #ty::#ident {#backtrace: #varsource, ..} => {
334 use ::thiserror::#private::ThiserrorProvide as _;
335 #source_provide
336 }
337 }
338 }
339 (Some(backtrace_field), _) => {
340 let backtrace = &backtrace_field.member;
341 let body = if type_is_option(backtrace_field.ty) {
342 quote! {
343 if let ::core::option::Option::Some(backtrace) = backtrace {
344 #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
345 }
346 }
347 } else {
348 quote! {
349 #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
350 }
351 };
352 quote! {
353 #ty::#ident {#backtrace: backtrace, ..} => {
354 #body
355 }
356 }
357 }
358 (None, _) => quote! {
359 #ty::#ident {..} => {}
360 },
361 }
362 });
363 Some(quote! {
364 fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
365 #[allow(deprecated)]
366 match self {
367 #(#arms)*
368 }
369 }
370 })
371 } else {
372 None
373 };
374
375 let display_impl = if input.has_display() {
376 let mut display_inferred_bounds = InferredBounds::new();
377 let has_bonus_display = input.variants.iter().any(|v| {
378 v.attrs
379 .display
380 .as_ref()
381 .map_or(false, |display| display.has_bonus_display)
382 });
383 let use_as_display = use_as_display(has_bonus_display);
384 let void_deref = if input.variants.is_empty() {
385 Some(quote!(*))
386 } else {
387 None
388 };
389 let arms = input.variants.iter().map(|variant| {
390 let mut display_implied_bounds = Set::new();
391 let display = if let Some(display) = &variant.attrs.display {
392 display_implied_bounds.clone_from(&display.implied_bounds);
393 display.to_token_stream()
394 } else if let Some(fmt) = &variant.attrs.fmt {
395 let fmt_path = &fmt.path;
396 let vars = variant.fields.iter().map(|field| match &field.member {
397 MemberUnraw::Named(ident) => ident.to_local(),
398 MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
399 });
400 quote!(#fmt_path(#(#vars,)* __formatter))
401 } else {
402 let only_field = match &variant.fields[0].member {
403 MemberUnraw::Named(ident) => ident.to_local(),
404 MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
405 };
406 display_implied_bounds.insert((0, Trait::Display));
407 quote!(::core::fmt::Display::fmt(#only_field, __formatter))
408 };
409 for (field, bound) in display_implied_bounds {
410 let field = &variant.fields[field];
411 if field.contains_generic {
412 display_inferred_bounds.insert(field.ty, bound);
413 }
414 }
415 let ident = &variant.ident;
416 let pat = fields_pat(&variant.fields);
417 quote! {
418 #ty::#ident #pat => #display
419 }
420 });
421 let arms = arms.collect::<Vec<_>>();
422 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
423 Some(quote! {
424 #[allow(unused_qualifications)]
425 #[automatically_derived]
426 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
427 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
428 #use_as_display
429 #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
430 match #void_deref self {
431 #(#arms,)*
432 }
433 }
434 }
435 })
436 } else {
437 None
438 };
439
440 let from_impls = input.variants.iter().filter_map(|variant| {
441 let from_field = variant.from_field()?;
442 let span = from_field.attrs.from.unwrap().span;
443 let backtrace_field = variant.distinct_backtrace_field();
444 let variant = &variant.ident;
445 let from = unoptional_type(from_field.ty);
446 let source_var = Ident::new("source", span);
447 let body = from_initializer(from_field, backtrace_field, &source_var);
448 let from_function = quote! {
449 fn from(#source_var: #from) -> Self {
450 #ty::#variant #body
451 }
452 };
453 let from_impl = quote_spanned! {span=>
454 #[automatically_derived]
455 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
456 #from_function
457 }
458 };
459 Some(quote! {
460 #[allow(
461 deprecated,
462 unused_qualifications,
463 clippy::elidable_lifetime_names,
464 clippy::needless_lifetimes,
465 )]
466 #from_impl
467 })
468 });
469
470 if input.generics.type_params().next().is_some() {
471 let self_token = <Token![Self]>::default();
472 error_inferred_bounds.insert(self_token, Trait::Debug);
473 error_inferred_bounds.insert(self_token, Trait::Display);
474 }
475 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
476
477 quote! {
478 #[allow(unused_qualifications)]
479 #[automatically_derived]
480 impl #impl_generics ::thiserror::#private::Error for #ty #ty_generics #error_where_clause {
481 #source_method
482 #provide_method
483 }
484 #display_impl
485 #(#from_impls)*
486 }
487}
488
489pub(crate) fn call_site_ident(ident: &Ident) -> Ident {
492 let mut ident = ident.clone();
493 ident.set_span(ident.span().resolved_at(Span::call_site()));
494 ident
495}
496
497fn fields_pat(fields: &[Field]) -> TokenStream {
498 let mut members = fields.iter().map(|field| &field.member).peekable();
499 match members.peek() {
500 Some(MemberUnraw::Named(_)) => quote!({ #(#members),* }),
501 Some(MemberUnraw::Unnamed(_)) => {
502 let vars = members.map(|member| match member {
503 MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
504 MemberUnraw::Named(_) => unreachable!(),
505 });
506 quote!((#(#vars),*))
507 }
508 None => quote!({}),
509 }
510}
511
512fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
513 if needs_as_display {
514 Some(quote! {
515 use ::thiserror::#private::AsDisplay as _;
516 })
517 } else {
518 None
519 }
520}
521
522fn from_initializer(
523 from_field: &Field,
524 backtrace_field: Option<&Field>,
525 source_var: &Ident,
526) -> TokenStream {
527 let from_member = &from_field.member;
528 let some_source = if type_is_option(from_field.ty) {
529 quote!(::core::option::Option::Some(#source_var))
530 } else {
531 quote!(#source_var)
532 };
533 let backtrace = backtrace_field.map(|backtrace_field| {
534 let backtrace_member = &backtrace_field.member;
535 if type_is_option(backtrace_field.ty) {
536 quote! {
537 #backtrace_member: ::core::option::Option::Some(::thiserror::#private::Backtrace::capture()),
538 }
539 } else {
540 quote! {
541 #backtrace_member: ::core::convert::From::from(::thiserror::#private::Backtrace::capture()),
542 }
543 }
544 });
545 quote!({
546 #from_member: #some_source,
547 #backtrace
548 })
549}
550
551fn type_is_option(ty: &Type) -> bool {
552 type_parameter_of_option(ty).is_some()
553}
554
555fn unoptional_type(ty: &Type) -> TokenStream {
556 let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
557 quote!(#unoptional)
558}
559
560fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
561 let path = match ty {
562 Type::Path(ty) => &ty.path,
563 _ => return None,
564 };
565
566 let last = path.segments.last().unwrap();
567 if last.ident != "Option" {
568 return None;
569 }
570
571 let bracketed = match &last.arguments {
572 PathArguments::AngleBracketed(bracketed) => bracketed,
573 _ => return None,
574 };
575
576 if bracketed.args.len() != 1 {
577 return None;
578 }
579
580 match &bracketed.args[0] {
581 GenericArgument::Type(arg) => Some(arg),
582 _ => None,
583 }
584}