pyo3_macros_backend/
attributes.rs

1use proc_macro2::TokenStream;
2use quote::{quote, ToTokens};
3use syn::parse::Parser;
4use syn::{
5    ext::IdentExt,
6    parse::{Parse, ParseStream},
7    punctuated::Punctuated,
8    spanned::Spanned,
9    token::Comma,
10    Attribute, Expr, ExprPath, Ident, Index, LitBool, LitStr, Member, Path, Result, Token,
11};
12
13use crate::combine_errors::CombineErrors;
14
15pub mod kw {
16    syn::custom_keyword!(annotation);
17    syn::custom_keyword!(attribute);
18    syn::custom_keyword!(cancel_handle);
19    syn::custom_keyword!(constructor);
20    syn::custom_keyword!(dict);
21    syn::custom_keyword!(eq);
22    syn::custom_keyword!(eq_int);
23    syn::custom_keyword!(extends);
24    syn::custom_keyword!(freelist);
25    syn::custom_keyword!(from_py_with);
26    syn::custom_keyword!(frozen);
27    syn::custom_keyword!(get);
28    syn::custom_keyword!(get_all);
29    syn::custom_keyword!(hash);
30    syn::custom_keyword!(into_py_with);
31    syn::custom_keyword!(item);
32    syn::custom_keyword!(immutable_type);
33    syn::custom_keyword!(from_item_all);
34    syn::custom_keyword!(mapping);
35    syn::custom_keyword!(module);
36    syn::custom_keyword!(name);
37    syn::custom_keyword!(ord);
38    syn::custom_keyword!(pass_module);
39    syn::custom_keyword!(rename_all);
40    syn::custom_keyword!(sequence);
41    syn::custom_keyword!(set);
42    syn::custom_keyword!(set_all);
43    syn::custom_keyword!(signature);
44    syn::custom_keyword!(str);
45    syn::custom_keyword!(subclass);
46    syn::custom_keyword!(submodule);
47    syn::custom_keyword!(text_signature);
48    syn::custom_keyword!(transparent);
49    syn::custom_keyword!(unsendable);
50    syn::custom_keyword!(weakref);
51    syn::custom_keyword!(generic);
52    syn::custom_keyword!(gil_used);
53    syn::custom_keyword!(warn);
54    syn::custom_keyword!(message);
55    syn::custom_keyword!(category);
56}
57
58fn take_int(read: &mut &str, tracker: &mut usize) -> String {
59    let mut int = String::new();
60    for (i, ch) in read.char_indices() {
61        match ch {
62            '0'..='9' => {
63                *tracker += 1;
64                int.push(ch)
65            }
66            _ => {
67                *read = &read[i..];
68                break;
69            }
70        }
71    }
72    int
73}
74
75fn take_ident(read: &mut &str, tracker: &mut usize) -> Ident {
76    let mut ident = String::new();
77    if read.starts_with("r#") {
78        ident.push_str("r#");
79        *tracker += 2;
80        *read = &read[2..];
81    }
82    for (i, ch) in read.char_indices() {
83        match ch {
84            'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => {
85                *tracker += 1;
86                ident.push(ch)
87            }
88            _ => {
89                *read = &read[i..];
90                break;
91            }
92        }
93    }
94    Ident::parse_any.parse_str(&ident).unwrap()
95}
96
97// shorthand parsing logic inspiration taken from https://github.com/dtolnay/thiserror/blob/master/impl/src/fmt.rs
98fn parse_shorthand_format(fmt: LitStr) -> Result<(LitStr, Vec<Member>)> {
99    let span = fmt.span();
100    let token = fmt.token();
101    let value = fmt.value();
102    let mut read = value.as_str();
103    let mut out = String::new();
104    let mut members = Vec::new();
105    let mut tracker = 1;
106    while let Some(brace) = read.find('{') {
107        tracker += brace;
108        out += &read[..brace + 1];
109        read = &read[brace + 1..];
110        if read.starts_with('{') {
111            out.push('{');
112            read = &read[1..];
113            tracker += 2;
114            continue;
115        }
116        let next = match read.chars().next() {
117            Some(next) => next,
118            None => break,
119        };
120        tracker += 1;
121        let member = match next {
122            '0'..='9' => {
123                let start = tracker;
124                let index = take_int(&mut read, &mut tracker).parse::<u32>().unwrap();
125                let end = tracker;
126                let subspan = token.subspan(start..end).unwrap_or(span);
127                let idx = Index {
128                    index,
129                    span: subspan,
130                };
131                Member::Unnamed(idx)
132            }
133            'a'..='z' | 'A'..='Z' | '_' => {
134                let start = tracker;
135                let mut ident = take_ident(&mut read, &mut tracker);
136                let end = tracker;
137                let subspan = token.subspan(start..end).unwrap_or(span);
138                ident.set_span(subspan);
139                Member::Named(ident)
140            }
141            '}' | ':' => {
142                let start = tracker;
143                tracker += 1;
144                let end = tracker;
145                let subspan = token.subspan(start..end).unwrap_or(span);
146                // we found a closing bracket or formatting ':' without finding a member, we assume the user wants the instance formatted here
147                bail_spanned!(subspan.span() => "No member found, you must provide a named or positionally specified member.")
148            }
149            _ => continue,
150        };
151        members.push(member);
152    }
153    out += read;
154    Ok((LitStr::new(&out, span), members))
155}
156
157#[derive(Clone, Debug)]
158pub struct StringFormatter {
159    pub fmt: LitStr,
160    pub args: Vec<Member>,
161}
162
163impl Parse for crate::attributes::StringFormatter {
164    fn parse(input: ParseStream<'_>) -> Result<Self> {
165        let (fmt, args) = parse_shorthand_format(input.parse()?)?;
166        Ok(Self { fmt, args })
167    }
168}
169
170impl ToTokens for crate::attributes::StringFormatter {
171    fn to_tokens(&self, tokens: &mut TokenStream) {
172        self.fmt.to_tokens(tokens);
173        tokens.extend(quote! {self.args})
174    }
175}
176
177#[derive(Clone, Debug)]
178pub struct KeywordAttribute<K, V> {
179    pub kw: K,
180    pub value: V,
181}
182
183#[derive(Clone, Debug)]
184pub struct OptionalKeywordAttribute<K, V> {
185    pub kw: K,
186    pub value: Option<V>,
187}
188
189/// A helper type which parses the inner type via a literal string
190/// e.g. `LitStrValue<Path>` -> parses "some::path" in quotes.
191#[derive(Clone, Debug, PartialEq, Eq)]
192pub struct LitStrValue<T>(pub T);
193
194impl<T: Parse> Parse for LitStrValue<T> {
195    fn parse(input: ParseStream<'_>) -> Result<Self> {
196        let lit_str: LitStr = input.parse()?;
197        lit_str.parse().map(LitStrValue)
198    }
199}
200
201impl<T: ToTokens> ToTokens for LitStrValue<T> {
202    fn to_tokens(&self, tokens: &mut TokenStream) {
203        self.0.to_tokens(tokens)
204    }
205}
206
207/// A helper type which parses a name via a literal string
208#[derive(Clone, Debug, PartialEq, Eq)]
209pub struct NameLitStr(pub Ident);
210
211impl Parse for NameLitStr {
212    fn parse(input: ParseStream<'_>) -> Result<Self> {
213        let string_literal: LitStr = input.parse()?;
214        if let Ok(ident) = string_literal.parse_with(Ident::parse_any) {
215            Ok(NameLitStr(ident))
216        } else {
217            bail_spanned!(string_literal.span() => "expected a single identifier in double quotes")
218        }
219    }
220}
221
222impl ToTokens for NameLitStr {
223    fn to_tokens(&self, tokens: &mut TokenStream) {
224        self.0.to_tokens(tokens)
225    }
226}
227
228/// Available renaming rules
229#[derive(Clone, Copy, Debug, PartialEq, Eq)]
230pub enum RenamingRule {
231    CamelCase,
232    KebabCase,
233    Lowercase,
234    PascalCase,
235    ScreamingKebabCase,
236    ScreamingSnakeCase,
237    SnakeCase,
238    Uppercase,
239}
240
241/// A helper type which parses a renaming rule via a literal string
242#[derive(Clone, Debug, PartialEq, Eq)]
243pub struct RenamingRuleLitStr {
244    pub lit: LitStr,
245    pub rule: RenamingRule,
246}
247
248impl Parse for RenamingRuleLitStr {
249    fn parse(input: ParseStream<'_>) -> Result<Self> {
250        let string_literal: LitStr = input.parse()?;
251        let rule = match string_literal.value().as_ref() {
252            "camelCase" => RenamingRule::CamelCase,
253            "kebab-case" => RenamingRule::KebabCase,
254            "lowercase" => RenamingRule::Lowercase,
255            "PascalCase" => RenamingRule::PascalCase,
256            "SCREAMING-KEBAB-CASE" => RenamingRule::ScreamingKebabCase,
257            "SCREAMING_SNAKE_CASE" => RenamingRule::ScreamingSnakeCase,
258            "snake_case" => RenamingRule::SnakeCase,
259            "UPPERCASE" => RenamingRule::Uppercase,
260            _ => {
261                bail_spanned!(string_literal.span() => "expected a valid renaming rule, possible values are: \"camelCase\", \"kebab-case\", \"lowercase\", \"PascalCase\", \"SCREAMING-KEBAB-CASE\", \"SCREAMING_SNAKE_CASE\", \"snake_case\", \"UPPERCASE\"")
262            }
263        };
264        Ok(Self {
265            lit: string_literal,
266            rule,
267        })
268    }
269}
270
271impl ToTokens for RenamingRuleLitStr {
272    fn to_tokens(&self, tokens: &mut TokenStream) {
273        self.lit.to_tokens(tokens)
274    }
275}
276
277/// Text signatue can be either a literal string or opt-in/out
278#[derive(Clone, Debug, PartialEq, Eq)]
279pub enum TextSignatureAttributeValue {
280    Str(LitStr),
281    // `None` ident to disable automatic text signature generation
282    Disabled(Ident),
283}
284
285impl Parse for TextSignatureAttributeValue {
286    fn parse(input: ParseStream<'_>) -> Result<Self> {
287        if let Ok(lit_str) = input.parse::<LitStr>() {
288            return Ok(TextSignatureAttributeValue::Str(lit_str));
289        }
290
291        let err_span = match input.parse::<Ident>() {
292            Ok(ident) if ident == "None" => {
293                return Ok(TextSignatureAttributeValue::Disabled(ident));
294            }
295            Ok(other_ident) => other_ident.span(),
296            Err(e) => e.span(),
297        };
298
299        Err(err_spanned!(err_span => "expected a string literal or `None`"))
300    }
301}
302
303impl ToTokens for TextSignatureAttributeValue {
304    fn to_tokens(&self, tokens: &mut TokenStream) {
305        match self {
306            TextSignatureAttributeValue::Str(s) => s.to_tokens(tokens),
307            TextSignatureAttributeValue::Disabled(b) => b.to_tokens(tokens),
308        }
309    }
310}
311
312pub type ExtendsAttribute = KeywordAttribute<kw::extends, Path>;
313pub type FreelistAttribute = KeywordAttribute<kw::freelist, Box<Expr>>;
314pub type ModuleAttribute = KeywordAttribute<kw::module, LitStr>;
315pub type NameAttribute = KeywordAttribute<kw::name, NameLitStr>;
316pub type RenameAllAttribute = KeywordAttribute<kw::rename_all, RenamingRuleLitStr>;
317pub type StrFormatterAttribute = OptionalKeywordAttribute<kw::str, StringFormatter>;
318pub type TextSignatureAttribute = KeywordAttribute<kw::text_signature, TextSignatureAttributeValue>;
319pub type SubmoduleAttribute = kw::submodule;
320pub type GILUsedAttribute = KeywordAttribute<kw::gil_used, LitBool>;
321
322impl<K: Parse + std::fmt::Debug, V: Parse> Parse for KeywordAttribute<K, V> {
323    fn parse(input: ParseStream<'_>) -> Result<Self> {
324        let kw: K = input.parse()?;
325        let _: Token![=] = input.parse()?;
326        let value = input.parse()?;
327        Ok(KeywordAttribute { kw, value })
328    }
329}
330
331impl<K: ToTokens, V: ToTokens> ToTokens for KeywordAttribute<K, V> {
332    fn to_tokens(&self, tokens: &mut TokenStream) {
333        self.kw.to_tokens(tokens);
334        Token![=](self.kw.span()).to_tokens(tokens);
335        self.value.to_tokens(tokens);
336    }
337}
338
339impl<K: Parse + std::fmt::Debug, V: Parse> Parse for OptionalKeywordAttribute<K, V> {
340    fn parse(input: ParseStream<'_>) -> Result<Self> {
341        let kw: K = input.parse()?;
342        let value = match input.parse::<Token![=]>() {
343            Ok(_) => Some(input.parse()?),
344            Err(_) => None,
345        };
346        Ok(OptionalKeywordAttribute { kw, value })
347    }
348}
349
350impl<K: ToTokens, V: ToTokens> ToTokens for OptionalKeywordAttribute<K, V> {
351    fn to_tokens(&self, tokens: &mut TokenStream) {
352        self.kw.to_tokens(tokens);
353        if self.value.is_some() {
354            Token![=](self.kw.span()).to_tokens(tokens);
355            self.value.to_tokens(tokens);
356        }
357    }
358}
359
360pub type FromPyWithAttribute = KeywordAttribute<kw::from_py_with, ExprPath>;
361pub type IntoPyWithAttribute = KeywordAttribute<kw::into_py_with, ExprPath>;
362
363pub type DefaultAttribute = OptionalKeywordAttribute<Token![default], Expr>;
364
365/// For specifying the path to the pyo3 crate.
366pub type CrateAttribute = KeywordAttribute<Token![crate], LitStrValue<Path>>;
367
368pub fn get_pyo3_options<T: Parse>(attr: &syn::Attribute) -> Result<Option<Punctuated<T, Comma>>> {
369    if attr.path().is_ident("pyo3") {
370        attr.parse_args_with(Punctuated::parse_terminated).map(Some)
371    } else {
372        Ok(None)
373    }
374}
375
376/// Takes attributes from an attribute vector.
377///
378/// For each attribute in `attrs`, `extractor` is called. If `extractor` returns `Ok(true)`, then
379/// the attribute will be removed from the vector.
380///
381/// This is similar to `Vec::retain` except the closure is fallible and the condition is reversed.
382/// (In `retain`, returning `true` keeps the element, here it removes it.)
383pub fn take_attributes(
384    attrs: &mut Vec<Attribute>,
385    mut extractor: impl FnMut(&Attribute) -> Result<bool>,
386) -> Result<()> {
387    *attrs = attrs
388        .drain(..)
389        .filter_map(|attr| {
390            extractor(&attr)
391                .map(move |attribute_handled| if attribute_handled { None } else { Some(attr) })
392                .transpose()
393        })
394        .collect::<Result<_>>()?;
395    Ok(())
396}
397
398pub fn take_pyo3_options<T: Parse>(attrs: &mut Vec<syn::Attribute>) -> Result<Vec<T>> {
399    let mut out = Vec::new();
400
401    take_attributes(attrs, |attr| match get_pyo3_options(attr) {
402        Ok(result) => {
403            if let Some(options) = result {
404                out.extend(options.into_iter().map(|a| Ok(a)));
405                Ok(true)
406            } else {
407                Ok(false)
408            }
409        }
410        Err(err) => {
411            out.push(Err(err));
412            Ok(true)
413        }
414    })?;
415
416    let out: Vec<T> = out.into_iter().try_combine_syn_errors()?;
417
418    Ok(out)
419}