vmm_vhost/
lib.rs

1// Copyright (C) 2019 Alibaba Cloud. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
3
4//! Virtio Vhost Backend Drivers
5//!
6//! Virtio devices use virtqueues to transport data efficiently. The first generation of virtqueue
7//! is a set of three different single-producer, single-consumer ring structures designed to store
8//! generic scatter-gather I/O. The virtio specification 1.1 introduces an alternative compact
9//! virtqueue layout named "Packed Virtqueue", which is more friendly to memory cache system and
10//! hardware implemented virtio devices. The packed virtqueue uses read-write memory, that means
11//! the memory will be both read and written by both host and guest. The new Packed Virtqueue is
12//! preferred for performance.
13//!
14//! Vhost is a mechanism to improve performance of Virtio devices by delegate data plane operations
15//! to dedicated IO service processes. Only the configuration, I/O submission notification, and I/O
16//! completion interruption are piped through the hypervisor.
17//! It uses the same virtqueue layout as Virtio to allow Vhost devices to be mapped directly to
18//! Virtio devices. This allows a Vhost device to be accessed directly by a guest OS inside a
19//! hypervisor process with an existing Virtio (PCI) driver.
20//!
21//! The initial vhost implementation is a part of the Linux kernel and uses ioctl interface to
22//! communicate with userspace applications. Dedicated kernel worker threads are created to handle
23//! IO requests from the guest.
24//!
25//! Later Vhost-user protocol is introduced to complement the ioctl interface used to control the
26//! vhost implementation in the Linux kernel. It implements the control plane needed to establish
27//! virtqueues sharing with a user space process on the same host. It uses communication over a
28//! Unix domain socket to share file descriptors in the ancillary data of the message. The protocol
29//! defines 2 sides of the communication, frontend and backend. Frontend is the application that
30//! shares its virtqueues. Backend is the consumer of the virtqueues. Frontend and backend can be
31//! either a client (i.e. connecting) or server (listening) in the socket communication.
32
33use std::fs::File;
34use std::io::Error as IOError;
35use std::num::TryFromIntError;
36
37use remain::sorted;
38use thiserror::Error as ThisError;
39
40mod backend;
41pub use backend::*;
42
43pub mod message;
44pub use message::VHOST_USER_F_PROTOCOL_FEATURES;
45
46pub mod connection;
47
48mod sys;
49pub use connection::Connection;
50pub use message::BackendReq;
51pub use message::FrontendReq;
52#[cfg(unix)]
53pub use sys::unix;
54
55pub(crate) mod backend_client;
56pub use backend_client::BackendClient;
57mod frontend_server;
58pub use self::frontend_server::Frontend;
59mod backend_server;
60mod frontend_client;
61pub use self::backend_server::Backend;
62pub use self::backend_server::BackendServer;
63pub use self::frontend_client::FrontendClient;
64pub use self::frontend_server::FrontendServer;
65
66/// Errors for vhost-user operations
67#[sorted]
68#[derive(Debug, ThisError)]
69pub enum Error {
70    /// Failure from the backend side.
71    #[error("backend internal error")]
72    BackendInternalError,
73    /// client exited properly.
74    #[error("client exited properly")]
75    ClientExit,
76    /// client disconnected.
77    /// If connection is closed properly, use `ClientExit` instead.
78    #[error("client closed the connection")]
79    Disconnect,
80    #[error("Failed to enter suspended state")]
81    EnterSuspendedState(anyhow::Error),
82    /// Failure from the frontend side.
83    #[error("frontend Internal error")]
84    FrontendInternalError,
85    /// Fd array in question is too big or too small
86    #[error("wrong number of attached fds")]
87    IncorrectFds,
88    /// Invalid cast to int.
89    #[error("invalid cast to int: {0}")]
90    InvalidCastToInt(TryFromIntError),
91    /// Invalid message format, flag or content.
92    #[error("invalid message")]
93    InvalidMessage,
94    /// Unsupported operations due to that the protocol feature hasn't been negotiated.
95    #[error("invalid operation")]
96    InvalidOperation,
97    /// Invalid parameters.
98    #[error("invalid parameters: {0}")]
99    InvalidParam(&'static str),
100    /// Message is too large
101    #[error("oversized message")]
102    OversizedMsg,
103    /// Only part of a message have been sent or received successfully
104    #[error("partial message")]
105    PartialMessage,
106    /// Provided recv buffer was too small, and data was dropped.
107    #[error("buffer for recv was too small, data was dropped: got size {got}, needed {want}")]
108    RecvBufferTooSmall {
109        /// The size of the buffer received.
110        got: usize,
111        /// The expected size of the buffer.
112        want: usize,
113    },
114    /// Error from request handler
115    #[error("handler failed to handle request: {0}")]
116    ReqHandlerError(#[source] IOError),
117    /// Failure to restore.
118    #[error("Failed to restore")]
119    RestoreError(anyhow::Error),
120    /// Failure to snapshot.
121    #[error("Failed to snapshot")]
122    SnapshotError(anyhow::Error),
123    /// Generic socket errors.
124    #[error("socket error: {0}")]
125    SocketError(std::io::Error),
126    /// Should retry the socket operation again.
127    #[error("temporary socket error: {0}")]
128    SocketRetry(std::io::Error),
129    /// Error from tx/rx on a Tube.
130    #[error("failed to read/write on Tube: {0}")]
131    TubeError(base::TubeError),
132}
133
134/// Result of vhost-user operations
135pub type Result<T> = std::result::Result<T, Error>;
136
137/// Result of request handler.
138pub type HandlerResult<T> = std::result::Result<T, IOError>;
139
140#[derive(Copy, Clone)]
141pub struct SharedMemoryRegion {
142    /// The id of the shared memory region. A device may have multiple regions, but each
143    /// must have a unique id. The meaning of a particular region is device-specific.
144    pub id: u8,
145    pub length: u64,
146}
147
148/// Utility function to convert a vector of files into a single file.
149/// Returns `None` if the vector contains no files or more than one file.
150pub(crate) fn into_single_file(mut files: Vec<File>) -> Option<File> {
151    if files.len() != 1 {
152        return None;
153    }
154    Some(files.swap_remove(0))
155}
156
157#[cfg(test)]
158mod test_backend;
159
160#[cfg(test)]
161mod tests {
162    use std::io::ErrorKind;
163    use std::sync::Arc;
164    use std::sync::Barrier;
165    use std::thread;
166
167    use base::AsRawDescriptor;
168    use tempfile::tempfile;
169
170    use super::*;
171    use crate::message::*;
172    use crate::test_backend::TestBackend;
173    use crate::test_backend::VIRTIO_FEATURES;
174    use crate::VhostUserMemoryRegionInfo;
175    use crate::VringConfigData;
176
177    fn create_client_server_pair<S>(backend: S) -> (BackendClient, BackendServer<S>)
178    where
179        S: Backend,
180    {
181        let (client_connection, server_connection) = Connection::pair().unwrap();
182        let backend_client = BackendClient::new(client_connection);
183        (
184            backend_client,
185            BackendServer::<S>::new(server_connection, backend),
186        )
187    }
188
189    /// Utility function to process a header and a message together.
190    fn handle_request(h: &mut BackendServer<TestBackend>) -> Result<()> {
191        // We assume that a header comes together with message body in tests so we don't wait before
192        // calling `process_message()`.
193        let (hdr, files) = h.recv_header()?;
194        h.process_message(hdr, files)
195    }
196
197    #[test]
198    fn create_test_backend() {
199        let mut backend = TestBackend::new();
200
201        backend.set_owner().unwrap();
202        assert!(backend.set_owner().is_err());
203    }
204
205    #[test]
206    fn test_set_owner() {
207        let test_backend = TestBackend::new();
208        let (backend_client, mut backend_server) = create_client_server_pair(test_backend);
209
210        assert!(!backend_server.as_ref().owned);
211        backend_client.set_owner().unwrap();
212        handle_request(&mut backend_server).unwrap();
213        assert!(backend_server.as_ref().owned);
214        backend_client.set_owner().unwrap();
215        assert!(handle_request(&mut backend_server).is_err());
216        assert!(backend_server.as_ref().owned);
217    }
218
219    #[test]
220    fn test_set_features() {
221        let mbar = Arc::new(Barrier::new(2));
222        let sbar = mbar.clone();
223        let test_backend = TestBackend::new();
224        let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
225
226        thread::spawn(move || {
227            handle_request(&mut backend_server).unwrap();
228            assert!(backend_server.as_ref().owned);
229
230            handle_request(&mut backend_server).unwrap();
231            handle_request(&mut backend_server).unwrap();
232            assert_eq!(
233                backend_server.as_ref().acked_features,
234                VIRTIO_FEATURES & !0x1
235            );
236
237            handle_request(&mut backend_server).unwrap();
238            handle_request(&mut backend_server).unwrap();
239            assert_eq!(
240                backend_server.as_ref().acked_protocol_features,
241                VhostUserProtocolFeatures::all().bits()
242            );
243
244            sbar.wait();
245        });
246
247        backend_client.set_owner().unwrap();
248
249        // set virtio features
250        let features = backend_client.get_features().unwrap();
251        assert_eq!(features, VIRTIO_FEATURES);
252        backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
253
254        // set vhost protocol features
255        let features = backend_client.get_protocol_features().unwrap();
256        assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
257        backend_client.set_protocol_features(features).unwrap();
258
259        mbar.wait();
260    }
261
262    #[test]
263    fn test_client_server_process_no_need_reply() {
264        test_client_server_process(false);
265    }
266
267    #[test]
268    fn test_client_server_process_need_reply() {
269        test_client_server_process(true);
270    }
271
272    fn test_client_server_process(set_need_reply: bool) {
273        let mbar = Arc::new(Barrier::new(2));
274        let sbar = mbar.clone();
275        let test_backend = TestBackend::new();
276        let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
277
278        thread::spawn(move || {
279            // set_own()
280            handle_request(&mut backend_server).unwrap();
281            assert!(backend_server.as_ref().owned);
282
283            // get/set_features()
284            handle_request(&mut backend_server).unwrap();
285            handle_request(&mut backend_server).unwrap();
286            assert_eq!(
287                backend_server.as_ref().acked_features,
288                VIRTIO_FEATURES & !0x1
289            );
290
291            handle_request(&mut backend_server).unwrap();
292            handle_request(&mut backend_server).unwrap();
293            assert_eq!(
294                backend_server.as_ref().acked_protocol_features,
295                VhostUserProtocolFeatures::all().bits()
296            );
297
298            // get_inflight_fd()
299            handle_request(&mut backend_server).unwrap();
300            // set_inflight_fd()
301            handle_request(&mut backend_server).unwrap();
302
303            // get_queue_num()
304            handle_request(&mut backend_server).unwrap();
305
306            // set_mem_table()
307            handle_request(&mut backend_server).unwrap();
308
309            // get/set_config()
310            handle_request(&mut backend_server).unwrap();
311            handle_request(&mut backend_server).unwrap();
312
313            // set_backend_req_fd
314            handle_request(&mut backend_server).unwrap();
315
316            // set_vring_enable
317            handle_request(&mut backend_server).unwrap();
318
319            // set_vring_xxx
320            handle_request(&mut backend_server).unwrap();
321            handle_request(&mut backend_server).unwrap();
322            handle_request(&mut backend_server).unwrap();
323            handle_request(&mut backend_server).unwrap();
324            handle_request(&mut backend_server).unwrap();
325            handle_request(&mut backend_server).unwrap();
326
327            // get_max_mem_slots()
328            handle_request(&mut backend_server).unwrap();
329
330            // add_mem_region()
331            handle_request(&mut backend_server).unwrap();
332
333            // remove_mem_region()
334            handle_request(&mut backend_server).unwrap();
335
336            // set_log_base
337            //
338            // Results in an error because it isn't implemented. When `set_need_reply` is true, the
339            // client waits for an ACK that will never come, instead they will get an error only
340            // when we drop `backend_server` below.
341            handle_request(&mut backend_server).unwrap_err();
342
343            std::mem::drop(backend_server);
344            sbar.wait();
345        });
346
347        backend_client.set_owner().unwrap();
348
349        // set virtio features
350        let features = backend_client.get_features().unwrap();
351        assert_eq!(features, VIRTIO_FEATURES);
352        backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
353
354        // set vhost protocol features
355        let features = backend_client.get_protocol_features().unwrap();
356        assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
357        backend_client.set_protocol_features(features).unwrap();
358
359        backend_client.set_need_reply(set_need_reply);
360
361        // Retrieve inflight I/O tracking information
362        let (inflight_info, inflight_file) = backend_client
363            .get_inflight_fd(&VhostUserInflight {
364                num_queues: 2,
365                queue_size: 256,
366                ..Default::default()
367            })
368            .unwrap();
369        // Set the buffer back to the backend
370        backend_client
371            .set_inflight_fd(&inflight_info, inflight_file.as_raw_descriptor())
372            .unwrap();
373
374        let num = backend_client.get_queue_num().unwrap();
375        assert_eq!(num, 2);
376
377        let event = base::Event::new().unwrap();
378        let mem = [VhostUserMemoryRegionInfo {
379            guest_phys_addr: 0,
380            memory_size: 0x10_0000,
381            userspace_addr: 0,
382            mmap_offset: 0,
383            mmap_handle: event.as_raw_descriptor(),
384        }];
385        backend_client.set_mem_table(&mem).unwrap();
386
387        backend_client
388            .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8])
389            .unwrap();
390        let buf = [0x0u8; 4];
391        let (reply_body, reply_payload) = backend_client
392            .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
393            .unwrap();
394        let offset = reply_body.offset;
395        assert_eq!(offset, 0x100);
396        assert_eq!(reply_payload[0], 0xa5);
397
398        #[cfg(windows)]
399        let tubes = base::Tube::pair().unwrap();
400        #[cfg(windows)]
401        let descriptor =
402            // SAFETY:
403            // Safe because we will be importing the Tube in the other thread.
404            unsafe { tube_transporter::packed_tube::pack(tubes.0, std::process::id()).unwrap() };
405
406        #[cfg(unix)]
407        let descriptor = base::Event::new().unwrap();
408
409        backend_client.set_backend_req_fd(&descriptor).unwrap();
410        backend_client.set_vring_enable(0, true).unwrap();
411
412        backend_client.set_vring_num(0, 256).unwrap();
413        backend_client.set_vring_base(0, 0).unwrap();
414        let config = VringConfigData {
415            queue_size: 128,
416            flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
417            desc_table_addr: 0x1000,
418            used_ring_addr: 0x2000,
419            avail_ring_addr: 0x3000,
420            log_addr: Some(0x4000),
421        };
422        backend_client.set_vring_addr(0, &config).unwrap();
423        backend_client.set_vring_call(0, &event).unwrap();
424        backend_client.set_vring_kick(0, &event).unwrap();
425        backend_client.set_vring_err(0, &event).unwrap();
426
427        let max_mem_slots = backend_client.get_max_mem_slots().unwrap();
428        assert_eq!(max_mem_slots, 32);
429
430        let region_file = tempfile().unwrap();
431        let region = VhostUserMemoryRegionInfo {
432            guest_phys_addr: 0x10_0000,
433            memory_size: 0x10_0000,
434            userspace_addr: 0,
435            mmap_offset: 0,
436            mmap_handle: region_file.as_raw_descriptor(),
437        };
438        backend_client.add_mem_region(&region).unwrap();
439
440        backend_client.remove_mem_region(&region).unwrap();
441
442        // set_log_base isn't implemented by the server and so will break the connection.
443        let result = backend_client.set_log_base(0, Some(event.as_raw_descriptor()));
444        if set_need_reply {
445            // When using `set_need_reply`, we'll get an immediate disconnect error.
446            assert!(
447                matches!(result, Err(Error::Disconnect)),
448                "unexpected result: {result:?}"
449            );
450        } else {
451            // When not using `set_need_reply`, it will seem to succeed and then the next request
452            // will fail.
453            result.unwrap();
454            let result = backend_client.get_features();
455            match &result {
456                // Windows errors with Disconnect and Unix with a SocketError.
457                Err(Error::Disconnect) => {}
458                Err(Error::SocketError(e))
459                    if e.kind() == ErrorKind::ConnectionReset
460                        || e.kind() == ErrorKind::BrokenPipe => {}
461                _ => panic!("unexpected result: {result:?}"),
462            }
463        }
464
465        mbar.wait();
466    }
467
468    #[test]
469    fn test_error_display() {
470        assert_eq!(
471            format!("{}", Error::InvalidParam("")),
472            "invalid parameters: "
473        );
474        assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation");
475    }
476}