Skip to main content

macros/
kunit.rs

1// SPDX-License-Identifier: GPL-2.0
2
3//! Procedural macro to run KUnit tests using a user-space like syntax.
4//!
5//! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>
6
7use std::ffi::CString;
8
9use proc_macro2::TokenStream;
10use quote::{
11    format_ident,
12    quote,
13    ToTokens, //
14};
15use syn::{
16    parse_quote,
17    Error,
18    Ident,
19    Item,
20    ItemMod,
21    LitCStr,
22    Result, //
23};
24
25pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> {
26    if test_suite.to_string().len() > 255 {
27        return Err(Error::new_spanned(
28            test_suite,
29            "test suite names cannot exceed the maximum length of 255 bytes",
30        ));
31    }
32
33    // We cannot handle modules that defer to another file (e.g. `mod foo;`).
34    let Some((module_brace, module_items)) = module.content.take() else {
35        Err(Error::new_spanned(
36            module,
37            "`#[kunit_tests(test_name)]` attribute should only be applied to inline modules",
38        ))?
39    };
40
41    // Make the entire module gated behind `CONFIG_KUNIT`.
42    module
43        .attrs
44        .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")]));
45
46    let mut processed_items = Vec::new();
47    let mut test_cases = Vec::new();
48
49    // Generate the test KUnit test suite and a test case for each `#[test]`.
50    //
51    // The code generated for the following test module:
52    //
53    // ```
54    // #[kunit_tests(kunit_test_suit_name)]
55    // mod tests {
56    //     #[test]
57    //     fn foo() {
58    //         assert_eq!(1, 1);
59    //     }
60    //
61    //     #[test]
62    //     fn bar() {
63    //         assert_eq!(2, 2);
64    //     }
65    // }
66    // ```
67    //
68    // Looks like:
69    //
70    // ```
71    // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); }
72    // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); }
73    //
74    // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [
75    //     ::kernel::kunit::kunit_case(c"foo", kunit_rust_wrapper_foo),
76    //     ::kernel::kunit::kunit_case(c"bar", kunit_rust_wrapper_bar),
77    //     ::pin_init::zeroed(),
78    // ];
79    //
80    // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
81    // ```
82    //
83    // Non-function items (e.g. imports) are preserved.
84    for item in module_items {
85        let Item::Fn(mut f) = item else {
86            processed_items.push(item);
87            continue;
88        };
89
90        // TODO: Replace below with `extract_if` when MSRV is bumped above 1.85.
91        let before_len = f.attrs.len();
92        f.attrs.retain(|attr| !attr.path().is_ident("test"));
93        if f.attrs.len() == before_len {
94            processed_items.push(Item::Fn(f));
95            continue;
96        }
97
98        let test = f.sig.ident.clone();
99
100        // Retrieve `#[cfg]` applied on the function which needs to be present on derived items too.
101        let cfg_attrs: Vec<_> = f
102            .attrs
103            .iter()
104            .filter(|attr| attr.path().is_ident("cfg"))
105            .cloned()
106            .collect();
107
108        // Before the test, override usual `assert!` and `assert_eq!` macros with ones that call
109        // KUnit instead.
110        let test_str = test.to_string();
111        let path = CString::new(crate::helpers::file()).expect("file path cannot contain NUL");
112        processed_items.push(parse_quote! {
113            #[allow(unused)]
114            macro_rules! assert {
115                ($cond:expr $(,)?) => {{
116                    kernel::kunit_assert!(#test_str, #path, 0, $cond);
117                }}
118            }
119        });
120        processed_items.push(parse_quote! {
121            #[allow(unused)]
122            macro_rules! assert_eq {
123                ($left:expr, $right:expr $(,)?) => {{
124                    kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right);
125                }}
126            }
127        });
128
129        // Add back the test item.
130        processed_items.push(Item::Fn(f));
131
132        let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}");
133        let test_cstr = LitCStr::new(
134            &CString::new(test_str.as_str()).expect("identifier cannot contain NUL"),
135            test.span(),
136        );
137        processed_items.push(parse_quote! {
138            unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) {
139                (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
140
141                // Append any `cfg` attributes the user might have written on their tests so we
142                // don't attempt to call them when they are `cfg`'d out. An extra `use` is used
143                // here to reduce the length of the assert message.
144                #(#cfg_attrs)*
145                {
146                    (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
147                    use ::kernel::kunit::is_test_result_ok;
148                    assert!(is_test_result_ok(#test()));
149                }
150            }
151        });
152
153        test_cases.push(quote!(
154            ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name)
155        ));
156    }
157
158    let num_tests_plus_1 = test_cases.len() + 1;
159    processed_items.push(parse_quote! {
160        static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [
161            #(#test_cases,)*
162            ::pin_init::zeroed(),
163        ];
164    });
165    processed_items.push(parse_quote! {
166        ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES);
167    });
168
169    module.content = Some((module_brace, processed_items));
170    Ok(module.to_token_stream())
171}