devices/virtio/
descriptor_utils.rs

1// Copyright 2019 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::cmp;
6use std::io;
7use std::io::Write;
8use std::iter::FromIterator;
9use std::marker::PhantomData;
10use std::mem::size_of;
11use std::mem::MaybeUninit;
12use std::ptr::copy_nonoverlapping;
13use std::sync::Arc;
14
15use anyhow::Context;
16use base::FileReadWriteAtVolatile;
17use base::FileReadWriteVolatile;
18use base::VolatileSlice;
19use cros_async::MemRegion;
20use cros_async::MemRegionIter;
21use data_model::Le16;
22use data_model::Le32;
23use data_model::Le64;
24use disk::AsyncDisk;
25use smallvec::SmallVec;
26use vm_memory::GuestAddress;
27use vm_memory::GuestMemory;
28use zerocopy::FromBytes;
29use zerocopy::Immutable;
30use zerocopy::IntoBytes;
31use zerocopy::KnownLayout;
32
33use super::DescriptorChain;
34use crate::virtio::SplitDescriptorChain;
35
36struct DescriptorChainRegions {
37    regions: SmallVec<[MemRegion; 2]>,
38
39    // Index of the current region in `regions`.
40    current_region_index: usize,
41
42    // Number of bytes consumed in the current region.
43    current_region_offset: usize,
44
45    // Total bytes consumed in the entire descriptor chain.
46    bytes_consumed: usize,
47}
48
49impl DescriptorChainRegions {
50    fn new(regions: SmallVec<[MemRegion; 2]>) -> Self {
51        DescriptorChainRegions {
52            regions,
53            current_region_index: 0,
54            current_region_offset: 0,
55            bytes_consumed: 0,
56        }
57    }
58
59    fn available_bytes(&self) -> usize {
60        // This is guaranteed not to overflow because the total length of the chain is checked
61        // during all creations of `DescriptorChain` (see `DescriptorChain::new()`).
62        self.get_remaining_regions()
63            .fold(0usize, |count, region| count + region.len)
64    }
65
66    fn bytes_consumed(&self) -> usize {
67        self.bytes_consumed
68    }
69
70    /// Returns all the remaining buffers in the `DescriptorChain`. Calling this function does not
71    /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume`
72    /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls
73    /// to `consume` will return the same data.
74    fn get_remaining_regions(&self) -> MemRegionIter {
75        MemRegionIter::new(&self.regions[self.current_region_index..])
76            .skip_bytes(self.current_region_offset)
77    }
78
79    /// Like `get_remaining_regions` but guarantees that the combined length of all the returned
80    /// iovecs is not greater than `count`. The combined length of the returned iovecs may be less
81    /// than `count` but will always be greater than 0 as long as there is still space left in the
82    /// `DescriptorChain`.
83    fn get_remaining_regions_with_count(&self, count: usize) -> MemRegionIter {
84        MemRegionIter::new(&self.regions[self.current_region_index..])
85            .skip_bytes(self.current_region_offset)
86            .take_bytes(count)
87    }
88
89    /// Returns all the remaining buffers in the `DescriptorChain` as `VolatileSlice`s of the given
90    /// `GuestMemory`. Calling this function does not consume any bytes from the `DescriptorChain`.
91    /// Instead callers should use the `consume` method to advance the `DescriptorChain`. Multiple
92    /// calls to `get` with no intervening calls to `consume` will return the same data.
93    fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]> {
94        self.get_remaining_regions()
95            .filter_map(|region| {
96                mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
97                    .ok()
98            })
99            .collect()
100    }
101
102    /// Like 'get_remaining_regions_with_count' except convert the offsets to volatile slices in
103    /// the 'GuestMemory' given by 'mem'.
104    fn get_remaining_with_count<'mem>(
105        &self,
106        mem: &'mem GuestMemory,
107        count: usize,
108    ) -> SmallVec<[VolatileSlice<'mem>; 16]> {
109        self.get_remaining_regions_with_count(count)
110            .filter_map(|region| {
111                mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
112                    .ok()
113            })
114            .collect()
115    }
116
117    /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than
118    /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed.
119    fn consume(&mut self, mut count: usize) {
120        while let Some(region) = self.regions.get(self.current_region_index) {
121            let region_remaining = region.len - self.current_region_offset;
122            if count < region_remaining {
123                // The remaining count to consume is less than the remaining un-consumed length of
124                // the current region. Adjust the region offset without advancing to the next region
125                // and stop.
126                self.current_region_offset += count;
127                self.bytes_consumed += count;
128                return;
129            }
130
131            // The current region has been exhausted. Advance to the next region.
132            self.current_region_index += 1;
133            self.current_region_offset = 0;
134
135            self.bytes_consumed += region_remaining;
136            count -= region_remaining;
137        }
138    }
139
140    fn split_at(&mut self, offset: usize) -> DescriptorChainRegions {
141        let mut other = DescriptorChainRegions {
142            regions: self.regions.clone(),
143            current_region_index: self.current_region_index,
144            current_region_offset: self.current_region_offset,
145            bytes_consumed: self.bytes_consumed,
146        };
147        other.consume(offset);
148        other.bytes_consumed = 0;
149
150        let mut rem = offset;
151        let mut end = self.current_region_index;
152        for region in &mut self.regions[self.current_region_index..] {
153            if rem <= region.len {
154                region.len = rem;
155                break;
156            }
157
158            end += 1;
159            rem -= region.len;
160        }
161
162        self.regions.truncate(end + 1);
163
164        other
165    }
166}
167
168/// Provides high-level interface over the sequence of memory regions
169/// defined by readable descriptors in the descriptor chain.
170///
171/// Note that virtio spec requires driver to place any device-writable
172/// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
173/// Reader will skip iterating over descriptor chain when first writable
174/// descriptor is encountered.
175pub struct Reader {
176    mem: GuestMemory,
177    regions: DescriptorChainRegions,
178}
179
180/// An iterator over `FromBytes` objects on readable descriptors in the descriptor chain.
181pub struct ReaderIterator<'a, T: FromBytes> {
182    reader: &'a mut Reader,
183    phantom: PhantomData<T>,
184}
185
186impl<T: FromBytes> Iterator for ReaderIterator<'_, T> {
187    type Item = io::Result<T>;
188
189    fn next(&mut self) -> Option<io::Result<T>> {
190        if self.reader.available_bytes() == 0 {
191            None
192        } else {
193            Some(self.reader.read_obj())
194        }
195    }
196}
197
198impl Reader {
199    /// Construct a new Reader wrapper over `readable_regions`.
200    pub fn new_from_regions(
201        mem: &GuestMemory,
202        readable_regions: SmallVec<[MemRegion; 2]>,
203    ) -> Reader {
204        Reader {
205            mem: mem.clone(),
206            regions: DescriptorChainRegions::new(readable_regions),
207        }
208    }
209
210    /// Reads an object from the descriptor chain buffer without consuming it.
211    pub fn peek_obj<T: FromBytes>(&self) -> io::Result<T> {
212        let mut obj = MaybeUninit::uninit();
213
214        // SAFETY: We pass a valid pointer and size of `obj`.
215        let copied = unsafe {
216            copy_regions_to_mut_ptr(
217                &self.mem,
218                self.get_remaining_regions(),
219                obj.as_mut_ptr() as *mut u8,
220                size_of::<T>(),
221            )?
222        };
223        if copied != size_of::<T>() {
224            return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
225        }
226
227        // SAFETY: `FromBytes` guarantees any set of initialized bytes is a valid value for `T`, and
228        // we initialized all bytes in `obj` in the copy above.
229        Ok(unsafe { obj.assume_init() })
230    }
231
232    /// Reads and consumes an object from the descriptor chain buffer.
233    pub fn read_obj<T: FromBytes>(&mut self) -> io::Result<T> {
234        let obj = self.peek_obj::<T>()?;
235        self.consume(size_of::<T>());
236        Ok(obj)
237    }
238
239    /// Reads objects by consuming all the remaining data in the descriptor chain buffer and returns
240    /// them as a collection. Returns an error if the size of the remaining data is indivisible by
241    /// the size of an object of type `T`.
242    pub fn collect<C: FromIterator<io::Result<T>>, T: FromBytes>(&mut self) -> C {
243        self.iter().collect()
244    }
245
246    /// Creates an iterator for sequentially reading `FromBytes` objects from the `Reader`.
247    /// Unlike `collect`, this doesn't consume all the remaining data in the `Reader` and
248    /// doesn't require the objects to be stored in a separate collection.
249    pub fn iter<T: FromBytes>(&mut self) -> ReaderIterator<T> {
250        ReaderIterator {
251            reader: self,
252            phantom: PhantomData,
253        }
254    }
255
256    /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
257    /// bytes remaining. Returns the number of bytes read.
258    pub fn read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
259        let mut read = 0usize;
260        let mut dst = slice;
261        for src in self.get_remaining() {
262            src.copy_to_volatile_slice(dst);
263            let copied = std::cmp::min(src.size(), dst.size());
264            read += copied;
265            dst = match dst.offset(copied) {
266                Ok(v) => v,
267                Err(_) => break, // The slice is fully consumed
268            };
269        }
270        self.regions.consume(read);
271        read
272    }
273
274    /// Reads data from the descriptor chain buffer and passes the `VolatileSlice`s to the callback
275    /// `cb`.
276    pub fn read_to_cb<C: FnOnce(&[VolatileSlice]) -> usize>(
277        &mut self,
278        cb: C,
279        count: usize,
280    ) -> usize {
281        let iovs = self.regions.get_remaining_with_count(&self.mem, count);
282        let written = cb(&iovs[..]);
283        self.regions.consume(written);
284        written
285    }
286
287    /// Reads data from the descriptor chain buffer into a writable object.
288    /// Returns the number of bytes read from the descriptor chain buffer.
289    /// The number of bytes read can be less than `count` if there isn't
290    /// enough data in the descriptor chain buffer.
291    pub fn read_to<F: FileReadWriteVolatile>(
292        &mut self,
293        mut dst: F,
294        count: usize,
295    ) -> io::Result<usize> {
296        let iovs = self.regions.get_remaining_with_count(&self.mem, count);
297        let written = dst.write_vectored_volatile(&iovs[..])?;
298        self.regions.consume(written);
299        Ok(written)
300    }
301
302    /// Reads data from the descriptor chain buffer into a File at offset `off`.
303    /// Returns the number of bytes read from the descriptor chain buffer.
304    /// The number of bytes read can be less than `count` if there isn't
305    /// enough data in the descriptor chain buffer.
306    pub fn read_to_at<F: FileReadWriteAtVolatile>(
307        &mut self,
308        dst: &F,
309        count: usize,
310        off: u64,
311    ) -> io::Result<usize> {
312        let iovs = self.regions.get_remaining_with_count(&self.mem, count);
313        let written = dst.write_vectored_at_volatile(&iovs[..], off)?;
314        self.regions.consume(written);
315        Ok(written)
316    }
317
318    /// Reads data from the descriptor chain similar to 'read_to' except reading 'count' or
319    /// returning an error if 'count' bytes can't be read.
320    pub fn read_exact_to<F: FileReadWriteVolatile>(
321        &mut self,
322        mut dst: F,
323        mut count: usize,
324    ) -> io::Result<()> {
325        while count > 0 {
326            match self.read_to(&mut dst, count) {
327                Ok(0) => {
328                    return Err(io::Error::new(
329                        io::ErrorKind::UnexpectedEof,
330                        "failed to fill whole buffer",
331                    ))
332                }
333                Ok(n) => count -= n,
334                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
335                Err(e) => return Err(e),
336            }
337        }
338
339        Ok(())
340    }
341
342    /// Reads data from the descriptor chain similar to 'read_to_at' except reading 'count' or
343    /// returning an error if 'count' bytes can't be read.
344    pub fn read_exact_to_at<F: FileReadWriteAtVolatile>(
345        &mut self,
346        dst: &F,
347        mut count: usize,
348        mut off: u64,
349    ) -> io::Result<()> {
350        while count > 0 {
351            match self.read_to_at(dst, count, off) {
352                Ok(0) => {
353                    return Err(io::Error::new(
354                        io::ErrorKind::UnexpectedEof,
355                        "failed to fill whole buffer",
356                    ))
357                }
358                Ok(n) => {
359                    count -= n;
360                    off += n as u64;
361                }
362                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
363                Err(e) => return Err(e),
364            }
365        }
366
367        Ok(())
368    }
369
370    /// Reads data from the descriptor chain buffer into an `AsyncDisk` at offset `off`.
371    /// Returns the number of bytes read from the descriptor chain buffer.
372    /// The number of bytes read can be less than `count` if there isn't
373    /// enough data in the descriptor chain buffer.
374    pub async fn read_to_at_fut<F: AsyncDisk + ?Sized>(
375        &mut self,
376        dst: &F,
377        count: usize,
378        off: u64,
379    ) -> disk::Result<usize> {
380        let written = dst
381            .write_from_mem(
382                off,
383                Arc::new(self.mem.clone()),
384                self.regions.get_remaining_regions_with_count(count),
385            )
386            .await?;
387        self.regions.consume(written);
388        Ok(written)
389    }
390
391    /// Reads exactly `count` bytes from the chain to the disk asynchronously or returns an error if
392    /// not enough data can be read.
393    pub async fn read_exact_to_at_fut<F: AsyncDisk + ?Sized>(
394        &mut self,
395        dst: &F,
396        mut count: usize,
397        mut off: u64,
398    ) -> disk::Result<()> {
399        while count > 0 {
400            let nread = self.read_to_at_fut(dst, count, off).await?;
401            if nread == 0 {
402                return Err(disk::Error::ReadingData(io::Error::new(
403                    io::ErrorKind::UnexpectedEof,
404                    "failed to write whole buffer",
405                )));
406            }
407            count -= nread;
408            off += nread as u64;
409        }
410
411        Ok(())
412    }
413
414    /// Returns number of bytes available for reading.  May return an error if the combined
415    /// lengths of all the buffers in the DescriptorChain would cause an integer overflow.
416    pub fn available_bytes(&self) -> usize {
417        self.regions.available_bytes()
418    }
419
420    /// Returns number of bytes already read from the descriptor chain buffer.
421    pub fn bytes_read(&self) -> usize {
422        self.regions.bytes_consumed()
423    }
424
425    pub fn get_remaining_regions(&self) -> MemRegionIter {
426        self.regions.get_remaining_regions()
427    }
428
429    /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`.
430    /// Calling this method does not actually consume any data from the `Reader` and callers should
431    /// call `consume` to advance the `Reader`.
432    pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
433        self.regions.get_remaining(&self.mem)
434    }
435
436    /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
437    /// remaining data left in this `Reader`, then all remaining data will be consumed.
438    pub fn consume(&mut self, amt: usize) {
439        self.regions.consume(amt)
440    }
441
442    /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the
443    /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read
444    /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the
445    /// returned `Reader` will not be able to read any bytes.
446    pub fn split_at(&mut self, offset: usize) -> Reader {
447        Reader {
448            mem: self.mem.clone(),
449            regions: self.regions.split_at(offset),
450        }
451    }
452}
453
454/// Copy up to `size` bytes from `src` into `dst`.
455///
456/// Returns the total number of bytes copied.
457///
458/// # Safety
459///
460/// The caller must ensure that it is safe to write `size` bytes of data into `dst`.
461///
462/// After the function returns, it is only safe to assume that the number of bytes indicated by the
463/// return value (which may be less than the requested `size`) have been initialized. Bytes beyond
464/// that point are not initialized by this function.
465unsafe fn copy_regions_to_mut_ptr(
466    mem: &GuestMemory,
467    src: MemRegionIter,
468    dst: *mut u8,
469    size: usize,
470) -> io::Result<usize> {
471    let mut copied = 0;
472    for src_region in src {
473        if copied >= size {
474            break;
475        }
476
477        let remaining = size - copied;
478        let count = cmp::min(remaining, src_region.len);
479
480        let vslice = mem
481            .get_slice_at_addr(GuestAddress(src_region.offset), count)
482            .map_err(|_e| io::Error::from(io::ErrorKind::InvalidData))?;
483
484        // SAFETY: `get_slice_at_addr()` verified that the region points to valid memory, and
485        // the `count` calculation ensures we will write at most `size` bytes into `dst`.
486        unsafe {
487            copy_nonoverlapping(vslice.as_ptr(), dst.add(copied), count);
488        }
489
490        copied += count;
491    }
492
493    Ok(copied)
494}
495
496impl io::Read for Reader {
497    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
498        // SAFETY: We pass a valid pointer and size combination derived from `buf`.
499        let total = unsafe {
500            copy_regions_to_mut_ptr(
501                &self.mem,
502                self.regions.get_remaining_regions(),
503                buf.as_mut_ptr(),
504                buf.len(),
505            )?
506        };
507        self.regions.consume(total);
508        Ok(total)
509    }
510}
511
512/// Provides high-level interface over the sequence of memory regions
513/// defined by writable descriptors in the descriptor chain.
514///
515/// Note that virtio spec requires driver to place any device-writable
516/// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
517/// Writer will start iterating the descriptors from the first writable one and will
518/// assume that all following descriptors are writable.
519pub struct Writer {
520    mem: GuestMemory,
521    regions: DescriptorChainRegions,
522}
523
524impl Writer {
525    /// Construct a new Writer wrapper over `writable_regions`.
526    pub fn new_from_regions(
527        mem: &GuestMemory,
528        writable_regions: SmallVec<[MemRegion; 2]>,
529    ) -> Writer {
530        Writer {
531            mem: mem.clone(),
532            regions: DescriptorChainRegions::new(writable_regions),
533        }
534    }
535
536    /// Writes an object to the descriptor chain buffer.
537    pub fn write_obj<T: Immutable + IntoBytes>(&mut self, val: T) -> io::Result<()> {
538        self.write_all(val.as_bytes())
539    }
540
541    /// Writes all objects produced by `iter` into the descriptor chain buffer. Unlike `consume`,
542    /// this doesn't require the values to be stored in an intermediate collection first. It also
543    /// allows callers to choose which elements in a collection to write, for example by using the
544    /// `filter` or `take` methods of the `Iterator` trait.
545    pub fn write_iter<T: Immutable + IntoBytes, I: Iterator<Item = T>>(
546        &mut self,
547        mut iter: I,
548    ) -> io::Result<()> {
549        iter.try_for_each(|v| self.write_obj(v))
550    }
551
552    /// Writes a collection of objects into the descriptor chain buffer.
553    pub fn consume<T: Immutable + IntoBytes, C: IntoIterator<Item = T>>(
554        &mut self,
555        vals: C,
556    ) -> io::Result<()> {
557        self.write_iter(vals.into_iter())
558    }
559
560    /// Returns number of bytes available for writing.  May return an error if the combined
561    /// lengths of all the buffers in the DescriptorChain would cause an overflow.
562    pub fn available_bytes(&self) -> usize {
563        self.regions.available_bytes()
564    }
565
566    /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
567    /// bytes remaining. Returns the number of bytes read.
568    pub fn write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
569        let mut written = 0usize;
570        let mut src = slice;
571        for dst in self.get_remaining() {
572            src.copy_to_volatile_slice(dst);
573            let copied = std::cmp::min(src.size(), dst.size());
574            written += copied;
575            src = match src.offset(copied) {
576                Ok(v) => v,
577                Err(_) => break, // The slice is fully consumed
578            };
579        }
580        self.regions.consume(written);
581        written
582    }
583
584    /// Writes data to the descriptor chain buffer from a readable object.
585    /// Returns the number of bytes written to the descriptor chain buffer.
586    /// The number of bytes written can be less than `count` if
587    /// there isn't enough data in the descriptor chain buffer.
588    pub fn write_from<F: FileReadWriteVolatile>(
589        &mut self,
590        mut src: F,
591        count: usize,
592    ) -> io::Result<usize> {
593        let iovs = self.regions.get_remaining_with_count(&self.mem, count);
594        let read = src.read_vectored_volatile(&iovs[..])?;
595        self.regions.consume(read);
596        Ok(read)
597    }
598
599    /// Writes data to the descriptor chain buffer from a File at offset `off`.
600    /// Returns the number of bytes written to the descriptor chain buffer.
601    /// The number of bytes written can be less than `count` if
602    /// there isn't enough data in the descriptor chain buffer.
603    pub fn write_from_at<F: FileReadWriteAtVolatile>(
604        &mut self,
605        src: &F,
606        count: usize,
607        off: u64,
608    ) -> io::Result<usize> {
609        let iovs = self.regions.get_remaining_with_count(&self.mem, count);
610        let read = src.read_vectored_at_volatile(&iovs[..], off)?;
611        self.regions.consume(read);
612        Ok(read)
613    }
614
615    pub fn write_all_from<F: FileReadWriteVolatile>(
616        &mut self,
617        mut src: F,
618        mut count: usize,
619    ) -> io::Result<()> {
620        while count > 0 {
621            match self.write_from(&mut src, count) {
622                Ok(0) => {
623                    return Err(io::Error::new(
624                        io::ErrorKind::WriteZero,
625                        "failed to write whole buffer",
626                    ))
627                }
628                Ok(n) => count -= n,
629                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
630                Err(e) => return Err(e),
631            }
632        }
633
634        Ok(())
635    }
636
637    pub fn write_all_from_at<F: FileReadWriteAtVolatile>(
638        &mut self,
639        src: &F,
640        mut count: usize,
641        mut off: u64,
642    ) -> io::Result<()> {
643        while count > 0 {
644            match self.write_from_at(src, count, off) {
645                Ok(0) => {
646                    return Err(io::Error::new(
647                        io::ErrorKind::WriteZero,
648                        "failed to write whole buffer",
649                    ))
650                }
651                Ok(n) => {
652                    count -= n;
653                    off += n as u64;
654                }
655                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
656                Err(e) => return Err(e),
657            }
658        }
659        Ok(())
660    }
661    /// Writes data to the descriptor chain buffer from an `AsyncDisk` at offset `off`.
662    /// Returns the number of bytes written to the descriptor chain buffer.
663    /// The number of bytes written can be less than `count` if
664    /// there isn't enough data in the descriptor chain buffer.
665    pub async fn write_from_at_fut<F: AsyncDisk + ?Sized>(
666        &mut self,
667        src: &F,
668        count: usize,
669        off: u64,
670    ) -> disk::Result<usize> {
671        let read = src
672            .read_to_mem(
673                off,
674                Arc::new(self.mem.clone()),
675                self.regions.get_remaining_regions_with_count(count),
676            )
677            .await?;
678        self.regions.consume(read);
679        Ok(read)
680    }
681
682    pub async fn write_all_from_at_fut<F: AsyncDisk + ?Sized>(
683        &mut self,
684        src: &F,
685        mut count: usize,
686        mut off: u64,
687    ) -> disk::Result<()> {
688        while count > 0 {
689            let nwritten = self.write_from_at_fut(src, count, off).await?;
690            if nwritten == 0 {
691                return Err(disk::Error::WritingData(io::Error::new(
692                    io::ErrorKind::UnexpectedEof,
693                    "failed to write whole buffer",
694                )));
695            }
696            count -= nwritten;
697            off += nwritten as u64;
698        }
699        Ok(())
700    }
701
702    /// Returns number of bytes already written to the descriptor chain buffer.
703    pub fn bytes_written(&self) -> usize {
704        self.regions.bytes_consumed()
705    }
706
707    pub fn get_remaining_regions(&self) -> MemRegionIter {
708        self.regions.get_remaining_regions()
709    }
710
711    /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Writer`.
712    /// Calling this method does not actually advance the current position of the `Writer` in the
713    /// buffer and callers should call `consume_bytes` to advance the `Writer`. Not calling
714    /// `consume_bytes` with the amount of data copied into the returned `VolatileSlice`s will
715    /// result in that that data being overwritten the next time data is written into the `Writer`.
716    pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
717        self.regions.get_remaining(&self.mem)
718    }
719
720    /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
721    /// remaining data left in this `Reader`, then all remaining data will be consumed.
722    pub fn consume_bytes(&mut self, amt: usize) {
723        self.regions.consume(amt)
724    }
725
726    /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. After the
727    /// split, `self` will be able to write up to `offset` bytes while the returned `Writer` can
728    /// write up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then
729    /// the returned `Writer` will not be able to write any bytes.
730    pub fn split_at(&mut self, offset: usize) -> Writer {
731        Writer {
732            mem: self.mem.clone(),
733            regions: self.regions.split_at(offset),
734        }
735    }
736}
737
738impl io::Write for Writer {
739    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
740        let mut rem = buf;
741        let mut total = 0;
742        for b in self.regions.get_remaining(&self.mem) {
743            if rem.is_empty() {
744                break;
745            }
746
747            let count = cmp::min(rem.len(), b.size());
748            // SAFETY:
749            // Safe because we have already verified that `vs` points to valid memory.
750            unsafe {
751                copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count);
752            }
753            rem = &rem[count..];
754            total += count;
755        }
756
757        self.regions.consume(total);
758        Ok(total)
759    }
760
761    fn flush(&mut self) -> io::Result<()> {
762        // Nothing to flush since the writes go straight into the buffer.
763        Ok(())
764    }
765}
766
767const VIRTQ_DESC_F_NEXT: u16 = 0x1;
768const VIRTQ_DESC_F_WRITE: u16 = 0x2;
769
770#[derive(Copy, Clone, PartialEq, Eq)]
771pub enum DescriptorType {
772    Readable,
773    Writable,
774}
775
776#[derive(Copy, Clone, Debug, FromBytes, Immutable, IntoBytes, KnownLayout)]
777#[repr(C)]
778struct virtq_desc {
779    addr: Le64,
780    len: Le32,
781    flags: Le16,
782    next: Le16,
783}
784
785/// Test utility function to create a descriptor chain in guest memory.
786pub fn create_descriptor_chain(
787    memory: &GuestMemory,
788    descriptor_array_addr: GuestAddress,
789    mut buffers_start_addr: GuestAddress,
790    descriptors: Vec<(DescriptorType, u32)>,
791    spaces_between_regions: u32,
792) -> anyhow::Result<DescriptorChain> {
793    let descriptors_len = descriptors.len();
794    for (index, (type_, size)) in descriptors.into_iter().enumerate() {
795        let mut flags = 0;
796        if let DescriptorType::Writable = type_ {
797            flags |= VIRTQ_DESC_F_WRITE;
798        }
799        if index + 1 < descriptors_len {
800            flags |= VIRTQ_DESC_F_NEXT;
801        }
802
803        let index = index as u16;
804        let desc = virtq_desc {
805            addr: buffers_start_addr.offset().into(),
806            len: size.into(),
807            flags: flags.into(),
808            next: (index + 1).into(),
809        };
810
811        let offset = size + spaces_between_regions;
812        buffers_start_addr = buffers_start_addr
813            .checked_add(offset as u64)
814            .context("Invalid buffers_start_addr)")?;
815
816        let _ = memory.write_obj_at_addr(
817            desc,
818            descriptor_array_addr
819                .checked_add(index as u64 * std::mem::size_of::<virtq_desc>() as u64)
820                .context("Invalid descriptor_array_addr")?,
821        );
822    }
823
824    let chain = SplitDescriptorChain::new(memory, descriptor_array_addr, 0x100, 0);
825    DescriptorChain::new(chain, memory, 0)
826}
827
828#[cfg(test)]
829mod tests {
830    use std::fs::File;
831    use std::io::Read;
832
833    use cros_async::Executor;
834    use tempfile::tempfile;
835    use tempfile::NamedTempFile;
836
837    use super::*;
838
839    #[test]
840    fn reader_test_simple_chain() {
841        use DescriptorType::*;
842
843        let memory_start_addr = GuestAddress(0x0);
844        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
845
846        let mut chain = create_descriptor_chain(
847            &memory,
848            GuestAddress(0x0),
849            GuestAddress(0x100),
850            vec![
851                (Readable, 8),
852                (Readable, 16),
853                (Readable, 18),
854                (Readable, 64),
855            ],
856            0,
857        )
858        .expect("create_descriptor_chain failed");
859        let reader = &mut chain.reader;
860        assert_eq!(reader.available_bytes(), 106);
861        assert_eq!(reader.bytes_read(), 0);
862
863        let mut buffer = [0u8; 64];
864        reader
865            .read_exact(&mut buffer)
866            .expect("read_exact should not fail here");
867
868        assert_eq!(reader.available_bytes(), 42);
869        assert_eq!(reader.bytes_read(), 64);
870
871        match reader.read(&mut buffer) {
872            Err(_) => panic!("read should not fail here"),
873            Ok(length) => assert_eq!(length, 42),
874        }
875
876        assert_eq!(reader.available_bytes(), 0);
877        assert_eq!(reader.bytes_read(), 106);
878    }
879
880    #[test]
881    fn writer_test_simple_chain() {
882        use DescriptorType::*;
883
884        let memory_start_addr = GuestAddress(0x0);
885        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
886
887        let mut chain = create_descriptor_chain(
888            &memory,
889            GuestAddress(0x0),
890            GuestAddress(0x100),
891            vec![
892                (Writable, 8),
893                (Writable, 16),
894                (Writable, 18),
895                (Writable, 64),
896            ],
897            0,
898        )
899        .expect("create_descriptor_chain failed");
900        let writer = &mut chain.writer;
901        assert_eq!(writer.available_bytes(), 106);
902        assert_eq!(writer.bytes_written(), 0);
903
904        let buffer = [0; 64];
905        writer
906            .write_all(&buffer)
907            .expect("write_all should not fail here");
908
909        assert_eq!(writer.available_bytes(), 42);
910        assert_eq!(writer.bytes_written(), 64);
911
912        match writer.write(&buffer) {
913            Err(_) => panic!("write should not fail here"),
914            Ok(length) => assert_eq!(length, 42),
915        }
916
917        assert_eq!(writer.available_bytes(), 0);
918        assert_eq!(writer.bytes_written(), 106);
919    }
920
921    #[test]
922    fn reader_test_incompatible_chain() {
923        use DescriptorType::*;
924
925        let memory_start_addr = GuestAddress(0x0);
926        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
927
928        let mut chain = create_descriptor_chain(
929            &memory,
930            GuestAddress(0x0),
931            GuestAddress(0x100),
932            vec![(Writable, 8)],
933            0,
934        )
935        .expect("create_descriptor_chain failed");
936        let reader = &mut chain.reader;
937        assert_eq!(reader.available_bytes(), 0);
938        assert_eq!(reader.bytes_read(), 0);
939
940        assert!(reader.read_obj::<u8>().is_err());
941
942        assert_eq!(reader.available_bytes(), 0);
943        assert_eq!(reader.bytes_read(), 0);
944    }
945
946    #[test]
947    fn writer_test_incompatible_chain() {
948        use DescriptorType::*;
949
950        let memory_start_addr = GuestAddress(0x0);
951        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
952
953        let mut chain = create_descriptor_chain(
954            &memory,
955            GuestAddress(0x0),
956            GuestAddress(0x100),
957            vec![(Readable, 8)],
958            0,
959        )
960        .expect("create_descriptor_chain failed");
961        let writer = &mut chain.writer;
962        assert_eq!(writer.available_bytes(), 0);
963        assert_eq!(writer.bytes_written(), 0);
964
965        assert!(writer.write_obj(0u8).is_err());
966
967        assert_eq!(writer.available_bytes(), 0);
968        assert_eq!(writer.bytes_written(), 0);
969    }
970
971    #[test]
972    fn reader_failing_io() {
973        use DescriptorType::*;
974
975        let memory_start_addr = GuestAddress(0x0);
976        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
977
978        let mut chain = create_descriptor_chain(
979            &memory,
980            GuestAddress(0x0),
981            GuestAddress(0x100),
982            vec![(Readable, 256), (Readable, 256)],
983            0,
984        )
985        .expect("create_descriptor_chain failed");
986
987        let reader = &mut chain.reader;
988
989        // Open a file in read-only mode so writes to it to trigger an I/O error.
990        let device_file = if cfg!(windows) { "NUL" } else { "/dev/zero" };
991        let mut ro_file = File::open(device_file).expect("failed to open device file");
992
993        reader
994            .read_exact_to(&mut ro_file, 512)
995            .expect_err("successfully read more bytes than SharedMemory size");
996
997        // The write above should have failed entirely, so we end up not writing any bytes at all.
998        assert_eq!(reader.available_bytes(), 512);
999        assert_eq!(reader.bytes_read(), 0);
1000    }
1001
1002    #[test]
1003    fn writer_failing_io() {
1004        use DescriptorType::*;
1005
1006        let memory_start_addr = GuestAddress(0x0);
1007        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1008
1009        let mut chain = create_descriptor_chain(
1010            &memory,
1011            GuestAddress(0x0),
1012            GuestAddress(0x100),
1013            vec![(Writable, 256), (Writable, 256)],
1014            0,
1015        )
1016        .expect("create_descriptor_chain failed");
1017
1018        let writer = &mut chain.writer;
1019
1020        let mut file = tempfile().unwrap();
1021
1022        file.set_len(384).unwrap();
1023
1024        writer
1025            .write_all_from(&mut file, 512)
1026            .expect_err("successfully wrote more bytes than in SharedMemory");
1027
1028        assert_eq!(writer.available_bytes(), 128);
1029        assert_eq!(writer.bytes_written(), 384);
1030    }
1031
1032    #[test]
1033    fn reader_writer_shared_chain() {
1034        use DescriptorType::*;
1035
1036        let memory_start_addr = GuestAddress(0x0);
1037        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1038
1039        let mut chain = create_descriptor_chain(
1040            &memory,
1041            GuestAddress(0x0),
1042            GuestAddress(0x100),
1043            vec![
1044                (Readable, 16),
1045                (Readable, 16),
1046                (Readable, 96),
1047                (Writable, 64),
1048                (Writable, 1),
1049                (Writable, 3),
1050            ],
1051            0,
1052        )
1053        .expect("create_descriptor_chain failed");
1054        let reader = &mut chain.reader;
1055        let writer = &mut chain.writer;
1056
1057        assert_eq!(reader.bytes_read(), 0);
1058        assert_eq!(writer.bytes_written(), 0);
1059
1060        let mut buffer = Vec::with_capacity(200);
1061
1062        assert_eq!(
1063            reader
1064                .read_to_end(&mut buffer)
1065                .expect("read should not fail here"),
1066            128
1067        );
1068
1069        // The writable descriptors are only 68 bytes long.
1070        writer
1071            .write_all(&buffer[..68])
1072            .expect("write should not fail here");
1073
1074        assert_eq!(reader.available_bytes(), 0);
1075        assert_eq!(reader.bytes_read(), 128);
1076        assert_eq!(writer.available_bytes(), 0);
1077        assert_eq!(writer.bytes_written(), 68);
1078    }
1079
1080    #[test]
1081    fn reader_writer_shattered_object() {
1082        use DescriptorType::*;
1083
1084        let memory_start_addr = GuestAddress(0x0);
1085        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1086
1087        let secret: Le32 = 0x12345678.into();
1088
1089        // Create a descriptor chain with memory regions that are properly separated.
1090        let mut chain_writer = create_descriptor_chain(
1091            &memory,
1092            GuestAddress(0x0),
1093            GuestAddress(0x100),
1094            vec![(Writable, 1), (Writable, 1), (Writable, 1), (Writable, 1)],
1095            123,
1096        )
1097        .expect("create_descriptor_chain failed");
1098        let writer = &mut chain_writer.writer;
1099        writer
1100            .write_obj(secret)
1101            .expect("write_obj should not fail here");
1102
1103        // Now create new descriptor chain pointing to the same memory and try to read it.
1104        let mut chain_reader = create_descriptor_chain(
1105            &memory,
1106            GuestAddress(0x0),
1107            GuestAddress(0x100),
1108            vec![(Readable, 1), (Readable, 1), (Readable, 1), (Readable, 1)],
1109            123,
1110        )
1111        .expect("create_descriptor_chain failed");
1112        let reader = &mut chain_reader.reader;
1113        match reader.read_obj::<Le32>() {
1114            Err(_) => panic!("read_obj should not fail here"),
1115            Ok(read_secret) => assert_eq!(read_secret, secret),
1116        }
1117    }
1118
1119    #[test]
1120    fn reader_unexpected_eof() {
1121        use DescriptorType::*;
1122
1123        let memory_start_addr = GuestAddress(0x0);
1124        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1125
1126        let mut chain = create_descriptor_chain(
1127            &memory,
1128            GuestAddress(0x0),
1129            GuestAddress(0x100),
1130            vec![(Readable, 256), (Readable, 256)],
1131            0,
1132        )
1133        .expect("create_descriptor_chain failed");
1134
1135        let reader = &mut chain.reader;
1136
1137        let mut buf = vec![0; 1024];
1138
1139        assert_eq!(
1140            reader
1141                .read_exact(&mut buf[..])
1142                .expect_err("read more bytes than available")
1143                .kind(),
1144            io::ErrorKind::UnexpectedEof
1145        );
1146    }
1147
1148    #[test]
1149    fn split_border() {
1150        use DescriptorType::*;
1151
1152        let memory_start_addr = GuestAddress(0x0);
1153        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1154
1155        let mut chain = create_descriptor_chain(
1156            &memory,
1157            GuestAddress(0x0),
1158            GuestAddress(0x100),
1159            vec![
1160                (Readable, 16),
1161                (Readable, 16),
1162                (Readable, 96),
1163                (Writable, 64),
1164                (Writable, 1),
1165                (Writable, 3),
1166            ],
1167            0,
1168        )
1169        .expect("create_descriptor_chain failed");
1170        let reader = &mut chain.reader;
1171
1172        let other = reader.split_at(32);
1173        assert_eq!(reader.available_bytes(), 32);
1174        assert_eq!(other.available_bytes(), 96);
1175    }
1176
1177    #[test]
1178    fn split_middle() {
1179        use DescriptorType::*;
1180
1181        let memory_start_addr = GuestAddress(0x0);
1182        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1183
1184        let mut chain = create_descriptor_chain(
1185            &memory,
1186            GuestAddress(0x0),
1187            GuestAddress(0x100),
1188            vec![
1189                (Readable, 16),
1190                (Readable, 16),
1191                (Readable, 96),
1192                (Writable, 64),
1193                (Writable, 1),
1194                (Writable, 3),
1195            ],
1196            0,
1197        )
1198        .expect("create_descriptor_chain failed");
1199        let reader = &mut chain.reader;
1200
1201        let other = reader.split_at(24);
1202        assert_eq!(reader.available_bytes(), 24);
1203        assert_eq!(other.available_bytes(), 104);
1204    }
1205
1206    #[test]
1207    fn split_end() {
1208        use DescriptorType::*;
1209
1210        let memory_start_addr = GuestAddress(0x0);
1211        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1212
1213        let mut chain = create_descriptor_chain(
1214            &memory,
1215            GuestAddress(0x0),
1216            GuestAddress(0x100),
1217            vec![
1218                (Readable, 16),
1219                (Readable, 16),
1220                (Readable, 96),
1221                (Writable, 64),
1222                (Writable, 1),
1223                (Writable, 3),
1224            ],
1225            0,
1226        )
1227        .expect("create_descriptor_chain failed");
1228        let reader = &mut chain.reader;
1229
1230        let other = reader.split_at(128);
1231        assert_eq!(reader.available_bytes(), 128);
1232        assert_eq!(other.available_bytes(), 0);
1233    }
1234
1235    #[test]
1236    fn split_beginning() {
1237        use DescriptorType::*;
1238
1239        let memory_start_addr = GuestAddress(0x0);
1240        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1241
1242        let mut chain = create_descriptor_chain(
1243            &memory,
1244            GuestAddress(0x0),
1245            GuestAddress(0x100),
1246            vec![
1247                (Readable, 16),
1248                (Readable, 16),
1249                (Readable, 96),
1250                (Writable, 64),
1251                (Writable, 1),
1252                (Writable, 3),
1253            ],
1254            0,
1255        )
1256        .expect("create_descriptor_chain failed");
1257        let reader = &mut chain.reader;
1258
1259        let other = reader.split_at(0);
1260        assert_eq!(reader.available_bytes(), 0);
1261        assert_eq!(other.available_bytes(), 128);
1262    }
1263
1264    #[test]
1265    fn split_outofbounds() {
1266        use DescriptorType::*;
1267
1268        let memory_start_addr = GuestAddress(0x0);
1269        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1270
1271        let mut chain = create_descriptor_chain(
1272            &memory,
1273            GuestAddress(0x0),
1274            GuestAddress(0x100),
1275            vec![
1276                (Readable, 16),
1277                (Readable, 16),
1278                (Readable, 96),
1279                (Writable, 64),
1280                (Writable, 1),
1281                (Writable, 3),
1282            ],
1283            0,
1284        )
1285        .expect("create_descriptor_chain failed");
1286        let reader = &mut chain.reader;
1287
1288        let other = reader.split_at(256);
1289        assert_eq!(
1290            other.available_bytes(),
1291            0,
1292            "Reader returned from out-of-bounds split still has available bytes"
1293        );
1294    }
1295
1296    #[test]
1297    fn read_full() {
1298        use DescriptorType::*;
1299
1300        let memory_start_addr = GuestAddress(0x0);
1301        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1302
1303        let mut chain = create_descriptor_chain(
1304            &memory,
1305            GuestAddress(0x0),
1306            GuestAddress(0x100),
1307            vec![(Readable, 16), (Readable, 16), (Readable, 16)],
1308            0,
1309        )
1310        .expect("create_descriptor_chain failed");
1311        let reader = &mut chain.reader;
1312
1313        let mut buf = [0u8; 64];
1314        assert_eq!(
1315            reader.read(&mut buf[..]).expect("failed to read to buffer"),
1316            48
1317        );
1318    }
1319
1320    #[test]
1321    fn write_full() {
1322        use DescriptorType::*;
1323
1324        let memory_start_addr = GuestAddress(0x0);
1325        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1326
1327        let mut chain = create_descriptor_chain(
1328            &memory,
1329            GuestAddress(0x0),
1330            GuestAddress(0x100),
1331            vec![(Writable, 16), (Writable, 16), (Writable, 16)],
1332            0,
1333        )
1334        .expect("create_descriptor_chain failed");
1335        let writer = &mut chain.writer;
1336
1337        let buf = [0xdeu8; 64];
1338        assert_eq!(
1339            writer.write(&buf[..]).expect("failed to write from buffer"),
1340            48
1341        );
1342    }
1343
1344    #[test]
1345    fn consume_collect() {
1346        use DescriptorType::*;
1347
1348        let memory_start_addr = GuestAddress(0x0);
1349        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1350        let vs: Vec<Le64> = vec![
1351            0x0101010101010101.into(),
1352            0x0202020202020202.into(),
1353            0x0303030303030303.into(),
1354        ];
1355
1356        let mut write_chain = create_descriptor_chain(
1357            &memory,
1358            GuestAddress(0x0),
1359            GuestAddress(0x100),
1360            vec![(Writable, 24)],
1361            0,
1362        )
1363        .expect("create_descriptor_chain failed");
1364        let writer = &mut write_chain.writer;
1365        writer
1366            .consume(vs.clone())
1367            .expect("failed to consume() a vector");
1368
1369        let mut read_chain = create_descriptor_chain(
1370            &memory,
1371            GuestAddress(0x0),
1372            GuestAddress(0x100),
1373            vec![(Readable, 24)],
1374            0,
1375        )
1376        .expect("create_descriptor_chain failed");
1377        let reader = &mut read_chain.reader;
1378        let vs_read = reader
1379            .collect::<io::Result<Vec<Le64>>, _>()
1380            .expect("failed to collect() values");
1381        assert_eq!(vs, vs_read);
1382    }
1383
1384    #[test]
1385    fn get_remaining_region_with_count() {
1386        use DescriptorType::*;
1387
1388        let memory_start_addr = GuestAddress(0x0);
1389        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1390
1391        let chain = create_descriptor_chain(
1392            &memory,
1393            GuestAddress(0x0),
1394            GuestAddress(0x100),
1395            vec![
1396                (Readable, 16),
1397                (Readable, 16),
1398                (Readable, 96),
1399                (Writable, 64),
1400                (Writable, 1),
1401                (Writable, 3),
1402            ],
1403            0,
1404        )
1405        .expect("create_descriptor_chain failed");
1406
1407        let Reader {
1408            mem: _,
1409            mut regions,
1410        } = chain.reader;
1411
1412        let drain = regions
1413            .get_remaining_regions_with_count(usize::MAX)
1414            .fold(0usize, |total, region| total + region.len);
1415        assert_eq!(drain, 128);
1416
1417        let exact = regions
1418            .get_remaining_regions_with_count(32)
1419            .fold(0usize, |total, region| total + region.len);
1420        assert!(exact > 0);
1421        assert!(exact <= 32);
1422
1423        let split = regions
1424            .get_remaining_regions_with_count(24)
1425            .fold(0usize, |total, region| total + region.len);
1426        assert!(split > 0);
1427        assert!(split <= 24);
1428
1429        regions.consume(64);
1430
1431        let first = regions
1432            .get_remaining_regions_with_count(8)
1433            .fold(0usize, |total, region| total + region.len);
1434        assert!(first > 0);
1435        assert!(first <= 8);
1436    }
1437
1438    #[test]
1439    fn get_remaining_with_count() {
1440        use DescriptorType::*;
1441
1442        let memory_start_addr = GuestAddress(0x0);
1443        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1444
1445        let chain = create_descriptor_chain(
1446            &memory,
1447            GuestAddress(0x0),
1448            GuestAddress(0x100),
1449            vec![
1450                (Readable, 16),
1451                (Readable, 16),
1452                (Readable, 96),
1453                (Writable, 64),
1454                (Writable, 1),
1455                (Writable, 3),
1456            ],
1457            0,
1458        )
1459        .expect("create_descriptor_chain failed");
1460        let Reader {
1461            mem: _,
1462            mut regions,
1463        } = chain.reader;
1464
1465        let drain = regions
1466            .get_remaining_with_count(&memory, usize::MAX)
1467            .iter()
1468            .fold(0usize, |total, iov| total + iov.size());
1469        assert_eq!(drain, 128);
1470
1471        let exact = regions
1472            .get_remaining_with_count(&memory, 32)
1473            .iter()
1474            .fold(0usize, |total, iov| total + iov.size());
1475        assert!(exact > 0);
1476        assert!(exact <= 32);
1477
1478        let split = regions
1479            .get_remaining_with_count(&memory, 24)
1480            .iter()
1481            .fold(0usize, |total, iov| total + iov.size());
1482        assert!(split > 0);
1483        assert!(split <= 24);
1484
1485        regions.consume(64);
1486
1487        let first = regions
1488            .get_remaining_with_count(&memory, 8)
1489            .iter()
1490            .fold(0usize, |total, iov| total + iov.size());
1491        assert!(first > 0);
1492        assert!(first <= 8);
1493    }
1494
1495    #[test]
1496    fn reader_peek_obj() {
1497        use DescriptorType::*;
1498
1499        let memory_start_addr = GuestAddress(0x0);
1500        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1501
1502        // Write test data to memory.
1503        memory
1504            .write_obj_at_addr(Le16::from(0xBEEF), GuestAddress(0x100))
1505            .unwrap();
1506        memory
1507            .write_obj_at_addr(Le16::from(0xDEAD), GuestAddress(0x200))
1508            .unwrap();
1509
1510        let mut chain_reader = create_descriptor_chain(
1511            &memory,
1512            GuestAddress(0x0),
1513            GuestAddress(0x100),
1514            vec![(Readable, 2), (Readable, 2)],
1515            0x100 - 2,
1516        )
1517        .expect("create_descriptor_chain failed");
1518        let reader = &mut chain_reader.reader;
1519
1520        // peek_obj() at the beginning of the chain should return the first object.
1521        let peek1 = reader.peek_obj::<Le16>().unwrap();
1522        assert_eq!(peek1, Le16::from(0xBEEF));
1523
1524        // peek_obj() again should return the same object, since it was not consumed.
1525        let peek2 = reader.peek_obj::<Le16>().unwrap();
1526        assert_eq!(peek2, Le16::from(0xBEEF));
1527
1528        // peek_obj() of an object spanning two descriptors should copy from both.
1529        let peek3 = reader.peek_obj::<Le32>().unwrap();
1530        assert_eq!(peek3, Le32::from(0xDEADBEEF));
1531
1532        // read_obj() should return the first object.
1533        let read1 = reader.read_obj::<Le16>().unwrap();
1534        assert_eq!(read1, Le16::from(0xBEEF));
1535
1536        // peek_obj() of a value that is larger than the rest of the chain should fail.
1537        reader
1538            .peek_obj::<Le32>()
1539            .expect_err("peek_obj past end of chain");
1540
1541        // read_obj() again should return the second object.
1542        let read2 = reader.read_obj::<Le16>().unwrap();
1543        assert_eq!(read2, Le16::from(0xDEAD));
1544
1545        // peek_obj() should fail at the end of the chain.
1546        reader
1547            .peek_obj::<Le16>()
1548            .expect_err("peek_obj past end of chain");
1549    }
1550
1551    #[test]
1552    fn region_reader_failing_io() {
1553        let ex = Executor::new().unwrap();
1554        ex.run_until(region_reader_failing_io_async(&ex)).unwrap();
1555    }
1556    async fn region_reader_failing_io_async(ex: &Executor) {
1557        use DescriptorType::*;
1558
1559        let memory_start_addr = GuestAddress(0x0);
1560        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1561
1562        let mut chain = create_descriptor_chain(
1563            &memory,
1564            GuestAddress(0x0),
1565            GuestAddress(0x100),
1566            vec![(Readable, 256), (Readable, 256)],
1567            0,
1568        )
1569        .expect("create_descriptor_chain failed");
1570
1571        let reader = &mut chain.reader;
1572
1573        // Open a file in read-only mode so writes to it to trigger an I/O error.
1574        let named_temp_file = NamedTempFile::new().expect("failed to create temp file");
1575        let ro_file =
1576            File::open(named_temp_file.path()).expect("failed to open temp file read only");
1577        let async_ro_file = disk::SingleFileDisk::new(ro_file, ex).expect("Failed to crate SFD");
1578
1579        reader
1580            .read_exact_to_at_fut(&async_ro_file, 512, 0)
1581            .await
1582            .expect_err("successfully read more bytes than SingleFileDisk size");
1583
1584        // The write above should have failed entirely, so we end up not writing any bytes at all.
1585        assert_eq!(reader.available_bytes(), 512);
1586        assert_eq!(reader.bytes_read(), 0);
1587    }
1588
1589    #[test]
1590    fn region_writer_failing_io() {
1591        let ex = Executor::new().unwrap();
1592        ex.run_until(region_writer_failing_io_async(&ex)).unwrap()
1593    }
1594    async fn region_writer_failing_io_async(ex: &Executor) {
1595        use DescriptorType::*;
1596
1597        let memory_start_addr = GuestAddress(0x0);
1598        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1599
1600        let mut chain = create_descriptor_chain(
1601            &memory,
1602            GuestAddress(0x0),
1603            GuestAddress(0x100),
1604            vec![(Writable, 256), (Writable, 256)],
1605            0,
1606        )
1607        .expect("create_descriptor_chain failed");
1608
1609        let writer = &mut chain.writer;
1610
1611        let file = tempfile().expect("failed to create temp file");
1612
1613        file.set_len(384).unwrap();
1614        let async_file = disk::SingleFileDisk::new(file, ex).expect("Failed to crate SFD");
1615
1616        writer
1617            .write_all_from_at_fut(&async_file, 512, 0)
1618            .await
1619            .expect_err("successfully wrote more bytes than in SingleFileDisk");
1620
1621        assert_eq!(writer.available_bytes(), 128);
1622        assert_eq!(writer.bytes_written(), 384);
1623    }
1624}