1use proc_macro2::{Span, TokenStream};
4use quote::{format_ident, quote, quote_spanned};
5use syn::{
6 braced,
7 parse::{End, Parse},
8 parse_quote,
9 punctuated::Punctuated,
10 spanned::Spanned,
11 token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12};
13
14use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15
16pub(crate) struct Initializer {
17 attrs: Vec<InitializerAttribute>,
18 this: Option<This>,
19 path: Path,
20 brace_token: token::Brace,
21 fields: Punctuated<InitializerField, Token![,]>,
22 rest: Option<(Token![..], Expr)>,
23 error: Option<(Token![?], Type)>,
24}
25
26struct This {
27 _and_token: Token![&],
28 ident: Ident,
29 _in_token: Token![in],
30}
31
32struct InitializerField {
33 attrs: Vec<Attribute>,
34 kind: InitializerKind,
35}
36
37enum InitializerKind {
38 Value {
39 ident: Ident,
40 value: Option<(Token![:], Expr)>,
41 },
42 Init {
43 ident: Ident,
44 _left_arrow_token: Token![<-],
45 value: Expr,
46 },
47 Code {
48 _underscore_token: Token![_],
49 _colon_token: Token![:],
50 block: Block,
51 },
52}
53
54impl InitializerKind {
55 fn ident(&self) -> Option<&Ident> {
56 match self {
57 Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
58 Self::Code { .. } => None,
59 }
60 }
61}
62
63enum InitializerAttribute {
64 DefaultError(DefaultErrorAttribute),
65 DisableInitializedFieldAccess,
66}
67
68struct DefaultErrorAttribute {
69 ty: Box<Type>,
70}
71
72pub(crate) fn expand(
73 Initializer {
74 attrs,
75 this,
76 path,
77 brace_token,
78 fields,
79 rest,
80 error,
81 }: Initializer,
82 default_error: Option<&'static str>,
83 pinned: bool,
84 dcx: &mut DiagCtxt,
85) -> Result<TokenStream, ErrorGuaranteed> {
86 let error = error.map_or_else(
87 || {
88 if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
89 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
90 Some(ty.clone())
91 } else {
92 acc
93 }
94 }) {
95 default_error
96 } else if let Some(default_error) = default_error {
97 syn::parse_str(default_error).unwrap()
98 } else {
99 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
100 parse_quote!(::core::convert::Infallible)
101 }
102 },
103 |(_, err)| Box::new(err),
104 );
105 let slot = format_ident!("slot");
106 let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
107 (
108 format_ident!("HasPinData"),
109 format_ident!("PinData"),
110 format_ident!("__pin_data"),
111 format_ident!("pin_init_from_closure"),
112 )
113 } else {
114 (
115 format_ident!("HasInitData"),
116 format_ident!("InitData"),
117 format_ident!("__init_data"),
118 format_ident!("init_from_closure"),
119 )
120 };
121 let init_kind = get_init_kind(rest, dcx);
122 let zeroable_check = match init_kind {
123 InitKind::Normal => quote!(),
124 InitKind::Zeroing => quote! {
125 fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
130 where T: ::pin_init::Zeroable
131 {}
132 assert_zeroable(#slot);
134 unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
136 },
137 };
138 let this = match this {
139 None => quote!(),
140 Some(This { ident, .. }) => quote! {
141 let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
144 },
145 };
146 let data = Ident::new("__data", Span::mixed_site());
148 let init_fields = init_fields(
149 &fields,
150 pinned,
151 !attrs
152 .iter()
153 .any(|attr| matches!(attr, InitializerAttribute::DisableInitializedFieldAccess)),
154 &data,
155 &slot,
156 );
157 let field_check = make_field_check(&fields, init_kind, &path);
158 Ok(quote! {{
159 struct __InitOk;
163
164 let #data = unsafe {
167 use ::pin_init::__internal::#has_data_trait;
168 #path::#get_data()
171 };
172 let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>(
174 #data,
175 move |slot| {
176 {
177 struct __InitOk;
179 #zeroable_check
180 #this
181 #init_fields
182 #field_check
183 }
184 Ok(__InitOk)
185 }
186 );
187 let init = move |slot| -> ::core::result::Result<(), #error> {
188 init(slot).map(|__InitOk| ())
189 };
190 let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
192 init
193 }})
194}
195
196enum InitKind {
197 Normal,
198 Zeroing,
199}
200
201fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
202 let Some((dotdot, expr)) = rest else {
203 return InitKind::Normal;
204 };
205 match &expr {
206 Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
207 Expr::Path(ExprPath {
208 attrs,
209 qself: None,
210 path:
211 Path {
212 leading_colon: None,
213 segments,
214 },
215 }) if attrs.is_empty()
216 && segments.len() == 2
217 && segments[0].ident == "Zeroable"
218 && segments[0].arguments.is_none()
219 && segments[1].ident == "init_zeroed"
220 && segments[1].arguments.is_none() =>
221 {
222 return InitKind::Zeroing;
223 }
224 _ => {}
225 },
226 _ => {}
227 }
228 dcx.error(
229 dotdot.span().join(expr.span()).unwrap_or(expr.span()),
230 "expected nothing or `..Zeroable::init_zeroed()`.",
231 );
232 InitKind::Normal
233}
234
235fn init_fields(
237 fields: &Punctuated<InitializerField, Token![,]>,
238 pinned: bool,
239 generate_initialized_accessors: bool,
240 data: &Ident,
241 slot: &Ident,
242) -> TokenStream {
243 let mut guards = vec![];
244 let mut guard_attrs = vec![];
245 let mut res = TokenStream::new();
246 for InitializerField { attrs, kind } in fields {
247 let cfgs = {
248 let mut cfgs = attrs.clone();
249 cfgs.retain(|attr| attr.path().is_ident("cfg"));
250 cfgs
251 };
252 let init = match kind {
253 InitializerKind::Value { ident, value } => {
254 let mut value_ident = ident.clone();
255 let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
256 value_ident.set_span(value.span());
259 quote!(let #value_ident = #value;)
260 });
261 let write = quote_spanned!(ident.span()=> ::core::ptr::write);
263 let accessor = if pinned {
264 let project_ident = format_ident!("__project_{ident}");
265 quote! {
266 unsafe { #data.#project_ident(&mut (*#slot).#ident) }
268 }
269 } else {
270 quote! {
271 unsafe { &mut (*#slot).#ident }
273 }
274 };
275 let accessor = generate_initialized_accessors.then(|| {
276 quote! {
277 #(#cfgs)*
278 #[allow(unused_variables)]
279 let #ident = #accessor;
280 }
281 });
282 quote! {
283 #(#attrs)*
284 {
285 #value_prep
286 unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
288 }
289 #accessor
290 }
291 }
292 InitializerKind::Init { ident, value, .. } => {
293 let init = format_ident!("init", span = value.span());
295 let (value_init, accessor) = if pinned {
296 let project_ident = format_ident!("__project_{ident}");
297 (
298 quote! {
299 unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
305 },
306 quote! {
307 unsafe { #data.#project_ident(&mut (*#slot).#ident) }
309 },
310 )
311 } else {
312 (
313 quote! {
314 unsafe {
317 ::pin_init::Init::__init(
318 #init,
319 ::core::ptr::addr_of_mut!((*#slot).#ident),
320 )?
321 };
322 },
323 quote! {
324 unsafe { &mut (*#slot).#ident }
326 },
327 )
328 };
329 let accessor = generate_initialized_accessors.then(|| {
330 quote! {
331 #(#cfgs)*
332 #[allow(unused_variables)]
333 let #ident = #accessor;
334 }
335 });
336 quote! {
337 #(#attrs)*
338 {
339 let #init = #value;
340 #value_init
341 }
342 #accessor
343 }
344 }
345 InitializerKind::Code { block: value, .. } => quote! {
346 #(#attrs)*
347 #[allow(unused_braces)]
348 #value
349 },
350 };
351 res.extend(init);
352 if let Some(ident) = kind.ident() {
353 let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
355 res.extend(quote! {
356 #(#cfgs)*
357 let #guard = unsafe {
363 ::pin_init::__internal::DropGuard::new(
364 ::core::ptr::addr_of_mut!((*slot).#ident)
365 )
366 };
367 });
368 guards.push(guard);
369 guard_attrs.push(cfgs);
370 }
371 }
372 quote! {
373 #res
374 #(
377 #(#guard_attrs)*
378 ::core::mem::forget(#guards);
379 )*
380 }
381}
382
383fn make_field_check(
385 fields: &Punctuated<InitializerField, Token![,]>,
386 init_kind: InitKind,
387 path: &Path,
388) -> TokenStream {
389 let field_attrs = fields
390 .iter()
391 .filter_map(|f| f.kind.ident().map(|_| &f.attrs));
392 let field_name = fields.iter().filter_map(|f| f.kind.ident());
393 match init_kind {
394 InitKind::Normal => quote! {
395 #[allow(unreachable_code, clippy::diverging_sub_expression)]
399 let _ = || unsafe {
401 ::core::ptr::write(slot, #path {
402 #(
403 #(#field_attrs)*
404 #field_name: ::core::panic!(),
405 )*
406 })
407 };
408 },
409 InitKind::Zeroing => quote! {
410 #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
415 let _ = || unsafe {
417 ::core::ptr::write(slot, #path {
418 #(
419 #(#field_attrs)*
420 #field_name: ::core::panic!(),
421 )*
422 ..::core::mem::zeroed()
423 })
424 };
425 },
426 }
427}
428
429impl Parse for Initializer {
430 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
431 let attrs = input.call(Attribute::parse_outer)?;
432 let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
433 let path = input.parse()?;
434 let content;
435 let brace_token = braced!(content in input);
436 let mut fields = Punctuated::new();
437 loop {
438 let lh = content.lookahead1();
439 if lh.peek(End) || lh.peek(Token![..]) {
440 break;
441 } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
442 fields.push_value(content.parse()?);
443 let lh = content.lookahead1();
444 if lh.peek(End) {
445 break;
446 } else if lh.peek(Token![,]) {
447 fields.push_punct(content.parse()?);
448 } else {
449 return Err(lh.error());
450 }
451 } else {
452 return Err(lh.error());
453 }
454 }
455 let rest = content
456 .peek(Token![..])
457 .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
458 .transpose()?;
459 let error = input
460 .peek(Token![?])
461 .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
462 .transpose()?;
463 let attrs = attrs
464 .into_iter()
465 .map(|a| {
466 if a.path().is_ident("default_error") {
467 a.parse_args::<DefaultErrorAttribute>()
468 .map(InitializerAttribute::DefaultError)
469 } else if a.path().is_ident("disable_initialized_field_access") {
470 a.meta
471 .require_path_only()
472 .map(|_| InitializerAttribute::DisableInitializedFieldAccess)
473 } else {
474 Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
475 }
476 })
477 .collect::<Result<Vec<_>, _>>()?;
478 Ok(Self {
479 attrs,
480 this,
481 path,
482 brace_token,
483 fields,
484 rest,
485 error,
486 })
487 }
488}
489
490impl Parse for DefaultErrorAttribute {
491 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
492 Ok(Self { ty: input.parse()? })
493 }
494}
495
496impl Parse for This {
497 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
498 Ok(Self {
499 _and_token: input.parse()?,
500 ident: input.parse()?,
501 _in_token: input.parse()?,
502 })
503 }
504}
505
506impl Parse for InitializerField {
507 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
508 let attrs = input.call(Attribute::parse_outer)?;
509 Ok(Self {
510 attrs,
511 kind: input.parse()?,
512 })
513 }
514}
515
516impl Parse for InitializerKind {
517 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
518 let lh = input.lookahead1();
519 if lh.peek(Token![_]) {
520 Ok(Self::Code {
521 _underscore_token: input.parse()?,
522 _colon_token: input.parse()?,
523 block: input.parse()?,
524 })
525 } else if lh.peek(Ident) {
526 let ident = input.parse()?;
527 let lh = input.lookahead1();
528 if lh.peek(Token![<-]) {
529 Ok(Self::Init {
530 ident,
531 _left_arrow_token: input.parse()?,
532 value: input.parse()?,
533 })
534 } else if lh.peek(Token![:]) {
535 Ok(Self::Value {
536 ident,
537 value: Some((input.parse()?, input.parse()?)),
538 })
539 } else if lh.peek(Token![,]) || lh.peek(End) {
540 Ok(Self::Value { ident, value: None })
541 } else {
542 Err(lh.error())
543 }
544 } else {
545 Err(lh.error())
546 }
547 }
548}