1use std::cmp::min;
24use std::mem::size_of;
25use std::ptr::copy;
26use std::ptr::read_volatile;
27use std::ptr::write_bytes;
28use std::ptr::write_volatile;
29use std::result;
30use std::slice;
31
32use remain::sorted;
33use thiserror::Error;
34use zerocopy::FromBytes;
35use zerocopy::IntoBytes;
36
37use crate::IoBufMut;
38
39#[sorted]
40#[derive(Error, Eq, PartialEq, Debug)]
41pub enum VolatileMemoryError {
42 #[error("address 0x{addr:x} is out of bounds")]
44 OutOfBounds { addr: usize },
45 #[error("address 0x{base:x} offset by 0x{offset:x} would overflow")]
47 Overflow { base: usize, offset: usize },
48}
49
50pub type VolatileMemoryResult<T> = result::Result<T, VolatileMemoryError>;
51
52use crate::VolatileMemoryError as Error;
53type Result<T> = VolatileMemoryResult<T>;
54
55pub trait VolatileMemory {
57 fn get_slice(&self, offset: usize, count: usize) -> Result<VolatileSlice>;
60}
61
62#[derive(Copy, Clone, Debug)]
66#[repr(transparent)]
67pub struct VolatileSlice<'a>(IoBufMut<'a>);
68
69impl<'a> VolatileSlice<'a> {
70 pub fn new(buf: &mut [u8]) -> VolatileSlice {
72 VolatileSlice(IoBufMut::new(buf))
73 }
74
75 pub unsafe fn from_raw_parts(addr: *mut u8, len: usize) -> VolatileSlice<'a> {
82 VolatileSlice(IoBufMut::from_raw_parts(addr, len))
83 }
84
85 pub fn as_ptr(&self) -> *const u8 {
87 self.0.as_ptr()
88 }
89
90 pub fn as_mut_ptr(&self) -> *mut u8 {
92 self.0.as_mut_ptr()
93 }
94
95 pub fn size(&self) -> usize {
97 self.0.len()
98 }
99
100 pub fn advance(&mut self, count: usize) {
104 self.0.advance(count)
105 }
106
107 pub fn truncate(&mut self, len: usize) {
111 self.0.truncate(len)
112 }
113
114 pub fn as_iobuf(&self) -> &IoBufMut {
116 &self.0
117 }
118
119 #[allow(clippy::wrong_self_convention)]
121 pub fn as_iobufs<'mem, 'slice>(
122 iovs: &'slice [VolatileSlice<'mem>],
123 ) -> &'slice [IoBufMut<'mem>] {
124 unsafe { slice::from_raw_parts(iovs.as_ptr() as *const IoBufMut, iovs.len()) }
127 }
128
129 #[inline]
131 pub fn as_iobufs_mut<'mem, 'slice>(
132 iovs: &'slice mut [VolatileSlice<'mem>],
133 ) -> &'slice mut [IoBufMut<'mem>] {
134 unsafe { slice::from_raw_parts_mut(iovs.as_mut_ptr() as *mut IoBufMut, iovs.len()) }
137 }
138
139 pub fn offset(self, count: usize) -> Result<VolatileSlice<'a>> {
142 let new_addr = (self.as_mut_ptr() as usize).checked_add(count).ok_or(
143 VolatileMemoryError::Overflow {
144 base: self.as_mut_ptr() as usize,
145 offset: count,
146 },
147 )?;
148 let new_size = self
149 .size()
150 .checked_sub(count)
151 .ok_or(VolatileMemoryError::OutOfBounds { addr: new_addr })?;
152
153 unsafe { Ok(VolatileSlice::from_raw_parts(new_addr as *mut u8, new_size)) }
157 }
158
159 pub fn sub_slice(self, offset: usize, count: usize) -> Result<VolatileSlice<'a>> {
163 let mem_end = offset
164 .checked_add(count)
165 .ok_or(VolatileMemoryError::Overflow {
166 base: offset,
167 offset: count,
168 })?;
169 if mem_end > self.size() {
170 return Err(Error::OutOfBounds { addr: mem_end });
171 }
172 let new_addr = (self.as_mut_ptr() as usize).checked_add(offset).ok_or(
173 VolatileMemoryError::Overflow {
174 base: self.as_mut_ptr() as usize,
175 offset,
176 },
177 )?;
178
179 Ok(unsafe { VolatileSlice::from_raw_parts(new_addr as *mut u8, count) })
182 }
183
184 pub fn write_bytes(&self, value: u8) {
202 unsafe {
205 write_bytes(self.as_mut_ptr(), value, self.size());
206 }
207 }
208
209 pub fn copy_to<T>(&self, buf: &mut [T])
232 where
233 T: FromBytes + IntoBytes + Copy,
234 {
235 let mut addr = self.as_mut_ptr() as *const u8;
236 for v in buf.iter_mut().take(self.size() / size_of::<T>()) {
237 unsafe {
239 *v = read_volatile(addr as *const T);
240 addr = addr.add(size_of::<T>());
241 }
242 }
243 }
244
245 pub fn copy_to_volatile_slice(&self, slice: VolatileSlice) {
261 unsafe {
263 copy(
264 self.as_mut_ptr() as *const u8,
265 slice.as_mut_ptr(),
266 min(self.size(), slice.size()),
267 );
268 }
269 }
270
271 pub fn copy_from<T>(&self, buf: &[T])
297 where
298 T: IntoBytes + Copy,
299 {
300 let mut addr = self.as_mut_ptr();
301 for v in buf.iter().take(self.size() / size_of::<T>()) {
302 unsafe {
304 write_volatile(addr as *mut T, *v);
305 addr = addr.add(size_of::<T>());
306 }
307 }
308 }
309
310 pub fn is_all_zero(&self) -> bool {
316 const MASK_4BIT: usize = 0x0f;
317 let head_addr = self.as_ptr() as usize;
318 let aligned_head_addr = (head_addr + MASK_4BIT) & !MASK_4BIT;
320 let tail_addr = head_addr + self.size();
321 let aligned_tail_addr = tail_addr & !MASK_4BIT;
323
324 if (aligned_head_addr..aligned_tail_addr).step_by(16).any(
326 |aligned_addr|
327 unsafe { *(aligned_addr as *const u128) } != 0,
329 ) {
330 return false;
331 }
332
333 if head_addr == aligned_head_addr && tail_addr == aligned_tail_addr {
334 true
337 } else {
338 unsafe {
342 is_all_zero_naive(head_addr, aligned_head_addr)
343 && is_all_zero_naive(aligned_tail_addr, tail_addr)
344 }
345 }
346 }
347}
348
349unsafe fn is_all_zero_naive(head_addr: usize, tail_addr: usize) -> bool {
358 (head_addr..tail_addr).all(|addr| *(addr as *const u8) == 0)
359}
360
361impl VolatileMemory for VolatileSlice<'_> {
362 fn get_slice(&self, offset: usize, count: usize) -> Result<VolatileSlice> {
363 self.sub_slice(offset, count)
364 }
365}
366
367impl PartialEq<VolatileSlice<'_>> for VolatileSlice<'_> {
368 fn eq(&self, other: &VolatileSlice) -> bool {
369 let size = self.size();
370 if size != other.size() {
371 return false;
372 }
373
374 let cmp = unsafe { libc::memcmp(self.as_ptr() as _, other.as_ptr() as _, size) };
376
377 cmp == 0
378 }
379}
380
381impl Eq for VolatileSlice<'_> {}
383
384impl std::io::Write for VolatileSlice<'_> {
385 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
386 let len = buf.len().min(self.size());
387 self.copy_from(&buf[..len]);
388 self.advance(len);
389 Ok(len)
390 }
391
392 fn flush(&mut self) -> std::io::Result<()> {
393 Ok(())
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use std::io::Write;
400 use std::sync::Arc;
401 use std::sync::Barrier;
402 use std::thread::spawn;
403
404 use super::*;
405
406 #[derive(Clone)]
407 struct VecMem {
408 mem: Arc<Vec<u8>>,
409 }
410
411 impl VecMem {
412 fn new(size: usize) -> VecMem {
413 VecMem {
414 mem: Arc::new(vec![0u8; size]),
415 }
416 }
417 }
418
419 impl VolatileMemory for VecMem {
420 fn get_slice(&self, offset: usize, count: usize) -> Result<VolatileSlice> {
421 let mem_end = offset
422 .checked_add(count)
423 .ok_or(VolatileMemoryError::Overflow {
424 base: offset,
425 offset: count,
426 })?;
427 if mem_end > self.mem.len() {
428 return Err(Error::OutOfBounds { addr: mem_end });
429 }
430
431 let new_addr = (self.mem.as_ptr() as usize).checked_add(offset).ok_or(
432 VolatileMemoryError::Overflow {
433 base: self.mem.as_ptr() as usize,
434 offset,
435 },
436 )?;
437
438 Ok(
439 unsafe { VolatileSlice::from_raw_parts(new_addr as *mut u8, count) },
441 )
442 }
443 }
444
445 #[test]
446 fn observe_mutate() {
447 let a = VecMem::new(1);
448 let a_clone = a.clone();
449 a.get_slice(0, 1).unwrap().write_bytes(99);
450
451 let start_barrier = Arc::new(Barrier::new(2));
452 let thread_start_barrier = start_barrier.clone();
453 let end_barrier = Arc::new(Barrier::new(2));
454 let thread_end_barrier = end_barrier.clone();
455 spawn(move || {
456 thread_start_barrier.wait();
457 a_clone.get_slice(0, 1).unwrap().write_bytes(0);
458 thread_end_barrier.wait();
459 });
460
461 let mut byte = [0u8; 1];
462 a.get_slice(0, 1).unwrap().copy_to(&mut byte);
463 assert_eq!(byte[0], 99);
464
465 start_barrier.wait();
466 end_barrier.wait();
467
468 a.get_slice(0, 1).unwrap().copy_to(&mut byte);
469 assert_eq!(byte[0], 0);
470 }
471
472 #[test]
473 fn slice_size() {
474 let a = VecMem::new(100);
475 let s = a.get_slice(0, 27).unwrap();
476 assert_eq!(s.size(), 27);
477
478 let s = a.get_slice(34, 27).unwrap();
479 assert_eq!(s.size(), 27);
480
481 let s = s.get_slice(20, 5).unwrap();
482 assert_eq!(s.size(), 5);
483 }
484
485 #[test]
486 fn slice_overflow_error() {
487 let a = VecMem::new(1);
488 let res = a.get_slice(usize::MAX, 1).unwrap_err();
489 assert_eq!(
490 res,
491 Error::Overflow {
492 base: usize::MAX,
493 offset: 1,
494 }
495 );
496 }
497
498 #[test]
499 fn slice_oob_error() {
500 let a = VecMem::new(100);
501 a.get_slice(50, 50).unwrap();
502 let res = a.get_slice(55, 50).unwrap_err();
503 assert_eq!(res, Error::OutOfBounds { addr: 105 });
504 }
505
506 #[test]
507 fn is_all_zero_16bytes_aligned() {
508 let a = VecMem::new(1024);
509 let slice = a.get_slice(0, 1024).unwrap();
510
511 assert!(slice.is_all_zero());
512 a.get_slice(129, 1).unwrap().write_bytes(1);
513 assert!(!slice.is_all_zero());
514 }
515
516 #[test]
517 fn is_all_zero_head_not_aligned() {
518 let a = VecMem::new(1024);
519 let slice = a.get_slice(1, 1023).unwrap();
520
521 assert!(slice.is_all_zero());
522 a.get_slice(0, 1).unwrap().write_bytes(1);
523 assert!(slice.is_all_zero());
524 a.get_slice(1, 1).unwrap().write_bytes(1);
525 assert!(!slice.is_all_zero());
526 a.get_slice(1, 1).unwrap().write_bytes(0);
527 a.get_slice(129, 1).unwrap().write_bytes(1);
528 assert!(!slice.is_all_zero());
529 }
530
531 #[test]
532 fn is_all_zero_tail_not_aligned() {
533 let a = VecMem::new(1024);
534 let slice = a.get_slice(0, 1023).unwrap();
535
536 assert!(slice.is_all_zero());
537 a.get_slice(1023, 1).unwrap().write_bytes(1);
538 assert!(slice.is_all_zero());
539 a.get_slice(1022, 1).unwrap().write_bytes(1);
540 assert!(!slice.is_all_zero());
541 a.get_slice(1022, 1).unwrap().write_bytes(0);
542 a.get_slice(0, 1).unwrap().write_bytes(1);
543 assert!(!slice.is_all_zero());
544 }
545
546 #[test]
547 fn is_all_zero_no_aligned_16bytes() {
548 let a = VecMem::new(1024);
549 let slice = a.get_slice(1, 16).unwrap();
550
551 assert!(slice.is_all_zero());
552 a.get_slice(0, 1).unwrap().write_bytes(1);
553 assert!(slice.is_all_zero());
554 for i in 1..17 {
555 a.get_slice(i, 1).unwrap().write_bytes(1);
556 assert!(!slice.is_all_zero());
557 a.get_slice(i, 1).unwrap().write_bytes(0);
558 }
559 a.get_slice(17, 1).unwrap().write_bytes(1);
560 assert!(slice.is_all_zero());
561 }
562
563 #[test]
564 fn write_partial() {
565 let mem = VecMem::new(1024);
566 let mut slice = mem.get_slice(1, 16).unwrap();
567 slice.write_bytes(0xCC);
568
569 let write_len = slice.write(&[1, 2, 3, 4]).unwrap();
571 assert_eq!(write_len, 4);
572 assert_eq!(slice.size(), 16 - 4);
573
574 assert_eq!(mem.mem[1..=4], [1, 2, 3, 4]);
576
577 assert_eq!(mem.mem[5], 0xCC);
579 }
580
581 #[test]
582 fn write_multiple() {
583 let mem = VecMem::new(1024);
584 let mut slice = mem.get_slice(1, 16).unwrap();
585 slice.write_bytes(0xCC);
586
587 let write_len = slice.write(&[1, 2, 3, 4]).unwrap();
589 assert_eq!(write_len, 4);
590 assert_eq!(slice.size(), 16 - 4);
591
592 assert_eq!(mem.mem[5], 0xCC);
594
595 let write2_len = slice.write(&[5, 6, 7, 8]).unwrap();
597 assert_eq!(write2_len, 4);
598 assert_eq!(slice.size(), 16 - 4 - 4);
599
600 assert_eq!(mem.mem[1..=8], [1, 2, 3, 4, 5, 6, 7, 8]);
602
603 assert_eq!(mem.mem[9], 0xCC);
605 }
606
607 #[test]
608 fn write_exact_slice_size() {
609 let mem = VecMem::new(1024);
610 let mut slice = mem.get_slice(1, 4).unwrap();
611 slice.write_bytes(0xCC);
612
613 let write_len = slice.write(&[1, 2, 3, 4]).unwrap();
615 assert_eq!(write_len, 4);
616 assert_eq!(slice.size(), 0);
617
618 assert_eq!(mem.mem[1..=4], [1, 2, 3, 4]);
620
621 assert_eq!(mem.mem[5], 0);
623 }
624
625 #[test]
626 fn write_more_than_slice_size() {
627 let mem = VecMem::new(1024);
628 let mut slice = mem.get_slice(1, 4).unwrap();
629 slice.write_bytes(0xCC);
630
631 let write_len = slice.write(&[1, 2, 3, 4, 5]).unwrap();
633 assert_eq!(write_len, 4);
634 assert_eq!(slice.size(), 0);
635
636 assert_eq!(mem.mem[1..=4], [1, 2, 3, 4]);
638
639 assert_eq!(mem.mem[5], 0);
641 }
642
643 #[test]
644 fn write_empty_slice() {
645 let mem = VecMem::new(1024);
646 let mut slice = mem.get_slice(1, 0).unwrap();
647
648 assert_eq!(slice.write(&[1, 2, 3, 4]).unwrap(), 0);
650 assert_eq!(slice.write(&[5, 6, 7, 8]).unwrap(), 0);
651 assert_eq!(slice.write(&[]).unwrap(), 0);
652 }
653}