pyo3_macros_backend/
frompyobject.rs

1use crate::attributes::{DefaultAttribute, FromPyWithAttribute, RenamingRule};
2use crate::derive_attributes::{ContainerAttributes, FieldAttributes, FieldGetter};
3#[cfg(feature = "experimental-inspect")]
4use crate::introspection::ConcatenationBuilder;
5#[cfg(feature = "experimental-inspect")]
6use crate::utils::TypeExt;
7use crate::utils::{self, Ctx};
8use proc_macro2::TokenStream;
9use quote::{format_ident, quote, quote_spanned, ToTokens};
10use syn::{
11    ext::IdentExt, parse_quote, punctuated::Punctuated, spanned::Spanned, DataEnum, DeriveInput,
12    Fields, Ident, Result, Token,
13};
14
15/// Describes derivation input of an enum.
16struct Enum<'a> {
17    enum_ident: &'a Ident,
18    variants: Vec<Container<'a>>,
19}
20
21impl<'a> Enum<'a> {
22    /// Construct a new enum representation.
23    ///
24    /// `data_enum` is the `syn` representation of the input enum, `ident` is the
25    /// `Identifier` of the enum.
26    fn new(
27        data_enum: &'a DataEnum,
28        ident: &'a Ident,
29        options: ContainerAttributes,
30    ) -> Result<Self> {
31        ensure_spanned!(
32            !data_enum.variants.is_empty(),
33            ident.span() => "cannot derive FromPyObject for empty enum"
34        );
35        let variants = data_enum
36            .variants
37            .iter()
38            .map(|variant| {
39                let mut variant_options = ContainerAttributes::from_attrs(&variant.attrs)?;
40                if let Some(rename_all) = &options.rename_all {
41                    ensure_spanned!(
42                        variant_options.rename_all.is_none(),
43                        variant_options.rename_all.span() => "Useless variant `rename_all` - enum is already annotated with `rename_all"
44                    );
45                    variant_options.rename_all = Some(rename_all.clone());
46
47                }
48                let var_ident = &variant.ident;
49                Container::new(
50                    &variant.fields,
51                    parse_quote!(#ident::#var_ident),
52                    variant_options,
53                )
54            })
55            .collect::<Result<Vec<_>>>()?;
56
57        Ok(Enum {
58            enum_ident: ident,
59            variants,
60        })
61    }
62
63    /// Build derivation body for enums.
64    fn build(&self, ctx: &Ctx) -> TokenStream {
65        let Ctx { pyo3_path, .. } = ctx;
66        let mut var_extracts = Vec::new();
67        let mut variant_names = Vec::new();
68        let mut error_names = Vec::new();
69
70        for var in &self.variants {
71            let struct_derive = var.build(ctx);
72            let ext = quote!({
73                let maybe_ret = || -> #pyo3_path::PyResult<Self> {
74                    #struct_derive
75                }();
76
77                match maybe_ret {
78                    ok @ ::std::result::Result::Ok(_) => return ok,
79                    ::std::result::Result::Err(err) => err
80                }
81            });
82
83            var_extracts.push(ext);
84            variant_names.push(var.path.segments.last().unwrap().ident.to_string());
85            error_names.push(&var.err_name);
86        }
87        let ty_name = self.enum_ident.to_string();
88        quote!(
89            let errors = [
90                #(#var_extracts),*
91            ];
92            ::std::result::Result::Err(
93                #pyo3_path::impl_::frompyobject::failed_to_extract_enum(
94                    obj.py(),
95                    #ty_name,
96                    &[#(#variant_names),*],
97                    &[#(#error_names),*],
98                    &errors
99                )
100            )
101        )
102    }
103
104    #[cfg(feature = "experimental-inspect")]
105    fn write_input_type(&self, builder: &mut ConcatenationBuilder, ctx: &Ctx) {
106        for (i, var) in self.variants.iter().enumerate() {
107            if i > 0 {
108                builder.push_str(" | ");
109            }
110            var.write_input_type(builder, ctx);
111        }
112    }
113}
114
115struct NamedStructField<'a> {
116    ident: &'a syn::Ident,
117    getter: Option<FieldGetter>,
118    from_py_with: Option<FromPyWithAttribute>,
119    default: Option<DefaultAttribute>,
120    ty: &'a syn::Type,
121}
122
123struct TupleStructField {
124    from_py_with: Option<FromPyWithAttribute>,
125    ty: syn::Type,
126}
127
128/// Container Style
129///
130/// Covers Structs, Tuplestructs and corresponding Newtypes.
131enum ContainerType<'a> {
132    /// Struct Container, e.g. `struct Foo { a: String }`
133    ///
134    /// Variant contains the list of field identifiers and the corresponding extraction call.
135    Struct(Vec<NamedStructField<'a>>),
136    /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
137    ///
138    /// The field specified by the identifier is extracted directly from the object.
139    #[cfg_attr(not(feature = "experimental-inspect"), allow(unused))]
140    StructNewtype(&'a syn::Ident, Option<FromPyWithAttribute>, &'a syn::Type),
141    /// Tuple struct, e.g. `struct Foo(String)`.
142    ///
143    /// Variant contains a list of conversion methods for each of the fields that are directly
144    ///  extracted from the tuple.
145    Tuple(Vec<TupleStructField>),
146    /// Tuple newtype, e.g. `#[transparent] struct Foo(String)`
147    ///
148    /// The wrapped field is directly extracted from the object.
149    #[cfg_attr(not(feature = "experimental-inspect"), allow(unused))]
150    TupleNewtype(Option<FromPyWithAttribute>, Box<syn::Type>),
151}
152
153/// Data container
154///
155/// Either describes a struct or an enum variant.
156struct Container<'a> {
157    path: syn::Path,
158    ty: ContainerType<'a>,
159    err_name: String,
160    rename_rule: Option<RenamingRule>,
161}
162
163impl<'a> Container<'a> {
164    /// Construct a container based on fields, identifier and attributes.
165    ///
166    /// Fails if the variant has no fields or incompatible attributes.
167    fn new(fields: &'a Fields, path: syn::Path, options: ContainerAttributes) -> Result<Self> {
168        let style = match fields {
169            Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
170                ensure_spanned!(
171                    options.rename_all.is_none(),
172                    options.rename_all.span() => "`rename_all` is useless on tuple structs and variants."
173                );
174                let mut tuple_fields = unnamed
175                    .unnamed
176                    .iter()
177                    .map(|field| {
178                        let attrs = FieldAttributes::from_attrs(&field.attrs)?;
179                        ensure_spanned!(
180                            attrs.getter.is_none(),
181                            field.span() => "`getter` is not permitted on tuple struct elements."
182                        );
183                        ensure_spanned!(
184                            attrs.default.is_none(),
185                            field.span() => "`default` is not permitted on tuple struct elements."
186                        );
187                        Ok(TupleStructField {
188                            from_py_with: attrs.from_py_with,
189                            ty: field.ty.clone(),
190                        })
191                    })
192                    .collect::<Result<Vec<_>>>()?;
193
194                if tuple_fields.len() == 1 {
195                    // Always treat a 1-length tuple struct as "transparent", even without the
196                    // explicit annotation.
197                    let field = tuple_fields.pop().unwrap();
198                    ContainerType::TupleNewtype(field.from_py_with, Box::new(field.ty))
199                } else if options.transparent.is_some() {
200                    bail_spanned!(
201                        fields.span() => "transparent structs and variants can only have 1 field"
202                    );
203                } else {
204                    ContainerType::Tuple(tuple_fields)
205                }
206            }
207            Fields::Named(named) if !named.named.is_empty() => {
208                let mut struct_fields = named
209                    .named
210                    .iter()
211                    .map(|field| {
212                        let ident = field
213                            .ident
214                            .as_ref()
215                            .expect("Named fields should have identifiers");
216                        let mut attrs = FieldAttributes::from_attrs(&field.attrs)?;
217
218                        if let Some(ref from_item_all) = options.from_item_all {
219                            if let Some(replaced) = attrs.getter.replace(FieldGetter::GetItem(parse_quote!(item), None))
220                            {
221                                match replaced {
222                                    FieldGetter::GetItem(item, Some(item_name)) => {
223                                        attrs.getter = Some(FieldGetter::GetItem(item, Some(item_name)));
224                                    }
225                                    FieldGetter::GetItem(_, None) => bail_spanned!(from_item_all.span() => "Useless `item` - the struct is already annotated with `from_item_all`"),
226                                    FieldGetter::GetAttr(_, _) => bail_spanned!(
227                                        from_item_all.span() => "The struct is already annotated with `from_item_all`, `attribute` is not allowed"
228                                    ),
229                                }
230                            }
231                        }
232
233                        Ok(NamedStructField {
234                            ident,
235                            getter: attrs.getter,
236                            from_py_with: attrs.from_py_with,
237                            default: attrs.default,
238                            ty: &field.ty,
239                        })
240                    })
241                    .collect::<Result<Vec<_>>>()?;
242                if struct_fields.iter().all(|field| field.default.is_some()) {
243                    bail_spanned!(
244                        fields.span() => "cannot derive FromPyObject for structs and variants with only default values"
245                    )
246                } else if options.transparent.is_some() {
247                    ensure_spanned!(
248                        struct_fields.len() == 1,
249                        fields.span() => "transparent structs and variants can only have 1 field"
250                    );
251                    ensure_spanned!(
252                        options.rename_all.is_none(),
253                        options.rename_all.span() => "`rename_all` is not permitted on `transparent` structs and variants"
254                    );
255                    let field = struct_fields.pop().unwrap();
256                    ensure_spanned!(
257                        field.getter.is_none(),
258                        field.ident.span() => "`transparent` structs may not have a `getter` for the inner field"
259                    );
260                    ContainerType::StructNewtype(field.ident, field.from_py_with, field.ty)
261                } else {
262                    ContainerType::Struct(struct_fields)
263                }
264            }
265            _ => bail_spanned!(
266                fields.span() => "cannot derive FromPyObject for empty structs and variants"
267            ),
268        };
269        let err_name = options.annotation.map_or_else(
270            || path.segments.last().unwrap().ident.to_string(),
271            |lit_str| lit_str.value(),
272        );
273
274        let v = Container {
275            path,
276            ty: style,
277            err_name,
278            rename_rule: options.rename_all.map(|v| v.value.rule),
279        };
280        Ok(v)
281    }
282
283    fn name(&self) -> String {
284        let mut value = String::new();
285        for segment in &self.path.segments {
286            if !value.is_empty() {
287                value.push_str("::");
288            }
289            value.push_str(&segment.ident.to_string());
290        }
291        value
292    }
293
294    /// Build derivation body for a struct.
295    fn build(&self, ctx: &Ctx) -> TokenStream {
296        match &self.ty {
297            ContainerType::StructNewtype(ident, from_py_with, _) => {
298                self.build_newtype_struct(Some(ident), from_py_with, ctx)
299            }
300            ContainerType::TupleNewtype(from_py_with, _) => {
301                self.build_newtype_struct(None, from_py_with, ctx)
302            }
303            ContainerType::Tuple(tups) => self.build_tuple_struct(tups, ctx),
304            ContainerType::Struct(tups) => self.build_struct(tups, ctx),
305        }
306    }
307
308    fn build_newtype_struct(
309        &self,
310        field_ident: Option<&Ident>,
311        from_py_with: &Option<FromPyWithAttribute>,
312        ctx: &Ctx,
313    ) -> TokenStream {
314        let Ctx { pyo3_path, .. } = ctx;
315        let self_ty = &self.path;
316        let struct_name = self.name();
317        if let Some(ident) = field_ident {
318            let field_name = ident.to_string();
319            if let Some(FromPyWithAttribute {
320                kw,
321                value: expr_path,
322            }) = from_py_with
323            {
324                let extractor = quote_spanned! { kw.span =>
325                    { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
326                };
327                quote! {
328                    Ok(#self_ty {
329                        #ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, obj, #struct_name, #field_name)?
330                    })
331                }
332            } else {
333                quote! {
334                    Ok(#self_ty {
335                        #ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
336                    })
337                }
338            }
339        } else if let Some(FromPyWithAttribute {
340            kw,
341            value: expr_path,
342        }) = from_py_with
343        {
344            let extractor = quote_spanned! { kw.span =>
345                { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
346            };
347            quote! {
348                #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, obj, #struct_name, 0).map(#self_ty)
349            }
350        } else {
351            quote! {
352                #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
353            }
354        }
355    }
356
357    fn build_tuple_struct(&self, struct_fields: &[TupleStructField], ctx: &Ctx) -> TokenStream {
358        let Ctx { pyo3_path, .. } = ctx;
359        let self_ty = &self.path;
360        let struct_name = &self.name();
361        let field_idents: Vec<_> = (0..struct_fields.len())
362            .map(|i| format_ident!("arg{}", i))
363            .collect();
364        let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
365            if let Some(FromPyWithAttribute {
366                kw,
367                value: expr_path, ..
368            }) = &field.from_py_with {
369                let extractor = quote_spanned! { kw.span =>
370                    { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
371                };
372               quote! {
373                    #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, &#ident, #struct_name, #index)?
374               }
375            } else {
376                quote!{
377                    #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
378            }}
379        });
380
381        quote!(
382            match #pyo3_path::types::PyAnyMethods::extract(obj) {
383                ::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)),
384                ::std::result::Result::Err(err) => ::std::result::Result::Err(err),
385            }
386        )
387    }
388
389    fn build_struct(&self, struct_fields: &[NamedStructField<'_>], ctx: &Ctx) -> TokenStream {
390        let Ctx { pyo3_path, .. } = ctx;
391        let self_ty = &self.path;
392        let struct_name = self.name();
393        let mut fields: Punctuated<TokenStream, Token![,]> = Punctuated::new();
394        for field in struct_fields {
395            let ident = field.ident;
396            let field_name = ident.unraw().to_string();
397            let getter = match field
398                .getter
399                .as_ref()
400                .unwrap_or(&FieldGetter::GetAttr(parse_quote!(attribute), None))
401            {
402                FieldGetter::GetAttr(_, Some(name)) => {
403                    quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
404                }
405                FieldGetter::GetAttr(_, None) => {
406                    let name = self
407                        .rename_rule
408                        .map(|rule| utils::apply_renaming_rule(rule, &field_name));
409                    let name = name.as_deref().unwrap_or(&field_name);
410                    quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
411                }
412                FieldGetter::GetItem(_, Some(syn::Lit::Str(key))) => {
413                    quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #key)))
414                }
415                FieldGetter::GetItem(_, Some(key)) => {
416                    quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #key))
417                }
418                FieldGetter::GetItem(_, None) => {
419                    let name = self
420                        .rename_rule
421                        .map(|rule| utils::apply_renaming_rule(rule, &field_name));
422                    let name = name.as_deref().unwrap_or(&field_name);
423                    quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #name)))
424                }
425            };
426            let extractor = if let Some(FromPyWithAttribute {
427                kw,
428                value: expr_path,
429            }) = &field.from_py_with
430            {
431                let extractor = quote_spanned! { kw.span =>
432                    { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
433                };
434                quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, &#getter?, #struct_name, #field_name)?)
435            } else {
436                quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?)
437            };
438            let extracted = if let Some(default) = &field.default {
439                let default_expr = if let Some(default_expr) = &default.value {
440                    default_expr.to_token_stream()
441                } else {
442                    quote!(::std::default::Default::default())
443                };
444                quote!(if let ::std::result::Result::Ok(value) = #getter {
445                    #extractor
446                } else {
447                    #default_expr
448                })
449            } else {
450                quote!({
451                    let value = #getter?;
452                    #extractor
453                })
454            };
455
456            fields.push(quote!(#ident: #extracted));
457        }
458
459        quote!(::std::result::Result::Ok(#self_ty{#fields}))
460    }
461
462    #[cfg(feature = "experimental-inspect")]
463    fn write_input_type(&self, builder: &mut ConcatenationBuilder, ctx: &Ctx) {
464        match &self.ty {
465            ContainerType::StructNewtype(_, from_py_with, ty) => {
466                Self::write_field_input_type(from_py_with, ty, builder, ctx);
467            }
468            ContainerType::TupleNewtype(from_py_with, ty) => {
469                Self::write_field_input_type(from_py_with, ty, builder, ctx);
470            }
471            ContainerType::Tuple(tups) => {
472                builder.push_str("tuple[");
473                for (i, TupleStructField { from_py_with, ty }) in tups.iter().enumerate() {
474                    if i > 0 {
475                        builder.push_str(", ");
476                    }
477                    Self::write_field_input_type(from_py_with, ty, builder, ctx);
478                }
479                builder.push_str("]");
480            }
481            ContainerType::Struct(_) => {
482                // TODO: implement using a Protocol?
483                builder.push_str("_typeshed.Incomplete")
484            }
485        }
486    }
487
488    #[cfg(feature = "experimental-inspect")]
489    fn write_field_input_type(
490        from_py_with: &Option<FromPyWithAttribute>,
491        ty: &syn::Type,
492        builder: &mut ConcatenationBuilder,
493        ctx: &Ctx,
494    ) {
495        if from_py_with.is_some() {
496            // We don't know what from_py_with is doing
497            builder.push_str("_typeshed.Incomplete")
498        } else {
499            let ty = ty.clone().elide_lifetimes();
500            let pyo3_crate_path = &ctx.pyo3_path;
501            builder.push_tokens(
502                quote! { <#ty as #pyo3_crate_path::FromPyObject<'_>>::INPUT_TYPE.as_bytes() },
503            )
504        }
505    }
506}
507
508fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeParam>> {
509    let mut lifetimes = generics.lifetimes();
510    let lifetime = lifetimes.next();
511    ensure_spanned!(
512        lifetimes.next().is_none(),
513        generics.span() => "FromPyObject can be derived with at most one lifetime parameter"
514    );
515    Ok(lifetime)
516}
517
518/// Derive FromPyObject for enums and structs.
519///
520///   * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier
521///   * At least one field, in case of `#[transparent]`, exactly one field
522///   * At least one variant for enums.
523///   * Fields of input structs and enums must implement `FromPyObject` or be annotated with `from_py_with`
524///   * Derivation for structs with generic fields like `struct<T> Foo(T)`
525///     adds `T: FromPyObject` on the derived implementation.
526pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
527    let options = ContainerAttributes::from_attrs(&tokens.attrs)?;
528    let ctx = &Ctx::new(&options.krate, None);
529    let Ctx { pyo3_path, .. } = &ctx;
530
531    let (_, ty_generics, _) = tokens.generics.split_for_impl();
532    let mut trait_generics = tokens.generics.clone();
533    let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics)? {
534        lt.clone()
535    } else {
536        trait_generics.params.push(parse_quote!('py));
537        parse_quote!('py)
538    };
539    let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
540
541    let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
542    for param in trait_generics.type_params() {
543        let gen_ident = &param.ident;
544        where_clause
545            .predicates
546            .push(parse_quote!(#gen_ident: #pyo3_path::FromPyObject<'py>))
547    }
548
549    let derives = match &tokens.data {
550        syn::Data::Enum(en) => {
551            if options.transparent.is_some() || options.annotation.is_some() {
552                bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \
553                                                at top level for enums");
554            }
555            let en = Enum::new(en, &tokens.ident, options.clone())?;
556            en.build(ctx)
557        }
558        syn::Data::Struct(st) => {
559            if let Some(lit_str) = &options.annotation {
560                bail_spanned!(lit_str.span() => "`annotation` is unsupported for structs");
561            }
562            let ident = &tokens.ident;
563            let st = Container::new(&st.fields, parse_quote!(#ident), options.clone())?;
564            st.build(ctx)
565        }
566        syn::Data::Union(_) => bail_spanned!(
567            tokens.span() => "#[derive(FromPyObject)] is not supported for unions"
568        ),
569    };
570
571    #[cfg(feature = "experimental-inspect")]
572    let input_type = {
573        let mut builder = ConcatenationBuilder::default();
574        if tokens
575            .generics
576            .params
577            .iter()
578            .all(|p| matches!(p, syn::GenericParam::Lifetime(_)))
579        {
580            match &tokens.data {
581                syn::Data::Enum(en) => {
582                    Enum::new(en, &tokens.ident, options)?.write_input_type(&mut builder, ctx)
583                }
584                syn::Data::Struct(st) => {
585                    let ident = &tokens.ident;
586                    Container::new(&st.fields, parse_quote!(#ident), options.clone())?
587                        .write_input_type(&mut builder, ctx)
588                }
589                syn::Data::Union(_) => {
590                    // Not supported at this point
591                    builder.push_str("_typeshed.Incomplete")
592                }
593            }
594        } else {
595            // We don't know how to deal with generic parameters
596            // Blocked by https://github.com/rust-lang/rust/issues/76560
597            builder.push_str("_typeshed.Incomplete")
598        };
599        let input_type = builder.into_token_stream(&ctx.pyo3_path);
600        quote! { const INPUT_TYPE: &'static str = unsafe { ::std::str::from_utf8_unchecked(#input_type) }; }
601    };
602    #[cfg(not(feature = "experimental-inspect"))]
603    let input_type = quote! {};
604
605    let ident = &tokens.ident;
606    Ok(quote!(
607        #[automatically_derived]
608        impl #impl_generics #pyo3_path::FromPyObject<#lt_param> for #ident #ty_generics #where_clause {
609            fn extract_bound(obj: &#pyo3_path::Bound<#lt_param, #pyo3_path::PyAny>) -> #pyo3_path::PyResult<Self>  {
610                #derives
611            }
612            #input_type
613        }
614    ))
615}