1use std::ffi::CString;
8
9use proc_macro2::TokenStream;
10use quote::{
11 format_ident,
12 quote,
13 ToTokens, };
15use syn::{
16 parse_quote,
17 Error,
18 Ident,
19 Item,
20 ItemMod,
21 LitCStr,
22 Result, };
24
25pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> {
26 if test_suite.to_string().len() > 255 {
27 return Err(Error::new_spanned(
28 test_suite,
29 "test suite names cannot exceed the maximum length of 255 bytes",
30 ));
31 }
32
33 let Some((module_brace, module_items)) = module.content.take() else {
35 Err(Error::new_spanned(
36 module,
37 "`#[kunit_tests(test_name)]` attribute should only be applied to inline modules",
38 ))?
39 };
40
41 module
43 .attrs
44 .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")]));
45
46 let mut processed_items = Vec::new();
47 let mut test_cases = Vec::new();
48
49 for item in module_items {
85 let Item::Fn(mut f) = item else {
86 processed_items.push(item);
87 continue;
88 };
89
90 let before_len = f.attrs.len();
92 f.attrs.retain(|attr| !attr.path().is_ident("test"));
93 if f.attrs.len() == before_len {
94 processed_items.push(Item::Fn(f));
95 continue;
96 }
97
98 let test = f.sig.ident.clone();
99
100 let cfg_attrs: Vec<_> = f
102 .attrs
103 .iter()
104 .filter(|attr| attr.path().is_ident("cfg"))
105 .cloned()
106 .collect();
107
108 let test_str = test.to_string();
111 let path = CString::new(crate::helpers::file()).expect("file path cannot contain NUL");
112 processed_items.push(parse_quote! {
113 #[allow(unused)]
114 macro_rules! assert {
115 ($cond:expr $(,)?) => {{
116 kernel::kunit_assert!(#test_str, #path, 0, $cond);
117 }}
118 }
119 });
120 processed_items.push(parse_quote! {
121 #[allow(unused)]
122 macro_rules! assert_eq {
123 ($left:expr, $right:expr $(,)?) => {{
124 kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right);
125 }}
126 }
127 });
128
129 processed_items.push(Item::Fn(f));
131
132 let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}");
133 let test_cstr = LitCStr::new(
134 &CString::new(test_str.as_str()).expect("identifier cannot contain NUL"),
135 test.span(),
136 );
137 processed_items.push(parse_quote! {
138 unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) {
139 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
140
141 #(#cfg_attrs)*
145 {
146 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
147 use ::kernel::kunit::is_test_result_ok;
148 assert!(is_test_result_ok(#test()));
149 }
150 }
151 });
152
153 test_cases.push(quote!(
154 ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name)
155 ));
156 }
157
158 let num_tests_plus_1 = test_cases.len() + 1;
159 processed_items.push(parse_quote! {
160 static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [
161 #(#test_cases,)*
162 ::pin_init::zeroed(),
163 ];
164 });
165 processed_items.push(parse_quote! {
166 ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES);
167 });
168
169 module.content = Some((module_brace, processed_items));
170 Ok(module.to_token_stream())
171}