Skip to main content

macros/
module.rs

1// SPDX-License-Identifier: GPL-2.0
2
3use std::ffi::CString;
4
5use proc_macro2::{
6    Literal,
7    TokenStream, //
8};
9use quote::{
10    format_ident,
11    quote, //
12};
13use syn::{
14    braced,
15    bracketed,
16    ext::IdentExt,
17    parse::{
18        Parse,
19        ParseStream, //
20    },
21    parse_quote,
22    punctuated::Punctuated,
23    Error,
24    Expr,
25    Ident,
26    LitStr,
27    Path,
28    Result,
29    Token,
30    Type, //
31};
32
33use crate::helpers::*;
34
35struct ModInfoBuilder<'a> {
36    module: &'a str,
37    counter: usize,
38    ts: TokenStream,
39    param_ts: TokenStream,
40}
41
42impl<'a> ModInfoBuilder<'a> {
43    fn new(module: &'a str) -> Self {
44        ModInfoBuilder {
45            module,
46            counter: 0,
47            ts: TokenStream::new(),
48            param_ts: TokenStream::new(),
49        }
50    }
51
52    fn emit_base(&mut self, field: &str, content: &str, builtin: bool, param: bool) {
53        let string = if builtin {
54            // Built-in modules prefix their modinfo strings by `module.`.
55            format!(
56                "{module}.{field}={content}\0",
57                module = self.module,
58                field = field,
59                content = content
60            )
61        } else {
62            // Loadable modules' modinfo strings go as-is.
63            format!("{field}={content}\0")
64        };
65        let length = string.len();
66        let string = Literal::byte_string(string.as_bytes());
67        let cfg = if builtin {
68            quote!(#[cfg(not(MODULE))])
69        } else {
70            quote!(#[cfg(MODULE)])
71        };
72
73        let counter = format_ident!(
74            "__{module}_{counter}",
75            module = self.module.to_uppercase(),
76            counter = self.counter
77        );
78        let item = quote! {
79            #cfg
80            #[cfg_attr(not(target_os = "macos"), link_section = ".modinfo")]
81            #[used(compiler)]
82            pub static #counter: [u8; #length] = *#string;
83        };
84
85        if param {
86            self.param_ts.extend(item);
87        } else {
88            self.ts.extend(item);
89        }
90
91        self.counter += 1;
92    }
93
94    fn emit_only_builtin(&mut self, field: &str, content: &str, param: bool) {
95        self.emit_base(field, content, true, param)
96    }
97
98    fn emit_only_loadable(&mut self, field: &str, content: &str, param: bool) {
99        self.emit_base(field, content, false, param)
100    }
101
102    fn emit(&mut self, field: &str, content: &str) {
103        self.emit_internal(field, content, false);
104    }
105
106    fn emit_internal(&mut self, field: &str, content: &str, param: bool) {
107        self.emit_only_builtin(field, content, param);
108        self.emit_only_loadable(field, content, param);
109    }
110
111    fn emit_param(&mut self, field: &str, param: &str, content: &str) {
112        let content = format!("{param}:{content}", param = param, content = content);
113        self.emit_internal(field, &content, true);
114    }
115
116    fn emit_params(&mut self, info: &ModuleInfo) {
117        let Some(params) = &info.params else {
118            return;
119        };
120
121        for param in params {
122            let param_name_str = param.name.to_string();
123            let param_type_str = param.ptype.to_string();
124
125            let ops = param_ops_path(&param_type_str);
126
127            // Note: The spelling of these fields is dictated by the user space
128            // tool `modinfo`.
129            self.emit_param("parmtype", &param_name_str, &param_type_str);
130            self.emit_param("parm", &param_name_str, &param.description.value());
131
132            let static_name = format_ident!("__{}_{}_struct", self.module, param.name);
133            let param_name_cstr =
134                CString::new(param_name_str).expect("name contains NUL-terminator");
135            let param_name_cstr_with_module =
136                CString::new(format!("{}.{}", self.module, param.name))
137                    .expect("name contains NUL-terminator");
138
139            let param_name = &param.name;
140            let param_type = &param.ptype;
141            let param_default = &param.default;
142
143            self.param_ts.extend(quote! {
144                #[allow(non_upper_case_globals)]
145                pub(crate) static #param_name:
146                    ::kernel::module_param::ModuleParamAccess<#param_type> =
147                        ::kernel::module_param::ModuleParamAccess::new(#param_default);
148
149                const _: () = {
150                    #[allow(non_upper_case_globals)]
151                    #[link_section = "__param"]
152                    #[used(compiler)]
153                    static #static_name:
154                        ::kernel::module_param::KernelParam =
155                        ::kernel::module_param::KernelParam::new(
156                            ::kernel::bindings::kernel_param {
157                                name: kernel::str::as_char_ptr_in_const_context(
158                                    if ::core::cfg!(MODULE) {
159                                        #param_name_cstr
160                                    } else {
161                                        #param_name_cstr_with_module
162                                    }
163                                ),
164                                // SAFETY: `__this_module` is constructed by the kernel at load
165                                // time and will not be freed until the module is unloaded.
166                                #[cfg(MODULE)]
167                                mod_: unsafe {
168                                    core::ptr::from_ref(&::kernel::bindings::__this_module)
169                                        .cast_mut()
170                                },
171                                #[cfg(not(MODULE))]
172                                mod_: ::core::ptr::null_mut(),
173                                ops: core::ptr::from_ref(&#ops),
174                                perm: 0, // Will not appear in sysfs
175                                level: -1,
176                                flags: 0,
177                                __bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 {
178                                    arg: #param_name.as_void_ptr()
179                                },
180                            }
181                        );
182                };
183            });
184        }
185    }
186}
187
188fn param_ops_path(param_type: &str) -> Path {
189    match param_type {
190        "i8" => parse_quote!(::kernel::module_param::PARAM_OPS_I8),
191        "u8" => parse_quote!(::kernel::module_param::PARAM_OPS_U8),
192        "i16" => parse_quote!(::kernel::module_param::PARAM_OPS_I16),
193        "u16" => parse_quote!(::kernel::module_param::PARAM_OPS_U16),
194        "i32" => parse_quote!(::kernel::module_param::PARAM_OPS_I32),
195        "u32" => parse_quote!(::kernel::module_param::PARAM_OPS_U32),
196        "i64" => parse_quote!(::kernel::module_param::PARAM_OPS_I64),
197        "u64" => parse_quote!(::kernel::module_param::PARAM_OPS_U64),
198        "isize" => parse_quote!(::kernel::module_param::PARAM_OPS_ISIZE),
199        "usize" => parse_quote!(::kernel::module_param::PARAM_OPS_USIZE),
200        t => panic!("Unsupported parameter type {}", t),
201    }
202}
203
204/// Parse fields that are required to use a specific order.
205///
206/// As fields must follow a specific order, we *could* just parse fields one by one by peeking.
207/// However the error message generated when implementing that way is not very friendly.
208///
209/// So instead we parse fields in an arbitrary order, but only enforce the ordering after parsing,
210/// and if the wrong order is used, the proper order is communicated to the user with error message.
211///
212/// Usage looks like this:
213/// ```ignore
214/// parse_ordered_fields! {
215///     from input;
216///
217///     // This will extract "foo: <field>" into a variable named "foo".
218///     // The variable will have type `Option<_>`.
219///     foo => <expression that parses the field>,
220///
221///     // If you need the variable name to be different than the key name.
222///     // This extracts "baz: <field>" into a variable named "bar".
223///     // You might want this if "baz" is a keyword.
224///     baz as bar => <expression that parse the field>,
225///
226///     // You can mark a key as required, and the variable will no longer be `Option`.
227///     // foobar will be of type `Expr` instead of `Option<Expr>`.
228///     foobar [required] => input.parse::<Expr>()?,
229/// }
230/// ```
231macro_rules! parse_ordered_fields {
232    (@gen
233        [$input:expr]
234        [$([$name:ident; $key:ident; $parser:expr])*]
235        [$([$req_name:ident; $req_key:ident])*]
236    ) => {
237        $(let mut $name = None;)*
238
239        const EXPECTED_KEYS: &[&str] = &[$(stringify!($key),)*];
240        const REQUIRED_KEYS: &[&str] = &[$(stringify!($req_key),)*];
241
242        let span = $input.span();
243        let mut seen_keys = Vec::new();
244
245        while !$input.is_empty() {
246            let key = $input.call(Ident::parse_any)?;
247
248            if seen_keys.contains(&key) {
249                Err(Error::new_spanned(
250                    &key,
251                    format!(r#"duplicated key "{key}". Keys can only be specified once."#),
252                ))?
253            }
254
255            $input.parse::<Token![:]>()?;
256
257            match &*key.to_string() {
258                $(
259                    stringify!($key) => $name = Some($parser),
260                )*
261                _ => {
262                    Err(Error::new_spanned(
263                        &key,
264                        format!(r#"unknown key "{key}". Valid keys are: {EXPECTED_KEYS:?}."#),
265                    ))?
266                }
267            }
268
269            $input.parse::<Token![,]>()?;
270            seen_keys.push(key);
271        }
272
273        for key in REQUIRED_KEYS {
274            if !seen_keys.iter().any(|e| e == key) {
275                Err(Error::new(span, format!(r#"missing required key "{key}""#)))?
276            }
277        }
278
279        let mut ordered_keys: Vec<&str> = Vec::new();
280        for key in EXPECTED_KEYS {
281            if seen_keys.iter().any(|e| e == key) {
282                ordered_keys.push(key);
283            }
284        }
285
286        if seen_keys != ordered_keys {
287            Err(Error::new(
288                span,
289                format!(r#"keys are not ordered as expected. Order them like: {ordered_keys:?}."#),
290            ))?
291        }
292
293        $(let $req_name = $req_name.expect("required field");)*
294    };
295
296    // Handle required fields.
297    (@gen
298        [$input:expr] [$($tok:tt)*] [$($req:tt)*]
299        $key:ident as $name:ident [required] => $parser:expr,
300        $($rest:tt)*
301    ) => {
302        parse_ordered_fields!(
303            @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)* [$name; $key]] $($rest)*
304        )
305    };
306    (@gen
307        [$input:expr] [$($tok:tt)*] [$($req:tt)*]
308        $name:ident [required] => $parser:expr,
309        $($rest:tt)*
310    ) => {
311        parse_ordered_fields!(
312            @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)* [$name; $name]] $($rest)*
313        )
314    };
315
316    // Handle optional fields.
317    (@gen
318        [$input:expr] [$($tok:tt)*] [$($req:tt)*]
319        $key:ident as $name:ident => $parser:expr,
320        $($rest:tt)*
321    ) => {
322        parse_ordered_fields!(
323            @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)*] $($rest)*
324        )
325    };
326    (@gen
327        [$input:expr] [$($tok:tt)*] [$($req:tt)*]
328        $name:ident => $parser:expr,
329        $($rest:tt)*
330    ) => {
331        parse_ordered_fields!(
332            @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)*] $($rest)*
333        )
334    };
335
336    (from $input:expr; $($tok:tt)*) => {
337        parse_ordered_fields!(@gen [$input] [] [] $($tok)*)
338    }
339}
340
341struct Parameter {
342    name: Ident,
343    ptype: Ident,
344    default: Expr,
345    description: LitStr,
346}
347
348impl Parse for Parameter {
349    fn parse(input: ParseStream<'_>) -> Result<Self> {
350        let name = input.parse()?;
351        input.parse::<Token![:]>()?;
352        let ptype = input.parse()?;
353
354        let fields;
355        braced!(fields in input);
356
357        parse_ordered_fields! {
358            from fields;
359            default [required] => fields.parse()?,
360            description [required] => fields.parse()?,
361        }
362
363        Ok(Self {
364            name,
365            ptype,
366            default,
367            description,
368        })
369    }
370}
371
372pub(crate) struct ModuleInfo {
373    type_: Type,
374    license: AsciiLitStr,
375    name: AsciiLitStr,
376    authors: Option<Punctuated<AsciiLitStr, Token![,]>>,
377    description: Option<LitStr>,
378    alias: Option<Punctuated<AsciiLitStr, Token![,]>>,
379    firmware: Option<Punctuated<AsciiLitStr, Token![,]>>,
380    imports_ns: Option<Punctuated<AsciiLitStr, Token![,]>>,
381    params: Option<Punctuated<Parameter, Token![,]>>,
382}
383
384impl Parse for ModuleInfo {
385    fn parse(input: ParseStream<'_>) -> Result<Self> {
386        parse_ordered_fields!(
387            from input;
388            type as type_ [required] => input.parse()?,
389            name [required] => input.parse()?,
390            authors => {
391                let list;
392                bracketed!(list in input);
393                Punctuated::parse_terminated(&list)?
394            },
395            description => input.parse()?,
396            license [required] => input.parse()?,
397            alias => {
398                let list;
399                bracketed!(list in input);
400                Punctuated::parse_terminated(&list)?
401            },
402            firmware => {
403                let list;
404                bracketed!(list in input);
405                Punctuated::parse_terminated(&list)?
406            },
407            imports_ns => {
408                let list;
409                bracketed!(list in input);
410                Punctuated::parse_terminated(&list)?
411            },
412            params => {
413                let list;
414                braced!(list in input);
415                Punctuated::parse_terminated(&list)?
416            },
417        );
418
419        Ok(ModuleInfo {
420            type_,
421            license,
422            name,
423            authors,
424            description,
425            alias,
426            firmware,
427            imports_ns,
428            params,
429        })
430    }
431}
432
433pub(crate) fn module(info: ModuleInfo) -> Result<TokenStream> {
434    let ModuleInfo {
435        type_,
436        license,
437        name,
438        authors,
439        description,
440        alias,
441        firmware,
442        imports_ns,
443        params: _,
444    } = &info;
445
446    // Rust does not allow hyphens in identifiers, use underscore instead.
447    let ident = name.value().replace('-', "_");
448    let mut modinfo = ModInfoBuilder::new(ident.as_ref());
449    if let Some(authors) = authors {
450        for author in authors {
451            modinfo.emit("author", &author.value());
452        }
453    }
454    if let Some(description) = description {
455        modinfo.emit("description", &description.value());
456    }
457    modinfo.emit("license", &license.value());
458    if let Some(aliases) = alias {
459        for alias in aliases {
460            modinfo.emit("alias", &alias.value());
461        }
462    }
463    if let Some(firmware) = firmware {
464        for fw in firmware {
465            modinfo.emit("firmware", &fw.value());
466        }
467    }
468    if let Some(imports) = imports_ns {
469        for ns in imports {
470            modinfo.emit("import_ns", &ns.value());
471        }
472    }
473
474    // Built-in modules also export the `file` modinfo string.
475    let file =
476        std::env::var("RUST_MODFILE").expect("Unable to fetch RUST_MODFILE environmental variable");
477    modinfo.emit_only_builtin("file", &file, false);
478
479    modinfo.emit_params(&info);
480
481    let modinfo_ts = modinfo.ts;
482    let params_ts = modinfo.param_ts;
483
484    let ident_init = format_ident!("__{ident}_init");
485    let ident_exit = format_ident!("__{ident}_exit");
486    let ident_initcall = format_ident!("__{ident}_initcall");
487    let initcall_section = ".initcall6.init";
488
489    let global_asm = format!(
490        r#".section "{initcall_section}", "a"
491        __{ident}_initcall:
492            .long   __{ident}_init - .
493            .previous
494        "#
495    );
496
497    let name_cstr = CString::new(name.value()).expect("name contains NUL-terminator");
498
499    Ok(quote! {
500        /// The module name.
501        ///
502        /// Used by the printing macros, e.g. [`info!`].
503        const __LOG_PREFIX: &[u8] = #name_cstr.to_bytes_with_nul();
504
505        // SAFETY: `__this_module` is constructed by the kernel at load time and will not be
506        // freed until the module is unloaded.
507        #[cfg(MODULE)]
508        static THIS_MODULE: ::kernel::ThisModule = unsafe {
509            extern "C" {
510                static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>;
511            };
512
513            ::kernel::ThisModule::from_ptr(__this_module.get())
514        };
515
516        #[cfg(not(MODULE))]
517        static THIS_MODULE: ::kernel::ThisModule = unsafe {
518            ::kernel::ThisModule::from_ptr(::core::ptr::null_mut())
519        };
520
521        /// The `LocalModule` type is the type of the module created by `module!`,
522        /// `module_pci_driver!`, `module_platform_driver!`, etc.
523        type LocalModule = #type_;
524
525        impl ::kernel::ModuleMetadata for #type_ {
526            const NAME: &'static ::kernel::str::CStr = #name_cstr;
527        }
528
529        // Double nested modules, since then nobody can access the public items inside.
530        #[doc(hidden)]
531        mod __module_init {
532            mod __module_init {
533                use pin_init::PinInit;
534
535                /// The "Rust loadable module" mark.
536                //
537                // This may be best done another way later on, e.g. as a new modinfo
538                // key or a new section. For the moment, keep it simple.
539                #[cfg(MODULE)]
540                #[used(compiler)]
541                static __IS_RUST_MODULE: () = ();
542
543                static mut __MOD: ::core::mem::MaybeUninit<super::super::LocalModule> =
544                    ::core::mem::MaybeUninit::uninit();
545
546                // Loadable modules need to export the `{init,cleanup}_module` identifiers.
547                /// # Safety
548                ///
549                /// This function must not be called after module initialization, because it may be
550                /// freed after that completes.
551                #[cfg(MODULE)]
552                #[no_mangle]
553                #[link_section = ".init.text"]
554                pub unsafe extern "C" fn init_module() -> ::kernel::ffi::c_int {
555                    // SAFETY: This function is inaccessible to the outside due to the double
556                    // module wrapping it. It is called exactly once by the C side via its
557                    // unique name.
558                    unsafe { __init() }
559                }
560
561                #[cfg(MODULE)]
562                #[used(compiler)]
563                #[link_section = ".init.data"]
564                static __UNIQUE_ID___addressable_init_module: unsafe extern "C" fn() -> i32 =
565                    init_module;
566
567                #[cfg(MODULE)]
568                #[no_mangle]
569                #[link_section = ".exit.text"]
570                pub extern "C" fn cleanup_module() {
571                    // SAFETY:
572                    // - This function is inaccessible to the outside due to the double
573                    //   module wrapping it. It is called exactly once by the C side via its
574                    //   unique name,
575                    // - furthermore it is only called after `init_module` has returned `0`
576                    //   (which delegates to `__init`).
577                    unsafe { __exit() }
578                }
579
580                #[cfg(MODULE)]
581                #[used(compiler)]
582                #[link_section = ".exit.data"]
583                static __UNIQUE_ID___addressable_cleanup_module: extern "C" fn() = cleanup_module;
584
585                // Built-in modules are initialized through an initcall pointer
586                // and the identifiers need to be unique.
587                #[cfg(not(MODULE))]
588                #[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))]
589                #[link_section = #initcall_section]
590                #[used(compiler)]
591                pub static #ident_initcall: extern "C" fn() ->
592                    ::kernel::ffi::c_int = #ident_init;
593
594                #[cfg(not(MODULE))]
595                #[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)]
596                ::core::arch::global_asm!(#global_asm);
597
598                #[cfg(not(MODULE))]
599                #[no_mangle]
600                pub extern "C" fn #ident_init() -> ::kernel::ffi::c_int {
601                    // SAFETY: This function is inaccessible to the outside due to the double
602                    // module wrapping it. It is called exactly once by the C side via its
603                    // placement above in the initcall section.
604                    unsafe { __init() }
605                }
606
607                #[cfg(not(MODULE))]
608                #[no_mangle]
609                pub extern "C" fn #ident_exit() {
610                    // SAFETY:
611                    // - This function is inaccessible to the outside due to the double
612                    //   module wrapping it. It is called exactly once by the C side via its
613                    //   unique name,
614                    // - furthermore it is only called after `#ident_init` has
615                    //   returned `0` (which delegates to `__init`).
616                    unsafe { __exit() }
617                }
618
619                /// # Safety
620                ///
621                /// This function must only be called once.
622                unsafe fn __init() -> ::kernel::ffi::c_int {
623                    let initer = <super::super::LocalModule as ::kernel::InPlaceModule>::init(
624                        &super::super::THIS_MODULE
625                    );
626                    // SAFETY: No data race, since `__MOD` can only be accessed by this module
627                    // and there only `__init` and `__exit` access it. These functions are only
628                    // called once and `__exit` cannot be called before or during `__init`.
629                    match unsafe { initer.__pinned_init(__MOD.as_mut_ptr()) } {
630                        Ok(m) => 0,
631                        Err(e) => e.to_errno(),
632                    }
633                }
634
635                /// # Safety
636                ///
637                /// This function must
638                /// - only be called once,
639                /// - be called after `__init` has been called and returned `0`.
640                unsafe fn __exit() {
641                    // SAFETY: No data race, since `__MOD` can only be accessed by this module
642                    // and there only `__init` and `__exit` access it. These functions are only
643                    // called once and `__init` was already called.
644                    unsafe {
645                        // Invokes `drop()` on `__MOD`, which should be used for cleanup.
646                        __MOD.assume_init_drop();
647                    }
648                }
649
650                #modinfo_ts
651            }
652        }
653
654        mod module_parameters {
655            #params_ts
656        }
657    })
658}