1#![deny(missing_docs)]
6
7use anyhow::Context;
8use base::error;
9use base::EventToken;
10use base::WaitContext;
11
12use crate::userfaultfd::DeadUffdChecker;
13use crate::userfaultfd::Userfaultfd;
14
15pub trait Token: EventToken {
17 fn uffd_token(idx: u32) -> Self;
18}
19
20pub struct UffdList<'a, T: Token, D: DeadUffdChecker> {
22 list: Vec<Userfaultfd>,
23 dead_uffd_checker: &'a D,
24 wait_ctx: &'a WaitContext<T>,
25 num_static_uffd: Option<usize>,
26}
27
28impl<'a, T: Token, D: DeadUffdChecker> UffdList<'a, T, D> {
29 const ID_MAIN_UFFD: u32 = 0;
30
31 pub fn new(
35 main_uffd: Userfaultfd,
36 dead_uffd_checker: &'a D,
37 wait_ctx: &'a WaitContext<T>,
38 ) -> anyhow::Result<Self> {
39 let mut list = Self {
40 list: Vec::with_capacity(1),
41 dead_uffd_checker,
42 wait_ctx,
43 num_static_uffd: None,
44 };
45 list.register(main_uffd)?;
46 Ok(list)
47 }
48
49 pub fn set_num_static_devices(&mut self, num_static_devices: u32) -> bool {
54 if self.num_static_uffd.is_some() {
55 return false;
56 }
57 let num_static_uffd = num_static_devices as usize + 1;
59 self.num_static_uffd = Some(num_static_uffd);
60 true
61 }
62
63 pub fn register(&mut self, uffd: Userfaultfd) -> anyhow::Result<bool> {
65 let is_dynamic_uffd = self
66 .num_static_uffd
67 .map(|num_static_uffd| self.list.len() >= num_static_uffd)
68 .unwrap_or(false);
69 if is_dynamic_uffd {
70 self.dead_uffd_checker.register(&uffd)?;
72 }
73
74 let id_uffd = self
75 .list
76 .len()
77 .try_into()
78 .context("too many userfaultfd forked")?;
79
80 self.wait_ctx
81 .add(&uffd, T::uffd_token(id_uffd))
82 .context("add to wait context")?;
83 self.list.push(uffd);
84
85 Ok(is_dynamic_uffd)
86 }
87
88 pub fn gc_dead_uffds(&mut self) -> anyhow::Result<()> {
90 let mut idx = self.num_static_uffd.unwrap();
91 let mut is_swapped = false;
92 while idx < self.list.len() {
93 if self.dead_uffd_checker.is_dead(&self.list[idx]) {
94 self.wait_ctx
95 .delete(&self.list[idx])
96 .context("delete dead uffd from wait context")?;
97 self.list.swap_remove(idx);
98 is_swapped = true;
99 } else {
100 if is_swapped {
101 self.wait_ctx
102 .modify(
103 &self.list[idx],
104 base::EventType::ReadWrite,
105 T::uffd_token(idx as u32),
106 )
107 .context("update token")?;
108 is_swapped = false;
109 }
110 idx += 1;
111 }
112 }
113
114 if let Err(e) = self.dead_uffd_checker.reset() {
117 error!("failed to reset dead uffd checker: {:?}", e);
118 }
119 Ok(())
120 }
121
122 pub fn get(&self, id: u32) -> Option<&Userfaultfd> {
124 self.list.get(id as usize)
125 }
126
127 pub fn main_uffd(&self) -> &Userfaultfd {
129 &self.list[Self::ID_MAIN_UFFD as usize]
130 }
131
132 pub fn clone_main_uffd(&self) -> crate::userfaultfd::Result<Userfaultfd> {
134 self.list[Self::ID_MAIN_UFFD as usize].try_clone()
135 }
136
137 pub fn get_list(&self) -> &[Userfaultfd] {
139 &self.list
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use std::cell::RefCell;
146 use std::time::Duration;
147
148 use base::AsRawDescriptor;
149 use base::Event;
150 use base::FromRawDescriptor;
151 use base::IntoRawDescriptor;
152 use base::RawDescriptor;
153
154 use super::*;
155
156 #[derive(EventToken, Clone, Copy)]
157 enum TestToken {
158 UffdEvents(u32),
159 }
160
161 impl TestToken {
162 fn get_idx(&self) -> u32 {
163 match self {
164 Self::UffdEvents(idx) => *idx,
165 }
166 }
167 }
168
169 impl Token for TestToken {
170 fn uffd_token(idx: u32) -> Self {
171 TestToken::UffdEvents(idx)
172 }
173 }
174
175 struct FakeDeadUffdChecker {
176 list: RefCell<Vec<(RawDescriptor, bool)>>,
178 }
179
180 impl FakeDeadUffdChecker {
181 fn new() -> Self {
182 Self {
183 list: RefCell::new(Vec::new()),
184 }
185 }
186
187 fn create_fake_uffd(&self) -> Userfaultfd {
192 let raw_desc = Event::new().unwrap().into_raw_descriptor();
193
194 self.list.borrow_mut().push((raw_desc, false));
195
196 #[allow(clippy::undocumented_unsafe_blocks)]
198 unsafe {
199 Userfaultfd::from_raw_descriptor(raw_desc)
200 }
201 }
202
203 fn make_readable(&self, raw_desc: RawDescriptor) {
204 #[allow(clippy::undocumented_unsafe_blocks)]
206 let ev = unsafe { Event::from_raw_descriptor(raw_desc) };
207 ev.signal().unwrap();
208 ev.into_raw_descriptor();
211 }
212
213 fn mark_as_dead(&self, raw_desc: RawDescriptor) {
214 for (rd, is_dead) in self.list.borrow_mut().iter_mut() {
215 if *rd == raw_desc {
216 *is_dead = true;
217 }
218 }
219 }
220 }
221
222 impl DeadUffdChecker for FakeDeadUffdChecker {
223 fn register(&self, _uffd: &Userfaultfd) -> anyhow::Result<()> {
224 Ok(())
226 }
227
228 fn is_dead(&self, uffd: &Userfaultfd) -> bool {
229 for (raw_desc, is_alive) in self.list.borrow().iter() {
230 if *raw_desc == uffd.as_raw_descriptor() {
231 return *is_alive;
232 }
233 }
234 false
235 }
236
237 fn reset(&self) -> anyhow::Result<()> {
238 Ok(())
240 }
241 }
242
243 #[test]
244 fn new_success() {
245 let wait_ctx = WaitContext::<TestToken>::new().unwrap();
246 let fake_checker = FakeDeadUffdChecker::new();
247 let main_uffd = fake_checker.create_fake_uffd();
248
249 assert!(UffdList::new(main_uffd, &fake_checker, &wait_ctx).is_ok());
250 }
251
252 #[test]
253 fn register_success() {
254 let wait_ctx = WaitContext::<TestToken>::new().unwrap();
255 let fake_checker = FakeDeadUffdChecker::new();
256 let main_uffd = fake_checker.create_fake_uffd();
257 let uffd = fake_checker.create_fake_uffd();
258 let mut uffd_list = UffdList::new(main_uffd, &fake_checker, &wait_ctx).unwrap();
259
260 let result = uffd_list.register(uffd);
261 assert!(result.is_ok());
262 assert!(!result.unwrap());
264 }
265
266 #[test]
267 fn register_dynamic_device() {
268 let wait_ctx = WaitContext::<TestToken>::new().unwrap();
269 let fake_checker = FakeDeadUffdChecker::new();
270 let main_uffd = fake_checker.create_fake_uffd();
271 let uffd1 = fake_checker.create_fake_uffd();
272 let uffd2 = fake_checker.create_fake_uffd();
273 let uffd3 = fake_checker.create_fake_uffd();
274 let mut uffd_list = UffdList::new(main_uffd, &fake_checker, &wait_ctx).unwrap();
275
276 assert!(!uffd_list.register(uffd1).unwrap());
278 assert!(uffd_list.set_num_static_devices(2));
279 assert!(!uffd_list.register(uffd2).unwrap());
281 assert!(uffd_list.register(uffd3).unwrap());
283 }
284
285 #[test]
286 fn set_num_static_devices_twice() {
287 let wait_ctx = WaitContext::<TestToken>::new().unwrap();
288 let fake_checker = FakeDeadUffdChecker::new();
289 let main_uffd = fake_checker.create_fake_uffd();
290 let mut uffd_list = UffdList::new(main_uffd, &fake_checker, &wait_ctx).unwrap();
291
292 assert!(uffd_list.set_num_static_devices(2));
293 assert!(!uffd_list.set_num_static_devices(2));
294 }
295
296 #[test]
297 fn register_token() {
298 let wait_ctx = WaitContext::<TestToken>::new().unwrap();
299 let fake_checker = FakeDeadUffdChecker::new();
300 let main_uffd = fake_checker.create_fake_uffd();
301 let uffd1 = fake_checker.create_fake_uffd();
302 let uffd2 = fake_checker.create_fake_uffd();
303 let rd2 = uffd2.as_raw_descriptor();
304 let uffd3 = fake_checker.create_fake_uffd();
305 let mut uffd_list = UffdList::new(main_uffd, &fake_checker, &wait_ctx).unwrap();
306 uffd_list.register(uffd1).unwrap();
307 uffd_list.register(uffd2).unwrap();
308 uffd_list.register(uffd3).unwrap();
309
310 fake_checker.make_readable(rd2);
311
312 let events = wait_ctx.wait_timeout(Duration::from_millis(1)).unwrap();
313 assert_eq!(events.len(), 1);
314 assert_eq!(
315 uffd_list
316 .get(events[0].token.get_idx())
317 .unwrap()
318 .as_raw_descriptor(),
319 rd2
320 );
321 }
322
323 #[test]
324 fn gc_dead_uffds_with_all_alive() {
325 let wait_ctx = WaitContext::<TestToken>::new().unwrap();
326 let fake_checker = FakeDeadUffdChecker::new();
327 let main_uffd = fake_checker.create_fake_uffd();
328 let uffd1 = fake_checker.create_fake_uffd();
329 let uffd2 = fake_checker.create_fake_uffd();
330 let uffd3 = fake_checker.create_fake_uffd();
331 let mut uffd_list = UffdList::new(main_uffd, &fake_checker, &wait_ctx).unwrap();
332 uffd_list.set_num_static_devices(1);
333 uffd_list.register(uffd1).unwrap();
334 uffd_list.register(uffd2).unwrap();
335 uffd_list.register(uffd3).unwrap();
336
337 assert!(uffd_list.gc_dead_uffds().is_ok());
338 assert_eq!(uffd_list.get_list().len(), 4);
339 }
340
341 #[test]
342 fn gc_dead_uffds_with_dead_static_device() {
343 let wait_ctx = WaitContext::<TestToken>::new().unwrap();
344 let fake_checker = FakeDeadUffdChecker::new();
345 let main_uffd = fake_checker.create_fake_uffd();
346 let uffd1 = fake_checker.create_fake_uffd();
347 let uffd2 = fake_checker.create_fake_uffd();
348 let rd2 = uffd2.as_raw_descriptor();
349 let uffd3 = fake_checker.create_fake_uffd();
350 let mut uffd_list = UffdList::new(main_uffd, &fake_checker, &wait_ctx).unwrap();
351 uffd_list.set_num_static_devices(2);
352 uffd_list.register(uffd1).unwrap();
353 uffd_list.register(uffd2).unwrap();
354 uffd_list.register(uffd3).unwrap();
355 fake_checker.mark_as_dead(rd2);
356
357 assert!(uffd_list.gc_dead_uffds().is_ok());
358 assert_eq!(uffd_list.get_list().len(), 4);
359 }
360
361 #[test]
362 fn gc_dead_uffds_with_dead_dynamic_device() {
363 let wait_ctx = WaitContext::<TestToken>::new().unwrap();
364 let fake_checker = FakeDeadUffdChecker::new();
365 let main_uffd = fake_checker.create_fake_uffd();
366 let uffd1 = fake_checker.create_fake_uffd();
367 let uffd2 = fake_checker.create_fake_uffd();
368 let rd2 = uffd2.as_raw_descriptor();
369 let uffd3 = fake_checker.create_fake_uffd();
370 let rd3 = uffd3.as_raw_descriptor();
371 let mut uffd_list = UffdList::new(main_uffd, &fake_checker, &wait_ctx).unwrap();
372 uffd_list.set_num_static_devices(1);
373 uffd_list.register(uffd1).unwrap();
374 uffd_list.register(uffd2).unwrap();
375 uffd_list.register(uffd3).unwrap();
376 fake_checker.mark_as_dead(rd2);
377
378 assert!(uffd_list.gc_dead_uffds().is_ok());
379 assert_eq!(uffd_list.get_list().len(), 3);
381 fake_checker.make_readable(rd3);
382 let events = wait_ctx.wait_timeout(Duration::from_millis(1)).unwrap();
383 assert_eq!(events.len(), 1);
384 assert_eq!(
385 uffd_list
386 .get(events[0].token.get_idx())
387 .unwrap()
388 .as_raw_descriptor(),
389 rd3
390 );
391 }
392
393 #[test]
394 fn gc_dead_uffds_with_dead_dynamic_device_readable_before_gc() {
395 let wait_ctx = WaitContext::<TestToken>::new().unwrap();
396 let fake_checker = FakeDeadUffdChecker::new();
397 let main_uffd = fake_checker.create_fake_uffd();
398 let uffd1 = fake_checker.create_fake_uffd();
399 let uffd2 = fake_checker.create_fake_uffd();
400 let rd2 = uffd2.as_raw_descriptor();
401 let uffd3 = fake_checker.create_fake_uffd();
402 let rd3 = uffd3.as_raw_descriptor();
403 let mut uffd_list = UffdList::new(main_uffd, &fake_checker, &wait_ctx).unwrap();
404 uffd_list.set_num_static_devices(1);
405 uffd_list.register(uffd1).unwrap();
406 uffd_list.register(uffd2).unwrap();
407 uffd_list.register(uffd3).unwrap();
408 fake_checker.mark_as_dead(rd2);
409 fake_checker.make_readable(rd3);
410
411 assert!(uffd_list.gc_dead_uffds().is_ok());
412 assert_eq!(uffd_list.get_list().len(), 3);
414 let events = wait_ctx.wait_timeout(Duration::from_millis(1)).unwrap();
415 assert_eq!(events.len(), 1);
416 assert_eq!(
417 uffd_list
418 .get(events[0].token.get_idx())
419 .unwrap()
420 .as_raw_descriptor(),
421 rd3
422 );
423 }
424
425 #[test]
426 fn gc_dead_uffds_with_many_dead_dynamic_device() {
427 let wait_ctx = WaitContext::<TestToken>::new().unwrap();
428 let fake_checker = FakeDeadUffdChecker::new();
429 let main_uffd = fake_checker.create_fake_uffd();
430 let uffd1 = fake_checker.create_fake_uffd();
431 let uffd2 = fake_checker.create_fake_uffd();
432 fake_checker.mark_as_dead(uffd2.as_raw_descriptor());
433 let uffd3 = fake_checker.create_fake_uffd();
434 fake_checker.mark_as_dead(uffd3.as_raw_descriptor());
435 let uffd4 = fake_checker.create_fake_uffd();
436 let uffd5 = fake_checker.create_fake_uffd();
437 fake_checker.mark_as_dead(uffd5.as_raw_descriptor());
438 let rd4 = uffd4.as_raw_descriptor();
439 let mut uffd_list = UffdList::new(main_uffd, &fake_checker, &wait_ctx).unwrap();
440 uffd_list.set_num_static_devices(0);
441 uffd_list.register(uffd1).unwrap();
442 uffd_list.register(uffd2).unwrap();
443 uffd_list.register(uffd3).unwrap();
444 uffd_list.register(uffd4).unwrap();
445 uffd_list.register(uffd5).unwrap();
446
447 assert!(uffd_list.gc_dead_uffds().is_ok());
448 assert_eq!(uffd_list.get_list().len(), 3);
450 fake_checker.make_readable(rd4);
451 let events = wait_ctx.wait_timeout(Duration::from_millis(1)).unwrap();
452 assert_eq!(events.len(), 1);
453 assert_eq!(
454 uffd_list
455 .get(events[0].token.get_idx())
456 .unwrap()
457 .as_raw_descriptor(),
458 rd4
459 );
460 }
461}