devices/virtio/vhost_user_backend/
handler.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Library for implementing vhost-user device executables.
6//!
7//! This crate provides
8//! * `VhostUserDevice` trait, which is a collection of methods to handle vhost-user requests, and
9//! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10//!
11//! They are expected to be used as follows:
12//!
13//! 1. Define a struct and implement `VhostUserDevice` for it.
14//! 2. Create a `DeviceRequestHandler` with the backend struct.
15//! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16//!
17//! ```ignore
18//! struct MyBackend {
19//!   /* fields */
20//! }
21//!
22//! impl VhostUserDevice for MyBackend {
23//!   /* implement methods */
24//! }
25//!
26//! fn main() -> Result<(), Box<dyn Error>> {
27//!   let backend = MyBackend { /* initialize fields */ };
28//!   let handler = DeviceRequestHandler::new(backend);
29//!   let socket = std::path::Path("/path/to/socket");
30//!   let ex = cros_async::Executor::new()?;
31//!
32//!   if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33//!     eprintln!("error happened: {}", e);
34//!   }
35//!   Ok(())
36//! }
37//! ```
38// Implementation note:
39// This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
40// protocol. DeviceRequestHandler implements the Backend trait from vmm_vhost, and includes some
41// common code for setting up guest memory and managing partially configured vrings.
42// DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request() when it
43// becomes readable. handle_request() reads and parses the message and then calls one of the
44// Backend trait methods. These dispatch back to the supplied VhostUserDevice implementation (this
45// is what our devices implement).
46
47pub(super) mod sys;
48
49use std::collections::BTreeMap;
50use std::convert::From;
51use std::fs::File;
52use std::io::BufReader;
53use std::io::Write;
54use std::num::Wrapping;
55#[cfg(any(target_os = "android", target_os = "linux"))]
56use std::os::unix::io::AsRawFd;
57use std::sync::Arc;
58
59use anyhow::bail;
60use anyhow::Context;
61#[cfg(any(target_os = "android", target_os = "linux"))]
62use base::clear_fd_flags;
63use base::error;
64use base::trace;
65use base::warn;
66use base::Event;
67use base::Protection;
68use base::SafeDescriptor;
69use base::SharedMemory;
70use base::WorkerThread;
71use cros_async::TaskHandle;
72use hypervisor::MemCacheType;
73use serde::Deserialize;
74use serde::Serialize;
75use snapshot::AnySnapshot;
76use sync::Mutex;
77use thiserror::Error as ThisError;
78use vm_control::VmMemorySource;
79use vm_memory::GuestAddress;
80use vm_memory::GuestMemory;
81use vm_memory::MemoryRegion;
82use vmm_vhost::message::VhostUserConfigFlags;
83use vmm_vhost::message::VhostUserExternalMapMsg;
84use vmm_vhost::message::VhostUserGpuMapMsg;
85use vmm_vhost::message::VhostUserInflight;
86use vmm_vhost::message::VhostUserMMap;
87use vmm_vhost::message::VhostUserMMapFlags;
88use vmm_vhost::message::VhostUserMemoryRegion;
89use vmm_vhost::message::VhostUserMigrationPhase;
90use vmm_vhost::message::VhostUserProtocolFeatures;
91use vmm_vhost::message::VhostUserSingleMemoryRegion;
92use vmm_vhost::message::VhostUserTransferDirection;
93use vmm_vhost::message::VhostUserVringAddrFlags;
94use vmm_vhost::message::VhostUserVringState;
95use vmm_vhost::Connection;
96use vmm_vhost::Error as VhostError;
97use vmm_vhost::Frontend;
98use vmm_vhost::FrontendClient;
99use vmm_vhost::Result as VhostResult;
100use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
101
102use crate::virtio::Interrupt;
103use crate::virtio::Queue;
104use crate::virtio::QueueConfig;
105use crate::virtio::SharedMemoryMapper;
106use crate::virtio::SharedMemoryRegion;
107
108/// Keeps a mapping from the vmm's virtual addresses to guest addresses.
109/// used to translate messages from the vmm to guest offsets.
110#[derive(Default)]
111pub struct MappingInfo {
112    pub vmm_addr: u64,
113    pub guest_phys: u64,
114    pub size: u64,
115}
116
117pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
118    for map in maps {
119        if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
120            return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
121        }
122    }
123    Err(VhostError::InvalidMessage)
124}
125
126/// Trait for vhost-user devices. Analogous to the `VirtioDevice` trait.
127///
128/// In contrast with [[vmm_vhost::Backend]], which closely matches the vhost-user spec, this trait
129/// is designed to follow crosvm conventions for implementing devices.
130pub trait VhostUserDevice {
131    /// The maximum number of queues that this backend can manage.
132    fn max_queue_num(&self) -> usize;
133
134    /// The set of feature bits that this backend supports.
135    fn features(&self) -> u64;
136
137    /// Acknowledges that this set of features should be enabled.
138    ///
139    /// Implementations only need to handle device-specific feature bits; the `DeviceRequestHandler`
140    /// framework will manage generic vhost and vring features.
141    ///
142    /// `DeviceRequestHandler` checks for valid features before calling this function, so the
143    /// features in `value` will always be a subset of those advertised by `features()`.
144    fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {
145        Ok(())
146    }
147
148    /// The set of protocol feature bits that this backend supports.
149    fn protocol_features(&self) -> VhostUserProtocolFeatures;
150
151    /// Reads this device configuration space at `offset`.
152    fn read_config(&self, offset: u64, dst: &mut [u8]);
153
154    /// writes `data` to this device's configuration space at `offset`.
155    fn write_config(&self, _offset: u64, _data: &[u8]) {}
156
157    /// Indicates that the backend should start processing requests for virtio queue number `idx`.
158    /// This method must not block the current thread so device backends should either spawn an
159    /// async task or another thread to handle messages from the Queue.
160    fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;
161
162    /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
163    /// This method should return the queue passed to `start_queue` for the corresponding `idx`.
164    /// This method will only be called for queues that were previously started by `start_queue`.
165    fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
166
167    /// Resets the vhost-user backend.
168    fn reset(&mut self);
169
170    /// Returns the device's shared memory region if present.
171    fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
172        None
173    }
174
175    /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message
176    /// handling.
177    ///
178    /// This method will be called when `VhostUserProtocolFeatures::BACKEND_REQ` is
179    /// negotiated.
180    fn set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection) {}
181
182    /// Enter the "suspended device state" described in the vhost-user spec. See the spec for
183    /// requirements.
184    ///
185    /// One reasonably foolproof way to satisfy the requirements is to stop all worker threads.
186    ///
187    /// Called after a `stop_queue` call if there are no running queues left. Also called soon
188    /// after device creation to ensure the device is acting suspended immediately on construction.
189    ///
190    /// The next `start_queue` call implicitly exits the "suspend device state".
191    ///
192    /// * Ok(())    => device successfully suspended
193    /// * Err(_)    => unrecoverable error
194    fn enter_suspended_state(&mut self) -> anyhow::Result<()>;
195
196    /// Snapshot device and return serialized state.
197    fn snapshot(&mut self) -> anyhow::Result<AnySnapshot>;
198
199    /// Restore device state from a snapshot.
200    fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>;
201
202    /// Whether guest memory should be unmapped in forked processes.
203    ///
204    /// This is intended for use in combination with --protected-vm, where the guest memory can be
205    /// dangerous to access. Some systems, e.g. Android, have tools that fork processes and examine
206    /// their memory. This flag effectively hides the guest memory from those tools.
207    ///
208    /// Not compatible with sandboxing.
209    fn unmap_guest_memory_on_fork(&self) -> bool {
210        false
211    }
212}
213
214/// A virtio ring entry.
215struct Vring {
216    // The queue config. This doesn't get mutated by the queue workers.
217    queue: QueueConfig,
218    doorbell: Option<Interrupt>,
219    enabled: bool,
220}
221
222impl Vring {
223    fn new(max_size: u16, features: u64) -> Self {
224        Self {
225            queue: QueueConfig::new(max_size, features),
226            doorbell: None,
227            enabled: false,
228        }
229    }
230
231    fn reset(&mut self) {
232        self.queue.reset();
233        self.doorbell = None;
234        self.enabled = false;
235    }
236}
237
238/// Ops for running vhost-user over a stream (i.e. regular protocol).
239pub(super) struct VhostUserRegularOps;
240
241impl VhostUserRegularOps {
242    pub fn set_mem_table(
243        contexts: &[VhostUserMemoryRegion],
244        files: Vec<File>,
245    ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
246        if files.len() != contexts.len() {
247            return Err(VhostError::InvalidParam(
248                "number of files & contexts was not equal",
249            ));
250        }
251
252        let mut regions = Vec::with_capacity(files.len());
253        for (region, file) in contexts.iter().zip(files.into_iter()) {
254            let region = MemoryRegion::new_from_shm(
255                region.memory_size,
256                GuestAddress(region.guest_phys_addr),
257                region.mmap_offset,
258                Arc::new(
259                    SharedMemory::from_safe_descriptor(
260                        SafeDescriptor::from(file),
261                        region.memory_size,
262                    )
263                    .unwrap(),
264                ),
265            )
266            .map_err(|e| {
267                error!("failed to create a memory region: {}", e);
268                VhostError::InvalidOperation
269            })?;
270            regions.push(region);
271        }
272        let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
273            error!("failed to create guest memory: {}", e);
274            VhostError::InvalidOperation
275        })?;
276
277        let vmm_maps = contexts
278            .iter()
279            .map(|region| MappingInfo {
280                vmm_addr: region.user_addr,
281                guest_phys: region.guest_phys_addr,
282                size: region.memory_size,
283            })
284            .collect();
285        Ok((guest_mem, vmm_maps))
286    }
287}
288
289/// An adapter that implements `vmm_vhost::Backend` for any type implementing `VhostUserDevice`.
290pub struct DeviceRequestHandler<T: VhostUserDevice> {
291    vrings: Vec<Vring>,
292    owned: bool,
293    vmm_maps: Option<Vec<MappingInfo>>,
294    mem: Option<GuestMemory>,
295    acked_features: u64,
296    acked_protocol_features: VhostUserProtocolFeatures,
297    backend: T,
298    backend_req_connection: Option<VhostBackendReqConnection>,
299    // Thread processing active device state FD.
300    device_state_thread: Option<DeviceStateThread>,
301}
302
303enum DeviceStateThread {
304    Save(WorkerThread<Result<(), ciborium::ser::Error<std::io::Error>>>),
305    Load(WorkerThread<Result<DeviceRequestHandlerSnapshot, ciborium::de::Error<std::io::Error>>>),
306}
307
308#[derive(Serialize, Deserialize)]
309pub struct DeviceRequestHandlerSnapshot {
310    acked_features: u64,
311    acked_protocol_features: u64,
312    backend: AnySnapshot,
313}
314
315impl<T: VhostUserDevice> DeviceRequestHandler<T> {
316    /// Creates a vhost-user handler instance for `backend`.
317    pub(crate) fn new(mut backend: T) -> Self {
318        let mut vrings = Vec::with_capacity(backend.max_queue_num());
319        for _ in 0..backend.max_queue_num() {
320            vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
321        }
322
323        // VhostUserDevice implementations must support `enter_suspended_state()`.
324        // Call it on startup to ensure it works and to initialize the device in a suspended state.
325        backend
326            .enter_suspended_state()
327            .expect("enter_suspended_state failed on device init");
328
329        DeviceRequestHandler {
330            vrings,
331            owned: false,
332            vmm_maps: None,
333            mem: None,
334            acked_features: 0,
335            acked_protocol_features: VhostUserProtocolFeatures::empty(),
336            backend,
337            backend_req_connection: None,
338            device_state_thread: None,
339        }
340    }
341
342    /// Check if all queues are stopped.
343    ///
344    /// The device can be suspended with `enter_suspended_state()` only when all queues are stopped.
345    fn all_queues_stopped(&self) -> bool {
346        self.vrings.iter().all(|vring| !vring.queue.ready())
347    }
348}
349
350impl<T: VhostUserDevice> Drop for DeviceRequestHandler<T> {
351    fn drop(&mut self) {
352        for (index, vring) in self.vrings.iter().enumerate() {
353            if vring.queue.ready() {
354                if let Err(e) = self.backend.stop_queue(index) {
355                    error!("Failed to stop queue {} during drop: {:#}", index, e);
356                }
357            }
358        }
359    }
360}
361
362impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
363    fn as_ref(&self) -> &T {
364        &self.backend
365    }
366}
367
368impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
369    fn as_mut(&mut self) -> &mut T {
370        &mut self.backend
371    }
372}
373
374impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
375    fn set_owner(&mut self) -> VhostResult<()> {
376        if self.owned {
377            return Err(VhostError::InvalidOperation);
378        }
379        self.owned = true;
380        Ok(())
381    }
382
383    fn reset_owner(&mut self) -> VhostResult<()> {
384        self.owned = false;
385        self.acked_features = 0;
386        self.backend.reset();
387        Ok(())
388    }
389
390    fn get_features(&mut self) -> VhostResult<u64> {
391        let features = self.backend.features();
392        Ok(features)
393    }
394
395    fn set_features(&mut self, features: u64) -> VhostResult<()> {
396        if !self.owned {
397            return Err(VhostError::InvalidOperation);
398        }
399
400        let unexpected_features = features & !self.backend.features();
401        if unexpected_features != 0 {
402            error!("unexpected set_features {:#x}", unexpected_features);
403            return Err(VhostError::InvalidParam("unexpected set_features"));
404        }
405
406        if let Err(e) = self.backend.ack_features(features) {
407            error!("failed to acknowledge features 0x{:x}: {}", features, e);
408            return Err(VhostError::InvalidOperation);
409        }
410
411        self.acked_features |= features;
412
413        // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
414        // enabled state.
415        // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
416        // disabled state.
417        // Client must not pass data to/from the backend until ring is enabled by
418        // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
419        // VHOST_USER_SET_VRING_ENABLE with parameter 0.
420        let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
421        for v in &mut self.vrings {
422            v.enabled = vring_enabled;
423        }
424
425        Ok(())
426    }
427
428    fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
429        Ok(self.backend.protocol_features() | VhostUserProtocolFeatures::REPLY_ACK)
430    }
431
432    fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
433        let features = match VhostUserProtocolFeatures::from_bits(features) {
434            Some(proto_features) => proto_features,
435            None => {
436                error!(
437                    "unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",
438                    features
439                );
440                return Err(VhostError::InvalidOperation);
441            }
442        };
443        let supported = self.get_protocol_features()?;
444        self.acked_protocol_features = features & supported;
445        Ok(())
446    }
447
448    fn set_mem_table(
449        &mut self,
450        contexts: &[VhostUserMemoryRegion],
451        files: Vec<File>,
452    ) -> VhostResult<()> {
453        let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
454        if self.backend.unmap_guest_memory_on_fork() {
455            #[cfg(any(target_os = "android", target_os = "linux"))]
456            if let Err(e) = guest_mem.use_dontfork() {
457                error!("failed to set MADV_DONTFORK on guest memory: {e:#}");
458            }
459            #[cfg(not(any(target_os = "android", target_os = "linux")))]
460            error!("unmap_guest_memory_on_fork unsupported; skipping");
461        }
462        self.mem = Some(guest_mem);
463        self.vmm_maps = Some(vmm_maps);
464        Ok(())
465    }
466
467    fn get_queue_num(&mut self) -> VhostResult<u64> {
468        Ok(self.vrings.len() as u64)
469    }
470
471    fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
472        if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
473            return Err(VhostError::InvalidParam(
474                "set_vring_num: invalid index or num",
475            ));
476        }
477        self.vrings[index as usize].queue.set_size(num as u16);
478
479        Ok(())
480    }
481
482    fn set_vring_addr(
483        &mut self,
484        index: u32,
485        _flags: VhostUserVringAddrFlags,
486        descriptor: u64,
487        used: u64,
488        available: u64,
489        _log: u64,
490    ) -> VhostResult<()> {
491        if index as usize >= self.vrings.len() {
492            return Err(VhostError::InvalidParam(
493                "set_vring_addr: index out of range",
494            ));
495        }
496
497        let vmm_maps = self
498            .vmm_maps
499            .as_ref()
500            .ok_or(VhostError::InvalidParam("set_vring_addr: missing vmm_maps"))?;
501        let vring = &mut self.vrings[index as usize];
502        vring
503            .queue
504            .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
505        vring
506            .queue
507            .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
508        vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
509
510        Ok(())
511    }
512
513    fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
514        if index as usize >= self.vrings.len() {
515            return Err(VhostError::InvalidParam(
516                "set_vring_base: index out of range",
517            ));
518        }
519
520        let vring = &mut self.vrings[index as usize];
521        vring.queue.set_next_avail(Wrapping(base as u16));
522        vring.queue.set_next_used(Wrapping(base as u16));
523
524        Ok(())
525    }
526
527    fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
528        let vring = self
529            .vrings
530            .get_mut(index as usize)
531            .ok_or(VhostError::InvalidParam(
532                "get_vring_base: index out of range",
533            ))?;
534
535        // Quotation from vhost-user spec:
536        // "The back-end must [...] stop ring upon receiving VHOST_USER_GET_VRING_BASE."
537        // We only call `queue.set_ready()` when starting the queue, so if the queue is ready, that
538        // means it is started and should be stopped.
539        let vring_base = if vring.queue.ready() {
540            let queue = match self.backend.stop_queue(index as usize) {
541                Ok(q) => q,
542                Err(e) => {
543                    error!("Failed to stop queue in get_vring_base: {:#}", e);
544                    return Err(VhostError::BackendInternalError);
545                }
546            };
547
548            trace!("stopped queue {index}");
549            vring.reset();
550
551            if self.all_queues_stopped() {
552                trace!("all queues stopped; entering suspended state");
553                self.backend
554                    .enter_suspended_state()
555                    .map_err(VhostError::EnterSuspendedState)?;
556            }
557
558            queue.next_avail_to_process()
559        } else {
560            0
561        };
562
563        Ok(VhostUserVringState::new(index, vring_base.into()))
564    }
565
566    fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
567        if index as usize >= self.vrings.len() {
568            return Err(VhostError::InvalidParam(
569                "set_vring_kick: index out of range",
570            ));
571        }
572
573        let vring = &mut self.vrings[index as usize];
574        if vring.queue.ready() {
575            error!("kick fd cannot replaced after queue is started");
576            return Err(VhostError::InvalidOperation);
577        }
578
579        let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_kick"))?;
580
581        // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
582        // values via `next_val()` later.
583        // This is only required (and can only be done) on Unix platforms.
584        #[cfg(any(target_os = "android", target_os = "linux"))]
585        if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
586            error!("failed to remove O_NONBLOCK for kick fd: {}", e);
587            return Err(VhostError::InvalidParam(
588                "could not remove O_NONBLOCK from vring_kick",
589            ));
590        }
591
592        let kick_evt = Event::from(SafeDescriptor::from(file));
593
594        // Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).
595        vring.queue.ack_features(self.acked_features);
596        vring.queue.set_ready(true);
597
598        let mem = self
599            .mem
600            .as_ref()
601            .cloned()
602            .ok_or(VhostError::InvalidOperation)?;
603
604        let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
605
606        let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {
607            Ok(queue) => queue,
608            Err(e) => {
609                error!("failed to activate vring: {:#}", e);
610                return Err(VhostError::BackendInternalError);
611            }
612        };
613
614        if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {
615            error!("Failed to start queue {}: {}", index, e);
616            return Err(VhostError::BackendInternalError);
617        }
618        trace!("started queue {index}");
619
620        Ok(())
621    }
622
623    fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
624        if index as usize >= self.vrings.len() {
625            return Err(VhostError::InvalidParam(
626                "set_vring_call: index out of range",
627            ));
628        }
629
630        let backend_req_conn = self.backend_req_connection.clone();
631        let signal_config_change_fn = Box::new(move || {
632            if let Some(frontend) = backend_req_conn.as_ref() {
633                if let Err(e) = frontend.send_config_changed() {
634                    error!("Failed to notify config change: {:#}", e);
635                }
636            } else {
637                error!("No Backend request connection found");
638            }
639        });
640
641        let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_call"))?;
642        self.vrings[index as usize].doorbell = Some(Interrupt::new_vhost_user(
643            Event::from(SafeDescriptor::from(file)),
644            signal_config_change_fn,
645        ));
646        Ok(())
647    }
648
649    fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
650        // TODO
651        Ok(())
652    }
653
654    fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
655        if index as usize >= self.vrings.len() {
656            return Err(VhostError::InvalidParam(
657                "set_vring_enable: index out of range",
658            ));
659        }
660
661        // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
662        // has been negotiated.
663        if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
664            return Err(VhostError::InvalidOperation);
665        }
666
667        // Backend must not pass data to/from the ring until ring is enabled by
668        // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
669        // VHOST_USER_SET_VRING_ENABLE with parameter 0.
670        self.vrings[index as usize].enabled = enable;
671
672        Ok(())
673    }
674
675    fn get_config(
676        &mut self,
677        offset: u32,
678        size: u32,
679        _flags: VhostUserConfigFlags,
680    ) -> VhostResult<Vec<u8>> {
681        let mut data = vec![0; size as usize];
682        self.backend.read_config(u64::from(offset), &mut data);
683        Ok(data)
684    }
685
686    fn set_config(
687        &mut self,
688        offset: u32,
689        buf: &[u8],
690        _flags: VhostUserConfigFlags,
691    ) -> VhostResult<()> {
692        self.backend.write_config(u64::from(offset), buf);
693        Ok(())
694    }
695
696    fn set_backend_req_fd(&mut self, ep: Connection) {
697        let conn = VhostBackendReqConnection::new(
698            FrontendClient::new(
699                ep,
700                self.acked_protocol_features
701                    .contains(VhostUserProtocolFeatures::REPLY_ACK),
702            ),
703            self.backend.get_shared_memory_region().map(|r| r.id),
704        );
705
706        if self.backend_req_connection.is_some() {
707            warn!("Backend Request Connection already established. Overwriting");
708        }
709        self.backend_req_connection = Some(conn.clone());
710
711        self.backend.set_backend_req_connection(conn);
712    }
713
714    fn get_inflight_fd(
715        &mut self,
716        _inflight: &VhostUserInflight,
717    ) -> VhostResult<(VhostUserInflight, File)> {
718        unimplemented!("get_inflight_fd");
719    }
720
721    fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
722        unimplemented!("set_inflight_fd");
723    }
724
725    fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
726        //TODO
727        Ok(0)
728    }
729
730    fn add_mem_region(
731        &mut self,
732        _region: &VhostUserSingleMemoryRegion,
733        _fd: File,
734    ) -> VhostResult<()> {
735        //TODO
736        Ok(())
737    }
738
739    fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
740        //TODO
741        Ok(())
742    }
743
744    fn set_device_state_fd(
745        &mut self,
746        transfer_direction: VhostUserTransferDirection,
747        migration_phase: VhostUserMigrationPhase,
748        fd: File,
749    ) -> VhostResult<Option<File>> {
750        if migration_phase != VhostUserMigrationPhase::Stopped {
751            return Err(VhostError::InvalidOperation);
752        }
753        if !self.all_queues_stopped() {
754            return Err(VhostError::InvalidOperation);
755        }
756        if self.device_state_thread.is_some() {
757            error!("must call check_device_state before starting new state transfer");
758            return Err(VhostError::InvalidOperation);
759        }
760        // `set_device_state_fd` is designed to allow snapshot/restore concurrently with other
761        // methods, but, for simplicitly, we do those operations inline and only spawn a thread to
762        // handle the serialization and data transfer (the latter which seems necessary to
763        // implement the API correctly without, e.g., deadlocking because a pipe is full).
764        match transfer_direction {
765            VhostUserTransferDirection::Save => {
766                // Snapshot the state.
767                let snapshot = DeviceRequestHandlerSnapshot {
768                    acked_features: self.acked_features,
769                    acked_protocol_features: self.acked_protocol_features.bits(),
770                    backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
771                };
772                // Spawn thread to write the serialized bytes.
773                self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(
774                    "device_state_save",
775                    move |_kill_event| -> Result<(), ciborium::ser::Error<std::io::Error>> {
776                        let mut w = std::io::BufWriter::new(fd);
777                        ciborium::into_writer(&snapshot, &mut w)?;
778                        w.flush()?;
779                        Ok(())
780                    },
781                )));
782                Ok(None)
783            }
784            VhostUserTransferDirection::Load => {
785                // Spawn a thread to read the bytes and deserialize. Restore will happen in
786                // `check_device_state`.
787                self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(
788                    "device_state_load",
789                    move |_kill_event| ciborium::from_reader(&mut BufReader::new(fd)),
790                )));
791                Ok(None)
792            }
793        }
794    }
795
796    fn check_device_state(&mut self) -> VhostResult<()> {
797        let Some(thread) = self.device_state_thread.take() else {
798            error!("check_device_state: no active state transfer");
799            return Err(VhostError::InvalidOperation);
800        };
801        match thread {
802            DeviceStateThread::Save(worker) => {
803                worker.stop().map_err(|e| {
804                    error!("device state save thread failed: {:#}", e);
805                    VhostError::BackendInternalError
806                })?;
807                Ok(())
808            }
809            DeviceStateThread::Load(worker) => {
810                let snapshot = worker.stop().map_err(|e| {
811                    error!("device state load thread failed: {:#}", e);
812                    VhostError::BackendInternalError
813                })?;
814                self.acked_features = snapshot.acked_features;
815                self.acked_protocol_features =
816                    VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)
817                        .with_context(|| {
818                            format!(
819                                "unsupported bits in acked_protocol_features: {:#x}",
820                                snapshot.acked_protocol_features
821                            )
822                        })
823                        .map_err(VhostError::RestoreError)?;
824                self.backend
825                    .restore(snapshot.backend)
826                    .map_err(VhostError::RestoreError)?;
827                Ok(())
828            }
829        }
830    }
831
832    fn get_shmem_config(&mut self) -> VhostResult<Vec<SharedMemoryRegion>> {
833        Ok(self
834            .backend
835            .get_shared_memory_region()
836            .into_iter()
837            .collect())
838    }
839}
840
841/// Keeps track of Vhost user backend request connection.
842#[derive(Clone)]
843pub struct VhostBackendReqConnection {
844    shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
845    shmid: Option<u8>,
846}
847
848struct VhostBackendReqConnectionShared {
849    conn: FrontendClient,
850    mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,
851}
852
853impl VhostBackendReqConnection {
854    fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
855        Self {
856            shared: Arc::new(Mutex::new(VhostBackendReqConnectionShared {
857                conn,
858                mapped_regions: BTreeMap::new(),
859            })),
860            shmid,
861        }
862    }
863
864    /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend
865    fn send_config_changed(&self) -> anyhow::Result<()> {
866        let mut shared = self.shared.lock();
867        shared
868            .conn
869            .handle_config_change()
870            .context("Could not send config change message")?;
871        Ok(())
872    }
873
874    /// Create a SharedMemoryMapper trait object using this backend request connection.
875    pub fn shmem_mapper(&self) -> Option<Box<dyn SharedMemoryMapper>> {
876        if let Some(shmid) = self.shmid {
877            Some(Box::new(VhostShmemMapper {
878                shared: self.shared.clone(),
879                shmid,
880            }))
881        } else {
882            None
883        }
884    }
885}
886
887#[derive(Clone)]
888struct VhostShmemMapper {
889    shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
890    shmid: u8,
891}
892
893impl SharedMemoryMapper for VhostShmemMapper {
894    fn add_mapping(
895        &mut self,
896        source: VmMemorySource,
897        offset: u64,
898        prot: Protection,
899        _cache: MemCacheType,
900    ) -> anyhow::Result<()> {
901        let mut shared = self.shared.lock();
902        let size = match source {
903            VmMemorySource::Vulkan {
904                descriptor,
905                handle_type,
906                memory_idx,
907                device_uuid,
908                driver_uuid,
909                size,
910            } => {
911                let msg = VhostUserGpuMapMsg::new(
912                    self.shmid,
913                    offset,
914                    size,
915                    memory_idx,
916                    handle_type,
917                    device_uuid,
918                    driver_uuid,
919                );
920                shared
921                    .conn
922                    .gpu_map(&msg, &descriptor)
923                    .context("map GPU memory")?;
924                size
925            }
926            VmMemorySource::ExternalMapping { ptr, size } => {
927                let msg = VhostUserExternalMapMsg::new(self.shmid, offset, size, ptr);
928                shared
929                    .conn
930                    .external_map(&msg)
931                    .context("create external mapping")?;
932                size
933            }
934            source => {
935                // The last two sources use the same VhostUserMMap, continue matching here
936                // on the aliased `source` above.
937                let (descriptor, fd_offset, size) = match source {
938                    VmMemorySource::Descriptor {
939                        descriptor,
940                        offset,
941                        size,
942                    } => (descriptor, offset, size),
943                    VmMemorySource::SharedMemory(shmem) => {
944                        let size = shmem.size();
945                        let descriptor = SafeDescriptor::from(shmem);
946                        (descriptor, 0, size)
947                    }
948                    _ => bail!("unsupported source"),
949                };
950                let mut flags = VhostUserMMapFlags::empty();
951                anyhow::ensure!(prot.allows(&Protection::read()), "mapping must be readable");
952                if prot.allows(&Protection::write()) {
953                    flags |= VhostUserMMapFlags::MAP_RW;
954                }
955                let msg = VhostUserMMap {
956                    shmid: self.shmid,
957                    padding: Default::default(),
958                    fd_offset,
959                    shm_offset: offset,
960                    len: size,
961                    flags,
962                };
963                shared
964                    .conn
965                    .shmem_map(&msg, &descriptor)
966                    .context("map shmem")?;
967                size
968            }
969        };
970
971        shared.mapped_regions.insert(offset, size);
972        Ok(())
973    }
974
975    fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
976        let mut shared = self.shared.lock();
977        let size = shared
978            .mapped_regions
979            .remove(&offset)
980            .context("unknown offset")?;
981        let msg = VhostUserMMap {
982            shmid: self.shmid,
983            padding: Default::default(),
984            fd_offset: 0,
985            shm_offset: offset,
986            len: size,
987            flags: VhostUserMMapFlags::empty(),
988        };
989        shared
990            .conn
991            .shmem_unmap(&msg)
992            .context("unmap shmem")
993            .map(|_| ())
994    }
995}
996
997pub(crate) struct WorkerState<T, U> {
998    pub(crate) queue_task: TaskHandle<U>,
999    pub(crate) queue: T,
1000}
1001
1002/// Errors for device operations
1003#[derive(Debug, ThisError)]
1004pub enum Error {
1005    #[error("worker not found when stopping queue")]
1006    WorkerNotFound,
1007}
1008
1009#[cfg(test)]
1010mod tests {
1011    use std::sync::mpsc::channel;
1012
1013    use anyhow::bail;
1014    use base::Event;
1015    use virtio_sys::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
1016    use vmm_vhost::BackendServer;
1017    use vmm_vhost::FrontendReq;
1018    use zerocopy::FromBytes;
1019    use zerocopy::FromZeros;
1020    use zerocopy::Immutable;
1021    use zerocopy::IntoBytes;
1022    use zerocopy::KnownLayout;
1023
1024    use super::*;
1025    use crate::virtio::vhost_user_frontend::VhostUserFrontend;
1026    use crate::virtio::DeviceType;
1027    use crate::virtio::VirtioDevice;
1028
1029    #[derive(Clone, Copy, Debug, PartialEq, Eq, FromBytes, Immutable, IntoBytes, KnownLayout)]
1030    #[repr(C, packed(4))]
1031    struct FakeConfig {
1032        x: u32,
1033        y: u64,
1034    }
1035
1036    const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
1037
1038    pub(super) struct FakeBackend {
1039        avail_features: u64,
1040        acked_features: u64,
1041        active_queues: Vec<Option<Queue>>,
1042        allow_backend_req: bool,
1043        backend_conn: Option<VhostBackendReqConnection>,
1044    }
1045
1046    #[derive(Deserialize, Serialize)]
1047    struct FakeBackendSnapshot {
1048        data: Vec<u8>,
1049    }
1050
1051    impl FakeBackend {
1052        const MAX_QUEUE_NUM: usize = 16;
1053
1054        pub(super) fn new() -> Self {
1055            let mut active_queues = Vec::new();
1056            active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1057            Self {
1058                avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES | 1 << VIRTIO_RING_F_EVENT_IDX,
1059                acked_features: 0,
1060                active_queues,
1061                allow_backend_req: false,
1062                backend_conn: None,
1063            }
1064        }
1065    }
1066
1067    impl VhostUserDevice for FakeBackend {
1068        fn max_queue_num(&self) -> usize {
1069            Self::MAX_QUEUE_NUM
1070        }
1071
1072        fn features(&self) -> u64 {
1073            self.avail_features
1074        }
1075
1076        fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1077            let unrequested_features = value & !self.avail_features;
1078            if unrequested_features != 0 {
1079                bail!(
1080                    "invalid protocol features are given: 0x{:x}",
1081                    unrequested_features
1082                );
1083            }
1084            self.acked_features |= value;
1085            Ok(())
1086        }
1087
1088        fn protocol_features(&self) -> VhostUserProtocolFeatures {
1089            let mut features =
1090                VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;
1091            if self.allow_backend_req {
1092                features |= VhostUserProtocolFeatures::BACKEND_REQ;
1093            }
1094            features
1095        }
1096
1097        fn read_config(&self, offset: u64, dst: &mut [u8]) {
1098            dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1099        }
1100
1101        fn reset(&mut self) {}
1102
1103        fn start_queue(
1104            &mut self,
1105            idx: usize,
1106            queue: Queue,
1107            _mem: GuestMemory,
1108        ) -> anyhow::Result<()> {
1109            self.active_queues[idx] = Some(queue);
1110            Ok(())
1111        }
1112
1113        fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1114            Ok(self.active_queues[idx]
1115                .take()
1116                .ok_or(Error::WorkerNotFound)?)
1117        }
1118
1119        fn set_backend_req_connection(&mut self, conn: VhostBackendReqConnection) {
1120            self.backend_conn = Some(conn);
1121        }
1122
1123        fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
1124            Ok(())
1125        }
1126
1127        fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1128            AnySnapshot::to_any(FakeBackendSnapshot {
1129                data: vec![1, 2, 3],
1130            })
1131            .context("failed to serialize snapshot")
1132        }
1133
1134        fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1135            let snapshot: FakeBackendSnapshot =
1136                AnySnapshot::from_any(data).context("failed to deserialize snapshot")?;
1137            assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");
1138            Ok(())
1139        }
1140    }
1141
1142    fn create_queues(
1143        num: usize,
1144        mem: &GuestMemory,
1145        interrupt: &Interrupt,
1146    ) -> BTreeMap<usize, Queue> {
1147        let mut queues = BTreeMap::new();
1148        for idx in 0..num {
1149            let mut queue = QueueConfig::new(0x10, 0);
1150            queue.set_ready(true);
1151            let queue = queue
1152                .activate(mem, Event::new().unwrap(), interrupt.clone())
1153                .expect("QueueConfig::activate");
1154            queues.insert(idx, queue);
1155        }
1156        queues
1157    }
1158
1159    #[test]
1160    fn test_vhost_user_lifecycle() {
1161        test_vhost_user_lifecycle_parameterized(false);
1162    }
1163
1164    #[test]
1165    #[cfg(not(windows))] // Windows requries more complex connection setup.
1166    fn test_vhost_user_lifecycle_with_backend_req() {
1167        test_vhost_user_lifecycle_parameterized(true);
1168    }
1169
1170    fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {
1171        const QUEUES_NUM: usize = 2;
1172        const BASE_FEATURES: u64 = 1 << VIRTIO_RING_F_EVENT_IDX;
1173        const EXPECTED_FEATURES: u64 =
1174            1 << VHOST_USER_F_PROTOCOL_FEATURES | 1 << VIRTIO_RING_F_EVENT_IDX;
1175
1176        // First phase: Test normal usage, then take a snapshot and shutdown.
1177        let snapshot = {
1178            let (client_connection, server_connection) = vmm_vhost::Connection::pair().unwrap();
1179            let (shutdown_tx, shutdown_rx) = channel();
1180            let (vm_evt_wrtube, _vm_evt_rdtube) = base::Tube::directional_pair().unwrap();
1181            let vmm_thread = std::thread::spawn(move || {
1182                // VMM side
1183                let mut vmm_device = VhostUserFrontend::new(
1184                    DeviceType::Console,
1185                    BASE_FEATURES,
1186                    client_connection,
1187                    vm_evt_wrtube,
1188                    None,
1189                    None,
1190                )
1191                .unwrap();
1192
1193                vmm_device.ack_features(BASE_FEATURES);
1194
1195                let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1196                let interrupt = Interrupt::new_for_test_with_msix();
1197
1198                println!("read_config");
1199                let mut config = FakeConfig::new_zeroed();
1200                vmm_device.read_config(0, config.as_mut_bytes());
1201                // Check if the obtained config data is correct.
1202                assert_eq!(config, FAKE_CONFIG_DATA);
1203
1204                println!("activate");
1205                vmm_device
1206                    .activate(
1207                        mem.clone(),
1208                        interrupt.clone(),
1209                        create_queues(QUEUES_NUM, &mem, &interrupt),
1210                    )
1211                    .unwrap();
1212
1213                println!("reset");
1214                let reset_result = vmm_device.reset();
1215                assert!(
1216                    reset_result.is_ok(),
1217                    "reset failed: {:#}",
1218                    reset_result.unwrap_err()
1219                );
1220
1221                println!("activate");
1222                vmm_device
1223                    .activate(
1224                        mem.clone(),
1225                        interrupt.clone(),
1226                        create_queues(QUEUES_NUM, &mem, &interrupt),
1227                    )
1228                    .unwrap();
1229
1230                println!("virtio_sleep");
1231                let queues = vmm_device
1232                    .virtio_sleep()
1233                    .unwrap()
1234                    .expect("virtio_sleep unexpectedly returned None");
1235
1236                println!("virtio_snapshot");
1237                let snapshot = vmm_device
1238                    .virtio_snapshot()
1239                    .expect("virtio_snapshot failed");
1240
1241                println!("virtio_wake");
1242                vmm_device
1243                    .virtio_wake(Some((mem.clone(), interrupt.clone(), queues)))
1244                    .unwrap();
1245
1246                println!("wait for shutdown signal");
1247                shutdown_rx.recv().unwrap();
1248
1249                // The VMM side is supposed to stop before the device side.
1250                println!("drop");
1251
1252                snapshot
1253            });
1254
1255            // Device side
1256            let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1257            handler.as_mut().allow_backend_req = allow_backend_req;
1258
1259            let mut req_handler = BackendServer::new(server_connection, handler);
1260
1261            // VhostUserFrontend::new()
1262            handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1263            handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1264            handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1265            handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1266            if allow_backend_req {
1267                handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1268            }
1269
1270            // VhostUserFrontend::read_config()
1271            handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1272
1273            // VhostUserFrontend::activate()
1274            handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1275            assert_eq!(req_handler.as_ref().acked_features, EXPECTED_FEATURES);
1276            handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1277            for _ in 0..QUEUES_NUM {
1278                handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1279                handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1280                handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1281                handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1282                handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1283                handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1284            }
1285
1286            // VhostUserFrontend::reset()
1287            for _ in 0..QUEUES_NUM {
1288                handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1289                handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1290            }
1291
1292            // VhostUserFrontend::activate()
1293            handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1294            assert_eq!(req_handler.as_ref().acked_features, EXPECTED_FEATURES);
1295            handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1296            for _ in 0..QUEUES_NUM {
1297                handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1298                handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1299                handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1300                handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1301                handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1302                handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1303            }
1304
1305            if allow_backend_req {
1306                // Make sure the connection still works even after reset/reactivate.
1307                req_handler
1308                    .as_ref()
1309                    .as_ref()
1310                    .backend_conn
1311                    .as_ref()
1312                    .expect("backend_conn missing")
1313                    .send_config_changed()
1314                    .expect("send_config_changed failed");
1315            }
1316
1317            // VhostUserFrontend::virtio_sleep()
1318            for _ in 0..QUEUES_NUM {
1319                handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1320                handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1321            }
1322
1323            // VhostUserFrontend::virtio_snapshot()
1324            handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1325            handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1326
1327            // VhostUserFrontend::virtio_wake()
1328            handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1329            assert_eq!(req_handler.as_ref().acked_features, EXPECTED_FEATURES);
1330            handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1331            for _ in 0..QUEUES_NUM {
1332                handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1333                handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1334                handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1335                handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1336                handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1337                handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1338            }
1339
1340            if allow_backend_req {
1341                // Make sure the connection still works even after sleep/wake.
1342                req_handler
1343                    .as_ref()
1344                    .as_ref()
1345                    .backend_conn
1346                    .as_ref()
1347                    .expect("backend_conn missing")
1348                    .send_config_changed()
1349                    .expect("send_config_changed failed");
1350            }
1351
1352            // Ask the client to shutdown, then wait to it to finish.
1353            shutdown_tx.send(()).unwrap();
1354
1355            // Verify recv_header fails with `ClientExit` after the client has disconnected.
1356            match req_handler.recv_header() {
1357                Err(VhostError::ClientExit) => (),
1358                r => panic!("expected Err(ClientExit) but got {r:?}"),
1359            }
1360
1361            vmm_thread.join().unwrap()
1362        };
1363
1364        // Second phase: Restore the snapshot.
1365        {
1366            let (client_connection, server_connection) = vmm_vhost::Connection::pair().unwrap();
1367            let (shutdown_tx, shutdown_rx) = channel();
1368            let (vm_evt_wrtube, _vm_evt_rdtube) = base::Tube::directional_pair().unwrap();
1369            let vmm_thread = std::thread::spawn(move || {
1370                // VMM side
1371                let mut vmm_device = VhostUserFrontend::new(
1372                    DeviceType::Console,
1373                    BASE_FEATURES,
1374                    client_connection,
1375                    vm_evt_wrtube,
1376                    None,
1377                    None,
1378                )
1379                .unwrap();
1380
1381                let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1382                let interrupt = Interrupt::new_for_test_with_msix();
1383
1384                println!("virtio_sleep");
1385                assert!(vmm_device.virtio_sleep().unwrap().is_none());
1386
1387                println!("virtio_restore");
1388                vmm_device
1389                    .virtio_restore(snapshot)
1390                    .expect("virtio_restore failed");
1391
1392                println!("virtio_wake");
1393                vmm_device
1394                    .virtio_wake(Some((
1395                        mem.clone(),
1396                        interrupt.clone(),
1397                        create_queues(QUEUES_NUM, &mem, &interrupt),
1398                    )))
1399                    .unwrap();
1400
1401                println!("wait for shutdown signal");
1402                shutdown_rx.recv().unwrap();
1403
1404                // The VMM side is supposed to stop before the device side.
1405                println!("drop");
1406            });
1407
1408            // Device side
1409            let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1410            handler.as_mut().allow_backend_req = allow_backend_req;
1411
1412            let mut req_handler = BackendServer::new(server_connection, handler);
1413
1414            // VhostUserFrontend::new()
1415            handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1416            handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1417            handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1418            handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1419            if allow_backend_req {
1420                handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1421            }
1422
1423            // VhostUserFrontend::virtio_sleep()
1424            // (no-op)
1425
1426            // VhostUserFrontend::virtio_restore()
1427            handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1428            assert_eq!(
1429                req_handler.as_ref().acked_features,
1430                1 << VHOST_USER_F_PROTOCOL_FEATURES
1431            );
1432            handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1433            assert_eq!(req_handler.as_ref().acked_features, EXPECTED_FEATURES);
1434            handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1435            handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1436
1437            // VhostUserFrontend::virtio_wake()
1438            handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1439            for _ in 0..QUEUES_NUM {
1440                handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1441                handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1442                handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1443                handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1444                handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1445                handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1446            }
1447
1448            if allow_backend_req {
1449                // Make sure the connection still works even after restore.
1450                req_handler
1451                    .as_ref()
1452                    .as_ref()
1453                    .backend_conn
1454                    .as_ref()
1455                    .expect("backend_conn missing")
1456                    .send_config_changed()
1457                    .expect("send_config_changed failed");
1458            }
1459
1460            // Ask the client to shutdown, then wait to it to finish.
1461            shutdown_tx.send(()).unwrap();
1462            // Verify recv_header fails with `ClientExit` after the client has disconnected.
1463            match req_handler.recv_header() {
1464                Err(VhostError::ClientExit) => (),
1465                r => panic!("expected Err(ClientExit) but got {r:?}"),
1466            }
1467            vmm_thread.join().unwrap();
1468        }
1469    }
1470
1471    #[track_caller]
1472    fn handle_request<S: vmm_vhost::Backend>(
1473        handler: &mut BackendServer<S>,
1474        expected_message_type: FrontendReq,
1475    ) -> Result<(), VhostError> {
1476        let (hdr, files) = handler.recv_header()?;
1477        assert_eq!(hdr.get_code(), Ok(expected_message_type));
1478        handler.process_message(hdr, files)
1479    }
1480}