Skip to main content

pin_init_internal/
init.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use proc_macro2::{Span, TokenStream};
4use quote::{format_ident, quote, quote_spanned};
5use syn::{
6    braced,
7    parse::{End, Parse},
8    parse_quote,
9    punctuated::Punctuated,
10    spanned::Spanned,
11    token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12};
13
14use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15
16pub(crate) struct Initializer {
17    attrs: Vec<InitializerAttribute>,
18    this: Option<This>,
19    path: Path,
20    brace_token: token::Brace,
21    fields: Punctuated<InitializerField, Token![,]>,
22    rest: Option<(Token![..], Expr)>,
23    error: Option<(Token![?], Type)>,
24}
25
26struct This {
27    _and_token: Token![&],
28    ident: Ident,
29    _in_token: Token![in],
30}
31
32struct InitializerField {
33    attrs: Vec<Attribute>,
34    kind: InitializerKind,
35}
36
37enum InitializerKind {
38    Value {
39        ident: Ident,
40        value: Option<(Token![:], Expr)>,
41    },
42    Init {
43        ident: Ident,
44        _left_arrow_token: Token![<-],
45        value: Expr,
46    },
47    Code {
48        _underscore_token: Token![_],
49        _colon_token: Token![:],
50        block: Block,
51    },
52}
53
54impl InitializerKind {
55    fn ident(&self) -> Option<&Ident> {
56        match self {
57            Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
58            Self::Code { .. } => None,
59        }
60    }
61}
62
63enum InitializerAttribute {
64    DefaultError(DefaultErrorAttribute),
65    DisableInitializedFieldAccess,
66}
67
68struct DefaultErrorAttribute {
69    ty: Box<Type>,
70}
71
72pub(crate) fn expand(
73    Initializer {
74        attrs,
75        this,
76        path,
77        brace_token,
78        fields,
79        rest,
80        error,
81    }: Initializer,
82    default_error: Option<&'static str>,
83    pinned: bool,
84    dcx: &mut DiagCtxt,
85) -> Result<TokenStream, ErrorGuaranteed> {
86    let error = error.map_or_else(
87        || {
88            if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
89                if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
90                    Some(ty.clone())
91                } else {
92                    acc
93                }
94            }) {
95                default_error
96            } else if let Some(default_error) = default_error {
97                syn::parse_str(default_error).unwrap()
98            } else {
99                dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
100                parse_quote!(::core::convert::Infallible)
101            }
102        },
103        |(_, err)| Box::new(err),
104    );
105    let slot = format_ident!("slot");
106    let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
107        (
108            format_ident!("HasPinData"),
109            format_ident!("PinData"),
110            format_ident!("__pin_data"),
111            format_ident!("pin_init_from_closure"),
112        )
113    } else {
114        (
115            format_ident!("HasInitData"),
116            format_ident!("InitData"),
117            format_ident!("__init_data"),
118            format_ident!("init_from_closure"),
119        )
120    };
121    let init_kind = get_init_kind(rest, dcx);
122    let zeroable_check = match init_kind {
123        InitKind::Normal => quote!(),
124        InitKind::Zeroing => quote! {
125            // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
126            // Therefore we check if the struct implements `Zeroable` and then zero the memory.
127            // This allows us to also remove the check that all fields are present (since we
128            // already set the memory to zero and that is a valid bit pattern).
129            fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
130            where T: ::pin_init::Zeroable
131            {}
132            // Ensure that the struct is indeed `Zeroable`.
133            assert_zeroable(#slot);
134            // SAFETY: The type implements `Zeroable` by the check above.
135            unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
136        },
137    };
138    let this = match this {
139        None => quote!(),
140        Some(This { ident, .. }) => quote! {
141            // Create the `this` so it can be referenced by the user inside of the
142            // expressions creating the individual fields.
143            let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
144        },
145    };
146    // `mixed_site` ensures that the data is not accessible to the user-controlled code.
147    let data = Ident::new("__data", Span::mixed_site());
148    let init_fields = init_fields(
149        &fields,
150        pinned,
151        !attrs
152            .iter()
153            .any(|attr| matches!(attr, InitializerAttribute::DisableInitializedFieldAccess)),
154        &data,
155        &slot,
156    );
157    let field_check = make_field_check(&fields, init_kind, &path);
158    Ok(quote! {{
159        // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return
160        // type and shadow it later when we insert the arbitrary user code. That way there will be
161        // no possibility of returning without `unsafe`.
162        struct __InitOk;
163
164        // Get the data about fields from the supplied type.
165        // SAFETY: TODO
166        let #data = unsafe {
167            use ::pin_init::__internal::#has_data_trait;
168            // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
169            // generics (which need to be present with that syntax).
170            #path::#get_data()
171        };
172        // Ensure that `#data` really is of type `#data` and help with type inference:
173        let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>(
174            #data,
175            move |slot| {
176                {
177                    // Shadow the structure so it cannot be used to return early.
178                    struct __InitOk;
179                    #zeroable_check
180                    #this
181                    #init_fields
182                    #field_check
183                }
184                Ok(__InitOk)
185            }
186        );
187        let init = move |slot| -> ::core::result::Result<(), #error> {
188            init(slot).map(|__InitOk| ())
189        };
190        // SAFETY: TODO
191        let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
192        init
193    }})
194}
195
196enum InitKind {
197    Normal,
198    Zeroing,
199}
200
201fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
202    let Some((dotdot, expr)) = rest else {
203        return InitKind::Normal;
204    };
205    match &expr {
206        Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
207            Expr::Path(ExprPath {
208                attrs,
209                qself: None,
210                path:
211                    Path {
212                        leading_colon: None,
213                        segments,
214                    },
215            }) if attrs.is_empty()
216                && segments.len() == 2
217                && segments[0].ident == "Zeroable"
218                && segments[0].arguments.is_none()
219                && segments[1].ident == "init_zeroed"
220                && segments[1].arguments.is_none() =>
221            {
222                return InitKind::Zeroing;
223            }
224            _ => {}
225        },
226        _ => {}
227    }
228    dcx.error(
229        dotdot.span().join(expr.span()).unwrap_or(expr.span()),
230        "expected nothing or `..Zeroable::init_zeroed()`.",
231    );
232    InitKind::Normal
233}
234
235/// Generate the code that initializes the fields of the struct using the initializers in `field`.
236fn init_fields(
237    fields: &Punctuated<InitializerField, Token![,]>,
238    pinned: bool,
239    generate_initialized_accessors: bool,
240    data: &Ident,
241    slot: &Ident,
242) -> TokenStream {
243    let mut guards = vec![];
244    let mut guard_attrs = vec![];
245    let mut res = TokenStream::new();
246    for InitializerField { attrs, kind } in fields {
247        let cfgs = {
248            let mut cfgs = attrs.clone();
249            cfgs.retain(|attr| attr.path().is_ident("cfg"));
250            cfgs
251        };
252        let init = match kind {
253            InitializerKind::Value { ident, value } => {
254                let mut value_ident = ident.clone();
255                let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
256                    // Setting the span of `value_ident` to `value`'s span improves error messages
257                    // when the type of `value` is wrong.
258                    value_ident.set_span(value.span());
259                    quote!(let #value_ident = #value;)
260                });
261                // Again span for better diagnostics
262                let write = quote_spanned!(ident.span()=> ::core::ptr::write);
263                let accessor = if pinned {
264                    let project_ident = format_ident!("__project_{ident}");
265                    quote! {
266                        // SAFETY: TODO
267                        unsafe { #data.#project_ident(&mut (*#slot).#ident) }
268                    }
269                } else {
270                    quote! {
271                        // SAFETY: TODO
272                        unsafe { &mut (*#slot).#ident }
273                    }
274                };
275                let accessor = generate_initialized_accessors.then(|| {
276                    quote! {
277                        #(#cfgs)*
278                        #[allow(unused_variables)]
279                        let #ident = #accessor;
280                    }
281                });
282                quote! {
283                    #(#attrs)*
284                    {
285                        #value_prep
286                        // SAFETY: TODO
287                        unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
288                    }
289                    #accessor
290                }
291            }
292            InitializerKind::Init { ident, value, .. } => {
293                // Again span for better diagnostics
294                let init = format_ident!("init", span = value.span());
295                let (value_init, accessor) = if pinned {
296                    let project_ident = format_ident!("__project_{ident}");
297                    (
298                        quote! {
299                            // SAFETY:
300                            // - `slot` is valid, because we are inside of an initializer closure, we
301                            //   return when an error/panic occurs.
302                            // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
303                            //   for `#ident`.
304                            unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
305                        },
306                        quote! {
307                            // SAFETY: TODO
308                            unsafe { #data.#project_ident(&mut (*#slot).#ident) }
309                        },
310                    )
311                } else {
312                    (
313                        quote! {
314                            // SAFETY: `slot` is valid, because we are inside of an initializer
315                            // closure, we return when an error/panic occurs.
316                            unsafe {
317                                ::pin_init::Init::__init(
318                                    #init,
319                                    ::core::ptr::addr_of_mut!((*#slot).#ident),
320                                )?
321                            };
322                        },
323                        quote! {
324                            // SAFETY: TODO
325                            unsafe { &mut (*#slot).#ident }
326                        },
327                    )
328                };
329                let accessor = generate_initialized_accessors.then(|| {
330                    quote! {
331                        #(#cfgs)*
332                        #[allow(unused_variables)]
333                        let #ident = #accessor;
334                    }
335                });
336                quote! {
337                    #(#attrs)*
338                    {
339                        let #init = #value;
340                        #value_init
341                    }
342                    #accessor
343                }
344            }
345            InitializerKind::Code { block: value, .. } => quote! {
346                #(#attrs)*
347                #[allow(unused_braces)]
348                #value
349            },
350        };
351        res.extend(init);
352        if let Some(ident) = kind.ident() {
353            // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
354            let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
355            res.extend(quote! {
356                #(#cfgs)*
357                // Create the drop guard:
358                //
359                // We rely on macro hygiene to make it impossible for users to access this local
360                // variable.
361                // SAFETY: We forget the guard later when initialization has succeeded.
362                let #guard = unsafe {
363                    ::pin_init::__internal::DropGuard::new(
364                        ::core::ptr::addr_of_mut!((*slot).#ident)
365                    )
366                };
367            });
368            guards.push(guard);
369            guard_attrs.push(cfgs);
370        }
371    }
372    quote! {
373        #res
374        // If execution reaches this point, all fields have been initialized. Therefore we can now
375        // dismiss the guards by forgetting them.
376        #(
377            #(#guard_attrs)*
378            ::core::mem::forget(#guards);
379        )*
380    }
381}
382
383/// Generate the check for ensuring that every field has been initialized.
384fn make_field_check(
385    fields: &Punctuated<InitializerField, Token![,]>,
386    init_kind: InitKind,
387    path: &Path,
388) -> TokenStream {
389    let field_attrs = fields
390        .iter()
391        .filter_map(|f| f.kind.ident().map(|_| &f.attrs));
392    let field_name = fields.iter().filter_map(|f| f.kind.ident());
393    match init_kind {
394        InitKind::Normal => quote! {
395            // We use unreachable code to ensure that all fields have been mentioned exactly once,
396            // this struct initializer will still be type-checked and complain with a very natural
397            // error message if a field is forgotten/mentioned more than once.
398            #[allow(unreachable_code, clippy::diverging_sub_expression)]
399            // SAFETY: this code is never executed.
400            let _ = || unsafe {
401                ::core::ptr::write(slot, #path {
402                    #(
403                        #(#field_attrs)*
404                        #field_name: ::core::panic!(),
405                    )*
406                })
407            };
408        },
409        InitKind::Zeroing => quote! {
410            // We use unreachable code to ensure that all fields have been mentioned at most once.
411            // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
412            // be zeroed. This struct initializer will still be type-checked and complain with a
413            // very natural error message if a field is mentioned more than once, or doesn't exist.
414            #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
415            // SAFETY: this code is never executed.
416            let _ = || unsafe {
417                ::core::ptr::write(slot, #path {
418                    #(
419                        #(#field_attrs)*
420                        #field_name: ::core::panic!(),
421                    )*
422                    ..::core::mem::zeroed()
423                })
424            };
425        },
426    }
427}
428
429impl Parse for Initializer {
430    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
431        let attrs = input.call(Attribute::parse_outer)?;
432        let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
433        let path = input.parse()?;
434        let content;
435        let brace_token = braced!(content in input);
436        let mut fields = Punctuated::new();
437        loop {
438            let lh = content.lookahead1();
439            if lh.peek(End) || lh.peek(Token![..]) {
440                break;
441            } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
442                fields.push_value(content.parse()?);
443                let lh = content.lookahead1();
444                if lh.peek(End) {
445                    break;
446                } else if lh.peek(Token![,]) {
447                    fields.push_punct(content.parse()?);
448                } else {
449                    return Err(lh.error());
450                }
451            } else {
452                return Err(lh.error());
453            }
454        }
455        let rest = content
456            .peek(Token![..])
457            .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
458            .transpose()?;
459        let error = input
460            .peek(Token![?])
461            .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
462            .transpose()?;
463        let attrs = attrs
464            .into_iter()
465            .map(|a| {
466                if a.path().is_ident("default_error") {
467                    a.parse_args::<DefaultErrorAttribute>()
468                        .map(InitializerAttribute::DefaultError)
469                } else if a.path().is_ident("disable_initialized_field_access") {
470                    a.meta
471                        .require_path_only()
472                        .map(|_| InitializerAttribute::DisableInitializedFieldAccess)
473                } else {
474                    Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
475                }
476            })
477            .collect::<Result<Vec<_>, _>>()?;
478        Ok(Self {
479            attrs,
480            this,
481            path,
482            brace_token,
483            fields,
484            rest,
485            error,
486        })
487    }
488}
489
490impl Parse for DefaultErrorAttribute {
491    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
492        Ok(Self { ty: input.parse()? })
493    }
494}
495
496impl Parse for This {
497    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
498        Ok(Self {
499            _and_token: input.parse()?,
500            ident: input.parse()?,
501            _in_token: input.parse()?,
502        })
503    }
504}
505
506impl Parse for InitializerField {
507    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
508        let attrs = input.call(Attribute::parse_outer)?;
509        Ok(Self {
510            attrs,
511            kind: input.parse()?,
512        })
513    }
514}
515
516impl Parse for InitializerKind {
517    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
518        let lh = input.lookahead1();
519        if lh.peek(Token![_]) {
520            Ok(Self::Code {
521                _underscore_token: input.parse()?,
522                _colon_token: input.parse()?,
523                block: input.parse()?,
524            })
525        } else if lh.peek(Ident) {
526            let ident = input.parse()?;
527            let lh = input.lookahead1();
528            if lh.peek(Token![<-]) {
529                Ok(Self::Init {
530                    ident,
531                    _left_arrow_token: input.parse()?,
532                    value: input.parse()?,
533                })
534            } else if lh.peek(Token![:]) {
535                Ok(Self::Value {
536                    ident,
537                    value: Some((input.parse()?, input.parse()?)),
538                })
539            } else if lh.peek(Token![,]) || lh.peek(End) {
540                Ok(Self::Value { ident, value: None })
541            } else {
542                Err(lh.error())
543            }
544        } else {
545            Err(lh.error())
546        }
547    }
548}