Skip to main content

pin_init_internal/
pin_data.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{
6    parse::{End, Nothing, Parse},
7    parse_quote, parse_quote_spanned,
8    spanned::Spanned,
9    visit_mut::VisitMut,
10    Field, Generics, Ident, Item, PathSegment, Type, TypePath, Visibility, WhereClause,
11};
12
13use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
14
15pub(crate) mod kw {
16    syn::custom_keyword!(PinnedDrop);
17}
18
19pub(crate) enum Args {
20    Nothing(Nothing),
21    #[allow(dead_code)]
22    PinnedDrop(kw::PinnedDrop),
23}
24
25impl Parse for Args {
26    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
27        let lh = input.lookahead1();
28        if lh.peek(End) {
29            input.parse().map(Self::Nothing)
30        } else if lh.peek(kw::PinnedDrop) {
31            input.parse().map(Self::PinnedDrop)
32        } else {
33            Err(lh.error())
34        }
35    }
36}
37
38pub(crate) fn pin_data(
39    args: Args,
40    input: Item,
41    dcx: &mut DiagCtxt,
42) -> Result<TokenStream, ErrorGuaranteed> {
43    let mut struct_ = match input {
44        Item::Struct(struct_) => struct_,
45        Item::Enum(enum_) => {
46            return Err(dcx.error(
47                enum_.enum_token,
48                "`#[pin_data]` only supports structs for now",
49            ));
50        }
51        Item::Union(union) => {
52            return Err(dcx.error(
53                union.union_token,
54                "`#[pin_data]` only supports structs for now",
55            ));
56        }
57        rest => {
58            return Err(dcx.error(
59                rest,
60                "`#[pin_data]` can only be applied to struct, enum and union definitions",
61            ));
62        }
63    };
64
65    // The generics might contain the `Self` type. Since this macro will define a new type with the
66    // same generics and bounds, this poses a problem: `Self` will refer to the new type as opposed
67    // to this struct definition. Therefore we have to replace `Self` with the concrete name.
68    let mut replacer = {
69        let name = &struct_.ident;
70        let (_, ty_generics, _) = struct_.generics.split_for_impl();
71        SelfReplacer(parse_quote!(#name #ty_generics))
72    };
73    replacer.visit_generics_mut(&mut struct_.generics);
74    replacer.visit_fields_mut(&mut struct_.fields);
75
76    let fields: Vec<(bool, &Field)> = struct_
77        .fields
78        .iter_mut()
79        .map(|field| {
80            let len = field.attrs.len();
81            field.attrs.retain(|a| !a.path().is_ident("pin"));
82            (len != field.attrs.len(), &*field)
83        })
84        .collect();
85
86    for (pinned, field) in &fields {
87        if !pinned && is_phantom_pinned(&field.ty) {
88            dcx.error(
89                field,
90                format!(
91                    "The field `{}` of type `PhantomPinned` only has an effect \
92                    if it has the `#[pin]` attribute",
93                    field.ident.as_ref().unwrap(),
94                ),
95            );
96        }
97    }
98
99    let unpin_impl = generate_unpin_impl(&struct_.ident, &struct_.generics, &fields);
100    let drop_impl = generate_drop_impl(&struct_.ident, &struct_.generics, args);
101    let projections =
102        generate_projections(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
103    let the_pin_data =
104        generate_the_pin_data(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
105
106    Ok(quote! {
107        #struct_
108        #projections
109        // We put the rest into this const item, because it then will not be accessible to anything
110        // outside.
111        const _: () = {
112            #the_pin_data
113            #unpin_impl
114            #drop_impl
115        };
116    })
117}
118
119fn is_phantom_pinned(ty: &Type) -> bool {
120    match ty {
121        Type::Path(TypePath { qself: None, path }) => {
122            // Cannot possibly refer to `PhantomPinned` (except alias, but that's on the user).
123            if path.segments.len() > 3 {
124                return false;
125            }
126            // If there is a `::`, then the path needs to be `::core::marker::PhantomPinned` or
127            // `::std::marker::PhantomPinned`.
128            if path.leading_colon.is_some() && path.segments.len() != 3 {
129                return false;
130            }
131            let expected: Vec<&[&str]> = vec![&["PhantomPinned"], &["marker"], &["core", "std"]];
132            for (actual, expected) in path.segments.iter().rev().zip(expected) {
133                if !actual.arguments.is_empty() || expected.iter().all(|e| actual.ident != e) {
134                    return false;
135                }
136            }
137            true
138        }
139        _ => false,
140    }
141}
142
143fn generate_unpin_impl(
144    ident: &Ident,
145    generics: &Generics,
146    fields: &[(bool, &Field)],
147) -> TokenStream {
148    let (_, ty_generics, _) = generics.split_for_impl();
149    let mut generics_with_pin_lt = generics.clone();
150    generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
151    generics_with_pin_lt.make_where_clause();
152    let (
153        impl_generics_with_pin_lt,
154        ty_generics_with_pin_lt,
155        Some(WhereClause {
156            where_token,
157            predicates,
158        }),
159    ) = generics_with_pin_lt.split_for_impl()
160    else {
161        unreachable!()
162    };
163    let pinned_fields = fields.iter().filter_map(|(b, f)| b.then_some(f));
164    quote! {
165        // This struct will be used for the unpin analysis. It is needed, because only structurally
166        // pinned fields are relevant whether the struct should implement `Unpin`.
167        #[allow(dead_code)] // The fields below are never used.
168        struct __Unpin #generics_with_pin_lt
169        #where_token
170            #predicates
171        {
172            __phantom_pin: ::core::marker::PhantomData<fn(&'__pin ()) -> &'__pin ()>,
173            __phantom: ::core::marker::PhantomData<
174                fn(#ident #ty_generics) -> #ident #ty_generics
175            >,
176            #(#pinned_fields),*
177        }
178
179        #[doc(hidden)]
180        impl #impl_generics_with_pin_lt ::core::marker::Unpin for #ident #ty_generics
181        #where_token
182            __Unpin #ty_generics_with_pin_lt: ::core::marker::Unpin,
183            #predicates
184        {}
185    }
186}
187
188fn generate_drop_impl(ident: &Ident, generics: &Generics, args: Args) -> TokenStream {
189    let (impl_generics, ty_generics, whr) = generics.split_for_impl();
190    let has_pinned_drop = matches!(args, Args::PinnedDrop(_));
191    // We need to disallow normal `Drop` implementation, the exact behavior depends on whether
192    // `PinnedDrop` was specified in `args`.
193    if has_pinned_drop {
194        // When `PinnedDrop` was specified we just implement `Drop` and delegate.
195        quote! {
196            impl #impl_generics ::core::ops::Drop for #ident #ty_generics
197                #whr
198            {
199                fn drop(&mut self) {
200                    // SAFETY: Since this is a destructor, `self` will not move after this function
201                    // terminates, since it is inaccessible.
202                    let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) };
203                    // SAFETY: Since this is a drop function, we can create this token to call the
204                    // pinned destructor of this type.
205                    let token = unsafe { ::pin_init::__internal::OnlyCallFromDrop::new() };
206                    ::pin_init::PinnedDrop::drop(pinned, token);
207                }
208            }
209        }
210    } else {
211        // When no `PinnedDrop` was specified, then we have to prevent implementing drop.
212        quote! {
213            // We prevent this by creating a trait that will be implemented for all types implementing
214            // `Drop`. Additionally we will implement this trait for the struct leading to a conflict,
215            // if it also implements `Drop`
216            trait MustNotImplDrop {}
217            #[expect(drop_bounds)]
218            impl<T: ::core::ops::Drop + ?::core::marker::Sized> MustNotImplDrop for T {}
219            impl #impl_generics MustNotImplDrop for #ident #ty_generics
220                #whr
221            {}
222            // We also take care to prevent users from writing a useless `PinnedDrop` implementation.
223            // They might implement `PinnedDrop` correctly for the struct, but forget to give
224            // `PinnedDrop` as the parameter to `#[pin_data]`.
225            #[expect(non_camel_case_types)]
226            trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {}
227            impl<T: ::pin_init::PinnedDrop + ?::core::marker::Sized>
228                UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {}
229            impl #impl_generics
230                UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for #ident #ty_generics
231                #whr
232            {}
233        }
234    }
235}
236
237fn generate_projections(
238    vis: &Visibility,
239    ident: &Ident,
240    generics: &Generics,
241    fields: &[(bool, &Field)],
242) -> TokenStream {
243    let (impl_generics, ty_generics, _) = generics.split_for_impl();
244    let mut generics_with_pin_lt = generics.clone();
245    generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
246    let (_, ty_generics_with_pin_lt, whr) = generics_with_pin_lt.split_for_impl();
247    let projection = format_ident!("{ident}Projection");
248    let this = format_ident!("this");
249
250    let (fields_decl, fields_proj) = collect_tuple(fields.iter().map(
251        |(
252            pinned,
253            Field {
254                vis,
255                ident,
256                ty,
257                attrs,
258                ..
259            },
260        )| {
261            let mut attrs = attrs.clone();
262            attrs.retain(|a| !a.path().is_ident("pin"));
263            let mut no_doc_attrs = attrs.clone();
264            no_doc_attrs.retain(|a| !a.path().is_ident("doc"));
265            let ident = ident
266                .as_ref()
267                .expect("only structs with named fields are supported");
268            if *pinned {
269                (
270                    quote!(
271                        #(#attrs)*
272                        #vis #ident: ::core::pin::Pin<&'__pin mut #ty>,
273                    ),
274                    quote!(
275                        #(#no_doc_attrs)*
276                        // SAFETY: this field is structurally pinned.
277                        #ident: unsafe { ::core::pin::Pin::new_unchecked(&mut #this.#ident) },
278                    ),
279                )
280            } else {
281                (
282                    quote!(
283                        #(#attrs)*
284                        #vis #ident: &'__pin mut #ty,
285                    ),
286                    quote!(
287                        #(#no_doc_attrs)*
288                        #ident: &mut #this.#ident,
289                    ),
290                )
291            }
292        },
293    ));
294    let structurally_pinned_fields_docs = fields
295        .iter()
296        .filter_map(|(pinned, field)| pinned.then_some(field))
297        .map(|Field { ident, .. }| format!(" - `{}`", ident.as_ref().unwrap()));
298    let not_structurally_pinned_fields_docs = fields
299        .iter()
300        .filter_map(|(pinned, field)| (!pinned).then_some(field))
301        .map(|Field { ident, .. }| format!(" - `{}`", ident.as_ref().unwrap()));
302    let docs = format!(" Pin-projections of [`{ident}`]");
303    quote! {
304        #[doc = #docs]
305        #[allow(dead_code)]
306        #[doc(hidden)]
307        #vis struct #projection #generics_with_pin_lt {
308            #(#fields_decl)*
309            ___pin_phantom_data: ::core::marker::PhantomData<&'__pin mut ()>,
310        }
311
312        impl #impl_generics #ident #ty_generics
313            #whr
314        {
315            /// Pin-projects all fields of `Self`.
316            ///
317            /// These fields are structurally pinned:
318            #(#[doc = #structurally_pinned_fields_docs])*
319            ///
320            /// These fields are **not** structurally pinned:
321            #(#[doc = #not_structurally_pinned_fields_docs])*
322            #[inline]
323            #vis fn project<'__pin>(
324                self: ::core::pin::Pin<&'__pin mut Self>,
325            ) -> #projection #ty_generics_with_pin_lt {
326                // SAFETY: we only give access to `&mut` for fields not structurally pinned.
327                let #this = unsafe { ::core::pin::Pin::get_unchecked_mut(self) };
328                #projection {
329                    #(#fields_proj)*
330                    ___pin_phantom_data: ::core::marker::PhantomData,
331                }
332            }
333        }
334    }
335}
336
337fn generate_the_pin_data(
338    vis: &Visibility,
339    ident: &Ident,
340    generics: &Generics,
341    fields: &[(bool, &Field)],
342) -> TokenStream {
343    let (impl_generics, ty_generics, whr) = generics.split_for_impl();
344
345    // For every field, we create an initializing projection function according to its projection
346    // type. If a field is structurally pinned, then it must be initialized via `PinInit`, if it is
347    // not structurally pinned, then it can be initialized via `Init`.
348    //
349    // The functions are `unsafe` to prevent accidentally calling them.
350    fn handle_field(
351        Field {
352            vis,
353            ident,
354            ty,
355            attrs,
356            ..
357        }: &Field,
358        struct_ident: &Ident,
359        pinned: bool,
360    ) -> TokenStream {
361        let mut attrs = attrs.clone();
362        attrs.retain(|a| !a.path().is_ident("pin"));
363        let ident = ident
364            .as_ref()
365            .expect("only structs with named fields are supported");
366        let project_ident = format_ident!("__project_{ident}");
367        let (init_ty, init_fn, project_ty, project_body, pin_safety) = if pinned {
368            (
369                quote!(PinInit),
370                quote!(__pinned_init),
371                quote!(::core::pin::Pin<&'__slot mut #ty>),
372                // SAFETY: this field is structurally pinned.
373                quote!(unsafe { ::core::pin::Pin::new_unchecked(slot) }),
374                quote!(
375                    /// - `slot` will not move until it is dropped, i.e. it will be pinned.
376                ),
377            )
378        } else {
379            (
380                quote!(Init),
381                quote!(__init),
382                quote!(&'__slot mut #ty),
383                quote!(slot),
384                quote!(),
385            )
386        };
387        let slot_safety = format!(
388            " `slot` points at the field `{ident}` inside of `{struct_ident}`, which is pinned.",
389        );
390        quote! {
391            /// # Safety
392            ///
393            /// - `slot` is a valid pointer to uninitialized memory.
394            /// - the caller does not touch `slot` when `Err` is returned, they are only permitted
395            ///   to deallocate.
396            #pin_safety
397            #(#attrs)*
398            #vis unsafe fn #ident<E>(
399                self,
400                slot: *mut #ty,
401                init: impl ::pin_init::#init_ty<#ty, E>,
402            ) -> ::core::result::Result<(), E> {
403                // SAFETY: this function has the same safety requirements as the __init function
404                // called below.
405                unsafe { ::pin_init::#init_ty::#init_fn(init, slot) }
406            }
407
408            /// # Safety
409            ///
410            #[doc = #slot_safety]
411            #(#attrs)*
412            #vis unsafe fn #project_ident<'__slot>(
413                self,
414                slot: &'__slot mut #ty,
415            ) -> #project_ty {
416                #project_body
417            }
418        }
419    }
420
421    let field_accessors = fields
422        .iter()
423        .map(|(pinned, field)| handle_field(field, ident, *pinned))
424        .collect::<TokenStream>();
425    quote! {
426        // We declare this struct which will host all of the projection function for our type. It
427        // will be invariant over all generic parameters which are inherited from the struct.
428        #[doc(hidden)]
429        #vis struct __ThePinData #generics
430            #whr
431        {
432            __phantom: ::core::marker::PhantomData<
433                fn(#ident #ty_generics) -> #ident #ty_generics
434            >,
435        }
436
437        impl #impl_generics ::core::clone::Clone for __ThePinData #ty_generics
438            #whr
439        {
440            fn clone(&self) -> Self { *self }
441        }
442
443        impl #impl_generics ::core::marker::Copy for __ThePinData #ty_generics
444            #whr
445        {}
446
447        #[allow(dead_code)] // Some functions might never be used and private.
448        #[expect(clippy::missing_safety_doc)]
449        impl #impl_generics __ThePinData #ty_generics
450            #whr
451        {
452            #field_accessors
453        }
454
455        // SAFETY: We have added the correct projection functions above to `__ThePinData` and
456        // we also use the least restrictive generics possible.
457        unsafe impl #impl_generics ::pin_init::__internal::HasPinData for #ident #ty_generics
458            #whr
459        {
460            type PinData = __ThePinData #ty_generics;
461
462            unsafe fn __pin_data() -> Self::PinData {
463                __ThePinData { __phantom: ::core::marker::PhantomData }
464            }
465        }
466
467        // SAFETY: TODO
468        unsafe impl #impl_generics ::pin_init::__internal::PinData for __ThePinData #ty_generics
469            #whr
470        {
471            type Datee = #ident #ty_generics;
472        }
473    }
474}
475
476struct SelfReplacer(PathSegment);
477
478impl VisitMut for SelfReplacer {
479    fn visit_path_mut(&mut self, i: &mut syn::Path) {
480        if i.is_ident("Self") {
481            let span = i.span();
482            let seg = &self.0;
483            *i = parse_quote_spanned!(span=> #seg);
484        } else {
485            syn::visit_mut::visit_path_mut(self, i);
486        }
487    }
488
489    fn visit_path_segment_mut(&mut self, seg: &mut PathSegment) {
490        if seg.ident == "Self" {
491            let span = seg.span();
492            let this = &self.0;
493            *seg = parse_quote_spanned!(span=> #this);
494        } else {
495            syn::visit_mut::visit_path_segment_mut(self, seg);
496        }
497    }
498
499    fn visit_item_mut(&mut self, _: &mut Item) {
500        // Do not descend into items, since items reset/change what `Self` refers to.
501    }
502}
503
504// replace with `.collect()` once MSRV is above 1.79
505fn collect_tuple<A, B>(iter: impl Iterator<Item = (A, B)>) -> (Vec<A>, Vec<B>) {
506    let mut res_a = vec![];
507    let mut res_b = vec![];
508    for (a, b) in iter {
509        res_a.push(a);
510        res_b.push(b);
511    }
512    (res_a, res_b)
513}