devices/virtio/vhost_user_backend/
handler.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Library for implementing vhost-user device executables.
6//!
7//! This crate provides
8//! * `VhostUserDevice` trait, which is a collection of methods to handle vhost-user requests, and
9//! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10//!
11//! They are expected to be used as follows:
12//!
13//! 1. Define a struct and implement `VhostUserDevice` for it.
14//! 2. Create a `DeviceRequestHandler` with the backend struct.
15//! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16//!
17//! ```ignore
18//! struct MyBackend {
19//!   /* fields */
20//! }
21//!
22//! impl VhostUserDevice for MyBackend {
23//!   /* implement methods */
24//! }
25//!
26//! fn main() -> Result<(), Box<dyn Error>> {
27//!   let backend = MyBackend { /* initialize fields */ };
28//!   let handler = DeviceRequestHandler::new(backend);
29//!   let socket = std::path::Path("/path/to/socket");
30//!   let ex = cros_async::Executor::new()?;
31//!
32//!   if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33//!     eprintln!("error happened: {}", e);
34//!   }
35//!   Ok(())
36//! }
37//! ```
38// Implementation note:
39// This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
40// protocol. DeviceRequestHandler implements the Backend trait from vmm_vhost, and includes some
41// common code for setting up guest memory and managing partially configured vrings.
42// DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request() when it
43// becomes readable. handle_request() reads and parses the message and then calls one of the
44// Backend trait methods. These dispatch back to the supplied VhostUserDevice implementation (this
45// is what our devices implement).
46
47pub(super) mod sys;
48
49use std::collections::BTreeMap;
50use std::convert::From;
51use std::fs::File;
52use std::io::BufReader;
53use std::io::Write;
54use std::num::Wrapping;
55#[cfg(any(target_os = "android", target_os = "linux"))]
56use std::os::unix::io::AsRawFd;
57use std::sync::Arc;
58
59use anyhow::bail;
60use anyhow::Context;
61#[cfg(any(target_os = "android", target_os = "linux"))]
62use base::clear_fd_flags;
63use base::error;
64use base::trace;
65use base::warn;
66use base::Event;
67use base::Protection;
68use base::SafeDescriptor;
69use base::SharedMemory;
70use base::WorkerThread;
71use cros_async::TaskHandle;
72use hypervisor::MemCacheType;
73use serde::Deserialize;
74use serde::Serialize;
75use snapshot::AnySnapshot;
76use sync::Mutex;
77use thiserror::Error as ThisError;
78use vm_control::VmMemorySource;
79use vm_memory::GuestAddress;
80use vm_memory::GuestMemory;
81use vm_memory::MemoryRegion;
82use vmm_vhost::message::VhostUserConfigFlags;
83use vmm_vhost::message::VhostUserExternalMapMsg;
84use vmm_vhost::message::VhostUserGpuMapMsg;
85use vmm_vhost::message::VhostUserInflight;
86use vmm_vhost::message::VhostUserMMap;
87use vmm_vhost::message::VhostUserMMapFlags;
88use vmm_vhost::message::VhostUserMemoryRegion;
89use vmm_vhost::message::VhostUserMigrationPhase;
90use vmm_vhost::message::VhostUserProtocolFeatures;
91use vmm_vhost::message::VhostUserSingleMemoryRegion;
92use vmm_vhost::message::VhostUserTransferDirection;
93use vmm_vhost::message::VhostUserVringAddrFlags;
94use vmm_vhost::message::VhostUserVringState;
95use vmm_vhost::Connection;
96use vmm_vhost::Error as VhostError;
97use vmm_vhost::Frontend;
98use vmm_vhost::FrontendClient;
99use vmm_vhost::Result as VhostResult;
100use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
101
102use crate::virtio::Interrupt;
103use crate::virtio::Queue;
104use crate::virtio::QueueConfig;
105use crate::virtio::SharedMemoryMapper;
106use crate::virtio::SharedMemoryRegion;
107
108/// Keeps a mapping from the vmm's virtual addresses to guest addresses.
109/// used to translate messages from the vmm to guest offsets.
110#[derive(Default)]
111pub struct MappingInfo {
112    pub vmm_addr: u64,
113    pub guest_phys: u64,
114    pub size: u64,
115}
116
117pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
118    for map in maps {
119        if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
120            return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
121        }
122    }
123    Err(VhostError::InvalidMessage)
124}
125
126/// Trait for vhost-user devices. Analogous to the `VirtioDevice` trait.
127///
128/// In contrast with [[vmm_vhost::Backend]], which closely matches the vhost-user spec, this trait
129/// is designed to follow crosvm conventions for implementing devices.
130pub trait VhostUserDevice {
131    /// The maximum number of queues that this backend can manage.
132    fn max_queue_num(&self) -> usize;
133
134    /// The set of feature bits that this backend supports.
135    fn features(&self) -> u64;
136
137    /// Acknowledges that this set of features should be enabled.
138    ///
139    /// Implementations only need to handle device-specific feature bits; the `DeviceRequestHandler`
140    /// framework will manage generic vhost and vring features.
141    ///
142    /// `DeviceRequestHandler` checks for valid features before calling this function, so the
143    /// features in `value` will always be a subset of those advertised by `features()`.
144    fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {
145        Ok(())
146    }
147
148    /// The set of protocol feature bits that this backend supports.
149    fn protocol_features(&self) -> VhostUserProtocolFeatures;
150
151    /// Reads this device configuration space at `offset`.
152    fn read_config(&self, offset: u64, dst: &mut [u8]);
153
154    /// writes `data` to this device's configuration space at `offset`.
155    fn write_config(&self, _offset: u64, _data: &[u8]) {}
156
157    /// Indicates that the backend should start processing requests for virtio queue number `idx`.
158    /// This method must not block the current thread so device backends should either spawn an
159    /// async task or another thread to handle messages from the Queue.
160    fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;
161
162    /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
163    /// This method should return the queue passed to `start_queue` for the corresponding `idx`.
164    /// This method will only be called for queues that were previously started by `start_queue`.
165    fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
166
167    /// Resets the vhost-user backend.
168    fn reset(&mut self);
169
170    /// Returns the device's shared memory region if present.
171    fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
172        None
173    }
174
175    /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message
176    /// handling.
177    ///
178    /// This method will be called when `VhostUserProtocolFeatures::BACKEND_REQ` is
179    /// negotiated.
180    fn set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection) {}
181
182    /// Enter the "suspended device state" described in the vhost-user spec. See the spec for
183    /// requirements.
184    ///
185    /// One reasonably foolproof way to satisfy the requirements is to stop all worker threads.
186    ///
187    /// Called after a `stop_queue` call if there are no running queues left. Also called soon
188    /// after device creation to ensure the device is acting suspended immediately on construction.
189    ///
190    /// The next `start_queue` call implicitly exits the "suspend device state".
191    ///
192    /// * Ok(())    => device successfully suspended
193    /// * Err(_)    => unrecoverable error
194    fn enter_suspended_state(&mut self) -> anyhow::Result<()>;
195
196    /// Snapshot device and return serialized state.
197    fn snapshot(&mut self) -> anyhow::Result<AnySnapshot>;
198
199    /// Restore device state from a snapshot.
200    fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>;
201
202    /// Whether guest memory should be unmapped in forked processes.
203    ///
204    /// This is intended for use in combination with --protected-vm, where the guest memory can be
205    /// dangerous to access. Some systems, e.g. Android, have tools that fork processes and examine
206    /// their memory. This flag effectively hides the guest memory from those tools.
207    ///
208    /// Not compatible with sandboxing.
209    fn unmap_guest_memory_on_fork(&self) -> bool {
210        false
211    }
212}
213
214/// A virtio ring entry.
215struct Vring {
216    // The queue config. This doesn't get mutated by the queue workers.
217    queue: QueueConfig,
218    doorbell: Option<Interrupt>,
219    enabled: bool,
220}
221
222impl Vring {
223    fn new(max_size: u16, features: u64) -> Self {
224        Self {
225            queue: QueueConfig::new(max_size, features),
226            doorbell: None,
227            enabled: false,
228        }
229    }
230
231    fn reset(&mut self) {
232        self.queue.reset();
233        self.doorbell = None;
234        self.enabled = false;
235    }
236}
237
238/// Ops for running vhost-user over a stream (i.e. regular protocol).
239pub(super) struct VhostUserRegularOps;
240
241impl VhostUserRegularOps {
242    pub fn set_mem_table(
243        contexts: &[VhostUserMemoryRegion],
244        files: Vec<File>,
245    ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
246        if files.len() != contexts.len() {
247            return Err(VhostError::InvalidParam(
248                "number of files & contexts was not equal",
249            ));
250        }
251
252        let mut regions = Vec::with_capacity(files.len());
253        for (region, file) in contexts.iter().zip(files.into_iter()) {
254            let region = MemoryRegion::new_from_shm(
255                region.memory_size,
256                GuestAddress(region.guest_phys_addr),
257                region.mmap_offset,
258                Arc::new(
259                    SharedMemory::from_safe_descriptor(
260                        SafeDescriptor::from(file),
261                        region.memory_size,
262                    )
263                    .unwrap(),
264                ),
265            )
266            .map_err(|e| {
267                error!("failed to create a memory region: {}", e);
268                VhostError::InvalidOperation
269            })?;
270            regions.push(region);
271        }
272        let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
273            error!("failed to create guest memory: {}", e);
274            VhostError::InvalidOperation
275        })?;
276
277        let vmm_maps = contexts
278            .iter()
279            .map(|region| MappingInfo {
280                vmm_addr: region.user_addr,
281                guest_phys: region.guest_phys_addr,
282                size: region.memory_size,
283            })
284            .collect();
285        Ok((guest_mem, vmm_maps))
286    }
287}
288
289/// An adapter that implements `vmm_vhost::Backend` for any type implementing `VhostUserDevice`.
290pub struct DeviceRequestHandler<T: VhostUserDevice> {
291    vrings: Vec<Vring>,
292    owned: bool,
293    vmm_maps: Option<Vec<MappingInfo>>,
294    mem: Option<GuestMemory>,
295    acked_features: u64,
296    acked_protocol_features: VhostUserProtocolFeatures,
297    backend: T,
298    backend_req_connection: Option<VhostBackendReqConnection>,
299    // Thread processing active device state FD.
300    device_state_thread: Option<DeviceStateThread>,
301}
302
303enum DeviceStateThread {
304    Save(WorkerThread<Result<(), ciborium::ser::Error<std::io::Error>>>),
305    Load(WorkerThread<Result<DeviceRequestHandlerSnapshot, ciborium::de::Error<std::io::Error>>>),
306}
307
308#[derive(Serialize, Deserialize)]
309pub struct DeviceRequestHandlerSnapshot {
310    acked_features: u64,
311    acked_protocol_features: u64,
312    backend: AnySnapshot,
313}
314
315impl<T: VhostUserDevice> DeviceRequestHandler<T> {
316    /// Creates a vhost-user handler instance for `backend`.
317    pub(crate) fn new(mut backend: T) -> Self {
318        let mut vrings = Vec::with_capacity(backend.max_queue_num());
319        for _ in 0..backend.max_queue_num() {
320            vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
321        }
322
323        // VhostUserDevice implementations must support `enter_suspended_state()`.
324        // Call it on startup to ensure it works and to initialize the device in a suspended state.
325        backend
326            .enter_suspended_state()
327            .expect("enter_suspended_state failed on device init");
328
329        DeviceRequestHandler {
330            vrings,
331            owned: false,
332            vmm_maps: None,
333            mem: None,
334            acked_features: 0,
335            acked_protocol_features: VhostUserProtocolFeatures::empty(),
336            backend,
337            backend_req_connection: None,
338            device_state_thread: None,
339        }
340    }
341
342    /// Check if all queues are stopped.
343    ///
344    /// The device can be suspended with `enter_suspended_state()` only when all queues are stopped.
345    fn all_queues_stopped(&self) -> bool {
346        self.vrings.iter().all(|vring| !vring.queue.ready())
347    }
348}
349
350impl<T: VhostUserDevice> Drop for DeviceRequestHandler<T> {
351    fn drop(&mut self) {
352        for (index, vring) in self.vrings.iter().enumerate() {
353            if vring.queue.ready() {
354                if let Err(e) = self.backend.stop_queue(index) {
355                    error!("Failed to stop queue {} during drop: {:#}", index, e);
356                }
357            }
358        }
359    }
360}
361
362impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
363    fn as_ref(&self) -> &T {
364        &self.backend
365    }
366}
367
368impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
369    fn as_mut(&mut self) -> &mut T {
370        &mut self.backend
371    }
372}
373
374impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
375    fn set_owner(&mut self) -> VhostResult<()> {
376        if self.owned {
377            return Err(VhostError::InvalidOperation);
378        }
379        self.owned = true;
380        Ok(())
381    }
382
383    fn reset_owner(&mut self) -> VhostResult<()> {
384        self.owned = false;
385        self.acked_features = 0;
386        self.backend.reset();
387        Ok(())
388    }
389
390    fn get_features(&mut self) -> VhostResult<u64> {
391        let features = self.backend.features();
392        Ok(features)
393    }
394
395    fn set_features(&mut self, features: u64) -> VhostResult<()> {
396        if !self.owned {
397            return Err(VhostError::InvalidOperation);
398        }
399
400        let unexpected_features = features & !self.backend.features();
401        if unexpected_features != 0 {
402            error!("unexpected set_features {:#x}", unexpected_features);
403            return Err(VhostError::InvalidParam("unexpected set_features"));
404        }
405
406        if let Err(e) = self.backend.ack_features(features) {
407            error!("failed to acknowledge features 0x{:x}: {}", features, e);
408            return Err(VhostError::InvalidOperation);
409        }
410
411        self.acked_features |= features;
412
413        // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
414        // enabled state.
415        // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
416        // disabled state.
417        // Client must not pass data to/from the backend until ring is enabled by
418        // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
419        // VHOST_USER_SET_VRING_ENABLE with parameter 0.
420        let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
421        for v in &mut self.vrings {
422            v.enabled = vring_enabled;
423        }
424
425        Ok(())
426    }
427
428    fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
429        Ok(self.backend.protocol_features() | VhostUserProtocolFeatures::REPLY_ACK)
430    }
431
432    fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
433        let features = match VhostUserProtocolFeatures::from_bits(features) {
434            Some(proto_features) => proto_features,
435            None => {
436                error!(
437                    "unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",
438                    features
439                );
440                return Err(VhostError::InvalidOperation);
441            }
442        };
443        let supported = self.get_protocol_features()?;
444        self.acked_protocol_features = features & supported;
445        Ok(())
446    }
447
448    fn set_mem_table(
449        &mut self,
450        contexts: &[VhostUserMemoryRegion],
451        files: Vec<File>,
452    ) -> VhostResult<()> {
453        let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
454        if self.backend.unmap_guest_memory_on_fork() {
455            #[cfg(any(target_os = "android", target_os = "linux"))]
456            if let Err(e) = guest_mem.use_dontfork() {
457                error!("failed to set MADV_DONTFORK on guest memory: {e:#}");
458            }
459            #[cfg(not(any(target_os = "android", target_os = "linux")))]
460            error!("unmap_guest_memory_on_fork unsupported; skipping");
461        }
462        self.mem = Some(guest_mem);
463        self.vmm_maps = Some(vmm_maps);
464        Ok(())
465    }
466
467    fn get_queue_num(&mut self) -> VhostResult<u64> {
468        Ok(self.vrings.len() as u64)
469    }
470
471    fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
472        if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
473            return Err(VhostError::InvalidParam(
474                "set_vring_num: invalid index or num",
475            ));
476        }
477        self.vrings[index as usize].queue.set_size(num as u16);
478
479        Ok(())
480    }
481
482    fn set_vring_addr(
483        &mut self,
484        index: u32,
485        _flags: VhostUserVringAddrFlags,
486        descriptor: u64,
487        used: u64,
488        available: u64,
489        _log: u64,
490    ) -> VhostResult<()> {
491        if index as usize >= self.vrings.len() {
492            return Err(VhostError::InvalidParam(
493                "set_vring_addr: index out of range",
494            ));
495        }
496
497        let vmm_maps = self
498            .vmm_maps
499            .as_ref()
500            .ok_or(VhostError::InvalidParam("set_vring_addr: missing vmm_maps"))?;
501        let vring = &mut self.vrings[index as usize];
502        vring
503            .queue
504            .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
505        vring
506            .queue
507            .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
508        vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
509
510        Ok(())
511    }
512
513    fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
514        if index as usize >= self.vrings.len() {
515            return Err(VhostError::InvalidParam(
516                "set_vring_base: index out of range",
517            ));
518        }
519
520        let vring = &mut self.vrings[index as usize];
521        vring.queue.set_next_avail(Wrapping(base as u16));
522        vring.queue.set_next_used(Wrapping(base as u16));
523
524        Ok(())
525    }
526
527    fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
528        let vring = self
529            .vrings
530            .get_mut(index as usize)
531            .ok_or(VhostError::InvalidParam(
532                "get_vring_base: index out of range",
533            ))?;
534
535        // Quotation from vhost-user spec:
536        // "The back-end must [...] stop ring upon receiving VHOST_USER_GET_VRING_BASE."
537        // We only call `queue.set_ready()` when starting the queue, so if the queue is ready, that
538        // means it is started and should be stopped.
539        let vring_base = if vring.queue.ready() {
540            let queue = match self.backend.stop_queue(index as usize) {
541                Ok(q) => q,
542                Err(e) => {
543                    error!("Failed to stop queue in get_vring_base: {:#}", e);
544                    return Err(VhostError::BackendInternalError);
545                }
546            };
547
548            trace!("stopped queue {index}");
549            vring.reset();
550
551            if self.all_queues_stopped() {
552                trace!("all queues stopped; entering suspended state");
553                self.backend
554                    .enter_suspended_state()
555                    .map_err(VhostError::EnterSuspendedState)?;
556            }
557
558            queue.next_avail_to_process()
559        } else {
560            0
561        };
562
563        Ok(VhostUserVringState::new(index, vring_base.into()))
564    }
565
566    fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
567        if index as usize >= self.vrings.len() {
568            return Err(VhostError::InvalidParam(
569                "set_vring_kick: index out of range",
570            ));
571        }
572
573        let vring = &mut self.vrings[index as usize];
574        if vring.queue.ready() {
575            error!("kick fd cannot replaced after queue is started");
576            return Err(VhostError::InvalidOperation);
577        }
578
579        let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_kick"))?;
580
581        // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
582        // values via `next_val()` later.
583        // This is only required (and can only be done) on Unix platforms.
584        #[cfg(any(target_os = "android", target_os = "linux"))]
585        if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
586            error!("failed to remove O_NONBLOCK for kick fd: {}", e);
587            return Err(VhostError::InvalidParam(
588                "could not remove O_NONBLOCK from vring_kick",
589            ));
590        }
591
592        let kick_evt = Event::from(SafeDescriptor::from(file));
593
594        // Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).
595        vring.queue.ack_features(self.acked_features);
596        vring.queue.set_ready(true);
597
598        let mem = self
599            .mem
600            .as_ref()
601            .cloned()
602            .ok_or(VhostError::InvalidOperation)?;
603
604        let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
605
606        let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {
607            Ok(queue) => queue,
608            Err(e) => {
609                error!("failed to activate vring: {:#}", e);
610                return Err(VhostError::BackendInternalError);
611            }
612        };
613
614        if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {
615            error!("Failed to start queue {}: {}", index, e);
616            return Err(VhostError::BackendInternalError);
617        }
618        trace!("started queue {index}");
619
620        Ok(())
621    }
622
623    fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
624        if index as usize >= self.vrings.len() {
625            return Err(VhostError::InvalidParam(
626                "set_vring_call: index out of range",
627            ));
628        }
629
630        let backend_req_conn = self.backend_req_connection.clone();
631        let signal_config_change_fn = Box::new(move || {
632            if let Some(frontend) = backend_req_conn.as_ref() {
633                if let Err(e) = frontend.send_config_changed() {
634                    error!("Failed to notify config change: {:#}", e);
635                }
636            } else {
637                error!("No Backend request connection found");
638            }
639        });
640
641        let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_call"))?;
642        self.vrings[index as usize].doorbell = Some(Interrupt::new_vhost_user(
643            Event::from(SafeDescriptor::from(file)),
644            signal_config_change_fn,
645        ));
646        Ok(())
647    }
648
649    fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
650        // TODO
651        Ok(())
652    }
653
654    fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
655        if index as usize >= self.vrings.len() {
656            return Err(VhostError::InvalidParam(
657                "set_vring_enable: index out of range",
658            ));
659        }
660
661        // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
662        // has been negotiated.
663        if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
664            return Err(VhostError::InvalidOperation);
665        }
666
667        // Backend must not pass data to/from the ring until ring is enabled by
668        // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
669        // VHOST_USER_SET_VRING_ENABLE with parameter 0.
670        self.vrings[index as usize].enabled = enable;
671
672        Ok(())
673    }
674
675    fn get_config(
676        &mut self,
677        offset: u32,
678        size: u32,
679        _flags: VhostUserConfigFlags,
680    ) -> VhostResult<Vec<u8>> {
681        let mut data = vec![0; size as usize];
682        self.backend.read_config(u64::from(offset), &mut data);
683        Ok(data)
684    }
685
686    fn set_config(
687        &mut self,
688        offset: u32,
689        buf: &[u8],
690        _flags: VhostUserConfigFlags,
691    ) -> VhostResult<()> {
692        self.backend.write_config(u64::from(offset), buf);
693        Ok(())
694    }
695
696    fn set_backend_req_fd(&mut self, ep: Connection) {
697        let conn = VhostBackendReqConnection::new(
698            FrontendClient::new(
699                ep,
700                self.acked_protocol_features
701                    .contains(VhostUserProtocolFeatures::REPLY_ACK),
702            ),
703            self.backend.get_shared_memory_region().map(|r| r.id),
704        );
705
706        if self.backend_req_connection.is_some() {
707            warn!("Backend Request Connection already established. Overwriting");
708        }
709        self.backend_req_connection = Some(conn.clone());
710
711        self.backend.set_backend_req_connection(conn);
712    }
713
714    fn get_inflight_fd(
715        &mut self,
716        _inflight: &VhostUserInflight,
717    ) -> VhostResult<(VhostUserInflight, File)> {
718        unimplemented!("get_inflight_fd");
719    }
720
721    fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
722        unimplemented!("set_inflight_fd");
723    }
724
725    fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
726        //TODO
727        Ok(0)
728    }
729
730    fn add_mem_region(
731        &mut self,
732        _region: &VhostUserSingleMemoryRegion,
733        _fd: File,
734    ) -> VhostResult<()> {
735        //TODO
736        Ok(())
737    }
738
739    fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
740        //TODO
741        Ok(())
742    }
743
744    fn set_device_state_fd(
745        &mut self,
746        transfer_direction: VhostUserTransferDirection,
747        migration_phase: VhostUserMigrationPhase,
748        fd: File,
749    ) -> VhostResult<Option<File>> {
750        if migration_phase != VhostUserMigrationPhase::Stopped {
751            return Err(VhostError::InvalidOperation);
752        }
753        if !self.all_queues_stopped() {
754            return Err(VhostError::InvalidOperation);
755        }
756        if self.device_state_thread.is_some() {
757            error!("must call check_device_state before starting new state transfer");
758            return Err(VhostError::InvalidOperation);
759        }
760        // `set_device_state_fd` is designed to allow snapshot/restore concurrently with other
761        // methods, but, for simplicitly, we do those operations inline and only spawn a thread to
762        // handle the serialization and data transfer (the latter which seems necessary to
763        // implement the API correctly without, e.g., deadlocking because a pipe is full).
764        match transfer_direction {
765            VhostUserTransferDirection::Save => {
766                // Snapshot the state.
767                let snapshot = DeviceRequestHandlerSnapshot {
768                    acked_features: self.acked_features,
769                    acked_protocol_features: self.acked_protocol_features.bits(),
770                    backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
771                };
772                // Spawn thread to write the serialized bytes.
773                self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(
774                    "device_state_save",
775                    move |_kill_event| -> Result<(), ciborium::ser::Error<std::io::Error>> {
776                        let mut w = std::io::BufWriter::new(fd);
777                        ciborium::into_writer(&snapshot, &mut w)?;
778                        w.flush()?;
779                        Ok(())
780                    },
781                )));
782                Ok(None)
783            }
784            VhostUserTransferDirection::Load => {
785                // Spawn a thread to read the bytes and deserialize. Restore will happen in
786                // `check_device_state`.
787                self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(
788                    "device_state_load",
789                    move |_kill_event| ciborium::from_reader(&mut BufReader::new(fd)),
790                )));
791                Ok(None)
792            }
793        }
794    }
795
796    fn check_device_state(&mut self) -> VhostResult<()> {
797        let Some(thread) = self.device_state_thread.take() else {
798            error!("check_device_state: no active state transfer");
799            return Err(VhostError::InvalidOperation);
800        };
801        match thread {
802            DeviceStateThread::Save(worker) => {
803                worker.stop().map_err(|e| {
804                    error!("device state save thread failed: {:#}", e);
805                    VhostError::BackendInternalError
806                })?;
807                Ok(())
808            }
809            DeviceStateThread::Load(worker) => {
810                let snapshot = worker.stop().map_err(|e| {
811                    error!("device state load thread failed: {:#}", e);
812                    VhostError::BackendInternalError
813                })?;
814                self.acked_features = snapshot.acked_features;
815                self.acked_protocol_features =
816                    VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)
817                        .with_context(|| {
818                            format!(
819                                "unsupported bits in acked_protocol_features: {:#x}",
820                                snapshot.acked_protocol_features
821                            )
822                        })
823                        .map_err(VhostError::RestoreError)?;
824                self.backend
825                    .restore(snapshot.backend)
826                    .map_err(VhostError::RestoreError)?;
827                Ok(())
828            }
829        }
830    }
831
832    fn get_shmem_config(&mut self) -> VhostResult<Vec<SharedMemoryRegion>> {
833        Ok(self
834            .backend
835            .get_shared_memory_region()
836            .into_iter()
837            .collect())
838    }
839}
840
841/// Keeps track of Vhost user backend request connection.
842#[derive(Clone)]
843pub struct VhostBackendReqConnection {
844    shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
845    shmid: Option<u8>,
846}
847
848struct VhostBackendReqConnectionShared {
849    conn: FrontendClient,
850    mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,
851}
852
853impl VhostBackendReqConnection {
854    fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
855        Self {
856            shared: Arc::new(Mutex::new(VhostBackendReqConnectionShared {
857                conn,
858                mapped_regions: BTreeMap::new(),
859            })),
860            shmid,
861        }
862    }
863
864    /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend
865    fn send_config_changed(&self) -> anyhow::Result<()> {
866        let mut shared = self.shared.lock();
867        shared
868            .conn
869            .handle_config_change()
870            .context("Could not send config change message")?;
871        Ok(())
872    }
873
874    /// Create a SharedMemoryMapper trait object using this backend request connection.
875    pub fn shmem_mapper(&self) -> Option<Box<dyn SharedMemoryMapper>> {
876        if let Some(shmid) = self.shmid {
877            Some(Box::new(VhostShmemMapper {
878                shared: self.shared.clone(),
879                shmid,
880            }))
881        } else {
882            None
883        }
884    }
885}
886
887#[derive(Clone)]
888struct VhostShmemMapper {
889    shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
890    shmid: u8,
891}
892
893impl SharedMemoryMapper for VhostShmemMapper {
894    fn add_mapping(
895        &mut self,
896        source: VmMemorySource,
897        offset: u64,
898        prot: Protection,
899        _cache: MemCacheType,
900    ) -> anyhow::Result<()> {
901        let mut shared = self.shared.lock();
902        let size = match source {
903            VmMemorySource::Vulkan {
904                descriptor,
905                handle_type,
906                memory_idx,
907                device_uuid,
908                driver_uuid,
909                size,
910            } => {
911                let msg = VhostUserGpuMapMsg::new(
912                    self.shmid,
913                    offset,
914                    size,
915                    memory_idx,
916                    handle_type,
917                    device_uuid,
918                    driver_uuid,
919                );
920                shared
921                    .conn
922                    .gpu_map(&msg, &descriptor)
923                    .context("map GPU memory")?;
924                size
925            }
926            VmMemorySource::ExternalMapping { ptr, size } => {
927                let msg = VhostUserExternalMapMsg::new(self.shmid, offset, size, ptr);
928                shared
929                    .conn
930                    .external_map(&msg)
931                    .context("create external mapping")?;
932                size
933            }
934            source => {
935                // The last two sources use the same VhostUserMMap, continue matching here
936                // on the aliased `source` above.
937                let (descriptor, fd_offset, size) = match source {
938                    VmMemorySource::Descriptor {
939                        descriptor,
940                        offset,
941                        size,
942                    } => (descriptor, offset, size),
943                    VmMemorySource::SharedMemory(shmem) => {
944                        let size = shmem.size();
945                        let descriptor = SafeDescriptor::from(shmem);
946                        (descriptor, 0, size)
947                    }
948                    _ => bail!("unsupported source"),
949                };
950                let mut flags = VhostUserMMapFlags::empty();
951                anyhow::ensure!(prot.allows(&Protection::read()), "mapping must be readable");
952                if prot.allows(&Protection::write()) {
953                    flags |= VhostUserMMapFlags::MAP_RW;
954                }
955                let msg = VhostUserMMap {
956                    shmid: self.shmid,
957                    padding: Default::default(),
958                    fd_offset,
959                    shm_offset: offset,
960                    len: size,
961                    flags,
962                };
963                shared
964                    .conn
965                    .shmem_map(&msg, &descriptor)
966                    .context("map shmem")?;
967                size
968            }
969        };
970
971        shared.mapped_regions.insert(offset, size);
972        Ok(())
973    }
974
975    fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
976        let mut shared = self.shared.lock();
977        let size = shared
978            .mapped_regions
979            .remove(&offset)
980            .context("unknown offset")?;
981        let msg = VhostUserMMap {
982            shmid: self.shmid,
983            padding: Default::default(),
984            fd_offset: 0,
985            shm_offset: offset,
986            len: size,
987            flags: VhostUserMMapFlags::empty(),
988        };
989        shared
990            .conn
991            .shmem_unmap(&msg)
992            .context("unmap shmem")
993            .map(|_| ())
994    }
995}
996
997pub(crate) struct WorkerState<T, U> {
998    pub(crate) queue_task: TaskHandle<U>,
999    pub(crate) queue: T,
1000}
1001
1002/// Errors for device operations
1003#[derive(Debug, ThisError)]
1004pub enum Error {
1005    #[error("worker not found when stopping queue")]
1006    WorkerNotFound,
1007}
1008
1009#[cfg(test)]
1010mod tests {
1011    use std::sync::mpsc::channel;
1012    use std::sync::Barrier;
1013
1014    use anyhow::bail;
1015    use base::Event;
1016    use vmm_vhost::BackendServer;
1017    use vmm_vhost::FrontendReq;
1018    use zerocopy::FromBytes;
1019    use zerocopy::FromZeros;
1020    use zerocopy::Immutable;
1021    use zerocopy::IntoBytes;
1022    use zerocopy::KnownLayout;
1023
1024    use super::*;
1025    use crate::virtio::vhost_user_frontend::VhostUserFrontend;
1026    use crate::virtio::DeviceType;
1027    use crate::virtio::VirtioDevice;
1028
1029    #[derive(Clone, Copy, Debug, PartialEq, Eq, FromBytes, Immutable, IntoBytes, KnownLayout)]
1030    #[repr(C, packed(4))]
1031    struct FakeConfig {
1032        x: u32,
1033        y: u64,
1034    }
1035
1036    const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
1037
1038    pub(super) struct FakeBackend {
1039        avail_features: u64,
1040        acked_features: u64,
1041        active_queues: Vec<Option<Queue>>,
1042        allow_backend_req: bool,
1043        backend_conn: Option<VhostBackendReqConnection>,
1044    }
1045
1046    #[derive(Deserialize, Serialize)]
1047    struct FakeBackendSnapshot {
1048        data: Vec<u8>,
1049    }
1050
1051    impl FakeBackend {
1052        const MAX_QUEUE_NUM: usize = 16;
1053
1054        pub(super) fn new() -> Self {
1055            let mut active_queues = Vec::new();
1056            active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1057            Self {
1058                avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
1059                acked_features: 0,
1060                active_queues,
1061                allow_backend_req: false,
1062                backend_conn: None,
1063            }
1064        }
1065    }
1066
1067    impl VhostUserDevice for FakeBackend {
1068        fn max_queue_num(&self) -> usize {
1069            Self::MAX_QUEUE_NUM
1070        }
1071
1072        fn features(&self) -> u64 {
1073            self.avail_features
1074        }
1075
1076        fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1077            let unrequested_features = value & !self.avail_features;
1078            if unrequested_features != 0 {
1079                bail!(
1080                    "invalid protocol features are given: 0x{:x}",
1081                    unrequested_features
1082                );
1083            }
1084            self.acked_features |= value;
1085            Ok(())
1086        }
1087
1088        fn protocol_features(&self) -> VhostUserProtocolFeatures {
1089            let mut features =
1090                VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;
1091            if self.allow_backend_req {
1092                features |= VhostUserProtocolFeatures::BACKEND_REQ;
1093            }
1094            features
1095        }
1096
1097        fn read_config(&self, offset: u64, dst: &mut [u8]) {
1098            dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1099        }
1100
1101        fn reset(&mut self) {}
1102
1103        fn start_queue(
1104            &mut self,
1105            idx: usize,
1106            queue: Queue,
1107            _mem: GuestMemory,
1108        ) -> anyhow::Result<()> {
1109            self.active_queues[idx] = Some(queue);
1110            Ok(())
1111        }
1112
1113        fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1114            Ok(self.active_queues[idx]
1115                .take()
1116                .ok_or(Error::WorkerNotFound)?)
1117        }
1118
1119        fn set_backend_req_connection(&mut self, conn: VhostBackendReqConnection) {
1120            self.backend_conn = Some(conn);
1121        }
1122
1123        fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
1124            Ok(())
1125        }
1126
1127        fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1128            AnySnapshot::to_any(FakeBackendSnapshot {
1129                data: vec![1, 2, 3],
1130            })
1131            .context("failed to serialize snapshot")
1132        }
1133
1134        fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1135            let snapshot: FakeBackendSnapshot =
1136                AnySnapshot::from_any(data).context("failed to deserialize snapshot")?;
1137            assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");
1138            Ok(())
1139        }
1140    }
1141
1142    #[test]
1143    fn test_vhost_user_lifecycle() {
1144        test_vhost_user_lifecycle_parameterized(false);
1145    }
1146
1147    #[test]
1148    #[cfg(not(windows))] // Windows requries more complex connection setup.
1149    fn test_vhost_user_lifecycle_with_backend_req() {
1150        test_vhost_user_lifecycle_parameterized(true);
1151    }
1152
1153    fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {
1154        const QUEUES_NUM: usize = 2;
1155
1156        let (client_connection, server_connection) = vmm_vhost::Connection::pair().unwrap();
1157
1158        let vmm_bar = Arc::new(Barrier::new(2));
1159        let dev_bar = vmm_bar.clone();
1160
1161        let (ready_tx, ready_rx) = channel();
1162        let (shutdown_tx, shutdown_rx) = channel();
1163        let (vm_evt_wrtube, _vm_evt_rdtube) = base::Tube::directional_pair().unwrap();
1164
1165        std::thread::spawn(move || {
1166            // VMM side
1167            ready_rx.recv().unwrap(); // Ensure the device is ready.
1168
1169            let mut vmm_device = VhostUserFrontend::new(
1170                DeviceType::Console,
1171                0,
1172                client_connection,
1173                vm_evt_wrtube,
1174                None,
1175                None,
1176            )
1177            .unwrap();
1178
1179            println!("read_config");
1180            let mut config = FakeConfig::new_zeroed();
1181            vmm_device.read_config(0, config.as_mut_bytes());
1182            // Check if the obtained config data is correct.
1183            assert_eq!(config, FAKE_CONFIG_DATA);
1184
1185            let activate = |vmm_device: &mut VhostUserFrontend| {
1186                let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1187                let interrupt = Interrupt::new_for_test_with_msix();
1188
1189                let mut queues = BTreeMap::new();
1190                for idx in 0..QUEUES_NUM {
1191                    let mut queue = QueueConfig::new(0x10, 0);
1192                    queue.set_ready(true);
1193                    let queue = queue
1194                        .activate(&mem, Event::new().unwrap(), interrupt.clone())
1195                        .expect("QueueConfig::activate");
1196                    queues.insert(idx, queue);
1197                }
1198
1199                println!("activate");
1200                vmm_device.activate(mem, interrupt, queues).unwrap();
1201            };
1202
1203            activate(&mut vmm_device);
1204
1205            println!("reset");
1206            let reset_result = vmm_device.reset();
1207            assert!(
1208                reset_result.is_ok(),
1209                "reset failed: {:#}",
1210                reset_result.unwrap_err()
1211            );
1212
1213            activate(&mut vmm_device);
1214
1215            println!("virtio_sleep");
1216            let queues = vmm_device
1217                .virtio_sleep()
1218                .unwrap()
1219                .expect("virtio_sleep unexpectedly returned None");
1220
1221            println!("virtio_snapshot");
1222            let snapshot = vmm_device
1223                .virtio_snapshot()
1224                .expect("virtio_snapshot failed");
1225            println!("virtio_restore");
1226            vmm_device
1227                .virtio_restore(snapshot)
1228                .expect("virtio_restore failed");
1229
1230            println!("virtio_wake");
1231            let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1232            let interrupt = Interrupt::new_for_test_with_msix();
1233            vmm_device
1234                .virtio_wake(Some((mem, interrupt, queues)))
1235                .unwrap();
1236
1237            println!("wait for shutdown signal");
1238            shutdown_rx.recv().unwrap();
1239
1240            // The VMM side is supposed to stop before the device side.
1241            println!("drop");
1242            drop(vmm_device);
1243
1244            vmm_bar.wait();
1245        });
1246
1247        // Device side
1248        let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1249        handler.as_mut().allow_backend_req = allow_backend_req;
1250
1251        // Notify listener is ready.
1252        ready_tx.send(()).unwrap();
1253
1254        let mut req_handler = BackendServer::new(server_connection, handler);
1255
1256        // VhostUserFrontend::new()
1257        handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1258        handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1259        handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1260        handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1261        if allow_backend_req {
1262            handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1263        }
1264
1265        // VhostUserFrontend::read_config()
1266        handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1267
1268        // VhostUserFrontend::activate()
1269        handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1270        handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1271        for _ in 0..QUEUES_NUM {
1272            handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1273            handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1274            handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1275            handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1276            handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1277            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1278        }
1279
1280        // VhostUserFrontend::reset()
1281        for _ in 0..QUEUES_NUM {
1282            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1283            handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1284        }
1285
1286        // VhostUserFrontend::activate()
1287        handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1288        handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1289        for _ in 0..QUEUES_NUM {
1290            handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1291            handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1292            handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1293            handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1294            handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1295            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1296        }
1297
1298        if allow_backend_req {
1299            // Make sure the connection still works even after reset/reactivate.
1300            req_handler
1301                .as_ref()
1302                .as_ref()
1303                .backend_conn
1304                .as_ref()
1305                .expect("backend_conn missing")
1306                .send_config_changed()
1307                .expect("send_config_changed failed");
1308        }
1309
1310        // VhostUserFrontend::virtio_sleep()
1311        for _ in 0..QUEUES_NUM {
1312            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1313            handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1314        }
1315
1316        // VhostUserFrontend::virtio_snapshot()
1317        handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1318        handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1319        // VhostUserFrontend::virtio_restore()
1320        handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1321        handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1322        handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1323
1324        // VhostUserFrontend::virtio_wake()
1325        handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1326        for _ in 0..QUEUES_NUM {
1327            handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1328            handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1329            handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1330            handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1331            handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1332            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1333        }
1334
1335        if allow_backend_req {
1336            // Make sure the connection still works even after sleep/wake.
1337            req_handler
1338                .as_ref()
1339                .as_ref()
1340                .backend_conn
1341                .as_ref()
1342                .expect("backend_conn missing")
1343                .send_config_changed()
1344                .expect("send_config_changed failed");
1345        }
1346
1347        // Ask the client to shutdown, then wait to it to finish.
1348        shutdown_tx.send(()).unwrap();
1349        dev_bar.wait();
1350
1351        // Verify recv_header fails with `ClientExit` after the client has disconnected.
1352        match req_handler.recv_header() {
1353            Err(VhostError::ClientExit) => (),
1354            r => panic!("expected Err(ClientExit) but got {r:?}"),
1355        }
1356    }
1357
1358    fn handle_request<S: vmm_vhost::Backend>(
1359        handler: &mut BackendServer<S>,
1360        expected_message_type: FrontendReq,
1361    ) -> Result<(), VhostError> {
1362        let (hdr, files) = handler.recv_header()?;
1363        assert_eq!(hdr.get_code(), Ok(expected_message_type));
1364        handler.process_message(hdr, files)
1365    }
1366}