pyo3_macros_backend/
attributes.rs

1use proc_macro2::TokenStream;
2use quote::{quote, ToTokens};
3use syn::parse::Parser;
4use syn::{
5    ext::IdentExt,
6    parse::{Parse, ParseStream},
7    punctuated::Punctuated,
8    spanned::Spanned,
9    token::Comma,
10    Attribute, Expr, ExprPath, Ident, Index, LitBool, LitStr, Member, Path, Result, Token,
11};
12
13pub mod kw {
14    syn::custom_keyword!(annotation);
15    syn::custom_keyword!(attribute);
16    syn::custom_keyword!(cancel_handle);
17    syn::custom_keyword!(constructor);
18    syn::custom_keyword!(dict);
19    syn::custom_keyword!(eq);
20    syn::custom_keyword!(eq_int);
21    syn::custom_keyword!(extends);
22    syn::custom_keyword!(freelist);
23    syn::custom_keyword!(from_py_with);
24    syn::custom_keyword!(frozen);
25    syn::custom_keyword!(get);
26    syn::custom_keyword!(get_all);
27    syn::custom_keyword!(hash);
28    syn::custom_keyword!(item);
29    syn::custom_keyword!(from_item_all);
30    syn::custom_keyword!(mapping);
31    syn::custom_keyword!(module);
32    syn::custom_keyword!(name);
33    syn::custom_keyword!(ord);
34    syn::custom_keyword!(pass_module);
35    syn::custom_keyword!(rename_all);
36    syn::custom_keyword!(sequence);
37    syn::custom_keyword!(set);
38    syn::custom_keyword!(set_all);
39    syn::custom_keyword!(signature);
40    syn::custom_keyword!(str);
41    syn::custom_keyword!(subclass);
42    syn::custom_keyword!(submodule);
43    syn::custom_keyword!(text_signature);
44    syn::custom_keyword!(transparent);
45    syn::custom_keyword!(unsendable);
46    syn::custom_keyword!(weakref);
47    syn::custom_keyword!(gil_used);
48}
49
50fn take_int(read: &mut &str, tracker: &mut usize) -> String {
51    let mut int = String::new();
52    for (i, ch) in read.char_indices() {
53        match ch {
54            '0'..='9' => {
55                *tracker += 1;
56                int.push(ch)
57            }
58            _ => {
59                *read = &read[i..];
60                break;
61            }
62        }
63    }
64    int
65}
66
67fn take_ident(read: &mut &str, tracker: &mut usize) -> Ident {
68    let mut ident = String::new();
69    if read.starts_with("r#") {
70        ident.push_str("r#");
71        *tracker += 2;
72        *read = &read[2..];
73    }
74    for (i, ch) in read.char_indices() {
75        match ch {
76            'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => {
77                *tracker += 1;
78                ident.push(ch)
79            }
80            _ => {
81                *read = &read[i..];
82                break;
83            }
84        }
85    }
86    Ident::parse_any.parse_str(&ident).unwrap()
87}
88
89// shorthand parsing logic inspiration taken from https://github.com/dtolnay/thiserror/blob/master/impl/src/fmt.rs
90fn parse_shorthand_format(fmt: LitStr) -> Result<(LitStr, Vec<Member>)> {
91    let span = fmt.span();
92    let token = fmt.token();
93    let value = fmt.value();
94    let mut read = value.as_str();
95    let mut out = String::new();
96    let mut members = Vec::new();
97    let mut tracker = 1;
98    while let Some(brace) = read.find('{') {
99        tracker += brace;
100        out += &read[..brace + 1];
101        read = &read[brace + 1..];
102        if read.starts_with('{') {
103            out.push('{');
104            read = &read[1..];
105            tracker += 2;
106            continue;
107        }
108        let next = match read.chars().next() {
109            Some(next) => next,
110            None => break,
111        };
112        tracker += 1;
113        let member = match next {
114            '0'..='9' => {
115                let start = tracker;
116                let index = take_int(&mut read, &mut tracker).parse::<u32>().unwrap();
117                let end = tracker;
118                let subspan = token.subspan(start..end).unwrap_or(span);
119                let idx = Index {
120                    index,
121                    span: subspan,
122                };
123                Member::Unnamed(idx)
124            }
125            'a'..='z' | 'A'..='Z' | '_' => {
126                let start = tracker;
127                let mut ident = take_ident(&mut read, &mut tracker);
128                let end = tracker;
129                let subspan = token.subspan(start..end).unwrap_or(span);
130                ident.set_span(subspan);
131                Member::Named(ident)
132            }
133            '}' | ':' => {
134                let start = tracker;
135                tracker += 1;
136                let end = tracker;
137                let subspan = token.subspan(start..end).unwrap_or(span);
138                // we found a closing bracket or formatting ':' without finding a member, we assume the user wants the instance formatted here
139                bail_spanned!(subspan.span() => "No member found, you must provide a named or positionally specified member.")
140            }
141            _ => continue,
142        };
143        members.push(member);
144    }
145    out += read;
146    Ok((LitStr::new(&out, span), members))
147}
148
149#[derive(Clone, Debug)]
150pub struct StringFormatter {
151    pub fmt: LitStr,
152    pub args: Vec<Member>,
153}
154
155impl Parse for crate::attributes::StringFormatter {
156    fn parse(input: ParseStream<'_>) -> Result<Self> {
157        let (fmt, args) = parse_shorthand_format(input.parse()?)?;
158        Ok(Self { fmt, args })
159    }
160}
161
162impl ToTokens for crate::attributes::StringFormatter {
163    fn to_tokens(&self, tokens: &mut TokenStream) {
164        self.fmt.to_tokens(tokens);
165        tokens.extend(quote! {self.args})
166    }
167}
168
169#[derive(Clone, Debug)]
170pub struct KeywordAttribute<K, V> {
171    pub kw: K,
172    pub value: V,
173}
174
175#[derive(Clone, Debug)]
176pub struct OptionalKeywordAttribute<K, V> {
177    pub kw: K,
178    pub value: Option<V>,
179}
180
181/// A helper type which parses the inner type via a literal string
182/// e.g. `LitStrValue<Path>` -> parses "some::path" in quotes.
183#[derive(Clone, Debug, PartialEq, Eq)]
184pub struct LitStrValue<T>(pub T);
185
186impl<T: Parse> Parse for LitStrValue<T> {
187    fn parse(input: ParseStream<'_>) -> Result<Self> {
188        let lit_str: LitStr = input.parse()?;
189        lit_str.parse().map(LitStrValue)
190    }
191}
192
193impl<T: ToTokens> ToTokens for LitStrValue<T> {
194    fn to_tokens(&self, tokens: &mut TokenStream) {
195        self.0.to_tokens(tokens)
196    }
197}
198
199/// A helper type which parses a name via a literal string
200#[derive(Clone, Debug, PartialEq, Eq)]
201pub struct NameLitStr(pub Ident);
202
203impl Parse for NameLitStr {
204    fn parse(input: ParseStream<'_>) -> Result<Self> {
205        let string_literal: LitStr = input.parse()?;
206        if let Ok(ident) = string_literal.parse_with(Ident::parse_any) {
207            Ok(NameLitStr(ident))
208        } else {
209            bail_spanned!(string_literal.span() => "expected a single identifier in double quotes")
210        }
211    }
212}
213
214impl ToTokens for NameLitStr {
215    fn to_tokens(&self, tokens: &mut TokenStream) {
216        self.0.to_tokens(tokens)
217    }
218}
219
220/// Available renaming rules
221#[derive(Clone, Copy, Debug, PartialEq, Eq)]
222pub enum RenamingRule {
223    CamelCase,
224    KebabCase,
225    Lowercase,
226    PascalCase,
227    ScreamingKebabCase,
228    ScreamingSnakeCase,
229    SnakeCase,
230    Uppercase,
231}
232
233/// A helper type which parses a renaming rule via a literal string
234#[derive(Clone, Debug, PartialEq, Eq)]
235pub struct RenamingRuleLitStr {
236    pub lit: LitStr,
237    pub rule: RenamingRule,
238}
239
240impl Parse for RenamingRuleLitStr {
241    fn parse(input: ParseStream<'_>) -> Result<Self> {
242        let string_literal: LitStr = input.parse()?;
243        let rule = match string_literal.value().as_ref() {
244            "camelCase" => RenamingRule::CamelCase,
245            "kebab-case" => RenamingRule::KebabCase,
246            "lowercase" => RenamingRule::Lowercase,
247            "PascalCase" => RenamingRule::PascalCase,
248            "SCREAMING-KEBAB-CASE" => RenamingRule::ScreamingKebabCase,
249            "SCREAMING_SNAKE_CASE" => RenamingRule::ScreamingSnakeCase,
250            "snake_case" => RenamingRule::SnakeCase,
251            "UPPERCASE" => RenamingRule::Uppercase,
252            _ => {
253                bail_spanned!(string_literal.span() => "expected a valid renaming rule, possible values are: \"camelCase\", \"kebab-case\", \"lowercase\", \"PascalCase\", \"SCREAMING-KEBAB-CASE\", \"SCREAMING_SNAKE_CASE\", \"snake_case\", \"UPPERCASE\"")
254            }
255        };
256        Ok(Self {
257            lit: string_literal,
258            rule,
259        })
260    }
261}
262
263impl ToTokens for RenamingRuleLitStr {
264    fn to_tokens(&self, tokens: &mut TokenStream) {
265        self.lit.to_tokens(tokens)
266    }
267}
268
269/// Text signatue can be either a literal string or opt-in/out
270#[derive(Clone, Debug, PartialEq, Eq)]
271pub enum TextSignatureAttributeValue {
272    Str(LitStr),
273    // `None` ident to disable automatic text signature generation
274    Disabled(Ident),
275}
276
277impl Parse for TextSignatureAttributeValue {
278    fn parse(input: ParseStream<'_>) -> Result<Self> {
279        if let Ok(lit_str) = input.parse::<LitStr>() {
280            return Ok(TextSignatureAttributeValue::Str(lit_str));
281        }
282
283        let err_span = match input.parse::<Ident>() {
284            Ok(ident) if ident == "None" => {
285                return Ok(TextSignatureAttributeValue::Disabled(ident));
286            }
287            Ok(other_ident) => other_ident.span(),
288            Err(e) => e.span(),
289        };
290
291        Err(err_spanned!(err_span => "expected a string literal or `None`"))
292    }
293}
294
295impl ToTokens for TextSignatureAttributeValue {
296    fn to_tokens(&self, tokens: &mut TokenStream) {
297        match self {
298            TextSignatureAttributeValue::Str(s) => s.to_tokens(tokens),
299            TextSignatureAttributeValue::Disabled(b) => b.to_tokens(tokens),
300        }
301    }
302}
303
304pub type ExtendsAttribute = KeywordAttribute<kw::extends, Path>;
305pub type FreelistAttribute = KeywordAttribute<kw::freelist, Box<Expr>>;
306pub type ModuleAttribute = KeywordAttribute<kw::module, LitStr>;
307pub type NameAttribute = KeywordAttribute<kw::name, NameLitStr>;
308pub type RenameAllAttribute = KeywordAttribute<kw::rename_all, RenamingRuleLitStr>;
309pub type StrFormatterAttribute = OptionalKeywordAttribute<kw::str, StringFormatter>;
310pub type TextSignatureAttribute = KeywordAttribute<kw::text_signature, TextSignatureAttributeValue>;
311pub type SubmoduleAttribute = kw::submodule;
312pub type GILUsedAttribute = KeywordAttribute<kw::gil_used, LitBool>;
313
314impl<K: Parse + std::fmt::Debug, V: Parse> Parse for KeywordAttribute<K, V> {
315    fn parse(input: ParseStream<'_>) -> Result<Self> {
316        let kw: K = input.parse()?;
317        let _: Token![=] = input.parse()?;
318        let value = input.parse()?;
319        Ok(KeywordAttribute { kw, value })
320    }
321}
322
323impl<K: ToTokens, V: ToTokens> ToTokens for KeywordAttribute<K, V> {
324    fn to_tokens(&self, tokens: &mut TokenStream) {
325        self.kw.to_tokens(tokens);
326        Token![=](self.kw.span()).to_tokens(tokens);
327        self.value.to_tokens(tokens);
328    }
329}
330
331impl<K: Parse + std::fmt::Debug, V: Parse> Parse for OptionalKeywordAttribute<K, V> {
332    fn parse(input: ParseStream<'_>) -> Result<Self> {
333        let kw: K = input.parse()?;
334        let value = match input.parse::<Token![=]>() {
335            Ok(_) => Some(input.parse()?),
336            Err(_) => None,
337        };
338        Ok(OptionalKeywordAttribute { kw, value })
339    }
340}
341
342impl<K: ToTokens, V: ToTokens> ToTokens for OptionalKeywordAttribute<K, V> {
343    fn to_tokens(&self, tokens: &mut TokenStream) {
344        self.kw.to_tokens(tokens);
345        if self.value.is_some() {
346            Token![=](self.kw.span()).to_tokens(tokens);
347            self.value.to_tokens(tokens);
348        }
349    }
350}
351
352pub type FromPyWithAttribute = KeywordAttribute<kw::from_py_with, LitStrValue<ExprPath>>;
353
354/// For specifying the path to the pyo3 crate.
355pub type CrateAttribute = KeywordAttribute<Token![crate], LitStrValue<Path>>;
356
357pub fn get_pyo3_options<T: Parse>(attr: &syn::Attribute) -> Result<Option<Punctuated<T, Comma>>> {
358    if attr.path().is_ident("pyo3") {
359        attr.parse_args_with(Punctuated::parse_terminated).map(Some)
360    } else {
361        Ok(None)
362    }
363}
364
365/// Takes attributes from an attribute vector.
366///
367/// For each attribute in `attrs`, `extractor` is called. If `extractor` returns `Ok(true)`, then
368/// the attribute will be removed from the vector.
369///
370/// This is similar to `Vec::retain` except the closure is fallible and the condition is reversed.
371/// (In `retain`, returning `true` keeps the element, here it removes it.)
372pub fn take_attributes(
373    attrs: &mut Vec<Attribute>,
374    mut extractor: impl FnMut(&Attribute) -> Result<bool>,
375) -> Result<()> {
376    *attrs = attrs
377        .drain(..)
378        .filter_map(|attr| {
379            extractor(&attr)
380                .map(move |attribute_handled| if attribute_handled { None } else { Some(attr) })
381                .transpose()
382        })
383        .collect::<Result<_>>()?;
384    Ok(())
385}
386
387pub fn take_pyo3_options<T: Parse>(attrs: &mut Vec<syn::Attribute>) -> Result<Vec<T>> {
388    let mut out = Vec::new();
389    let mut all_errors = ErrorCombiner(None);
390    take_attributes(attrs, |attr| match get_pyo3_options(attr) {
391        Ok(result) => {
392            if let Some(options) = result {
393                out.extend(options);
394                Ok(true)
395            } else {
396                Ok(false)
397            }
398        }
399        Err(err) => {
400            all_errors.combine(err);
401            Ok(true)
402        }
403    })?;
404    all_errors.ensure_empty()?;
405    Ok(out)
406}
407
408pub struct ErrorCombiner(pub Option<syn::Error>);
409
410impl ErrorCombiner {
411    pub fn combine(&mut self, error: syn::Error) {
412        if let Some(existing) = &mut self.0 {
413            existing.combine(error);
414        } else {
415            self.0 = Some(error);
416        }
417    }
418
419    pub fn ensure_empty(self) -> Result<()> {
420        if let Some(error) = self.0 {
421            Err(error)
422        } else {
423            Ok(())
424        }
425    }
426}