pin_init_internal/
zeroable.rs

1// SPDX-License-Identifier: GPL-2.0
2
3#[cfg(not(kernel))]
4use proc_macro2 as proc_macro;
5
6use crate::helpers::{parse_generics, Generics};
7use proc_macro::{TokenStream, TokenTree};
8
9pub(crate) fn derive(input: TokenStream) -> TokenStream {
10    let (
11        Generics {
12            impl_generics,
13            decl_generics: _,
14            ty_generics,
15        },
16        mut rest,
17    ) = parse_generics(input);
18    // This should be the body of the struct `{...}`.
19    let last = rest.pop();
20    // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
21    let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
22    // Are we inside of a generic where we want to add `Zeroable`?
23    let mut in_generic = !impl_generics.is_empty();
24    // Have we already inserted `Zeroable`?
25    let mut inserted = false;
26    // Level of `<>` nestings.
27    let mut nested = 0;
28    for tt in impl_generics {
29        match &tt {
30            // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
31            TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
32                if in_generic && !inserted {
33                    new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
34                }
35                in_generic = true;
36                inserted = false;
37                new_impl_generics.push(tt);
38            }
39            // If we find `'`, then we are entering a lifetime.
40            TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
41                in_generic = false;
42                new_impl_generics.push(tt);
43            }
44            TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
45                new_impl_generics.push(tt);
46                if in_generic {
47                    new_impl_generics.extend(quote! { ::pin_init::Zeroable + });
48                    inserted = true;
49                }
50            }
51            TokenTree::Punct(p) if p.as_char() == '<' => {
52                nested += 1;
53                new_impl_generics.push(tt);
54            }
55            TokenTree::Punct(p) if p.as_char() == '>' => {
56                assert!(nested > 0);
57                nested -= 1;
58                new_impl_generics.push(tt);
59            }
60            _ => new_impl_generics.push(tt),
61        }
62    }
63    assert_eq!(nested, 0);
64    if in_generic && !inserted {
65        new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
66    }
67    quote! {
68        ::pin_init::__derive_zeroable!(
69            parse_input:
70                @sig(#(#rest)*),
71                @impl_generics(#(#new_impl_generics)*),
72                @ty_generics(#(#ty_generics)*),
73                @body(#last),
74        );
75    }
76}