vhost/
lib.rs

1// Copyright 2017 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Linux vhost kernel API wrapper.
6
7#![cfg(any(target_os = "android", target_os = "linux"))]
8
9pub mod net;
10#[cfg(any(target_os = "android", target_os = "linux"))]
11#[cfg(target_arch = "aarch64")]
12mod scmi;
13mod vsock;
14
15use std::alloc::Layout;
16use std::io::Error as IoError;
17use std::ptr::null;
18
19use base::ioctl;
20use base::ioctl_with_mut_ref;
21use base::ioctl_with_ptr;
22use base::ioctl_with_ref;
23use base::AsRawDescriptor;
24use base::Event;
25use base::LayoutAllocation;
26use remain::sorted;
27use static_assertions::const_assert;
28use thiserror::Error;
29use vm_memory::GuestAddress;
30use vm_memory::GuestMemory;
31use vm_memory::GuestMemoryError;
32
33#[cfg(any(target_os = "android", target_os = "linux"))]
34pub use crate::net::Net;
35#[cfg(any(target_os = "android", target_os = "linux"))]
36pub use crate::net::NetT;
37#[cfg(any(target_os = "android", target_os = "linux"))]
38#[cfg(target_arch = "aarch64")]
39pub use crate::scmi::Scmi;
40pub use crate::vsock::Vsock;
41
42#[sorted]
43#[derive(Error, Debug)]
44pub enum Error {
45    /// Invalid available address.
46    #[error("invalid available address: {0}")]
47    AvailAddress(GuestMemoryError),
48    /// Invalid descriptor table address.
49    #[error("invalid descriptor table address: {0}")]
50    DescriptorTableAddress(GuestMemoryError),
51    /// Invalid queue.
52    #[error("invalid queue")]
53    InvalidQueue,
54    /// Error while running ioctl.
55    #[error("failed to run ioctl: {0}")]
56    IoctlError(IoError),
57    /// Invalid log address.
58    #[error("invalid log address: {0}")]
59    LogAddress(GuestMemoryError),
60    /// Invalid used address.
61    #[error("invalid used address: {0}")]
62    UsedAddress(GuestMemoryError),
63    /// Error opening vhost device.
64    #[error("failed to open vhost device: {0}")]
65    VhostOpen(IoError),
66}
67
68pub type Result<T> = std::result::Result<T, Error>;
69
70fn ioctl_result<T>() -> Result<T> {
71    Err(Error::IoctlError(IoError::last_os_error()))
72}
73
74/// An interface for setting up vhost-based virtio devices.  Vhost-based devices are different
75/// from regular virtio devices because the host kernel takes care of handling all the data
76/// transfer.  The device itself only needs to deal with setting up the kernel driver and
77/// managing the control channel.
78pub trait Vhost: AsRawDescriptor + std::marker::Sized {
79    /// Set the current process as the owner of this file descriptor.
80    /// This must be run before any other vhost ioctls.
81    fn set_owner(&self) -> Result<()> {
82        // SAFETY:
83        // This ioctl is called on a valid vhost_net descriptor and has its
84        // return value checked.
85        let ret = unsafe { ioctl(self, virtio_sys::VHOST_SET_OWNER) };
86        if ret < 0 {
87            return ioctl_result();
88        }
89        Ok(())
90    }
91
92    /// Give up ownership and reset the device to default values. Allows a subsequent call to
93    /// `set_owner` to succeed.
94    fn reset_owner(&self) -> Result<()> {
95        // SAFETY:
96        // This ioctl is called on a valid vhost fd and has its
97        // return value checked.
98        let ret = unsafe { ioctl(self, virtio_sys::VHOST_RESET_OWNER) };
99        if ret < 0 {
100            return ioctl_result();
101        }
102        Ok(())
103    }
104
105    /// Get a bitmask of supported virtio/vhost features.
106    fn get_features(&self) -> Result<u64> {
107        let mut avail_features: u64 = 0;
108        // SAFETY:
109        // This ioctl is called on a valid vhost_net descriptor and has its
110        // return value checked.
111        let ret = unsafe {
112            ioctl_with_mut_ref(self, virtio_sys::VHOST_GET_FEATURES, &mut avail_features)
113        };
114        if ret < 0 {
115            return ioctl_result();
116        }
117        Ok(avail_features)
118    }
119
120    /// Inform the vhost subsystem which features to enable. This should be a subset of
121    /// supported features from VHOST_GET_FEATURES.
122    ///
123    /// # Arguments
124    /// * `features` - Bitmask of features to set.
125    fn set_features(&self, features: u64) -> Result<()> {
126        // SAFETY:
127        // This ioctl is called on a valid vhost_net descriptor and has its
128        // return value checked.
129        let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_FEATURES, &features) };
130        if ret < 0 {
131            return ioctl_result();
132        }
133        Ok(())
134    }
135
136    /// Set the guest memory mappings for vhost to use.
137    fn set_mem_table(&self, mem: &GuestMemory) -> Result<()> {
138        const SIZE_OF_MEMORY: usize = std::mem::size_of::<virtio_sys::vhost::vhost_memory>();
139        const SIZE_OF_REGION: usize = std::mem::size_of::<virtio_sys::vhost::vhost_memory_region>();
140        const ALIGN_OF_MEMORY: usize = std::mem::align_of::<virtio_sys::vhost::vhost_memory>();
141        const_assert!(
142            ALIGN_OF_MEMORY >= std::mem::align_of::<virtio_sys::vhost::vhost_memory_region>()
143        );
144
145        let num_regions = mem.num_regions() as usize;
146        let size = SIZE_OF_MEMORY + num_regions * SIZE_OF_REGION;
147        let layout = Layout::from_size_align(size, ALIGN_OF_MEMORY).expect("impossible layout");
148        let mut allocation = LayoutAllocation::zeroed(layout);
149
150        // SAFETY:
151        // Safe to obtain an exclusive reference because there are no other
152        // references to the allocation yet and all-zero is a valid bit pattern.
153        let vhost_memory = unsafe { allocation.as_mut::<virtio_sys::vhost::vhost_memory>() };
154
155        vhost_memory.nregions = num_regions as u32;
156        // SAFETY:
157        // regions is a zero-length array, so taking a mut slice requires that
158        // we correctly specify the size to match the amount of backing memory.
159        let vhost_regions = unsafe { vhost_memory.regions.as_mut_slice(num_regions) };
160
161        for region in mem.regions() {
162            vhost_regions[region.index] = virtio_sys::vhost::vhost_memory_region {
163                guest_phys_addr: region.guest_addr.offset(),
164                memory_size: region.size as u64,
165                userspace_addr: region.host_addr as u64,
166                flags_padding: 0u64,
167            };
168        }
169
170        // SAFETY:
171        // This ioctl is called with a pointer that is valid for the lifetime
172        // of this function. The kernel will make its own copy of the memory
173        // tables. As always, check the return value.
174        let ret = unsafe { ioctl_with_ptr(self, virtio_sys::VHOST_SET_MEM_TABLE, vhost_memory) };
175        if ret < 0 {
176            return ioctl_result();
177        }
178
179        Ok(())
180
181        // vhost_memory allocation is deallocated.
182    }
183
184    /// Set the number of descriptors in the vring.
185    ///
186    /// # Arguments
187    /// * `queue_index` - Index of the queue to set descriptor count for.
188    /// * `num` - Number of descriptors in the queue.
189    fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
190        let vring_state = virtio_sys::vhost::vhost_vring_state {
191            index: queue_index as u32,
192            num: num as u32,
193        };
194
195        // SAFETY:
196        // This ioctl is called on a valid vhost_net descriptor and has its
197        // return value checked.
198        let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_VRING_NUM, &vring_state) };
199        if ret < 0 {
200            return ioctl_result();
201        }
202        Ok(())
203    }
204
205    /// Set the addresses for a given vring.
206    ///
207    /// # Arguments
208    /// * `queue_size` - Actual queue size negotiated by the driver.
209    /// * `queue_index` - Index of the queue to set addresses for.
210    /// * `flags` - Bitmask of vring flags.
211    /// * `desc_addr` - Descriptor table address.
212    /// * `used_addr` - Used ring buffer address.
213    /// * `avail_addr` - Available ring buffer address.
214    /// * `log_addr` - Optional address for logging.
215    fn set_vring_addr(
216        &self,
217        mem: &GuestMemory,
218        queue_size: u16,
219        queue_index: usize,
220        flags: u32,
221        desc_addr: GuestAddress,
222        used_addr: GuestAddress,
223        avail_addr: GuestAddress,
224        log_addr: Option<GuestAddress>,
225    ) -> Result<()> {
226        if queue_size == 0 || !queue_size.is_power_of_two() {
227            return Err(Error::InvalidQueue);
228        }
229
230        let queue_size = usize::from(queue_size);
231
232        let desc_table_size = 16 * queue_size;
233        let desc_table = mem
234            .get_slice_at_addr(desc_addr, desc_table_size)
235            .map_err(Error::DescriptorTableAddress)?;
236
237        let used_ring_size = 6 + 8 * queue_size;
238        let used_ring = mem
239            .get_slice_at_addr(used_addr, used_ring_size)
240            .map_err(Error::UsedAddress)?;
241
242        let avail_ring_size = 6 + 2 * queue_size;
243        let avail_ring = mem
244            .get_slice_at_addr(avail_addr, avail_ring_size)
245            .map_err(Error::AvailAddress)?;
246
247        let log_addr = match log_addr {
248            None => null(),
249            Some(a) => mem.get_host_address(a).map_err(Error::LogAddress)?,
250        };
251
252        let vring_addr = virtio_sys::vhost::vhost_vring_addr {
253            index: queue_index as u32,
254            flags,
255            desc_user_addr: desc_table.as_ptr() as u64,
256            used_user_addr: used_ring.as_ptr() as u64,
257            avail_user_addr: avail_ring.as_ptr() as u64,
258            log_guest_addr: log_addr as u64,
259        };
260
261        // SAFETY:
262        // This ioctl is called on a valid vhost_net descriptor and has its
263        // return value checked.
264        let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_VRING_ADDR, &vring_addr) };
265        if ret < 0 {
266            return ioctl_result();
267        }
268        Ok(())
269    }
270
271    /// Set the first index to look for available descriptors.
272    ///
273    /// # Arguments
274    /// * `queue_index` - Index of the queue to modify.
275    /// * `num` - Index where available descriptors start.
276    fn set_vring_base(&self, queue_index: usize, num: u16) -> Result<()> {
277        let vring_state = virtio_sys::vhost::vhost_vring_state {
278            index: queue_index as u32,
279            num: num as u32,
280        };
281
282        // SAFETY:
283        // This ioctl is called on a valid vhost_net descriptor and has its
284        // return value checked.
285        let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_VRING_BASE, &vring_state) };
286        if ret < 0 {
287            return ioctl_result();
288        }
289        Ok(())
290    }
291
292    /// Gets the index of the next available descriptor in the queue.
293    ///
294    /// # Arguments
295    /// * `queue_index` - Index of the queue to query.
296    fn get_vring_base(&self, queue_index: usize) -> Result<u16> {
297        let mut vring_state = virtio_sys::vhost::vhost_vring_state {
298            index: queue_index as u32,
299            num: 0,
300        };
301
302        // SAFETY:
303        // Safe because this will only modify `vring_state` and we check the return value.
304        let ret =
305            unsafe { ioctl_with_mut_ref(self, virtio_sys::VHOST_GET_VRING_BASE, &mut vring_state) };
306        if ret < 0 {
307            return ioctl_result();
308        }
309
310        Ok(vring_state.num as u16)
311    }
312
313    /// Set the event to trigger when buffers have been used by the host.
314    ///
315    /// # Arguments
316    /// * `queue_index` - Index of the queue to modify.
317    /// * `event` - Event to trigger.
318    fn set_vring_call(&self, queue_index: usize, event: &Event) -> Result<()> {
319        let vring_file = virtio_sys::vhost::vhost_vring_file {
320            index: queue_index as u32,
321            fd: event.as_raw_descriptor(),
322        };
323
324        // SAFETY:
325        // This ioctl is called on a valid vhost_net descriptor and has its
326        // return value checked.
327        let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_VRING_CALL, &vring_file) };
328        if ret < 0 {
329            return ioctl_result();
330        }
331        Ok(())
332    }
333
334    /// Set the event to trigger to signal an error.
335    ///
336    /// # Arguments
337    /// * `queue_index` - Index of the queue to modify.
338    /// * `event` - Event to trigger.
339    fn set_vring_err(&self, queue_index: usize, event: &Event) -> Result<()> {
340        let vring_file = virtio_sys::vhost::vhost_vring_file {
341            index: queue_index as u32,
342            fd: event.as_raw_descriptor(),
343        };
344
345        // SAFETY:
346        // This ioctl is called on a valid vhost_net fd and has its
347        // return value checked.
348        let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_VRING_ERR, &vring_file) };
349        if ret < 0 {
350            return ioctl_result();
351        }
352        Ok(())
353    }
354
355    /// Set the event that will be signaled by the guest when buffers are
356    /// available for the host to process.
357    ///
358    /// # Arguments
359    /// * `queue_index` - Index of the queue to modify.
360    /// * `event` - Event that will be signaled from guest.
361    fn set_vring_kick(&self, queue_index: usize, event: &Event) -> Result<()> {
362        let vring_file = virtio_sys::vhost::vhost_vring_file {
363            index: queue_index as u32,
364            fd: event.as_raw_descriptor(),
365        };
366
367        // SAFETY:
368        // This ioctl is called on a valid vhost_net descriptor and has its
369        // return value checked.
370        let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_VRING_KICK, &vring_file) };
371        if ret < 0 {
372            return ioctl_result();
373        }
374        Ok(())
375    }
376}