1use crate::attributes::{DefaultAttribute, FromPyWithAttribute, RenamingRule};
2use crate::derive_attributes::{ContainerAttributes, FieldAttributes, FieldGetter};
3#[cfg(feature = "experimental-inspect")]
4use crate::introspection::ConcatenationBuilder;
5#[cfg(feature = "experimental-inspect")]
6use crate::utils::TypeExt;
7use crate::utils::{self, Ctx};
8use proc_macro2::TokenStream;
9use quote::{format_ident, quote, quote_spanned, ToTokens};
10use syn::{
11 ext::IdentExt, parse_quote, punctuated::Punctuated, spanned::Spanned, DataEnum, DeriveInput,
12 Fields, Ident, Result, Token,
13};
14
15struct Enum<'a> {
17 enum_ident: &'a Ident,
18 variants: Vec<Container<'a>>,
19}
20
21impl<'a> Enum<'a> {
22 fn new(
27 data_enum: &'a DataEnum,
28 ident: &'a Ident,
29 options: ContainerAttributes,
30 ) -> Result<Self> {
31 ensure_spanned!(
32 !data_enum.variants.is_empty(),
33 ident.span() => "cannot derive FromPyObject for empty enum"
34 );
35 let variants = data_enum
36 .variants
37 .iter()
38 .map(|variant| {
39 let mut variant_options = ContainerAttributes::from_attrs(&variant.attrs)?;
40 if let Some(rename_all) = &options.rename_all {
41 ensure_spanned!(
42 variant_options.rename_all.is_none(),
43 variant_options.rename_all.span() => "Useless variant `rename_all` - enum is already annotated with `rename_all"
44 );
45 variant_options.rename_all = Some(rename_all.clone());
46
47 }
48 let var_ident = &variant.ident;
49 Container::new(
50 &variant.fields,
51 parse_quote!(#ident::#var_ident),
52 variant_options,
53 )
54 })
55 .collect::<Result<Vec<_>>>()?;
56
57 Ok(Enum {
58 enum_ident: ident,
59 variants,
60 })
61 }
62
63 fn build(&self, ctx: &Ctx) -> TokenStream {
65 let Ctx { pyo3_path, .. } = ctx;
66 let mut var_extracts = Vec::new();
67 let mut variant_names = Vec::new();
68 let mut error_names = Vec::new();
69
70 for var in &self.variants {
71 let struct_derive = var.build(ctx);
72 let ext = quote!({
73 let maybe_ret = || -> #pyo3_path::PyResult<Self> {
74 #struct_derive
75 }();
76
77 match maybe_ret {
78 ok @ ::std::result::Result::Ok(_) => return ok,
79 ::std::result::Result::Err(err) => err
80 }
81 });
82
83 var_extracts.push(ext);
84 variant_names.push(var.path.segments.last().unwrap().ident.to_string());
85 error_names.push(&var.err_name);
86 }
87 let ty_name = self.enum_ident.to_string();
88 quote!(
89 let errors = [
90 #(#var_extracts),*
91 ];
92 ::std::result::Result::Err(
93 #pyo3_path::impl_::frompyobject::failed_to_extract_enum(
94 obj.py(),
95 #ty_name,
96 &[#(#variant_names),*],
97 &[#(#error_names),*],
98 &errors
99 )
100 )
101 )
102 }
103
104 #[cfg(feature = "experimental-inspect")]
105 fn write_input_type(&self, builder: &mut ConcatenationBuilder, ctx: &Ctx) {
106 for (i, var) in self.variants.iter().enumerate() {
107 if i > 0 {
108 builder.push_str(" | ");
109 }
110 var.write_input_type(builder, ctx);
111 }
112 }
113}
114
115struct NamedStructField<'a> {
116 ident: &'a syn::Ident,
117 getter: Option<FieldGetter>,
118 from_py_with: Option<FromPyWithAttribute>,
119 default: Option<DefaultAttribute>,
120 ty: &'a syn::Type,
121}
122
123struct TupleStructField {
124 from_py_with: Option<FromPyWithAttribute>,
125 ty: syn::Type,
126}
127
128enum ContainerType<'a> {
132 Struct(Vec<NamedStructField<'a>>),
136 #[cfg_attr(not(feature = "experimental-inspect"), allow(unused))]
140 StructNewtype(&'a syn::Ident, Option<FromPyWithAttribute>, &'a syn::Type),
141 Tuple(Vec<TupleStructField>),
146 #[cfg_attr(not(feature = "experimental-inspect"), allow(unused))]
150 TupleNewtype(Option<FromPyWithAttribute>, Box<syn::Type>),
151}
152
153struct Container<'a> {
157 path: syn::Path,
158 ty: ContainerType<'a>,
159 err_name: String,
160 rename_rule: Option<RenamingRule>,
161}
162
163impl<'a> Container<'a> {
164 fn new(fields: &'a Fields, path: syn::Path, options: ContainerAttributes) -> Result<Self> {
168 let style = match fields {
169 Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
170 ensure_spanned!(
171 options.rename_all.is_none(),
172 options.rename_all.span() => "`rename_all` is useless on tuple structs and variants."
173 );
174 let mut tuple_fields = unnamed
175 .unnamed
176 .iter()
177 .map(|field| {
178 let attrs = FieldAttributes::from_attrs(&field.attrs)?;
179 ensure_spanned!(
180 attrs.getter.is_none(),
181 field.span() => "`getter` is not permitted on tuple struct elements."
182 );
183 ensure_spanned!(
184 attrs.default.is_none(),
185 field.span() => "`default` is not permitted on tuple struct elements."
186 );
187 Ok(TupleStructField {
188 from_py_with: attrs.from_py_with,
189 ty: field.ty.clone(),
190 })
191 })
192 .collect::<Result<Vec<_>>>()?;
193
194 if tuple_fields.len() == 1 {
195 let field = tuple_fields.pop().unwrap();
198 ContainerType::TupleNewtype(field.from_py_with, Box::new(field.ty))
199 } else if options.transparent.is_some() {
200 bail_spanned!(
201 fields.span() => "transparent structs and variants can only have 1 field"
202 );
203 } else {
204 ContainerType::Tuple(tuple_fields)
205 }
206 }
207 Fields::Named(named) if !named.named.is_empty() => {
208 let mut struct_fields = named
209 .named
210 .iter()
211 .map(|field| {
212 let ident = field
213 .ident
214 .as_ref()
215 .expect("Named fields should have identifiers");
216 let mut attrs = FieldAttributes::from_attrs(&field.attrs)?;
217
218 if let Some(ref from_item_all) = options.from_item_all {
219 if let Some(replaced) = attrs.getter.replace(FieldGetter::GetItem(parse_quote!(item), None))
220 {
221 match replaced {
222 FieldGetter::GetItem(item, Some(item_name)) => {
223 attrs.getter = Some(FieldGetter::GetItem(item, Some(item_name)));
224 }
225 FieldGetter::GetItem(_, None) => bail_spanned!(from_item_all.span() => "Useless `item` - the struct is already annotated with `from_item_all`"),
226 FieldGetter::GetAttr(_, _) => bail_spanned!(
227 from_item_all.span() => "The struct is already annotated with `from_item_all`, `attribute` is not allowed"
228 ),
229 }
230 }
231 }
232
233 Ok(NamedStructField {
234 ident,
235 getter: attrs.getter,
236 from_py_with: attrs.from_py_with,
237 default: attrs.default,
238 ty: &field.ty,
239 })
240 })
241 .collect::<Result<Vec<_>>>()?;
242 if struct_fields.iter().all(|field| field.default.is_some()) {
243 bail_spanned!(
244 fields.span() => "cannot derive FromPyObject for structs and variants with only default values"
245 )
246 } else if options.transparent.is_some() {
247 ensure_spanned!(
248 struct_fields.len() == 1,
249 fields.span() => "transparent structs and variants can only have 1 field"
250 );
251 ensure_spanned!(
252 options.rename_all.is_none(),
253 options.rename_all.span() => "`rename_all` is not permitted on `transparent` structs and variants"
254 );
255 let field = struct_fields.pop().unwrap();
256 ensure_spanned!(
257 field.getter.is_none(),
258 field.ident.span() => "`transparent` structs may not have a `getter` for the inner field"
259 );
260 ContainerType::StructNewtype(field.ident, field.from_py_with, field.ty)
261 } else {
262 ContainerType::Struct(struct_fields)
263 }
264 }
265 _ => bail_spanned!(
266 fields.span() => "cannot derive FromPyObject for empty structs and variants"
267 ),
268 };
269 let err_name = options.annotation.map_or_else(
270 || path.segments.last().unwrap().ident.to_string(),
271 |lit_str| lit_str.value(),
272 );
273
274 let v = Container {
275 path,
276 ty: style,
277 err_name,
278 rename_rule: options.rename_all.map(|v| v.value.rule),
279 };
280 Ok(v)
281 }
282
283 fn name(&self) -> String {
284 let mut value = String::new();
285 for segment in &self.path.segments {
286 if !value.is_empty() {
287 value.push_str("::");
288 }
289 value.push_str(&segment.ident.to_string());
290 }
291 value
292 }
293
294 fn build(&self, ctx: &Ctx) -> TokenStream {
296 match &self.ty {
297 ContainerType::StructNewtype(ident, from_py_with, _) => {
298 self.build_newtype_struct(Some(ident), from_py_with, ctx)
299 }
300 ContainerType::TupleNewtype(from_py_with, _) => {
301 self.build_newtype_struct(None, from_py_with, ctx)
302 }
303 ContainerType::Tuple(tups) => self.build_tuple_struct(tups, ctx),
304 ContainerType::Struct(tups) => self.build_struct(tups, ctx),
305 }
306 }
307
308 fn build_newtype_struct(
309 &self,
310 field_ident: Option<&Ident>,
311 from_py_with: &Option<FromPyWithAttribute>,
312 ctx: &Ctx,
313 ) -> TokenStream {
314 let Ctx { pyo3_path, .. } = ctx;
315 let self_ty = &self.path;
316 let struct_name = self.name();
317 if let Some(ident) = field_ident {
318 let field_name = ident.to_string();
319 if let Some(FromPyWithAttribute {
320 kw,
321 value: expr_path,
322 }) = from_py_with
323 {
324 let extractor = quote_spanned! { kw.span =>
325 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
326 };
327 quote! {
328 Ok(#self_ty {
329 #ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, obj, #struct_name, #field_name)?
330 })
331 }
332 } else {
333 quote! {
334 Ok(#self_ty {
335 #ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
336 })
337 }
338 }
339 } else if let Some(FromPyWithAttribute {
340 kw,
341 value: expr_path,
342 }) = from_py_with
343 {
344 let extractor = quote_spanned! { kw.span =>
345 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
346 };
347 quote! {
348 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, obj, #struct_name, 0).map(#self_ty)
349 }
350 } else {
351 quote! {
352 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
353 }
354 }
355 }
356
357 fn build_tuple_struct(&self, struct_fields: &[TupleStructField], ctx: &Ctx) -> TokenStream {
358 let Ctx { pyo3_path, .. } = ctx;
359 let self_ty = &self.path;
360 let struct_name = &self.name();
361 let field_idents: Vec<_> = (0..struct_fields.len())
362 .map(|i| format_ident!("arg{}", i))
363 .collect();
364 let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
365 if let Some(FromPyWithAttribute {
366 kw,
367 value: expr_path, ..
368 }) = &field.from_py_with {
369 let extractor = quote_spanned! { kw.span =>
370 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
371 };
372 quote! {
373 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, &#ident, #struct_name, #index)?
374 }
375 } else {
376 quote!{
377 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
378 }}
379 });
380
381 quote!(
382 match #pyo3_path::types::PyAnyMethods::extract(obj) {
383 ::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)),
384 ::std::result::Result::Err(err) => ::std::result::Result::Err(err),
385 }
386 )
387 }
388
389 fn build_struct(&self, struct_fields: &[NamedStructField<'_>], ctx: &Ctx) -> TokenStream {
390 let Ctx { pyo3_path, .. } = ctx;
391 let self_ty = &self.path;
392 let struct_name = self.name();
393 let mut fields: Punctuated<TokenStream, Token![,]> = Punctuated::new();
394 for field in struct_fields {
395 let ident = field.ident;
396 let field_name = ident.unraw().to_string();
397 let getter = match field
398 .getter
399 .as_ref()
400 .unwrap_or(&FieldGetter::GetAttr(parse_quote!(attribute), None))
401 {
402 FieldGetter::GetAttr(_, Some(name)) => {
403 quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
404 }
405 FieldGetter::GetAttr(_, None) => {
406 let name = self
407 .rename_rule
408 .map(|rule| utils::apply_renaming_rule(rule, &field_name));
409 let name = name.as_deref().unwrap_or(&field_name);
410 quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
411 }
412 FieldGetter::GetItem(_, Some(syn::Lit::Str(key))) => {
413 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #key)))
414 }
415 FieldGetter::GetItem(_, Some(key)) => {
416 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #key))
417 }
418 FieldGetter::GetItem(_, None) => {
419 let name = self
420 .rename_rule
421 .map(|rule| utils::apply_renaming_rule(rule, &field_name));
422 let name = name.as_deref().unwrap_or(&field_name);
423 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #name)))
424 }
425 };
426 let extractor = if let Some(FromPyWithAttribute {
427 kw,
428 value: expr_path,
429 }) = &field.from_py_with
430 {
431 let extractor = quote_spanned! { kw.span =>
432 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
433 };
434 quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, &#getter?, #struct_name, #field_name)?)
435 } else {
436 quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?)
437 };
438 let extracted = if let Some(default) = &field.default {
439 let default_expr = if let Some(default_expr) = &default.value {
440 default_expr.to_token_stream()
441 } else {
442 quote!(::std::default::Default::default())
443 };
444 quote!(if let ::std::result::Result::Ok(value) = #getter {
445 #extractor
446 } else {
447 #default_expr
448 })
449 } else {
450 quote!({
451 let value = #getter?;
452 #extractor
453 })
454 };
455
456 fields.push(quote!(#ident: #extracted));
457 }
458
459 quote!(::std::result::Result::Ok(#self_ty{#fields}))
460 }
461
462 #[cfg(feature = "experimental-inspect")]
463 fn write_input_type(&self, builder: &mut ConcatenationBuilder, ctx: &Ctx) {
464 match &self.ty {
465 ContainerType::StructNewtype(_, from_py_with, ty) => {
466 Self::write_field_input_type(from_py_with, ty, builder, ctx);
467 }
468 ContainerType::TupleNewtype(from_py_with, ty) => {
469 Self::write_field_input_type(from_py_with, ty, builder, ctx);
470 }
471 ContainerType::Tuple(tups) => {
472 builder.push_str("tuple[");
473 for (i, TupleStructField { from_py_with, ty }) in tups.iter().enumerate() {
474 if i > 0 {
475 builder.push_str(", ");
476 }
477 Self::write_field_input_type(from_py_with, ty, builder, ctx);
478 }
479 builder.push_str("]");
480 }
481 ContainerType::Struct(_) => {
482 builder.push_str("_typeshed.Incomplete")
484 }
485 }
486 }
487
488 #[cfg(feature = "experimental-inspect")]
489 fn write_field_input_type(
490 from_py_with: &Option<FromPyWithAttribute>,
491 ty: &syn::Type,
492 builder: &mut ConcatenationBuilder,
493 ctx: &Ctx,
494 ) {
495 if from_py_with.is_some() {
496 builder.push_str("_typeshed.Incomplete")
498 } else {
499 let ty = ty.clone().elide_lifetimes();
500 let pyo3_crate_path = &ctx.pyo3_path;
501 builder.push_tokens(
502 quote! { <#ty as #pyo3_crate_path::FromPyObject<'_>>::INPUT_TYPE.as_bytes() },
503 )
504 }
505 }
506}
507
508fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeParam>> {
509 let mut lifetimes = generics.lifetimes();
510 let lifetime = lifetimes.next();
511 ensure_spanned!(
512 lifetimes.next().is_none(),
513 generics.span() => "FromPyObject can be derived with at most one lifetime parameter"
514 );
515 Ok(lifetime)
516}
517
518pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
527 let options = ContainerAttributes::from_attrs(&tokens.attrs)?;
528 let ctx = &Ctx::new(&options.krate, None);
529 let Ctx { pyo3_path, .. } = &ctx;
530
531 let (_, ty_generics, _) = tokens.generics.split_for_impl();
532 let mut trait_generics = tokens.generics.clone();
533 let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics)? {
534 lt.clone()
535 } else {
536 trait_generics.params.push(parse_quote!('py));
537 parse_quote!('py)
538 };
539 let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
540
541 let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
542 for param in trait_generics.type_params() {
543 let gen_ident = ¶m.ident;
544 where_clause
545 .predicates
546 .push(parse_quote!(#gen_ident: #pyo3_path::FromPyObject<'py>))
547 }
548
549 let derives = match &tokens.data {
550 syn::Data::Enum(en) => {
551 if options.transparent.is_some() || options.annotation.is_some() {
552 bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \
553 at top level for enums");
554 }
555 let en = Enum::new(en, &tokens.ident, options.clone())?;
556 en.build(ctx)
557 }
558 syn::Data::Struct(st) => {
559 if let Some(lit_str) = &options.annotation {
560 bail_spanned!(lit_str.span() => "`annotation` is unsupported for structs");
561 }
562 let ident = &tokens.ident;
563 let st = Container::new(&st.fields, parse_quote!(#ident), options.clone())?;
564 st.build(ctx)
565 }
566 syn::Data::Union(_) => bail_spanned!(
567 tokens.span() => "#[derive(FromPyObject)] is not supported for unions"
568 ),
569 };
570
571 #[cfg(feature = "experimental-inspect")]
572 let input_type = {
573 let mut builder = ConcatenationBuilder::default();
574 if tokens
575 .generics
576 .params
577 .iter()
578 .all(|p| matches!(p, syn::GenericParam::Lifetime(_)))
579 {
580 match &tokens.data {
581 syn::Data::Enum(en) => {
582 Enum::new(en, &tokens.ident, options)?.write_input_type(&mut builder, ctx)
583 }
584 syn::Data::Struct(st) => {
585 let ident = &tokens.ident;
586 Container::new(&st.fields, parse_quote!(#ident), options.clone())?
587 .write_input_type(&mut builder, ctx)
588 }
589 syn::Data::Union(_) => {
590 builder.push_str("_typeshed.Incomplete")
592 }
593 }
594 } else {
595 builder.push_str("_typeshed.Incomplete")
598 };
599 let input_type = builder.into_token_stream(&ctx.pyo3_path);
600 quote! { const INPUT_TYPE: &'static str = unsafe { ::std::str::from_utf8_unchecked(#input_type) }; }
601 };
602 #[cfg(not(feature = "experimental-inspect"))]
603 let input_type = quote! {};
604
605 let ident = &tokens.ident;
606 Ok(quote!(
607 #[automatically_derived]
608 impl #impl_generics #pyo3_path::FromPyObject<#lt_param> for #ident #ty_generics #where_clause {
609 fn extract_bound(obj: &#pyo3_path::Bound<#lt_param, #pyo3_path::PyAny>) -> #pyo3_path::PyResult<Self> {
610 #derives
611 }
612 #input_type
613 }
614 ))
615}