hypervisor_test_macro/
lib.rs

1// Copyright 2024 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![warn(missing_docs)]
6#![recursion_limit = "128"]
7
8//! Macros for hypervisor tests
9
10use std::collections::hash_map::DefaultHasher;
11use std::hash::Hash;
12use std::hash::Hasher;
13use std::sync::atomic::AtomicU64;
14
15use proc_macro::TokenStream;
16use proc_macro2::Span;
17use proc_macro2::TokenStream as TokenStream2;
18use quote::quote;
19use syn::parse::Parse;
20use syn::parse_macro_input;
21use syn::Error;
22use syn::Ident;
23use syn::LitStr;
24use syn::Token;
25use syn::Visibility;
26
27/// Embed the compiled assembly as an array.
28///
29/// This macro will generate a module with the given `$name` and provides a `data` function in the
30/// module to allow accessing the compiled machine code as an array.
31///
32/// Note that this macro uses [`std::arch::global_asm`], so we can only use this macro in a global
33/// scope, outside a function.
34///
35/// # Example
36///
37/// Given the following x86 assembly:
38/// ```Text
39/// 0:  01 d8                   add    eax,ebx
40/// 2:  f4                      hlt
41/// ```
42///
43/// ```rust
44/// # use hypervisor_test_macro::global_asm_data;
45/// global_asm_data!(
46///     my_code,
47///     ".code64",
48///     "add eax, ebx",
49///     "hlt",
50/// );
51/// # fn main() {
52/// assert_eq!([0x01, 0xd8, 0xf4], my_code::data());
53/// # }
54/// ```
55///
56/// It is supported to pass arbitrary supported [`std::arch::global_asm`] operands and options.
57/// ```rust
58/// # use hypervisor_test_macro::global_asm_data;
59/// fn f() {}
60/// global_asm_data!(
61///     my_code1,
62///     ".global {0}",
63///     ".code64",
64///     "add eax, ebx",
65///     "hlt",
66///     sym f,
67/// );
68/// global_asm_data!(
69///     my_code2,
70///     ".code64",
71///     "add eax, ebx",
72///     "hlt",
73///     options(raw),
74/// );
75/// # fn main() {
76/// assert_eq!([0x01, 0xd8, 0xf4], my_code1::data());
77/// assert_eq!([0x01, 0xd8, 0xf4], my_code2::data());
78/// # }
79/// ```
80///
81/// It is also supported to specify the visibility of the generated module. Note that the below
82/// example won't work if the `pub` in the macro is missing.
83/// ```rust
84/// # use hypervisor_test_macro::global_asm_data;
85/// mod my_mod {
86///     // This use is needed to import the global_asm_data macro to this module.
87///     use super::*;
88///
89///     global_asm_data!(
90///         // pub is needed so that my_mod::my_code is visible to the outer scope.
91///         pub my_code,
92///         ".code64",
93///         "add eax, ebx",
94///         "hlt",
95///     );
96/// }
97/// # fn main() {
98/// assert_eq!([0x01, 0xd8, 0xf4], my_mod::my_code::data());
99/// # }
100/// ```
101#[proc_macro]
102pub fn global_asm_data(item: TokenStream) -> TokenStream {
103    let args = parse_macro_input!(item as GlobalAsmDataArgs);
104    global_asm_data_impl(args).unwrap_or_else(|e| e.to_compile_error().into())
105}
106
107struct GlobalAsmDataArgs {
108    visibility: Visibility,
109    mod_name: Ident,
110    global_asm_strings: Vec<LitStr>,
111    global_asm_rest_args: TokenStream2,
112}
113
114impl Parse for GlobalAsmDataArgs {
115    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
116        // The first argument is visibilty + identifier, e.g. my_code or pub my_code. The identifier
117        // will be used as the name of the gnerated module.
118        let visibility: Visibility = input.parse()?;
119        let mod_name: Ident = input.parse()?;
120        // There must be following arguments, so we consume the first argument separator here.
121        input.parse::<Token![,]>()?;
122
123        // Retrieve the input assemblies, which are a list of comma separated string literals. We
124        // need to obtain the list of assemblies explicitly, so that we can insert the begin tag and
125        // the end tag to the global_asm! call when we generate the result code.
126        let mut global_asm_strings = vec![];
127        loop {
128            let lookahead = input.lookahead1();
129            if !lookahead.peek(LitStr) {
130                // If the upcoming tokens are not string literal, we hit the end of the input
131                // assemblies.
132                break;
133            }
134            global_asm_strings.push(input.parse::<LitStr>()?);
135
136            if input.is_empty() {
137                // In case the current string literal is the last argument.
138                break;
139            }
140            input.parse::<Token![,]>()?;
141            if input.is_empty() {
142                // In case the current string literal is the last argument with a trailing comma.
143                break;
144            }
145        }
146
147        // We store the rest of the arguments, and we will forward them as is to global_asm!.
148        let global_asm_rest_args: TokenStream2 = input.parse()?;
149        Ok(Self {
150            visibility,
151            mod_name,
152            global_asm_strings,
153            global_asm_rest_args,
154        })
155    }
156}
157
158static COUNTER: AtomicU64 = AtomicU64::new(0);
159
160fn global_asm_data_impl(
161    GlobalAsmDataArgs {
162        visibility,
163        mod_name,
164        global_asm_strings,
165        global_asm_rest_args,
166    }: GlobalAsmDataArgs,
167) -> Result<TokenStream, Error> {
168    let span = Span::call_site();
169
170    // Generate the unique tags based on the macro input, code location and a random number to avoid
171    // symbol collision.
172    let tag_base_name = {
173        let content_id = {
174            let mut hasher = DefaultHasher::new();
175            span.source_text().hash(&mut hasher);
176            hasher.finish()
177        };
178        let location_id = format!(
179            "{}_{}_{}_{}",
180            span.start().line,
181            span.start().column,
182            span.end().line,
183            span.end().column
184        );
185        let rand_id: u64 = rand::random();
186        let static_counter_id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
187        let prefix = "crosvm_hypervisor_test_macro_global_asm_data";
188        format!("{prefix}_{mod_name}_{content_id}_{location_id}_{static_counter_id}_{rand_id}")
189    };
190    let start_tag = format!("{tag_base_name}_start");
191    let end_tag = format!("{tag_base_name}_end");
192
193    let global_directive = LitStr::new(&format!(".global {start_tag}, {end_tag}"), span);
194    let start_tag_asm = LitStr::new(&format!("{start_tag}:"), span);
195    let end_tag_asm = LitStr::new(&format!("{end_tag}:"), span);
196    let start_tag_ident = Ident::new(&start_tag, span);
197    let end_tag_ident = Ident::new(&end_tag, span);
198
199    Ok(quote! {
200        #visibility mod #mod_name {
201            use super::*;
202
203            extern {
204                static #start_tag_ident: u8;
205                static #end_tag_ident: u8;
206            }
207
208            std::arch::global_asm!(
209                #global_directive,
210                #start_tag_asm,
211                #(#global_asm_strings),*,
212                #end_tag_asm,
213                #global_asm_rest_args
214            );
215            pub fn data() -> &'static [u8] {
216                // SAFETY:
217                // * The extern statics are u8, and any arbitrary bit patterns are valid for u8.
218                // * The data starting from start to end is valid u8.
219                // * Without unsafe block, one can't mutate the value between start and end. In
220                //   addition, it is likely that the data is written to a readonly block, and can't
221                //   be mutated at all.
222                // * The address shouldn't be too large, and won't wrap around.
223                unsafe {
224                    let ptr = std::ptr::addr_of!(#start_tag_ident);
225                    let len = std::ptr::addr_of!(#end_tag_ident).offset_from(ptr);
226                    std::slice::from_raw_parts(
227                        ptr,
228                        len.try_into().expect("length must be positive")
229                    )
230                }
231            }
232        }
233    }
234    .into())
235}