hypervisor_test_macro/lib.rs
1// Copyright 2024 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![warn(missing_docs)]
6#![recursion_limit = "128"]
7
8//! Macros for hypervisor tests
9
10use std::collections::hash_map::DefaultHasher;
11use std::hash::Hash;
12use std::hash::Hasher;
13use std::sync::atomic::AtomicU64;
14
15use proc_macro::TokenStream;
16use proc_macro2::Span;
17use proc_macro2::TokenStream as TokenStream2;
18use quote::quote;
19use syn::parse::Parse;
20use syn::parse_macro_input;
21use syn::Error;
22use syn::Ident;
23use syn::LitStr;
24use syn::Token;
25use syn::Visibility;
26
27/// Embed the compiled assembly as an array.
28///
29/// This macro will generate a module with the given `$name` and provides a `data` function in the
30/// module to allow accessing the compiled machine code as an array.
31///
32/// Note that this macro uses [`std::arch::global_asm`], so we can only use this macro in a global
33/// scope, outside a function.
34///
35/// # Example
36///
37/// Given the following x86 assembly:
38/// ```Text
39/// 0: 01 d8 add eax,ebx
40/// 2: f4 hlt
41/// ```
42///
43/// ```rust
44/// # use hypervisor_test_macro::global_asm_data;
45/// global_asm_data!(
46/// my_code,
47/// ".code64",
48/// "add eax, ebx",
49/// "hlt",
50/// );
51/// # fn main() {
52/// assert_eq!([0x01, 0xd8, 0xf4], my_code::data());
53/// # }
54/// ```
55///
56/// It is supported to pass arbitrary supported [`std::arch::global_asm`] operands and options.
57/// ```rust
58/// # use hypervisor_test_macro::global_asm_data;
59/// fn f() {}
60/// global_asm_data!(
61/// my_code1,
62/// ".global {0}",
63/// ".code64",
64/// "add eax, ebx",
65/// "hlt",
66/// sym f,
67/// );
68/// global_asm_data!(
69/// my_code2,
70/// ".code64",
71/// "add eax, ebx",
72/// "hlt",
73/// options(raw),
74/// );
75/// # fn main() {
76/// assert_eq!([0x01, 0xd8, 0xf4], my_code1::data());
77/// assert_eq!([0x01, 0xd8, 0xf4], my_code2::data());
78/// # }
79/// ```
80///
81/// It is also supported to specify the visibility of the generated module. Note that the below
82/// example won't work if the `pub` in the macro is missing.
83/// ```rust
84/// # use hypervisor_test_macro::global_asm_data;
85/// mod my_mod {
86/// // This use is needed to import the global_asm_data macro to this module.
87/// use super::*;
88///
89/// global_asm_data!(
90/// // pub is needed so that my_mod::my_code is visible to the outer scope.
91/// pub my_code,
92/// ".code64",
93/// "add eax, ebx",
94/// "hlt",
95/// );
96/// }
97/// # fn main() {
98/// assert_eq!([0x01, 0xd8, 0xf4], my_mod::my_code::data());
99/// # }
100/// ```
101#[proc_macro]
102pub fn global_asm_data(item: TokenStream) -> TokenStream {
103 let args = parse_macro_input!(item as GlobalAsmDataArgs);
104 global_asm_data_impl(args).unwrap_or_else(|e| e.to_compile_error().into())
105}
106
107struct GlobalAsmDataArgs {
108 visibility: Visibility,
109 mod_name: Ident,
110 global_asm_strings: Vec<LitStr>,
111 global_asm_rest_args: TokenStream2,
112}
113
114impl Parse for GlobalAsmDataArgs {
115 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
116 // The first argument is visibilty + identifier, e.g. my_code or pub my_code. The identifier
117 // will be used as the name of the gnerated module.
118 let visibility: Visibility = input.parse()?;
119 let mod_name: Ident = input.parse()?;
120 // There must be following arguments, so we consume the first argument separator here.
121 input.parse::<Token![,]>()?;
122
123 // Retrieve the input assemblies, which are a list of comma separated string literals. We
124 // need to obtain the list of assemblies explicitly, so that we can insert the begin tag and
125 // the end tag to the global_asm! call when we generate the result code.
126 let mut global_asm_strings = vec![];
127 loop {
128 let lookahead = input.lookahead1();
129 if !lookahead.peek(LitStr) {
130 // If the upcoming tokens are not string literal, we hit the end of the input
131 // assemblies.
132 break;
133 }
134 global_asm_strings.push(input.parse::<LitStr>()?);
135
136 if input.is_empty() {
137 // In case the current string literal is the last argument.
138 break;
139 }
140 input.parse::<Token![,]>()?;
141 if input.is_empty() {
142 // In case the current string literal is the last argument with a trailing comma.
143 break;
144 }
145 }
146
147 // We store the rest of the arguments, and we will forward them as is to global_asm!.
148 let global_asm_rest_args: TokenStream2 = input.parse()?;
149 Ok(Self {
150 visibility,
151 mod_name,
152 global_asm_strings,
153 global_asm_rest_args,
154 })
155 }
156}
157
158static COUNTER: AtomicU64 = AtomicU64::new(0);
159
160fn global_asm_data_impl(
161 GlobalAsmDataArgs {
162 visibility,
163 mod_name,
164 global_asm_strings,
165 global_asm_rest_args,
166 }: GlobalAsmDataArgs,
167) -> Result<TokenStream, Error> {
168 let span = Span::call_site();
169
170 // Generate the unique tags based on the macro input, code location and a random number to avoid
171 // symbol collision.
172 let tag_base_name = {
173 let content_id = {
174 let mut hasher = DefaultHasher::new();
175 span.source_text().hash(&mut hasher);
176 hasher.finish()
177 };
178 let location_id = format!(
179 "{}_{}_{}_{}",
180 span.start().line,
181 span.start().column,
182 span.end().line,
183 span.end().column
184 );
185 let rand_id: u64 = rand::random();
186 let static_counter_id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
187 let prefix = "crosvm_hypervisor_test_macro_global_asm_data";
188 format!("{prefix}_{mod_name}_{content_id}_{location_id}_{static_counter_id}_{rand_id}")
189 };
190 let start_tag = format!("{tag_base_name}_start");
191 let end_tag = format!("{tag_base_name}_end");
192
193 let global_directive = LitStr::new(&format!(".global {start_tag}, {end_tag}"), span);
194 let start_tag_asm = LitStr::new(&format!("{start_tag}:"), span);
195 let end_tag_asm = LitStr::new(&format!("{end_tag}:"), span);
196 let start_tag_ident = Ident::new(&start_tag, span);
197 let end_tag_ident = Ident::new(&end_tag, span);
198
199 Ok(quote! {
200 #visibility mod #mod_name {
201 use super::*;
202
203 extern {
204 static #start_tag_ident: u8;
205 static #end_tag_ident: u8;
206 }
207
208 std::arch::global_asm!(
209 #global_directive,
210 #start_tag_asm,
211 #(#global_asm_strings),*,
212 #end_tag_asm,
213 #global_asm_rest_args
214 );
215 pub fn data() -> &'static [u8] {
216 // SAFETY:
217 // * The extern statics are u8, and any arbitrary bit patterns are valid for u8.
218 // * The data starting from start to end is valid u8.
219 // * Without unsafe block, one can't mutate the value between start and end. In
220 // addition, it is likely that the data is written to a readonly block, and can't
221 // be mutated at all.
222 // * The address shouldn't be too large, and won't wrap around.
223 unsafe {
224 let ptr = std::ptr::addr_of!(#start_tag_ident);
225 let len = std::ptr::addr_of!(#end_tag_ident).offset_from(ptr);
226 std::slice::from_raw_parts(
227 ptr,
228 len.try_into().expect("length must be positive")
229 )
230 }
231 }
232 }
233 }
234 .into())
235}