Skip to main content

pin_init_internal/
pin_data.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{
6    parse::{End, Nothing, Parse},
7    parse_quote, parse_quote_spanned,
8    spanned::Spanned,
9    visit_mut::VisitMut,
10    Field, Generics, Ident, Item, PathSegment, Type, TypePath, Visibility, WhereClause,
11};
12
13use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
14
15pub(crate) mod kw {
16    syn::custom_keyword!(PinnedDrop);
17}
18
19pub(crate) enum Args {
20    Nothing(Nothing),
21    #[allow(dead_code)]
22    PinnedDrop(kw::PinnedDrop),
23}
24
25impl Parse for Args {
26    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
27        let lh = input.lookahead1();
28        if lh.peek(End) {
29            input.parse().map(Self::Nothing)
30        } else if lh.peek(kw::PinnedDrop) {
31            input.parse().map(Self::PinnedDrop)
32        } else {
33            Err(lh.error())
34        }
35    }
36}
37
38struct FieldInfo<'a> {
39    field: &'a Field,
40    pinned: bool,
41}
42
43pub(crate) fn pin_data(
44    args: Args,
45    input: Item,
46    dcx: &mut DiagCtxt,
47) -> Result<TokenStream, ErrorGuaranteed> {
48    let mut struct_ = match input {
49        Item::Struct(struct_) => struct_,
50        Item::Enum(enum_) => {
51            return Err(dcx.error(
52                enum_.enum_token,
53                "`#[pin_data]` only supports structs for now",
54            ));
55        }
56        Item::Union(union) => {
57            return Err(dcx.error(
58                union.union_token,
59                "`#[pin_data]` only supports structs for now",
60            ));
61        }
62        rest => {
63            return Err(dcx.error(
64                rest,
65                "`#[pin_data]` can only be applied to struct, enum and union definitions",
66            ));
67        }
68    };
69
70    // The generics might contain the `Self` type. Since this macro will define a new type with the
71    // same generics and bounds, this poses a problem: `Self` will refer to the new type as opposed
72    // to this struct definition. Therefore we have to replace `Self` with the concrete name.
73    let mut replacer = {
74        let name = &struct_.ident;
75        let (_, ty_generics, _) = struct_.generics.split_for_impl();
76        SelfReplacer(parse_quote!(#name #ty_generics))
77    };
78    replacer.visit_generics_mut(&mut struct_.generics);
79    replacer.visit_fields_mut(&mut struct_.fields);
80
81    let fields: Vec<FieldInfo<'_>> = struct_
82        .fields
83        .iter_mut()
84        .map(|field| {
85            let len = field.attrs.len();
86            field.attrs.retain(|a| !a.path().is_ident("pin"));
87            let pinned = len != field.attrs.len();
88
89            FieldInfo {
90                field: &*field,
91                pinned,
92            }
93        })
94        .collect();
95
96    for field in &fields {
97        let ident = field.field.ident.as_ref().unwrap();
98
99        if !field.pinned && is_phantom_pinned(&field.field.ty) {
100            dcx.warn(
101                field.field,
102                format!(
103                    "The field `{ident}` of type `PhantomPinned` only has an effect \
104                    if it has the `#[pin]` attribute",
105                ),
106            );
107        }
108    }
109
110    let unpin_impl = generate_unpin_impl(&struct_.ident, &struct_.generics, &fields);
111    let drop_impl = generate_drop_impl(&struct_.ident, &struct_.generics, args);
112    let projections =
113        generate_projections(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
114    let the_pin_data =
115        generate_the_pin_data(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
116
117    Ok(quote! {
118        #struct_
119        #projections
120        // We put the rest into this const item, because it then will not be accessible to anything
121        // outside.
122        const _: () = {
123            #the_pin_data
124            #unpin_impl
125            #drop_impl
126        };
127    })
128}
129
130fn is_phantom_pinned(ty: &Type) -> bool {
131    match ty {
132        Type::Path(TypePath { qself: None, path }) => {
133            // Cannot possibly refer to `PhantomPinned` (except alias, but that's on the user).
134            if path.segments.len() > 3 {
135                return false;
136            }
137            // If there is a `::`, then the path needs to be `::core::marker::PhantomPinned` or
138            // `::std::marker::PhantomPinned`.
139            if path.leading_colon.is_some() && path.segments.len() != 3 {
140                return false;
141            }
142            let expected: Vec<&[&str]> = vec![&["PhantomPinned"], &["marker"], &["core", "std"]];
143            for (actual, expected) in path.segments.iter().rev().zip(expected) {
144                if !actual.arguments.is_empty() || expected.iter().all(|e| actual.ident != e) {
145                    return false;
146                }
147            }
148            true
149        }
150        _ => false,
151    }
152}
153
154fn generate_unpin_impl(
155    ident: &Ident,
156    generics: &Generics,
157    fields: &[FieldInfo<'_>],
158) -> TokenStream {
159    let (_, ty_generics, _) = generics.split_for_impl();
160    let mut generics_with_pin_lt = generics.clone();
161    generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
162    generics_with_pin_lt.make_where_clause();
163    let (
164        impl_generics_with_pin_lt,
165        ty_generics_with_pin_lt,
166        Some(WhereClause {
167            where_token,
168            predicates,
169        }),
170    ) = generics_with_pin_lt.split_for_impl()
171    else {
172        unreachable!()
173    };
174    let pinned_fields = fields.iter().filter(|f| f.pinned).map(|f| f.field);
175    quote! {
176        // This struct will be used for the unpin analysis. It is needed, because only structurally
177        // pinned fields are relevant whether the struct should implement `Unpin`.
178        #[allow(dead_code)] // The fields below are never used.
179        struct __Unpin #generics_with_pin_lt
180        #where_token
181            #predicates
182        {
183            __phantom_pin: ::pin_init::__internal::PhantomInvariantLifetime<'__pin>,
184            __phantom: ::pin_init::__internal::PhantomInvariant<#ident #ty_generics>,
185            #(#pinned_fields),*
186        }
187
188        #[doc(hidden)]
189        impl #impl_generics_with_pin_lt ::core::marker::Unpin for #ident #ty_generics
190        #where_token
191            __Unpin #ty_generics_with_pin_lt: ::core::marker::Unpin,
192            #predicates
193        {}
194    }
195}
196
197fn generate_drop_impl(ident: &Ident, generics: &Generics, args: Args) -> TokenStream {
198    let (impl_generics, ty_generics, whr) = generics.split_for_impl();
199    let has_pinned_drop = matches!(args, Args::PinnedDrop(_));
200    // We need to disallow normal `Drop` implementation, the exact behavior depends on whether
201    // `PinnedDrop` was specified in `args`.
202    if has_pinned_drop {
203        // When `PinnedDrop` was specified we just implement `Drop` and delegate.
204        quote! {
205            impl #impl_generics ::core::ops::Drop for #ident #ty_generics
206                #whr
207            {
208                fn drop(&mut self) {
209                    // SAFETY: Since this is a destructor, `self` will not move after this function
210                    // terminates, since it is inaccessible.
211                    let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) };
212                    // SAFETY: Since this is a drop function, we can create this token to call the
213                    // pinned destructor of this type.
214                    let token = unsafe { ::pin_init::__internal::OnlyCallFromDrop::new() };
215                    ::pin_init::PinnedDrop::drop(pinned, token);
216                }
217            }
218        }
219    } else {
220        // When no `PinnedDrop` was specified, then we have to prevent implementing drop.
221        quote! {
222            // We prevent this by creating a trait that will be implemented for all types implementing
223            // `Drop`. Additionally we will implement this trait for the struct leading to a conflict,
224            // if it also implements `Drop`
225            trait MustNotImplDrop {}
226            #[expect(drop_bounds)]
227            impl<T: ::core::ops::Drop + ?::core::marker::Sized> MustNotImplDrop for T {}
228            impl #impl_generics MustNotImplDrop for #ident #ty_generics
229                #whr
230            {}
231            // We also take care to prevent users from writing a useless `PinnedDrop` implementation.
232            // They might implement `PinnedDrop` correctly for the struct, but forget to give
233            // `PinnedDrop` as the parameter to `#[pin_data]`.
234            #[expect(non_camel_case_types)]
235            trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {}
236            impl<T: ::pin_init::PinnedDrop + ?::core::marker::Sized>
237                UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {}
238            impl #impl_generics
239                UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for #ident #ty_generics
240                #whr
241            {}
242        }
243    }
244}
245
246fn generate_projections(
247    vis: &Visibility,
248    ident: &Ident,
249    generics: &Generics,
250    fields: &[FieldInfo<'_>],
251) -> TokenStream {
252    let (impl_generics, ty_generics, _) = generics.split_for_impl();
253    let mut generics_with_pin_lt = generics.clone();
254    generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
255    let (_, ty_generics_with_pin_lt, whr) = generics_with_pin_lt.split_for_impl();
256    let projection = format_ident!("{ident}Projection");
257    let this = format_ident!("this");
258
259    let (fields_decl, fields_proj): (Vec<_>, Vec<_>) = fields
260        .iter()
261        .map(|field| {
262            let Field {
263                vis,
264                ident,
265                ty,
266                attrs,
267                ..
268            } = &field.field;
269
270            let mut no_doc_attrs = attrs.clone();
271            no_doc_attrs.retain(|a| !a.path().is_ident("doc"));
272            let ident = ident
273                .as_ref()
274                .expect("only structs with named fields are supported");
275            if field.pinned {
276                (
277                    quote!(
278                        #(#attrs)*
279                        #vis #ident: ::core::pin::Pin<&'__pin mut #ty>,
280                    ),
281                    quote!(
282                        #(#no_doc_attrs)*
283                        // SAFETY: this field is structurally pinned.
284                        #ident: unsafe { ::core::pin::Pin::new_unchecked(&mut #this.#ident) },
285                    ),
286                )
287            } else {
288                (
289                    quote!(
290                        #(#attrs)*
291                        #vis #ident: &'__pin mut #ty,
292                    ),
293                    quote!(
294                        #(#no_doc_attrs)*
295                        #ident: &mut #this.#ident,
296                    ),
297                )
298            }
299        })
300        .collect();
301    let structurally_pinned_fields_docs = fields
302        .iter()
303        .filter(|f| f.pinned)
304        .map(|f| format!(" - `{}`", f.field.ident.as_ref().unwrap()));
305    let not_structurally_pinned_fields_docs = fields
306        .iter()
307        .filter(|f| !f.pinned)
308        .map(|f| format!(" - `{}`", f.field.ident.as_ref().unwrap()));
309    let docs = format!(" Pin-projections of [`{ident}`]");
310    quote! {
311        #[doc = #docs]
312        #[allow(dead_code)]
313        #[doc(hidden)]
314        #vis struct #projection #generics_with_pin_lt
315            #whr
316        {
317            #(#fields_decl)*
318            ___pin_phantom_data: ::core::marker::PhantomData<&'__pin mut ()>,
319        }
320
321        impl #impl_generics #ident #ty_generics
322            #whr
323        {
324            /// Pin-projects all fields of `Self`.
325            ///
326            /// These fields are structurally pinned:
327            #(#[doc = #structurally_pinned_fields_docs])*
328            ///
329            /// These fields are **not** structurally pinned:
330            #(#[doc = #not_structurally_pinned_fields_docs])*
331            #[inline]
332            #vis fn project<'__pin>(
333                self: ::core::pin::Pin<&'__pin mut Self>,
334            ) -> #projection #ty_generics_with_pin_lt {
335                // SAFETY: we only give access to `&mut` for fields not structurally pinned.
336                let #this = unsafe { ::core::pin::Pin::get_unchecked_mut(self) };
337                #projection {
338                    #(#fields_proj)*
339                    ___pin_phantom_data: ::core::marker::PhantomData,
340                }
341            }
342        }
343    }
344}
345
346fn generate_the_pin_data(
347    vis: &Visibility,
348    struct_name: &Ident,
349    generics: &Generics,
350    fields: &[FieldInfo<'_>],
351) -> TokenStream {
352    let (impl_generics, ty_generics, whr) = generics.split_for_impl();
353
354    // For every field, we create an initializing projection function according to its projection
355    // type. If a field is structurally pinned, then it must be initialized via `PinInit`, if it is
356    // not structurally pinned, then it can be initialized via `Init`.
357    //
358    // The functions are `unsafe` to prevent accidentally calling them.
359    let field_accessors = fields
360        .iter()
361        .map(|f| {
362            let Field {
363                vis,
364                ident,
365                ty,
366                attrs,
367                ..
368            } = f.field;
369
370            let field_name = ident
371                .as_ref()
372                .expect("only structs with named fields are supported");
373            let project_ident = format_ident!("__project_{field_name}");
374            let (init_ty, init_fn, project_ty, project_body, pin_safety) = if f.pinned {
375                (
376                    quote!(PinInit),
377                    quote!(__pinned_init),
378                    quote!(::core::pin::Pin<&'__slot mut #ty>),
379                    // SAFETY: this field is structurally pinned.
380                    quote!(unsafe { ::core::pin::Pin::new_unchecked(slot) }),
381                    quote!(
382                        /// - `slot` will not move until it is dropped, i.e. it will be pinned.
383                    ),
384                )
385            } else {
386                (
387                    quote!(Init),
388                    quote!(__init),
389                    quote!(&'__slot mut #ty),
390                    quote!(slot),
391                    quote!(),
392                )
393            };
394            let slot_safety = format!(
395                " `slot` points at the field `{field_name}` inside of `{struct_name}`, which is pinned.",
396            );
397            quote! {
398                /// # Safety
399                ///
400                /// - `slot` is a valid pointer to uninitialized memory.
401                /// - the caller does not touch `slot` when `Err` is returned, they are only
402                ///   permitted to deallocate.
403                #pin_safety
404                #(#attrs)*
405                #vis unsafe fn #field_name<E>(
406                    self,
407                    slot: *mut #ty,
408                    init: impl ::pin_init::#init_ty<#ty, E>,
409                ) -> ::core::result::Result<(), E> {
410                    // SAFETY: this function has the same safety requirements as the __init function
411                    // called below.
412                    unsafe { ::pin_init::#init_ty::#init_fn(init, slot) }
413                }
414
415                /// # Safety
416                ///
417                #[doc = #slot_safety]
418                #(#attrs)*
419                #vis unsafe fn #project_ident<'__slot>(
420                    self,
421                    slot: &'__slot mut #ty,
422                ) -> #project_ty {
423                    #project_body
424                }
425            }
426        })
427        .collect::<TokenStream>();
428    quote! {
429        // We declare this struct which will host all of the projection function for our type. It
430        // will be invariant over all generic parameters which are inherited from the struct.
431        #[doc(hidden)]
432        #vis struct __ThePinData #generics
433            #whr
434        {
435            __phantom: ::pin_init::__internal::PhantomInvariant<#struct_name #ty_generics>,
436        }
437
438        impl #impl_generics ::core::clone::Clone for __ThePinData #ty_generics
439            #whr
440        {
441            fn clone(&self) -> Self { *self }
442        }
443
444        impl #impl_generics ::core::marker::Copy for __ThePinData #ty_generics
445            #whr
446        {}
447
448        #[allow(dead_code)] // Some functions might never be used and private.
449        #[expect(clippy::missing_safety_doc)]
450        impl #impl_generics __ThePinData #ty_generics
451            #whr
452        {
453            #field_accessors
454        }
455
456        // SAFETY: We have added the correct projection functions above to `__ThePinData` and
457        // we also use the least restrictive generics possible.
458        unsafe impl #impl_generics ::pin_init::__internal::HasPinData for #struct_name #ty_generics
459            #whr
460        {
461            type PinData = __ThePinData #ty_generics;
462
463            unsafe fn __pin_data() -> Self::PinData {
464                __ThePinData { __phantom: ::pin_init::__internal::PhantomInvariant::new() }
465            }
466        }
467
468        // SAFETY: TODO
469        unsafe impl #impl_generics ::pin_init::__internal::PinData for __ThePinData #ty_generics
470            #whr
471        {
472            type Datee = #struct_name #ty_generics;
473        }
474    }
475}
476
477struct SelfReplacer(PathSegment);
478
479impl VisitMut for SelfReplacer {
480    fn visit_path_mut(&mut self, i: &mut syn::Path) {
481        if i.is_ident("Self") {
482            let span = i.span();
483            let seg = &self.0;
484            *i = parse_quote_spanned!(span=> #seg);
485        } else {
486            syn::visit_mut::visit_path_mut(self, i);
487        }
488    }
489
490    fn visit_path_segment_mut(&mut self, seg: &mut PathSegment) {
491        if seg.ident == "Self" {
492            let span = seg.span();
493            let this = &self.0;
494            *seg = parse_quote_spanned!(span=> #this);
495        } else {
496            syn::visit_mut::visit_path_segment_mut(self, seg);
497        }
498    }
499
500    fn visit_item_mut(&mut self, _: &mut Item) {
501        // Do not descend into items, since items reset/change what `Self` refers to.
502    }
503}