use crate::{
attributes::{
self, kw, take_attributes, take_pyo3_options, CrateAttribute, ModuleAttribute,
NameAttribute, SubmoduleAttribute,
},
get_doc,
pyclass::PyClassPyO3Option,
pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
utils::{Ctx, LitCStr, PyO3CratePath},
};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use std::ffi::CString;
use syn::{
ext::IdentExt,
parse::{Parse, ParseStream},
parse_quote, parse_quote_spanned,
punctuated::Punctuated,
spanned::Spanned,
token::Comma,
Item, Meta, Path, Result,
};
#[derive(Default)]
pub struct PyModuleOptions {
krate: Option<CrateAttribute>,
name: Option<NameAttribute>,
module: Option<ModuleAttribute>,
submodule: Option<kw::submodule>,
}
impl Parse for PyModuleOptions {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let mut options: PyModuleOptions = Default::default();
options.add_attributes(
Punctuated::<PyModulePyO3Option, syn::Token![,]>::parse_terminated(input)?,
)?;
Ok(options)
}
}
impl PyModuleOptions {
fn take_pyo3_options(&mut self, attrs: &mut Vec<syn::Attribute>) -> Result<()> {
self.add_attributes(take_pyo3_options(attrs)?)
}
fn add_attributes(
&mut self,
attrs: impl IntoIterator<Item = PyModulePyO3Option>,
) -> Result<()> {
macro_rules! set_option {
($key:ident) => {
{
ensure_spanned!(
self.$key.is_none(),
$key.span() => concat!("`", stringify!($key), "` may only be specified once")
);
self.$key = Some($key);
}
};
}
for attr in attrs {
match attr {
PyModulePyO3Option::Crate(krate) => set_option!(krate),
PyModulePyO3Option::Name(name) => set_option!(name),
PyModulePyO3Option::Module(module) => set_option!(module),
PyModulePyO3Option::Submodule(submodule) => set_option!(submodule),
}
}
Ok(())
}
}
pub fn pymodule_module_impl(
module: &mut syn::ItemMod,
mut options: PyModuleOptions,
) -> Result<TokenStream> {
let syn::ItemMod {
attrs,
vis,
unsafety: _,
ident,
mod_token,
content,
semi: _,
} = module;
let items = if let Some((_, items)) = content {
items
} else {
bail_spanned!(mod_token.span() => "`#[pymodule]` can only be used on inline modules")
};
options.take_pyo3_options(attrs)?;
let ctx = &Ctx::new(&options.krate, None);
let Ctx { pyo3_path, .. } = ctx;
let doc = get_doc(attrs, None, ctx);
let name = options
.name
.map_or_else(|| ident.unraw(), |name| name.value.0);
let full_name = if let Some(module) = &options.module {
format!("{}.{}", module.value.value(), name)
} else {
name.to_string()
};
let mut module_items = Vec::new();
let mut module_items_cfg_attrs = Vec::new();
fn extract_use_items(
source: &syn::UseTree,
cfg_attrs: &[syn::Attribute],
target_items: &mut Vec<syn::Ident>,
target_cfg_attrs: &mut Vec<Vec<syn::Attribute>>,
) -> Result<()> {
match source {
syn::UseTree::Name(name) => {
target_items.push(name.ident.clone());
target_cfg_attrs.push(cfg_attrs.to_vec());
}
syn::UseTree::Path(path) => {
extract_use_items(&path.tree, cfg_attrs, target_items, target_cfg_attrs)?
}
syn::UseTree::Group(group) => {
for tree in &group.items {
extract_use_items(tree, cfg_attrs, target_items, target_cfg_attrs)?
}
}
syn::UseTree::Glob(glob) => {
bail_spanned!(glob.span() => "#[pymodule] cannot import glob statements")
}
syn::UseTree::Rename(rename) => {
target_items.push(rename.rename.clone());
target_cfg_attrs.push(cfg_attrs.to_vec());
}
}
Ok(())
}
let mut pymodule_init = None;
for item in &mut *items {
match item {
Item::Use(item_use) => {
let is_pymodule_export =
find_and_remove_attribute(&mut item_use.attrs, "pymodule_export");
if is_pymodule_export {
let cfg_attrs = get_cfg_attributes(&item_use.attrs);
extract_use_items(
&item_use.tree,
&cfg_attrs,
&mut module_items,
&mut module_items_cfg_attrs,
)?;
}
}
Item::Fn(item_fn) => {
ensure_spanned!(
!has_attribute(&item_fn.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
let is_pymodule_init =
find_and_remove_attribute(&mut item_fn.attrs, "pymodule_init");
let ident = &item_fn.sig.ident;
if is_pymodule_init {
ensure_spanned!(
!has_attribute(&item_fn.attrs, "pyfunction"),
item_fn.span() => "`#[pyfunction]` cannot be used alongside `#[pymodule_init]`"
);
ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified");
pymodule_init = Some(quote! { #ident(module)?; });
} else if has_attribute(&item_fn.attrs, "pyfunction")
|| has_attribute_with_namespace(
&item_fn.attrs,
Some(pyo3_path),
&["pyfunction"],
)
|| has_attribute_with_namespace(
&item_fn.attrs,
Some(pyo3_path),
&["prelude", "pyfunction"],
)
{
module_items.push(ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs));
}
}
Item::Struct(item_struct) => {
ensure_spanned!(
!has_attribute(&item_struct.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
if has_attribute(&item_struct.attrs, "pyclass")
|| has_attribute_with_namespace(
&item_struct.attrs,
Some(pyo3_path),
&["pyclass"],
)
|| has_attribute_with_namespace(
&item_struct.attrs,
Some(pyo3_path),
&["prelude", "pyclass"],
)
{
module_items.push(item_struct.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs));
if !has_pyo3_module_declared::<PyClassPyO3Option>(
&item_struct.attrs,
"pyclass",
|option| matches!(option, PyClassPyO3Option::Module(_)),
)? {
set_module_attribute(&mut item_struct.attrs, &full_name);
}
}
}
Item::Enum(item_enum) => {
ensure_spanned!(
!has_attribute(&item_enum.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
if has_attribute(&item_enum.attrs, "pyclass")
|| has_attribute_with_namespace(&item_enum.attrs, Some(pyo3_path), &["pyclass"])
|| has_attribute_with_namespace(
&item_enum.attrs,
Some(pyo3_path),
&["prelude", "pyclass"],
)
{
module_items.push(item_enum.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs));
if !has_pyo3_module_declared::<PyClassPyO3Option>(
&item_enum.attrs,
"pyclass",
|option| matches!(option, PyClassPyO3Option::Module(_)),
)? {
set_module_attribute(&mut item_enum.attrs, &full_name);
}
}
}
Item::Mod(item_mod) => {
ensure_spanned!(
!has_attribute(&item_mod.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
if has_attribute(&item_mod.attrs, "pymodule")
|| has_attribute_with_namespace(&item_mod.attrs, Some(pyo3_path), &["pymodule"])
|| has_attribute_with_namespace(
&item_mod.attrs,
Some(pyo3_path),
&["prelude", "pymodule"],
)
{
module_items.push(item_mod.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs));
if !has_pyo3_module_declared::<PyModulePyO3Option>(
&item_mod.attrs,
"pymodule",
|option| matches!(option, PyModulePyO3Option::Module(_)),
)? {
set_module_attribute(&mut item_mod.attrs, &full_name);
}
}
}
Item::ForeignMod(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Trait(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Const(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Static(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Macro(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::ExternCrate(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Impl(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::TraitAlias(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Type(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Union(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
_ => (),
}
}
let module_def = quote! {{
use #pyo3_path::impl_::pymodule as impl_;
const INITIALIZER: impl_::ModuleInitializer = impl_::ModuleInitializer(__pyo3_pymodule);
unsafe {
impl_::ModuleDef::new(
__PYO3_NAME,
#doc,
INITIALIZER
)
}
}};
let initialization = module_initialization(&name, ctx, module_def, options.submodule.is_some());
Ok(quote!(
#(#attrs)*
#vis #mod_token #ident {
#(#items)*
#initialization
fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
use #pyo3_path::impl_::pymodule::PyAddToModule;
#(
#(#module_items_cfg_attrs)*
#module_items::_PYO3_DEF.add_to_module(module)?;
)*
#pymodule_init
Ok(())
}
}
))
}
pub fn pymodule_function_impl(
function: &mut syn::ItemFn,
mut options: PyModuleOptions,
) -> Result<TokenStream> {
options.take_pyo3_options(&mut function.attrs)?;
process_functions_in_module(&options, function)?;
let ctx = &Ctx::new(&options.krate, None);
let stmts = std::mem::take(&mut function.block.stmts);
let Ctx { pyo3_path, .. } = ctx;
let ident = &function.sig.ident;
let name = options
.name
.map_or_else(|| ident.unraw(), |name| name.value.0);
let vis = &function.vis;
let doc = get_doc(&function.attrs, None, ctx);
let initialization = module_initialization(&name, ctx, quote! { MakeDef::make_def() }, false);
let mut module_args = Vec::new();
if function.sig.inputs.len() == 2 {
module_args.push(quote!(module.py()));
}
module_args
.push(quote!(::std::convert::Into::into(#pyo3_path::impl_::pymethods::BoundRef(module))));
let extractors = function
.sig
.inputs
.iter()
.filter_map(|param| {
if let syn::FnArg::Typed(pat_type) = param {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
let ident: &syn::Ident = &pat_ident.ident;
return Some([
parse_quote!{ let check_gil_refs = #pyo3_path::impl_::deprecations::GilRefs::new(); },
parse_quote! { let #ident = #pyo3_path::impl_::deprecations::inspect_type(#ident, &check_gil_refs); },
parse_quote_spanned! { pat_type.span() => check_gil_refs.function_arg(); },
]);
}
}
None
})
.flatten();
function.block.stmts = extractors.chain(stmts).collect();
function
.attrs
.push(parse_quote!(#[allow(clippy::used_underscore_binding)]));
Ok(quote! {
#[doc(hidden)]
#vis mod #ident {
#initialization
}
#[allow(unknown_lints, non_local_definitions)]
impl #ident::MakeDef {
const fn make_def() -> #pyo3_path::impl_::pymodule::ModuleDef {
fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
#ident(#(#module_args),*)
}
const INITIALIZER: #pyo3_path::impl_::pymodule::ModuleInitializer = #pyo3_path::impl_::pymodule::ModuleInitializer(__pyo3_pymodule);
unsafe {
#pyo3_path::impl_::pymodule::ModuleDef::new(
#ident::__PYO3_NAME,
#doc,
INITIALIZER
)
}
}
}
})
}
fn module_initialization(
name: &syn::Ident,
ctx: &Ctx,
module_def: TokenStream,
is_submodule: bool,
) -> TokenStream {
let Ctx { pyo3_path, .. } = ctx;
let pyinit_symbol = format!("PyInit_{}", name);
let name = name.to_string();
let pyo3_name = LitCStr::new(CString::new(name).unwrap(), Span::call_site(), ctx);
let mut result = quote! {
#[doc(hidden)]
pub const __PYO3_NAME: &'static ::std::ffi::CStr = #pyo3_name;
pub(super) struct MakeDef;
#[doc(hidden)]
pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = #module_def;
};
if !is_submodule {
result.extend(quote! {
#[doc(hidden)]
#[export_name = #pyinit_symbol]
pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
#pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py))
}
});
}
result
}
fn process_functions_in_module(options: &PyModuleOptions, func: &mut syn::ItemFn) -> Result<()> {
let ctx = &Ctx::new(&options.krate, None);
let Ctx { pyo3_path, .. } = ctx;
let mut stmts: Vec<syn::Stmt> = Vec::new();
#[cfg(feature = "gil-refs")]
let imports = quote!(use #pyo3_path::{PyNativeType, types::PyModuleMethods};);
#[cfg(not(feature = "gil-refs"))]
let imports = quote!(use #pyo3_path::types::PyModuleMethods;);
for mut stmt in func.block.stmts.drain(..) {
if let syn::Stmt::Item(Item::Fn(func)) = &mut stmt {
if let Some(pyfn_args) = get_pyfn_attr(&mut func.attrs)? {
let module_name = pyfn_args.modname;
let wrapped_function = impl_wrap_pyfunction(func, pyfn_args.options)?;
let name = &func.sig.ident;
let statements: Vec<syn::Stmt> = syn::parse_quote! {
#wrapped_function
{
#[allow(unknown_lints, unused_imports, redundant_imports)]
#imports
#module_name.as_borrowed().add_function(#pyo3_path::wrap_pyfunction!(#name, #module_name.as_borrowed())?)?;
}
};
stmts.extend(statements);
}
};
stmts.push(stmt);
}
func.block.stmts = stmts;
Ok(())
}
pub struct PyFnArgs {
modname: Path,
options: PyFunctionOptions,
}
impl Parse for PyFnArgs {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
let modname = input.parse().map_err(
|e| err_spanned!(e.span() => "expected module as first argument to #[pyfn()]"),
)?;
if input.is_empty() {
return Ok(Self {
modname,
options: Default::default(),
});
}
let _: Comma = input.parse()?;
Ok(Self {
modname,
options: input.parse()?,
})
}
}
fn get_pyfn_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<PyFnArgs>> {
let mut pyfn_args: Option<PyFnArgs> = None;
take_attributes(attrs, |attr| {
if attr.path().is_ident("pyfn") {
ensure_spanned!(
pyfn_args.is_none(),
attr.span() => "`#[pyfn] may only be specified once"
);
pyfn_args = Some(attr.parse_args()?);
Ok(true)
} else {
Ok(false)
}
})?;
if let Some(pyfn_args) = &mut pyfn_args {
pyfn_args
.options
.add_attributes(take_pyo3_options(attrs)?)?;
}
Ok(pyfn_args)
}
fn get_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
attrs
.iter()
.filter(|attr| attr.path().is_ident("cfg"))
.cloned()
.collect()
}
fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bool {
let mut found = false;
attrs.retain(|attr| {
if attr.path().is_ident(ident) {
found = true;
false
} else {
true
}
});
found
}
enum IdentOrStr<'a> {
Str(&'a str),
Ident(syn::Ident),
}
impl<'a> PartialEq<syn::Ident> for IdentOrStr<'a> {
fn eq(&self, other: &syn::Ident) -> bool {
match self {
IdentOrStr::Str(s) => other == s,
IdentOrStr::Ident(i) => other == i,
}
}
}
fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool {
has_attribute_with_namespace(attrs, None, &[ident])
}
fn has_attribute_with_namespace(
attrs: &[syn::Attribute],
crate_path: Option<&PyO3CratePath>,
idents: &[&str],
) -> bool {
let mut segments = vec![];
if let Some(c) = crate_path {
match c {
PyO3CratePath::Given(paths) => {
for p in &paths.segments {
segments.push(IdentOrStr::Ident(p.ident.clone()));
}
}
PyO3CratePath::Default => segments.push(IdentOrStr::Str("pyo3")),
}
};
for i in idents {
segments.push(IdentOrStr::Str(i));
}
attrs.iter().any(|attr| {
segments
.iter()
.eq(attr.path().segments.iter().map(|v| &v.ident))
})
}
fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) {
attrs.push(parse_quote!(#[pyo3(module = #module_name)]));
}
fn has_pyo3_module_declared<T: Parse>(
attrs: &[syn::Attribute],
root_attribute_name: &str,
is_module_option: impl Fn(&T) -> bool + Copy,
) -> Result<bool> {
for attr in attrs {
if (attr.path().is_ident("pyo3") || attr.path().is_ident(root_attribute_name))
&& matches!(attr.meta, Meta::List(_))
{
for option in &attr.parse_args_with(Punctuated::<T, Comma>::parse_terminated)? {
if is_module_option(option) {
return Ok(true);
}
}
}
}
Ok(false)
}
enum PyModulePyO3Option {
Submodule(SubmoduleAttribute),
Crate(CrateAttribute),
Name(NameAttribute),
Module(ModuleAttribute),
}
impl Parse for PyModulePyO3Option {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
input.parse().map(PyModulePyO3Option::Name)
} else if lookahead.peek(syn::Token![crate]) {
input.parse().map(PyModulePyO3Option::Crate)
} else if lookahead.peek(attributes::kw::module) {
input.parse().map(PyModulePyO3Option::Module)
} else if lookahead.peek(attributes::kw::submodule) {
input.parse().map(PyModulePyO3Option::Submodule)
} else {
Err(lookahead.error())
}
}
}