linearize_derive/
lib.rs

1use {
2    proc_macro2::{Ident, Span, TokenStream, TokenTree},
3    quote::{quote, quote_spanned},
4    syn::{
5        parse::{Parse, ParseStream},
6        parse_macro_input, parse_quote,
7        spanned::Spanned,
8        Attribute, Error, Generics, Item, ItemEnum, ItemStruct, LitInt, Path, Token, Type,
9    },
10};
11
12/// A proc macro to derive the `Linearize` trait.
13///
14/// This macro can be used to derive the `Linearize` trait for structs and enums.
15///
16/// The structure of these types can be arbitrary except that all contained fields must
17/// also implement the `Linearize` trait.
18///
19/// # Using different crate names
20///
21/// If you use the linearize crate under a name other than `linearize`, you can use the
22/// `crate` attribute to have the proc macro reference the correct crate. For example,
23/// if you import the linearize crate like this:
24///
25/// ```toml
26/// linearize-0_1 = { package = "linearize", version = "0.1" }
27/// ```
28///
29/// Then you can use this attribute as follows:
30///
31/// ```rust,ignore
32/// #[derive(Linearize)]
33/// #[linearize(crate = linearize_0_1)]
34/// struct S;
35/// ```
36///
37/// <div class="warning">
38///
39/// If you import the linearize crate under a name other than `linearize` or use the crate
40/// attribute, you must ensure that these two names are in sync. Otherwise the macro
41/// might not uphold the invariants of the `Linearize` trait.
42///
43/// </div>
44///
45/// # Implementing const functions
46///
47/// If you want to use the forms of the `static_map` and `static_copy_map` macros that
48/// work in constants and statics, you must enable the `const` attribute:
49///
50/// ```rust,ignore
51/// #[derive(Linearize)]
52/// #[linearize(const)]
53/// struct S;
54/// ```
55///
56/// In this case, your type must only contain fields that also enabled this attribute. In
57/// particular, you cannot use any of the standard types `u8`, `bool`, etc.
58///
59/// # Performance
60///
61/// If the type is a C-style enum with default discriminants, the derived functions will
62/// be compiled to a jump table in debug mode and will be completely optimized away in
63/// release mode.
64///
65/// If the type contains fields, the generated code will still be reasonably efficient.
66///
67/// # Limitations
68///
69/// While this macro fully supports types with generics, the generated output will not
70/// compile. This is due to limitations of the rust type system. If a future version of
71/// the rust compiler lifts these limitations, this macro will automatically start working
72/// for generic types.
73#[proc_macro_derive(Linearize, attributes(linearize))]
74pub fn derive_linearize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
75    let mut input: Input = parse_macro_input!(input as Input);
76    let crate_name = &input.attributes.crate_name;
77    let FullyLinearized {
78        linearize,
79        delinearize,
80        const_linearize,
81        const_delinearize,
82        const_names,
83        consts,
84        max_len,
85    } = input.build_linearize();
86    let where_clause = input.generics.make_where_clause();
87    for ty in &input.critical_types {
88        where_clause
89            .predicates
90            .push(parse_quote!(#ty: #crate_name::Linearize));
91    }
92    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
93    let ident = input.ident;
94    let mut const_impl = quote! {};
95    if input.attributes.enable_const {
96        const_impl = quote! {
97            #[doc(hidden)]
98            impl #impl_generics #ident #type_generics #where_clause {
99                #[inline]
100                pub const fn __linearize_d66aa8fa_6974_4651_b2b7_75291a9e7105(&self) -> usize {
101                    #const_linearize
102                }
103
104                #[inline]
105                pub const unsafe fn __from_linear_unchecked_fb2f0b31_5b5a_48b4_9264_39d0bdf94f1d(linear: usize) -> Self {
106                    #const_delinearize
107                }
108            }
109        };
110    }
111    let res = quote_spanned! { input.span =>
112        #[allow(clippy::modulo_one, clippy::manual_range_contains)]
113        const _: () = {
114            trait __C {
115                #(const #const_names: usize;)*
116            }
117
118            impl #impl_generics __C for #ident #type_generics #where_clause {
119                #(#consts)*
120            }
121
122            // SAFETY:
123            //
124            // Storage and CopyStorage obviously are the required type.
125            //
126            // The bodies if `linearize` and `from_linear_unchecked` are generated as follows:
127            //
128            // First, consider a struct s = { a1: T1, ..., an: Tn }. The calculated LENGTH
129            // is the product of the lengths of the Ti. We write |T| for the LENGTH of T.
130            // Write Bi = |T{i+1}| * ... * |Tn|, the product of the LENGTHs of the later types.
131            // Write linear(v) for the linearization of v. Then we define
132            // linear(s) = \sum_{i} linear(ai) * Bi.
133            // It is easy to see that linear(s) / Bi % Ti = linear(ai).
134            // Therefore we have created a bijection between the struct and [0, B0).
135            //
136            // Now consider an enum e = { V1, ..., Vn } where each variant can have fields.
137            // Each Vi can be treated like a struct and we can define a bijection between
138            // the enum and [0, |V1| + ... + |Vn|) by mapping V1 to [0, |V1|), V2 to
139            // [|V1|, |V1| + |V2|), and so on.
140            #[automatically_derived]
141            unsafe impl #impl_generics
142            #crate_name::Linearize for #ident #type_generics
143            #where_clause
144            {
145                type Storage<__T> = [__T; <Self as #crate_name::Linearize>::LENGTH];
146
147                type CopyStorage<__T> = [__T; <Self as #crate_name::Linearize>::LENGTH] where __T: Copy;
148
149                const LENGTH: usize = <Self as __C>::#max_len;
150
151                #[inline]
152                fn linearize(&self) -> usize {
153                    #linearize
154                }
155
156                #[inline]
157                unsafe fn from_linear_unchecked(linear: usize) -> Self {
158                    #delinearize
159                }
160            }
161
162            #const_impl
163        };
164    };
165    res.into()
166}
167
168struct Input {
169    span: Span,
170    ident: Ident,
171    generics: Generics,
172    critical_types: Vec<Type>,
173    kind: Kind,
174    attributes: InputAttributes,
175}
176
177struct InputAttributes {
178    crate_name: Path,
179    enable_const: bool,
180}
181
182#[derive(Default)]
183struct InputAttributesOpt {
184    crate_name: Option<Path>,
185    enable_const: bool,
186}
187
188enum Kind {
189    Struct(StructInput),
190    Enum(EnumInput),
191}
192
193struct StructInput {
194    fields: Vec<StructField>,
195}
196
197struct EnumInput {
198    variants: Vec<EnumVariant>,
199}
200
201struct EnumVariant {
202    ident: Ident,
203    fields: Vec<StructField>,
204}
205
206struct PartialLinearized {
207    linearize: TokenStream,
208    delinearize: TokenStream,
209    const_linearize: TokenStream,
210    const_delinearize: TokenStream,
211    max_len: TokenStream,
212}
213
214struct FullyLinearized {
215    linearize: TokenStream,
216    delinearize: TokenStream,
217    const_linearize: TokenStream,
218    const_delinearize: TokenStream,
219    const_names: Vec<Ident>,
220    consts: Vec<TokenStream>,
221    max_len: Ident,
222}
223
224struct StructField {
225    original_name: Option<Ident>,
226    generated_name: Option<Ident>,
227    ty: Type,
228}
229
230fn build_linearize_struct(
231    input: &Input,
232    fields: &[StructField],
233    base: &Ident,
234) -> PartialLinearized {
235    let crate_name = &input.attributes.crate_name;
236    let mut linearize_parts = vec![];
237    let mut delinearize_parts = vec![];
238    let mut const_linearize_parts = vec![];
239    let mut const_delinearize_parts = vec![];
240    let mut max_len = quote!(1usize);
241    for (idx, field) in fields.iter().enumerate().rev() {
242        let idx = LitInt::new(&idx.to_string(), Span::call_site());
243        let ref_name = match &field.generated_name {
244            Some(i) => quote! {#i},
245            None => match &field.original_name {
246                Some(i) => quote! { &self.#i },
247                None => quote! { &self.#idx },
248            },
249        };
250        let mut_name = match &field.original_name {
251            Some(i) => quote! { #i },
252            None => quote! { #idx },
253        };
254        let ty = &field.ty;
255        linearize_parts.push(quote! {
256            res = res.wrapping_add(<#ty as #crate_name::Linearize>::linearize(#ref_name).wrapping_mul(const { #max_len }));
257        });
258        delinearize_parts.push(quote! {
259            #mut_name: {
260                let idx = (linear / const { #max_len }) % <#ty as #crate_name::Linearize>::LENGTH;
261                <#ty as #crate_name::Linearize>::from_linear_unchecked(idx)
262            },
263        });
264        if input.attributes.enable_const {
265            const_linearize_parts.push(quote! {
266                res = res.wrapping_add(<#ty>::__linearize_d66aa8fa_6974_4651_b2b7_75291a9e7105(#ref_name).wrapping_mul(const { #max_len }));
267            });
268            const_delinearize_parts.push(quote! {
269                #mut_name: {
270                    let idx = (linear / const { #max_len }) % <#ty as #crate_name::Linearize>::LENGTH;
271                    <#ty>::__from_linear_unchecked_fb2f0b31_5b5a_48b4_9264_39d0bdf94f1d(idx)
272                },
273            });
274        }
275        max_len = quote! {
276            #max_len * <#ty as #crate_name::Linearize>::LENGTH
277        };
278    }
279    delinearize_parts.reverse();
280    const_delinearize_parts.reverse();
281    let make_linearize = |parts: &[TokenStream]| {
282        if fields.is_empty() {
283            quote! { <Self as __C>::#base }
284        } else {
285            quote! {
286                let mut res = <Self as __C>::#base;
287                #(#parts)*
288                res
289            }
290        }
291    };
292    let make_delinearize = |parts: &[TokenStream]| {
293        quote! {
294            { #(#parts)* }
295        }
296    };
297    PartialLinearized {
298        linearize: make_linearize(&linearize_parts),
299        delinearize: make_delinearize(&delinearize_parts),
300        const_linearize: make_linearize(&const_linearize_parts),
301        const_delinearize: make_delinearize(&const_delinearize_parts),
302        max_len,
303    }
304}
305
306impl StructInput {
307    fn build_linearize(&self, input: &Input) -> FullyLinearized {
308        let b0 = Ident::new("B0", Span::mixed_site());
309        let b1 = Ident::new("B1", Span::mixed_site());
310        let PartialLinearized {
311            linearize,
312            delinearize,
313            const_linearize,
314            const_delinearize,
315            max_len,
316        } = build_linearize_struct(input, &self.fields, &b0);
317        let mut consts = vec![];
318        consts.push(quote! { const B0: usize = 0; });
319        consts.push(quote! { const B1: usize = #max_len; });
320        FullyLinearized {
321            linearize,
322            delinearize: quote! { Self #delinearize },
323            const_linearize,
324            const_delinearize: quote! { Self #const_delinearize },
325            max_len: b1.clone(),
326            consts,
327            const_names: vec![b0, b1],
328        }
329    }
330}
331
332impl EnumInput {
333    fn build_linearize(&self, input: &Input) -> FullyLinearized {
334        let mut linearize_cases = vec![];
335        let mut delinearize_cases = vec![];
336        let mut const_linearize_cases = vec![];
337        let mut const_delinearize_cases = vec![];
338        let mut consts = vec![];
339        consts.push(quote! { const B0: usize = 0; });
340        let mut prev_const_name = Ident::new("B0", Span::mixed_site());
341        let mut const_names = vec![prev_const_name.clone()];
342        for (variant_idx, variant) in self.variants.iter().enumerate() {
343            let mut exposition = vec![];
344            for (idx, field) in variant.fields.iter().enumerate() {
345                let idx = LitInt::new(&idx.to_string(), Span::call_site());
346                let generated_name = field.generated_name.as_ref().unwrap();
347                match &field.original_name {
348                    None => exposition.push(quote! { #idx: #generated_name }),
349                    Some(i) => exposition.push(quote! { #i: #generated_name }),
350                }
351            }
352            let exposition = quote! {
353                { #(#exposition),* }
354            };
355            let PartialLinearized {
356                linearize,
357                delinearize,
358                const_linearize,
359                const_delinearize,
360                max_len,
361            } = build_linearize_struct(input, &variant.fields, &prev_const_name);
362            let next_base = quote! { <Self as __C>::#prev_const_name + #max_len };
363            let ident = &variant.ident;
364            linearize_cases.push(quote! {
365                Self::#ident #exposition => {
366                    #linearize
367                }
368            });
369            if input.attributes.enable_const {
370                const_linearize_cases.push(quote! {
371                    Self::#ident #exposition => {
372                        #const_linearize
373                    }
374                });
375            }
376            let const_name = Ident::new(&format!("B{}", variant_idx + 1), Span::mixed_site());
377            consts.push(quote! { const #const_name: usize = #next_base; });
378            if variant.fields.is_empty() {
379                let guard = if input.generics.params.is_empty() {
380                    quote! {
381                        <Self as __C>::#prev_const_name
382                    }
383                } else {
384                    quote! {
385                        n if n == <Self as __C>::#prev_const_name
386                    }
387                };
388                delinearize_cases.push(quote! {
389                    #guard => Self::#ident { },
390                });
391                if input.attributes.enable_const {
392                    const_delinearize_cases.push(quote! {
393                        #guard => Self::#ident { },
394                    });
395                }
396            } else {
397                let make_case = |delinearize: &TokenStream| {
398                    quote! {
399                        #[allow(clippy::impossible_comparisons)]
400                        n if n >= <Self as __C>::#prev_const_name && n < <Self as __C>::#const_name => {
401                            let linear = linear.wrapping_sub(<Self as __C>::#prev_const_name);
402                            Self::#ident #delinearize
403                        },
404                    }
405                };
406                delinearize_cases.push(make_case(&delinearize));
407                if input.attributes.enable_const {
408                    const_delinearize_cases.push(make_case(&const_delinearize));
409                }
410            }
411            prev_const_name = const_name;
412            const_names.push(prev_const_name.clone());
413        }
414        let make_linearize = |cases: &[TokenStream]| {
415            if self.variants.is_empty() {
416                quote! {
417                    #[cold]
418                    const fn unreachable() -> ! {
419                        unsafe { core::hint::unreachable_unchecked() }
420                    }
421                    unreachable()
422                }
423            } else {
424                quote! {
425                    match self {
426                        #(#cases)*
427                    }
428                }
429            }
430        };
431        let make_delinearize = |cases: &[TokenStream]| {
432            quote! {
433                match linear {
434                    #(#cases)*
435                    _ => {
436                        #[cold]
437                        const fn unreachable() -> ! {
438                            unsafe { core::hint::unreachable_unchecked() }
439                        }
440                        unreachable()
441                    },
442                }
443            }
444        };
445        FullyLinearized {
446            linearize: make_linearize(&linearize_cases),
447            const_linearize: make_linearize(&const_linearize_cases),
448            delinearize: make_delinearize(&delinearize_cases),
449            const_delinearize: make_delinearize(&const_delinearize_cases),
450            max_len: prev_const_name,
451            const_names,
452            consts,
453        }
454    }
455}
456
457impl Input {
458    fn parse_enum(input: ItemEnum) -> syn::Result<Self> {
459        let span = input.span();
460        let mut critical_types = Vec::new();
461        let mut variants = vec![];
462        let mut i = 0;
463        for variant in input.variants {
464            let mut fields = vec![];
465            for field in variant.fields {
466                critical_types.push(field.ty.clone());
467                let name = Ident::new(&format!("f{i}"), Span::mixed_site());
468                i += 1;
469                fields.push(StructField {
470                    original_name: field.ident,
471                    generated_name: Some(name),
472                    ty: field.ty,
473                })
474            }
475            variants.push(EnumVariant {
476                ident: variant.ident,
477                fields,
478            });
479        }
480        Ok(Self {
481            span,
482            ident: input.ident,
483            generics: input.generics,
484            critical_types,
485            kind: Kind::Enum(EnumInput { variants }),
486            attributes: parse_attributes(&input.attrs)?,
487        })
488    }
489
490    fn parse_struct(input: ItemStruct) -> syn::Result<Self> {
491        let span = input.span();
492        let mut critical_types = Vec::new();
493        let mut fields = vec![];
494        for field in input.fields {
495            critical_types.push(field.ty.clone());
496            fields.push(StructField {
497                original_name: field.ident,
498                generated_name: None,
499                ty: field.ty,
500            });
501        }
502        Ok(Self {
503            span,
504            ident: input.ident,
505            generics: input.generics,
506            critical_types,
507            kind: Kind::Struct(StructInput { fields }),
508            attributes: parse_attributes(&input.attrs)?,
509        })
510    }
511
512    fn build_linearize(&self) -> FullyLinearized {
513        match &self.kind {
514            Kind::Struct(s) => s.build_linearize(self),
515            Kind::Enum(e) => e.build_linearize(self),
516        }
517    }
518}
519
520fn parse_attributes(attrs: &[Attribute]) -> syn::Result<InputAttributes> {
521    let mut res = InputAttributesOpt::default();
522    for attr in attrs {
523        if !attr.meta.path().is_ident("linearize") {
524            continue;
525        }
526        let new: InputAttributesOpt = attr.meta.require_list()?.parse_args()?;
527        res.enable_const |= new.enable_const;
528        macro_rules! opt {
529            ($name:ident) => {
530                if new.$name.is_some() {
531                    res.$name = new.$name;
532                }
533            };
534        }
535        opt!(crate_name);
536    }
537    Ok(InputAttributes {
538        crate_name: res.crate_name.unwrap_or_else(|| parse_quote!(::linearize)),
539        enable_const: res.enable_const,
540    })
541}
542
543impl Parse for InputAttributesOpt {
544    fn parse(input: ParseStream) -> syn::Result<Self> {
545        let mut res = Self::default();
546        while !input.is_empty() {
547            let key: TokenTree = input.parse()?;
548            match key.to_string().as_str() {
549                "crate" => {
550                    let _: Token![=] = input.parse()?;
551                    let path: Path = input.parse()?;
552                    res.crate_name = Some(path);
553                }
554                "const" => {
555                    res.enable_const = true;
556                }
557                _ => {
558                    return Err(Error::new(
559                        key.span(),
560                        format!("Unknown attribute: {}", key),
561                    ))
562                }
563            }
564            if !input.is_empty() {
565                let _: Token![,] = input.parse()?;
566            }
567        }
568        Ok(res)
569    }
570}
571
572impl Parse for Input {
573    fn parse(input: ParseStream) -> syn::Result<Self> {
574        let item: Item = input.parse()?;
575        match item {
576            Item::Enum(e) => Self::parse_enum(e),
577            Item::Struct(s) => Self::parse_struct(s),
578            _ => Err(Error::new(item.span(), "expected enum or struct")),
579        }
580    }
581}