pyo3_macros_backend/
intopyobject.rs

1use crate::attributes::{IntoPyWithAttribute, RenamingRule};
2use crate::derive_attributes::{ContainerAttributes, FieldAttributes};
3use crate::utils::{self, Ctx};
4use proc_macro2::{Span, TokenStream};
5use quote::{format_ident, quote, quote_spanned, ToTokens};
6use syn::ext::IdentExt;
7use syn::spanned::Spanned as _;
8use syn::{parse_quote, DataEnum, DeriveInput, Fields, Ident, Index, Result};
9
10struct ItemOption(Option<syn::Lit>);
11
12enum IntoPyObjectTypes {
13    Transparent(syn::Type),
14    Opaque {
15        target: TokenStream,
16        output: TokenStream,
17        error: TokenStream,
18    },
19}
20
21struct IntoPyObjectImpl {
22    types: IntoPyObjectTypes,
23    body: TokenStream,
24}
25
26struct NamedStructField<'a> {
27    ident: &'a syn::Ident,
28    field: &'a syn::Field,
29    item: Option<ItemOption>,
30    into_py_with: Option<IntoPyWithAttribute>,
31}
32
33struct TupleStructField<'a> {
34    field: &'a syn::Field,
35    into_py_with: Option<IntoPyWithAttribute>,
36}
37
38/// Container Style
39///
40/// Covers Structs, Tuplestructs and corresponding Newtypes.
41enum ContainerType<'a> {
42    /// Struct Container, e.g. `struct Foo { a: String }`
43    ///
44    /// Variant contains the list of field identifiers and the corresponding extraction call.
45    Struct(Vec<NamedStructField<'a>>),
46    /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
47    ///
48    /// The field specified by the identifier is extracted directly from the object.
49    StructNewtype(&'a syn::Field),
50    /// Tuple struct, e.g. `struct Foo(String)`.
51    ///
52    /// Variant contains a list of conversion methods for each of the fields that are directly
53    ///  extracted from the tuple.
54    Tuple(Vec<TupleStructField<'a>>),
55    /// Tuple newtype, e.g. `#[transparent] struct Foo(String)`
56    ///
57    /// The wrapped field is directly extracted from the object.
58    TupleNewtype(&'a syn::Field),
59}
60
61/// Data container
62///
63/// Either describes a struct or an enum variant.
64struct Container<'a, const REF: bool> {
65    path: syn::Path,
66    receiver: Option<Ident>,
67    ty: ContainerType<'a>,
68    rename_rule: Option<RenamingRule>,
69}
70
71/// Construct a container based on fields, identifier and attributes.
72impl<'a, const REF: bool> Container<'a, REF> {
73    ///
74    /// Fails if the variant has no fields or incompatible attributes.
75    fn new(
76        receiver: Option<Ident>,
77        fields: &'a Fields,
78        path: syn::Path,
79        options: ContainerAttributes,
80    ) -> Result<Self> {
81        let style = match fields {
82            Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
83                ensure_spanned!(
84                    options.rename_all.is_none(),
85                    options.rename_all.span() => "`rename_all` is useless on tuple structs and variants."
86                );
87                let mut tuple_fields = unnamed
88                    .unnamed
89                    .iter()
90                    .map(|field| {
91                        let attrs = FieldAttributes::from_attrs(&field.attrs)?;
92                        ensure_spanned!(
93                            attrs.getter.is_none(),
94                            attrs.getter.unwrap().span() => "`item` and `attribute` are not permitted on tuple struct elements."
95                        );
96                        Ok(TupleStructField {
97                            field,
98                            into_py_with: attrs.into_py_with,
99                        })
100                    })
101                    .collect::<Result<Vec<_>>>()?;
102                if tuple_fields.len() == 1 {
103                    // Always treat a 1-length tuple struct as "transparent", even without the
104                    // explicit annotation.
105                    let TupleStructField {
106                        field,
107                        into_py_with,
108                    } = tuple_fields.pop().unwrap();
109                    ensure_spanned!(
110                        into_py_with.is_none(),
111                        into_py_with.span() => "`into_py_with` is not permitted on `transparent` structs"
112                    );
113                    ContainerType::TupleNewtype(field)
114                } else if options.transparent.is_some() {
115                    bail_spanned!(
116                        fields.span() => "transparent structs and variants can only have 1 field"
117                    );
118                } else {
119                    ContainerType::Tuple(tuple_fields)
120                }
121            }
122            Fields::Named(named) if !named.named.is_empty() => {
123                if options.transparent.is_some() {
124                    ensure_spanned!(
125                        named.named.iter().count() == 1,
126                        fields.span() => "transparent structs and variants can only have 1 field"
127                    );
128
129                    let field = named.named.iter().next().unwrap();
130                    let attrs = FieldAttributes::from_attrs(&field.attrs)?;
131                    ensure_spanned!(
132                        attrs.getter.is_none(),
133                        attrs.getter.unwrap().span() => "`transparent` structs may not have `item` nor `attribute` for the inner field"
134                    );
135                    ensure_spanned!(
136                        options.rename_all.is_none(),
137                        options.rename_all.span() => "`rename_all` is not permitted on `transparent` structs and variants"
138                    );
139                    ensure_spanned!(
140                        attrs.into_py_with.is_none(),
141                        attrs.into_py_with.span() => "`into_py_with` is not permitted on `transparent` structs or variants"
142                    );
143                    ContainerType::StructNewtype(field)
144                } else {
145                    let struct_fields = named
146                        .named
147                        .iter()
148                        .map(|field| {
149                            let ident = field
150                                .ident
151                                .as_ref()
152                                .expect("Named fields should have identifiers");
153
154                            let attrs = FieldAttributes::from_attrs(&field.attrs)?;
155
156                            Ok(NamedStructField {
157                                ident,
158                                field,
159                                item: attrs.getter.and_then(|getter| match getter {
160                                    crate::derive_attributes::FieldGetter::GetItem(_, lit) => {
161                                        Some(ItemOption(lit))
162                                    }
163                                    crate::derive_attributes::FieldGetter::GetAttr(_, _) => None,
164                                }),
165                                into_py_with: attrs.into_py_with,
166                            })
167                        })
168                        .collect::<Result<Vec<_>>>()?;
169                    ContainerType::Struct(struct_fields)
170                }
171            }
172            _ => bail_spanned!(
173                fields.span() => "cannot derive `IntoPyObject` for empty structs"
174            ),
175        };
176
177        let v = Container {
178            path,
179            receiver,
180            ty: style,
181            rename_rule: options.rename_all.map(|v| v.value.rule),
182        };
183        Ok(v)
184    }
185
186    fn match_pattern(&self) -> TokenStream {
187        let path = &self.path;
188        let pattern = match &self.ty {
189            ContainerType::Struct(fields) => fields
190                .iter()
191                .enumerate()
192                .map(|(i, f)| {
193                    let ident = f.ident;
194                    let new_ident = format_ident!("arg{i}");
195                    quote! {#ident: #new_ident,}
196                })
197                .collect::<TokenStream>(),
198            ContainerType::StructNewtype(field) => {
199                let ident = field.ident.as_ref().unwrap();
200                quote!(#ident: arg0)
201            }
202            ContainerType::Tuple(fields) => {
203                let i = (0..fields.len()).map(Index::from);
204                let idents = (0..fields.len()).map(|i| format_ident!("arg{i}"));
205                quote! { #(#i: #idents,)* }
206            }
207            ContainerType::TupleNewtype(_) => quote!(0: arg0),
208        };
209
210        quote! { #path{ #pattern } }
211    }
212
213    /// Build derivation body for a struct.
214    fn build(&self, ctx: &Ctx) -> IntoPyObjectImpl {
215        match &self.ty {
216            ContainerType::StructNewtype(field) | ContainerType::TupleNewtype(field) => {
217                self.build_newtype_struct(field, ctx)
218            }
219            ContainerType::Tuple(fields) => self.build_tuple_struct(fields, ctx),
220            ContainerType::Struct(fields) => self.build_struct(fields, ctx),
221        }
222    }
223
224    fn build_newtype_struct(&self, field: &syn::Field, ctx: &Ctx) -> IntoPyObjectImpl {
225        let Ctx { pyo3_path, .. } = ctx;
226        let ty = &field.ty;
227
228        let unpack = self
229            .receiver
230            .as_ref()
231            .map(|i| {
232                let pattern = self.match_pattern();
233                quote! { let #pattern = #i;}
234            })
235            .unwrap_or_default();
236
237        IntoPyObjectImpl {
238            types: IntoPyObjectTypes::Transparent(ty.clone()),
239            body: quote_spanned! { ty.span() =>
240                #unpack
241                #pyo3_path::conversion::IntoPyObject::into_pyobject(arg0, py)
242            },
243        }
244    }
245
246    fn build_struct(&self, fields: &[NamedStructField<'_>], ctx: &Ctx) -> IntoPyObjectImpl {
247        let Ctx { pyo3_path, .. } = ctx;
248
249        let unpack = self
250            .receiver
251            .as_ref()
252            .map(|i| {
253                let pattern = self.match_pattern();
254                quote! { let #pattern = #i;}
255            })
256            .unwrap_or_default();
257
258        let setter = fields
259            .iter()
260            .enumerate()
261            .map(|(i, f)| {
262                let key = f
263                    .item
264                    .as_ref()
265                    .and_then(|item| item.0.as_ref())
266                    .map(|item| item.into_token_stream())
267                    .unwrap_or_else(|| {
268                        let name = f.ident.unraw().to_string();
269                        self.rename_rule.map(|rule| utils::apply_renaming_rule(rule, &name)).unwrap_or(name).into_token_stream()
270                    });
271                let value = Ident::new(&format!("arg{i}"), f.field.ty.span());
272
273                if let Some(expr_path) = f.into_py_with.as_ref().map(|i|&i.value) {
274                    let cow = if REF {
275                        quote!(::std::borrow::Cow::Borrowed(#value))
276                    } else {
277                        quote!(::std::borrow::Cow::Owned(#value))
278                    };
279                    quote! {
280                        let into_py_with: fn(::std::borrow::Cow<'_, _>, #pyo3_path::Python<'py>) -> #pyo3_path::PyResult<#pyo3_path::Bound<'py, #pyo3_path::PyAny>> = #expr_path;
281                        #pyo3_path::types::PyDictMethods::set_item(&dict, #key, into_py_with(#cow, py)?)?;
282                    }
283                } else {
284                    quote! {
285                        #pyo3_path::types::PyDictMethods::set_item(&dict, #key, #value)?;
286                    }
287                }
288            })
289            .collect::<TokenStream>();
290
291        IntoPyObjectImpl {
292            types: IntoPyObjectTypes::Opaque {
293                target: quote!(#pyo3_path::types::PyDict),
294                output: quote!(#pyo3_path::Bound<'py, Self::Target>),
295                error: quote!(#pyo3_path::PyErr),
296            },
297            body: quote! {
298                #unpack
299                let dict = #pyo3_path::types::PyDict::new(py);
300                #setter
301                ::std::result::Result::Ok::<_, Self::Error>(dict)
302            },
303        }
304    }
305
306    fn build_tuple_struct(&self, fields: &[TupleStructField<'_>], ctx: &Ctx) -> IntoPyObjectImpl {
307        let Ctx { pyo3_path, .. } = ctx;
308
309        let unpack = self
310            .receiver
311            .as_ref()
312            .map(|i| {
313                let pattern = self.match_pattern();
314                quote! { let #pattern = #i;}
315            })
316            .unwrap_or_default();
317
318        let setter = fields
319            .iter()
320            .enumerate()
321            .map(|(i, f)| {
322                let ty = &f.field.ty;
323                let value = Ident::new(&format!("arg{i}"), f.field.ty.span());
324
325                if let Some(expr_path) = f.into_py_with.as_ref().map(|i|&i.value) {
326                    let cow = if REF {
327                        quote!(::std::borrow::Cow::Borrowed(#value))
328                    } else {
329                        quote!(::std::borrow::Cow::Owned(#value))
330                    };
331                    quote_spanned! { ty.span() =>
332                        {
333                            let into_py_with: fn(::std::borrow::Cow<'_, _>, #pyo3_path::Python<'py>) -> #pyo3_path::PyResult<#pyo3_path::Bound<'py, #pyo3_path::PyAny>> = #expr_path;
334                            into_py_with(#cow, py)?
335                        },
336                    }
337                } else {
338                    quote_spanned! { ty.span() =>
339                        #pyo3_path::conversion::IntoPyObject::into_pyobject(#value, py)
340                            .map(#pyo3_path::BoundObject::into_any)
341                            .map(#pyo3_path::BoundObject::into_bound)?,
342                    }
343                }
344            })
345            .collect::<TokenStream>();
346
347        IntoPyObjectImpl {
348            types: IntoPyObjectTypes::Opaque {
349                target: quote!(#pyo3_path::types::PyTuple),
350                output: quote!(#pyo3_path::Bound<'py, Self::Target>),
351                error: quote!(#pyo3_path::PyErr),
352            },
353            body: quote! {
354                #unpack
355                #pyo3_path::types::PyTuple::new(py, [#setter])
356            },
357        }
358    }
359}
360
361/// Describes derivation input of an enum.
362struct Enum<'a, const REF: bool> {
363    variants: Vec<Container<'a, REF>>,
364}
365
366impl<'a, const REF: bool> Enum<'a, REF> {
367    /// Construct a new enum representation.
368    ///
369    /// `data_enum` is the `syn` representation of the input enum, `ident` is the
370    /// `Identifier` of the enum.
371    fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
372        ensure_spanned!(
373            !data_enum.variants.is_empty(),
374            ident.span() => "cannot derive `IntoPyObject` for empty enum"
375        );
376        let variants = data_enum
377            .variants
378            .iter()
379            .map(|variant| {
380                let attrs = ContainerAttributes::from_attrs(&variant.attrs)?;
381                let var_ident = &variant.ident;
382
383                ensure_spanned!(
384                    !variant.fields.is_empty(),
385                    variant.ident.span() => "cannot derive `IntoPyObject` for empty variants"
386                );
387
388                Container::new(
389                    None,
390                    &variant.fields,
391                    parse_quote!(#ident::#var_ident),
392                    attrs,
393                )
394            })
395            .collect::<Result<Vec<_>>>()?;
396
397        Ok(Enum { variants })
398    }
399
400    /// Build derivation body for enums.
401    fn build(&self, ctx: &Ctx) -> IntoPyObjectImpl {
402        let Ctx { pyo3_path, .. } = ctx;
403
404        let variants = self
405            .variants
406            .iter()
407            .map(|v| {
408                let IntoPyObjectImpl { body, .. } = v.build(ctx);
409                let pattern = v.match_pattern();
410                quote! {
411                    #pattern => {
412                        {#body}
413                            .map(#pyo3_path::BoundObject::into_any)
414                            .map(#pyo3_path::BoundObject::into_bound)
415                            .map_err(::std::convert::Into::<#pyo3_path::PyErr>::into)
416                    }
417                }
418            })
419            .collect::<TokenStream>();
420
421        IntoPyObjectImpl {
422            types: IntoPyObjectTypes::Opaque {
423                target: quote!(#pyo3_path::types::PyAny),
424                output: quote!(#pyo3_path::Bound<'py, <Self as #pyo3_path::conversion::IntoPyObject<'py>>::Target>),
425                error: quote!(#pyo3_path::PyErr),
426            },
427            body: quote! {
428                match self {
429                    #variants
430                }
431            },
432        }
433    }
434}
435
436// if there is a `'py` lifetime, we treat it as the `Python<'py>` lifetime
437fn verify_and_get_lifetime(generics: &syn::Generics) -> Option<&syn::LifetimeParam> {
438    let mut lifetimes = generics.lifetimes();
439    lifetimes.find(|l| l.lifetime.ident == "py")
440}
441
442pub fn build_derive_into_pyobject<const REF: bool>(tokens: &DeriveInput) -> Result<TokenStream> {
443    let options = ContainerAttributes::from_attrs(&tokens.attrs)?;
444    let ctx = &Ctx::new(&options.krate, None);
445    let Ctx { pyo3_path, .. } = &ctx;
446
447    let (_, ty_generics, _) = tokens.generics.split_for_impl();
448    let mut trait_generics = tokens.generics.clone();
449    if REF {
450        trait_generics.params.push(parse_quote!('_a));
451    }
452    let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics) {
453        lt.clone()
454    } else {
455        trait_generics.params.push(parse_quote!('py));
456        parse_quote!('py)
457    };
458    let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
459
460    let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
461    for param in trait_generics.type_params() {
462        let gen_ident = &param.ident;
463        where_clause.predicates.push(if REF {
464            parse_quote!(&'_a #gen_ident: #pyo3_path::conversion::IntoPyObject<'py>)
465        } else {
466            parse_quote!(#gen_ident: #pyo3_path::conversion::IntoPyObject<'py>)
467        })
468    }
469
470    let IntoPyObjectImpl { types, body } = match &tokens.data {
471        syn::Data::Enum(en) => {
472            if options.transparent.is_some() {
473                bail_spanned!(tokens.span() => "`transparent` is not supported at top level for enums");
474            }
475            if let Some(rename_all) = options.rename_all {
476                bail_spanned!(rename_all.span() => "`rename_all` is not supported at top level for enums");
477            }
478            let en = Enum::<REF>::new(en, &tokens.ident)?;
479            en.build(ctx)
480        }
481        syn::Data::Struct(st) => {
482            let ident = &tokens.ident;
483            let st = Container::<REF>::new(
484                Some(Ident::new("self", Span::call_site())),
485                &st.fields,
486                parse_quote!(#ident),
487                options,
488            )?;
489            st.build(ctx)
490        }
491        syn::Data::Union(_) => bail_spanned!(
492            tokens.span() => "#[derive(`IntoPyObject`)] is not supported for unions"
493        ),
494    };
495
496    let (target, output, error) = match types {
497        IntoPyObjectTypes::Transparent(ty) => {
498            if REF {
499                (
500                    quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Target },
501                    quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Output },
502                    quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Error },
503                )
504            } else {
505                (
506                    quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Target },
507                    quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Output },
508                    quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Error },
509                )
510            }
511        }
512        IntoPyObjectTypes::Opaque {
513            target,
514            output,
515            error,
516        } => (target, output, error),
517    };
518
519    let ident = &tokens.ident;
520    let ident = if REF {
521        quote! { &'_a #ident}
522    } else {
523        quote! { #ident }
524    };
525    Ok(quote!(
526        #[automatically_derived]
527        impl #impl_generics #pyo3_path::conversion::IntoPyObject<#lt_param> for #ident #ty_generics #where_clause {
528            type Target = #target;
529            type Output = #output;
530            type Error = #error;
531
532            fn into_pyobject(self, py: #pyo3_path::Python<#lt_param>) -> ::std::result::Result<
533                <Self as #pyo3_path::conversion::IntoPyObject<#lt_param>>::Output,
534                <Self as #pyo3_path::conversion::IntoPyObject<#lt_param>>::Error,
535            > {
536                #body
537            }
538        }
539    ))
540}