devices/virtio/vhost_user_backend/
handler.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Library for implementing vhost-user device executables.
6//!
7//! This crate provides
8//! * `VhostUserDevice` trait, which is a collection of methods to handle vhost-user requests, and
9//! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10//!
11//! They are expected to be used as follows:
12//!
13//! 1. Define a struct and implement `VhostUserDevice` for it.
14//! 2. Create a `DeviceRequestHandler` with the backend struct.
15//! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16//!
17//! ```ignore
18//! struct MyBackend {
19//!   /* fields */
20//! }
21//!
22//! impl VhostUserDevice for MyBackend {
23//!   /* implement methods */
24//! }
25//!
26//! fn main() -> Result<(), Box<dyn Error>> {
27//!   let backend = MyBackend { /* initialize fields */ };
28//!   let handler = DeviceRequestHandler::new(backend);
29//!   let socket = std::path::Path("/path/to/socket");
30//!   let ex = cros_async::Executor::new()?;
31//!
32//!   if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33//!     eprintln!("error happened: {}", e);
34//!   }
35//!   Ok(())
36//! }
37//! ```
38// Implementation note:
39// This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
40// protocol. DeviceRequestHandler implements the Backend trait from vmm_vhost, and includes some
41// common code for setting up guest memory and managing partially configured vrings.
42// DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request() when it
43// becomes readable. handle_request() reads and parses the message and then calls one of the
44// Backend trait methods. These dispatch back to the supplied VhostUserDevice implementation (this
45// is what our devices implement).
46
47pub(super) mod sys;
48
49use std::collections::BTreeMap;
50use std::convert::From;
51use std::fs::File;
52use std::io::BufReader;
53use std::io::Write;
54use std::num::Wrapping;
55#[cfg(any(target_os = "android", target_os = "linux"))]
56use std::os::unix::io::AsRawFd;
57use std::sync::Arc;
58
59use anyhow::bail;
60use anyhow::Context;
61#[cfg(any(target_os = "android", target_os = "linux"))]
62use base::clear_fd_flags;
63use base::error;
64use base::trace;
65use base::warn;
66use base::Event;
67use base::Protection;
68use base::SafeDescriptor;
69use base::SharedMemory;
70use base::WorkerThread;
71use cros_async::TaskHandle;
72use hypervisor::MemCacheType;
73use serde::Deserialize;
74use serde::Serialize;
75use snapshot::AnySnapshot;
76use sync::Mutex;
77use thiserror::Error as ThisError;
78use vm_control::VmMemorySource;
79use vm_memory::GuestAddress;
80use vm_memory::GuestMemory;
81use vm_memory::MemoryRegion;
82use vmm_vhost::message::VhostUserConfigFlags;
83use vmm_vhost::message::VhostUserExternalMapMsg;
84use vmm_vhost::message::VhostUserGpuMapMsg;
85use vmm_vhost::message::VhostUserInflight;
86use vmm_vhost::message::VhostUserMMap;
87use vmm_vhost::message::VhostUserMMapFlags;
88use vmm_vhost::message::VhostUserMemoryRegion;
89use vmm_vhost::message::VhostUserMigrationPhase;
90use vmm_vhost::message::VhostUserProtocolFeatures;
91use vmm_vhost::message::VhostUserShMemConfigHeader;
92use vmm_vhost::message::VhostUserSingleMemoryRegion;
93use vmm_vhost::message::VhostUserTransferDirection;
94use vmm_vhost::message::VhostUserVringAddrFlags;
95use vmm_vhost::message::VhostUserVringState;
96use vmm_vhost::BackendReq;
97use vmm_vhost::Connection;
98use vmm_vhost::Error as VhostError;
99use vmm_vhost::Frontend;
100use vmm_vhost::FrontendClient;
101use vmm_vhost::Result as VhostResult;
102use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
103
104use crate::virtio::Interrupt;
105use crate::virtio::Queue;
106use crate::virtio::QueueConfig;
107use crate::virtio::SharedMemoryMapper;
108use crate::virtio::SharedMemoryRegion;
109
110/// Keeps a mapping from the vmm's virtual addresses to guest addresses.
111/// used to translate messages from the vmm to guest offsets.
112#[derive(Default)]
113pub struct MappingInfo {
114    pub vmm_addr: u64,
115    pub guest_phys: u64,
116    pub size: u64,
117}
118
119pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
120    for map in maps {
121        if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
122            return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
123        }
124    }
125    Err(VhostError::InvalidMessage)
126}
127
128/// Trait for vhost-user devices. Analogous to the `VirtioDevice` trait.
129///
130/// In contrast with [[vmm_vhost::Backend]], which closely matches the vhost-user spec, this trait
131/// is designed to follow crosvm conventions for implementing devices.
132pub trait VhostUserDevice {
133    /// The maximum number of queues that this backend can manage.
134    fn max_queue_num(&self) -> usize;
135
136    /// The set of feature bits that this backend supports.
137    fn features(&self) -> u64;
138
139    /// Acknowledges that this set of features should be enabled.
140    ///
141    /// Implementations only need to handle device-specific feature bits; the `DeviceRequestHandler`
142    /// framework will manage generic vhost and vring features.
143    ///
144    /// `DeviceRequestHandler` checks for valid features before calling this function, so the
145    /// features in `value` will always be a subset of those advertised by `features()`.
146    fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {
147        Ok(())
148    }
149
150    /// The set of protocol feature bits that this backend supports.
151    fn protocol_features(&self) -> VhostUserProtocolFeatures;
152
153    /// Reads this device configuration space at `offset`.
154    fn read_config(&self, offset: u64, dst: &mut [u8]);
155
156    /// writes `data` to this device's configuration space at `offset`.
157    fn write_config(&self, _offset: u64, _data: &[u8]) {}
158
159    /// Indicates that the backend should start processing requests for virtio queue number `idx`.
160    /// This method must not block the current thread so device backends should either spawn an
161    /// async task or another thread to handle messages from the Queue.
162    fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;
163
164    /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
165    /// This method should return the queue passed to `start_queue` for the corresponding `idx`.
166    /// This method will only be called for queues that were previously started by `start_queue`.
167    fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
168
169    /// Resets the vhost-user backend.
170    fn reset(&mut self);
171
172    /// Returns the device's shared memory region if present.
173    fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
174        None
175    }
176
177    /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message
178    /// handling.
179    ///
180    /// This method will be called when `VhostUserProtocolFeatures::BACKEND_REQ` is
181    /// negotiated.
182    fn set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection) {}
183
184    /// Enter the "suspended device state" described in the vhost-user spec. See the spec for
185    /// requirements.
186    ///
187    /// One reasonably foolproof way to satisfy the requirements is to stop all worker threads.
188    ///
189    /// Called after a `stop_queue` call if there are no running queues left. Also called soon
190    /// after device creation to ensure the device is acting suspended immediately on construction.
191    ///
192    /// The next `start_queue` call implicitly exits the "suspend device state".
193    ///
194    /// * Ok(())    => device successfully suspended
195    /// * Err(_)    => unrecoverable error
196    fn enter_suspended_state(&mut self) -> anyhow::Result<()>;
197
198    /// Snapshot device and return serialized state.
199    fn snapshot(&mut self) -> anyhow::Result<AnySnapshot>;
200
201    /// Restore device state from a snapshot.
202    fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>;
203
204    /// Whether guest memory should be unmapped in forked processes.
205    ///
206    /// This is intended for use in combination with --protected-vm, where the guest memory can be
207    /// dangerous to access. Some systems, e.g. Android, have tools that fork processes and examine
208    /// their memory. This flag effectively hides the guest memory from those tools.
209    ///
210    /// Not compatible with sandboxing.
211    fn unmap_guest_memory_on_fork(&self) -> bool {
212        false
213    }
214}
215
216/// A virtio ring entry.
217struct Vring {
218    // The queue config. This doesn't get mutated by the queue workers.
219    queue: QueueConfig,
220    doorbell: Option<Interrupt>,
221    enabled: bool,
222}
223
224impl Vring {
225    fn new(max_size: u16, features: u64) -> Self {
226        Self {
227            queue: QueueConfig::new(max_size, features),
228            doorbell: None,
229            enabled: false,
230        }
231    }
232
233    fn reset(&mut self) {
234        self.queue.reset();
235        self.doorbell = None;
236        self.enabled = false;
237    }
238}
239
240/// Ops for running vhost-user over a stream (i.e. regular protocol).
241pub(super) struct VhostUserRegularOps;
242
243impl VhostUserRegularOps {
244    pub fn set_mem_table(
245        contexts: &[VhostUserMemoryRegion],
246        files: Vec<File>,
247    ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
248        if files.len() != contexts.len() {
249            return Err(VhostError::InvalidParam(
250                "number of files & contexts was not equal",
251            ));
252        }
253
254        let mut regions = Vec::with_capacity(files.len());
255        for (region, file) in contexts.iter().zip(files.into_iter()) {
256            let region = MemoryRegion::new_from_shm(
257                region.memory_size,
258                GuestAddress(region.guest_phys_addr),
259                region.mmap_offset,
260                Arc::new(
261                    SharedMemory::from_safe_descriptor(
262                        SafeDescriptor::from(file),
263                        region.memory_size,
264                    )
265                    .unwrap(),
266                ),
267            )
268            .map_err(|e| {
269                error!("failed to create a memory region: {}", e);
270                VhostError::InvalidOperation
271            })?;
272            regions.push(region);
273        }
274        let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
275            error!("failed to create guest memory: {}", e);
276            VhostError::InvalidOperation
277        })?;
278
279        let vmm_maps = contexts
280            .iter()
281            .map(|region| MappingInfo {
282                vmm_addr: region.user_addr,
283                guest_phys: region.guest_phys_addr,
284                size: region.memory_size,
285            })
286            .collect();
287        Ok((guest_mem, vmm_maps))
288    }
289}
290
291/// An adapter that implements `vmm_vhost::Backend` for any type implementing `VhostUserDevice`.
292pub struct DeviceRequestHandler<T: VhostUserDevice> {
293    vrings: Vec<Vring>,
294    owned: bool,
295    vmm_maps: Option<Vec<MappingInfo>>,
296    mem: Option<GuestMemory>,
297    acked_features: u64,
298    acked_protocol_features: VhostUserProtocolFeatures,
299    backend: T,
300    backend_req_connection: Option<VhostBackendReqConnection>,
301    // Thread processing active device state FD.
302    device_state_thread: Option<DeviceStateThread>,
303}
304
305enum DeviceStateThread {
306    Save(WorkerThread<Result<(), ciborium::ser::Error<std::io::Error>>>),
307    Load(WorkerThread<Result<DeviceRequestHandlerSnapshot, ciborium::de::Error<std::io::Error>>>),
308}
309
310#[derive(Serialize, Deserialize)]
311pub struct DeviceRequestHandlerSnapshot {
312    acked_features: u64,
313    acked_protocol_features: u64,
314    backend: AnySnapshot,
315}
316
317impl<T: VhostUserDevice> DeviceRequestHandler<T> {
318    /// Creates a vhost-user handler instance for `backend`.
319    pub(crate) fn new(mut backend: T) -> Self {
320        let mut vrings = Vec::with_capacity(backend.max_queue_num());
321        for _ in 0..backend.max_queue_num() {
322            vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
323        }
324
325        // VhostUserDevice implementations must support `enter_suspended_state()`.
326        // Call it on startup to ensure it works and to initialize the device in a suspended state.
327        backend
328            .enter_suspended_state()
329            .expect("enter_suspended_state failed on device init");
330
331        DeviceRequestHandler {
332            vrings,
333            owned: false,
334            vmm_maps: None,
335            mem: None,
336            acked_features: 0,
337            acked_protocol_features: VhostUserProtocolFeatures::empty(),
338            backend,
339            backend_req_connection: None,
340            device_state_thread: None,
341        }
342    }
343
344    /// Check if all queues are stopped.
345    ///
346    /// The device can be suspended with `enter_suspended_state()` only when all queues are stopped.
347    fn all_queues_stopped(&self) -> bool {
348        self.vrings.iter().all(|vring| !vring.queue.ready())
349    }
350}
351
352impl<T: VhostUserDevice> Drop for DeviceRequestHandler<T> {
353    fn drop(&mut self) {
354        for (index, vring) in self.vrings.iter().enumerate() {
355            if vring.queue.ready() {
356                if let Err(e) = self.backend.stop_queue(index) {
357                    error!("Failed to stop queue {} during drop: {:#}", index, e);
358                }
359            }
360        }
361    }
362}
363
364impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
365    fn as_ref(&self) -> &T {
366        &self.backend
367    }
368}
369
370impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
371    fn as_mut(&mut self) -> &mut T {
372        &mut self.backend
373    }
374}
375
376impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
377    fn set_owner(&mut self) -> VhostResult<()> {
378        if self.owned {
379            return Err(VhostError::InvalidOperation);
380        }
381        self.owned = true;
382        Ok(())
383    }
384
385    fn reset_owner(&mut self) -> VhostResult<()> {
386        self.owned = false;
387        self.acked_features = 0;
388        self.backend.reset();
389        Ok(())
390    }
391
392    fn get_features(&mut self) -> VhostResult<u64> {
393        let features = self.backend.features();
394        Ok(features)
395    }
396
397    fn set_features(&mut self, features: u64) -> VhostResult<()> {
398        if !self.owned {
399            return Err(VhostError::InvalidOperation);
400        }
401
402        let unexpected_features = features & !self.backend.features();
403        if unexpected_features != 0 {
404            error!("unexpected set_features {:#x}", unexpected_features);
405            return Err(VhostError::InvalidParam("unexpected set_features"));
406        }
407
408        if let Err(e) = self.backend.ack_features(features) {
409            error!("failed to acknowledge features 0x{:x}: {}", features, e);
410            return Err(VhostError::InvalidOperation);
411        }
412
413        self.acked_features |= features;
414
415        // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
416        // enabled state.
417        // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
418        // disabled state.
419        // Client must not pass data to/from the backend until ring is enabled by
420        // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
421        // VHOST_USER_SET_VRING_ENABLE with parameter 0.
422        let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
423        for v in &mut self.vrings {
424            v.enabled = vring_enabled;
425        }
426
427        Ok(())
428    }
429
430    fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
431        Ok(self.backend.protocol_features())
432    }
433
434    fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
435        let features = match VhostUserProtocolFeatures::from_bits(features) {
436            Some(proto_features) => proto_features,
437            None => {
438                error!(
439                    "unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",
440                    features
441                );
442                return Err(VhostError::InvalidOperation);
443            }
444        };
445        let supported = self.backend.protocol_features();
446        self.acked_protocol_features = features & supported;
447        Ok(())
448    }
449
450    fn set_mem_table(
451        &mut self,
452        contexts: &[VhostUserMemoryRegion],
453        files: Vec<File>,
454    ) -> VhostResult<()> {
455        let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
456        if self.backend.unmap_guest_memory_on_fork() {
457            #[cfg(any(target_os = "android", target_os = "linux"))]
458            if let Err(e) = guest_mem.use_dontfork() {
459                error!("failed to set MADV_DONTFORK on guest memory: {e:#}");
460            }
461            #[cfg(not(any(target_os = "android", target_os = "linux")))]
462            error!("unmap_guest_memory_on_fork unsupported; skipping");
463        }
464        self.mem = Some(guest_mem);
465        self.vmm_maps = Some(vmm_maps);
466        Ok(())
467    }
468
469    fn get_queue_num(&mut self) -> VhostResult<u64> {
470        Ok(self.vrings.len() as u64)
471    }
472
473    fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
474        if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
475            return Err(VhostError::InvalidParam(
476                "set_vring_num: invalid index or num",
477            ));
478        }
479        self.vrings[index as usize].queue.set_size(num as u16);
480
481        Ok(())
482    }
483
484    fn set_vring_addr(
485        &mut self,
486        index: u32,
487        _flags: VhostUserVringAddrFlags,
488        descriptor: u64,
489        used: u64,
490        available: u64,
491        _log: u64,
492    ) -> VhostResult<()> {
493        if index as usize >= self.vrings.len() {
494            return Err(VhostError::InvalidParam(
495                "set_vring_addr: index out of range",
496            ));
497        }
498
499        let vmm_maps = self
500            .vmm_maps
501            .as_ref()
502            .ok_or(VhostError::InvalidParam("set_vring_addr: missing vmm_maps"))?;
503        let vring = &mut self.vrings[index as usize];
504        vring
505            .queue
506            .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
507        vring
508            .queue
509            .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
510        vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
511
512        Ok(())
513    }
514
515    fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
516        if index as usize >= self.vrings.len() {
517            return Err(VhostError::InvalidParam(
518                "set_vring_base: index out of range",
519            ));
520        }
521
522        let vring = &mut self.vrings[index as usize];
523        vring.queue.set_next_avail(Wrapping(base as u16));
524        vring.queue.set_next_used(Wrapping(base as u16));
525
526        Ok(())
527    }
528
529    fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
530        let vring = self
531            .vrings
532            .get_mut(index as usize)
533            .ok_or(VhostError::InvalidParam(
534                "get_vring_base: index out of range",
535            ))?;
536
537        // Quotation from vhost-user spec:
538        // "The back-end must [...] stop ring upon receiving VHOST_USER_GET_VRING_BASE."
539        // We only call `queue.set_ready()` when starting the queue, so if the queue is ready, that
540        // means it is started and should be stopped.
541        let vring_base = if vring.queue.ready() {
542            let queue = match self.backend.stop_queue(index as usize) {
543                Ok(q) => q,
544                Err(e) => {
545                    error!("Failed to stop queue in get_vring_base: {:#}", e);
546                    return Err(VhostError::BackendInternalError);
547                }
548            };
549
550            trace!("stopped queue {index}");
551            vring.reset();
552
553            if self.all_queues_stopped() {
554                trace!("all queues stopped; entering suspended state");
555                self.backend
556                    .enter_suspended_state()
557                    .map_err(VhostError::EnterSuspendedState)?;
558            }
559
560            queue.next_avail_to_process()
561        } else {
562            0
563        };
564
565        Ok(VhostUserVringState::new(index, vring_base.into()))
566    }
567
568    fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
569        if index as usize >= self.vrings.len() {
570            return Err(VhostError::InvalidParam(
571                "set_vring_kick: index out of range",
572            ));
573        }
574
575        let vring = &mut self.vrings[index as usize];
576        if vring.queue.ready() {
577            error!("kick fd cannot replaced after queue is started");
578            return Err(VhostError::InvalidOperation);
579        }
580
581        let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_kick"))?;
582
583        // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
584        // values via `next_val()` later.
585        // This is only required (and can only be done) on Unix platforms.
586        #[cfg(any(target_os = "android", target_os = "linux"))]
587        if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
588            error!("failed to remove O_NONBLOCK for kick fd: {}", e);
589            return Err(VhostError::InvalidParam(
590                "could not remove O_NONBLOCK from vring_kick",
591            ));
592        }
593
594        let kick_evt = Event::from(SafeDescriptor::from(file));
595
596        // Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).
597        vring.queue.ack_features(self.acked_features);
598        vring.queue.set_ready(true);
599
600        let mem = self
601            .mem
602            .as_ref()
603            .cloned()
604            .ok_or(VhostError::InvalidOperation)?;
605
606        let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
607
608        let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {
609            Ok(queue) => queue,
610            Err(e) => {
611                error!("failed to activate vring: {:#}", e);
612                return Err(VhostError::BackendInternalError);
613            }
614        };
615
616        if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {
617            error!("Failed to start queue {}: {}", index, e);
618            return Err(VhostError::BackendInternalError);
619        }
620        trace!("started queue {index}");
621
622        Ok(())
623    }
624
625    fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
626        if index as usize >= self.vrings.len() {
627            return Err(VhostError::InvalidParam(
628                "set_vring_call: index out of range",
629            ));
630        }
631
632        let backend_req_conn = self.backend_req_connection.clone();
633        let signal_config_change_fn = Box::new(move || {
634            if let Some(frontend) = backend_req_conn.as_ref() {
635                if let Err(e) = frontend.send_config_changed() {
636                    error!("Failed to notify config change: {:#}", e);
637                }
638            } else {
639                error!("No Backend request connection found");
640            }
641        });
642
643        let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_call"))?;
644        self.vrings[index as usize].doorbell = Some(Interrupt::new_vhost_user(
645            Event::from(SafeDescriptor::from(file)),
646            signal_config_change_fn,
647        ));
648        Ok(())
649    }
650
651    fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
652        // TODO
653        Ok(())
654    }
655
656    fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
657        if index as usize >= self.vrings.len() {
658            return Err(VhostError::InvalidParam(
659                "set_vring_enable: index out of range",
660            ));
661        }
662
663        // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
664        // has been negotiated.
665        if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
666            return Err(VhostError::InvalidOperation);
667        }
668
669        // Backend must not pass data to/from the ring until ring is enabled by
670        // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
671        // VHOST_USER_SET_VRING_ENABLE with parameter 0.
672        self.vrings[index as usize].enabled = enable;
673
674        Ok(())
675    }
676
677    fn get_config(
678        &mut self,
679        offset: u32,
680        size: u32,
681        _flags: VhostUserConfigFlags,
682    ) -> VhostResult<Vec<u8>> {
683        let mut data = vec![0; size as usize];
684        self.backend.read_config(u64::from(offset), &mut data);
685        Ok(data)
686    }
687
688    fn set_config(
689        &mut self,
690        offset: u32,
691        buf: &[u8],
692        _flags: VhostUserConfigFlags,
693    ) -> VhostResult<()> {
694        self.backend.write_config(u64::from(offset), buf);
695        Ok(())
696    }
697
698    fn set_backend_req_fd(&mut self, ep: Connection<BackendReq>) {
699        let conn = VhostBackendReqConnection::new(
700            FrontendClient::new(ep),
701            self.backend.get_shared_memory_region().map(|r| r.id),
702        );
703
704        if self.backend_req_connection.is_some() {
705            warn!("Backend Request Connection already established. Overwriting");
706        }
707        self.backend_req_connection = Some(conn.clone());
708
709        self.backend.set_backend_req_connection(conn);
710    }
711
712    fn get_inflight_fd(
713        &mut self,
714        _inflight: &VhostUserInflight,
715    ) -> VhostResult<(VhostUserInflight, File)> {
716        unimplemented!("get_inflight_fd");
717    }
718
719    fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
720        unimplemented!("set_inflight_fd");
721    }
722
723    fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
724        //TODO
725        Ok(0)
726    }
727
728    fn add_mem_region(
729        &mut self,
730        _region: &VhostUserSingleMemoryRegion,
731        _fd: File,
732    ) -> VhostResult<()> {
733        //TODO
734        Ok(())
735    }
736
737    fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
738        //TODO
739        Ok(())
740    }
741
742    fn set_device_state_fd(
743        &mut self,
744        transfer_direction: VhostUserTransferDirection,
745        migration_phase: VhostUserMigrationPhase,
746        fd: File,
747    ) -> VhostResult<Option<File>> {
748        if migration_phase != VhostUserMigrationPhase::Stopped {
749            return Err(VhostError::InvalidOperation);
750        }
751        if !self.all_queues_stopped() {
752            return Err(VhostError::InvalidOperation);
753        }
754        if self.device_state_thread.is_some() {
755            error!("must call check_device_state before starting new state transfer");
756            return Err(VhostError::InvalidOperation);
757        }
758        // `set_device_state_fd` is designed to allow snapshot/restore concurrently with other
759        // methods, but, for simplicitly, we do those operations inline and only spawn a thread to
760        // handle the serialization and data transfer (the latter which seems necessary to
761        // implement the API correctly without, e.g., deadlocking because a pipe is full).
762        match transfer_direction {
763            VhostUserTransferDirection::Save => {
764                // Snapshot the state.
765                let snapshot = DeviceRequestHandlerSnapshot {
766                    acked_features: self.acked_features,
767                    acked_protocol_features: self.acked_protocol_features.bits(),
768                    backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
769                };
770                // Spawn thread to write the serialized bytes.
771                self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(
772                    "device_state_save",
773                    move |_kill_event| -> Result<(), ciborium::ser::Error<std::io::Error>> {
774                        let mut w = std::io::BufWriter::new(fd);
775                        ciborium::into_writer(&snapshot, &mut w)?;
776                        w.flush()?;
777                        Ok(())
778                    },
779                )));
780                Ok(None)
781            }
782            VhostUserTransferDirection::Load => {
783                // Spawn a thread to read the bytes and deserialize. Restore will happen in
784                // `check_device_state`.
785                self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(
786                    "device_state_load",
787                    move |_kill_event| ciborium::from_reader(&mut BufReader::new(fd)),
788                )));
789                Ok(None)
790            }
791        }
792    }
793
794    fn check_device_state(&mut self) -> VhostResult<()> {
795        let Some(thread) = self.device_state_thread.take() else {
796            error!("check_device_state: no active state transfer");
797            return Err(VhostError::InvalidOperation);
798        };
799        match thread {
800            DeviceStateThread::Save(worker) => {
801                worker.stop().map_err(|e| {
802                    error!("device state save thread failed: {:#}", e);
803                    VhostError::BackendInternalError
804                })?;
805                Ok(())
806            }
807            DeviceStateThread::Load(worker) => {
808                let snapshot = worker.stop().map_err(|e| {
809                    error!("device state load thread failed: {:#}", e);
810                    VhostError::BackendInternalError
811                })?;
812                self.acked_features = snapshot.acked_features;
813                self.acked_protocol_features =
814                    VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)
815                        .with_context(|| {
816                            format!(
817                                "unsupported bits in acked_protocol_features: {:#x}",
818                                snapshot.acked_protocol_features
819                            )
820                        })
821                        .map_err(VhostError::RestoreError)?;
822                self.backend
823                    .restore(snapshot.backend)
824                    .map_err(VhostError::RestoreError)?;
825                Ok(())
826            }
827        }
828    }
829
830    fn get_shmem_config(&mut self) -> VhostResult<(VhostUserShMemConfigHeader, Vec<u64>)> {
831        Ok(if let Some(r) = self.backend.get_shared_memory_region() {
832            (VhostUserShMemConfigHeader::new(1), vec![r.length])
833        } else {
834            (VhostUserShMemConfigHeader::new(0), Vec::new())
835        })
836    }
837}
838
839/// Keeps track of Vhost user backend request connection.
840#[derive(Clone)]
841pub struct VhostBackendReqConnection {
842    shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
843    shmid: Option<u8>,
844}
845
846struct VhostBackendReqConnectionShared {
847    conn: FrontendClient,
848    mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,
849}
850
851impl VhostBackendReqConnection {
852    fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
853        Self {
854            shared: Arc::new(Mutex::new(VhostBackendReqConnectionShared {
855                conn,
856                mapped_regions: BTreeMap::new(),
857            })),
858            shmid,
859        }
860    }
861
862    /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend
863    fn send_config_changed(&self) -> anyhow::Result<()> {
864        let mut shared = self.shared.lock();
865        shared
866            .conn
867            .handle_config_change()
868            .context("Could not send config change message")?;
869        Ok(())
870    }
871
872    /// Create a SharedMemoryMapper trait object using this backend request connection.
873    pub fn shmem_mapper(&self) -> Option<Box<dyn SharedMemoryMapper>> {
874        if let Some(shmid) = self.shmid {
875            Some(Box::new(VhostShmemMapper {
876                shared: self.shared.clone(),
877                shmid,
878            }))
879        } else {
880            None
881        }
882    }
883}
884
885#[derive(Clone)]
886struct VhostShmemMapper {
887    shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
888    shmid: u8,
889}
890
891impl SharedMemoryMapper for VhostShmemMapper {
892    fn add_mapping(
893        &mut self,
894        source: VmMemorySource,
895        offset: u64,
896        prot: Protection,
897        _cache: MemCacheType,
898    ) -> anyhow::Result<()> {
899        let mut shared = self.shared.lock();
900        let size = match source {
901            VmMemorySource::Vulkan {
902                descriptor,
903                handle_type,
904                memory_idx,
905                device_uuid,
906                driver_uuid,
907                size,
908            } => {
909                let msg = VhostUserGpuMapMsg::new(
910                    self.shmid,
911                    offset,
912                    size,
913                    memory_idx,
914                    handle_type,
915                    device_uuid,
916                    driver_uuid,
917                );
918                shared
919                    .conn
920                    .gpu_map(&msg, &descriptor)
921                    .context("map GPU memory")?;
922                size
923            }
924            VmMemorySource::ExternalMapping { ptr, size } => {
925                let msg = VhostUserExternalMapMsg::new(self.shmid, offset, size, ptr);
926                shared
927                    .conn
928                    .external_map(&msg)
929                    .context("create external mapping")?;
930                size
931            }
932            source => {
933                // The last two sources use the same VhostUserMMap, continue matching here
934                // on the aliased `source` above.
935                let (descriptor, fd_offset, size) = match source {
936                    VmMemorySource::Descriptor {
937                        descriptor,
938                        offset,
939                        size,
940                    } => (descriptor, offset, size),
941                    VmMemorySource::SharedMemory(shmem) => {
942                        let size = shmem.size();
943                        let descriptor = SafeDescriptor::from(shmem);
944                        (descriptor, 0, size)
945                    }
946                    _ => bail!("unsupported source"),
947                };
948                let mut flags = VhostUserMMapFlags::empty();
949                anyhow::ensure!(prot.allows(&Protection::read()), "mapping must be readable");
950                if prot.allows(&Protection::write()) {
951                    flags |= VhostUserMMapFlags::MAP_RW;
952                }
953                let msg = VhostUserMMap {
954                    shmid: self.shmid,
955                    padding: Default::default(),
956                    fd_offset,
957                    shm_offset: offset,
958                    len: size,
959                    flags,
960                };
961                shared
962                    .conn
963                    .shmem_map(&msg, &descriptor)
964                    .context("map shmem")?;
965                size
966            }
967        };
968
969        shared.mapped_regions.insert(offset, size);
970        Ok(())
971    }
972
973    fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
974        let mut shared = self.shared.lock();
975        let size = shared
976            .mapped_regions
977            .remove(&offset)
978            .context("unknown offset")?;
979        let msg = VhostUserMMap {
980            shmid: self.shmid,
981            padding: Default::default(),
982            fd_offset: 0,
983            shm_offset: offset,
984            len: size,
985            flags: VhostUserMMapFlags::empty(),
986        };
987        shared
988            .conn
989            .shmem_unmap(&msg)
990            .context("unmap shmem")
991            .map(|_| ())
992    }
993}
994
995pub(crate) struct WorkerState<T, U> {
996    pub(crate) queue_task: TaskHandle<U>,
997    pub(crate) queue: T,
998}
999
1000/// Errors for device operations
1001#[derive(Debug, ThisError)]
1002pub enum Error {
1003    #[error("worker not found when stopping queue")]
1004    WorkerNotFound,
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009    use std::sync::mpsc::channel;
1010    use std::sync::Barrier;
1011
1012    use anyhow::bail;
1013    use base::Event;
1014    use vmm_vhost::BackendServer;
1015    use vmm_vhost::FrontendReq;
1016    use zerocopy::FromBytes;
1017    use zerocopy::FromZeros;
1018    use zerocopy::Immutable;
1019    use zerocopy::IntoBytes;
1020    use zerocopy::KnownLayout;
1021
1022    use super::*;
1023    use crate::virtio::vhost_user_frontend::VhostUserFrontend;
1024    use crate::virtio::DeviceType;
1025    use crate::virtio::VirtioDevice;
1026
1027    #[derive(Clone, Copy, Debug, PartialEq, Eq, FromBytes, Immutable, IntoBytes, KnownLayout)]
1028    #[repr(C, packed(4))]
1029    struct FakeConfig {
1030        x: u32,
1031        y: u64,
1032    }
1033
1034    const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
1035
1036    pub(super) struct FakeBackend {
1037        avail_features: u64,
1038        acked_features: u64,
1039        active_queues: Vec<Option<Queue>>,
1040        allow_backend_req: bool,
1041        backend_conn: Option<VhostBackendReqConnection>,
1042    }
1043
1044    #[derive(Deserialize, Serialize)]
1045    struct FakeBackendSnapshot {
1046        data: Vec<u8>,
1047    }
1048
1049    impl FakeBackend {
1050        const MAX_QUEUE_NUM: usize = 16;
1051
1052        pub(super) fn new() -> Self {
1053            let mut active_queues = Vec::new();
1054            active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1055            Self {
1056                avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
1057                acked_features: 0,
1058                active_queues,
1059                allow_backend_req: false,
1060                backend_conn: None,
1061            }
1062        }
1063    }
1064
1065    impl VhostUserDevice for FakeBackend {
1066        fn max_queue_num(&self) -> usize {
1067            Self::MAX_QUEUE_NUM
1068        }
1069
1070        fn features(&self) -> u64 {
1071            self.avail_features
1072        }
1073
1074        fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1075            let unrequested_features = value & !self.avail_features;
1076            if unrequested_features != 0 {
1077                bail!(
1078                    "invalid protocol features are given: 0x{:x}",
1079                    unrequested_features
1080                );
1081            }
1082            self.acked_features |= value;
1083            Ok(())
1084        }
1085
1086        fn protocol_features(&self) -> VhostUserProtocolFeatures {
1087            let mut features =
1088                VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;
1089            if self.allow_backend_req {
1090                features |= VhostUserProtocolFeatures::BACKEND_REQ;
1091            }
1092            features
1093        }
1094
1095        fn read_config(&self, offset: u64, dst: &mut [u8]) {
1096            dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1097        }
1098
1099        fn reset(&mut self) {}
1100
1101        fn start_queue(
1102            &mut self,
1103            idx: usize,
1104            queue: Queue,
1105            _mem: GuestMemory,
1106        ) -> anyhow::Result<()> {
1107            self.active_queues[idx] = Some(queue);
1108            Ok(())
1109        }
1110
1111        fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1112            Ok(self.active_queues[idx]
1113                .take()
1114                .ok_or(Error::WorkerNotFound)?)
1115        }
1116
1117        fn set_backend_req_connection(&mut self, conn: VhostBackendReqConnection) {
1118            self.backend_conn = Some(conn);
1119        }
1120
1121        fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
1122            Ok(())
1123        }
1124
1125        fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1126            AnySnapshot::to_any(FakeBackendSnapshot {
1127                data: vec![1, 2, 3],
1128            })
1129            .context("failed to serialize snapshot")
1130        }
1131
1132        fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1133            let snapshot: FakeBackendSnapshot =
1134                AnySnapshot::from_any(data).context("failed to deserialize snapshot")?;
1135            assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");
1136            Ok(())
1137        }
1138    }
1139
1140    #[test]
1141    fn test_vhost_user_lifecycle() {
1142        test_vhost_user_lifecycle_parameterized(false);
1143    }
1144
1145    #[test]
1146    #[cfg(not(windows))] // Windows requries more complex connection setup.
1147    fn test_vhost_user_lifecycle_with_backend_req() {
1148        test_vhost_user_lifecycle_parameterized(true);
1149    }
1150
1151    fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {
1152        const QUEUES_NUM: usize = 2;
1153
1154        let (client_connection, server_connection) =
1155            vmm_vhost::Connection::<FrontendReq>::pair().unwrap();
1156
1157        let vmm_bar = Arc::new(Barrier::new(2));
1158        let dev_bar = vmm_bar.clone();
1159
1160        let (ready_tx, ready_rx) = channel();
1161        let (shutdown_tx, shutdown_rx) = channel();
1162        let (vm_evt_wrtube, _vm_evt_rdtube) = base::Tube::directional_pair().unwrap();
1163
1164        std::thread::spawn(move || {
1165            // VMM side
1166            ready_rx.recv().unwrap(); // Ensure the device is ready.
1167
1168            let mut vmm_device = VhostUserFrontend::new(
1169                DeviceType::Console,
1170                0,
1171                client_connection,
1172                vm_evt_wrtube,
1173                None,
1174                None,
1175            )
1176            .unwrap();
1177
1178            println!("read_config");
1179            let mut config = FakeConfig::new_zeroed();
1180            vmm_device.read_config(0, config.as_mut_bytes());
1181            // Check if the obtained config data is correct.
1182            assert_eq!(config, FAKE_CONFIG_DATA);
1183
1184            let activate = |vmm_device: &mut VhostUserFrontend| {
1185                let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1186                let interrupt = Interrupt::new_for_test_with_msix();
1187
1188                let mut queues = BTreeMap::new();
1189                for idx in 0..QUEUES_NUM {
1190                    let mut queue = QueueConfig::new(0x10, 0);
1191                    queue.set_ready(true);
1192                    let queue = queue
1193                        .activate(&mem, Event::new().unwrap(), interrupt.clone())
1194                        .expect("QueueConfig::activate");
1195                    queues.insert(idx, queue);
1196                }
1197
1198                println!("activate");
1199                vmm_device.activate(mem, interrupt, queues).unwrap();
1200            };
1201
1202            activate(&mut vmm_device);
1203
1204            println!("reset");
1205            let reset_result = vmm_device.reset();
1206            assert!(
1207                reset_result.is_ok(),
1208                "reset failed: {:#}",
1209                reset_result.unwrap_err()
1210            );
1211
1212            activate(&mut vmm_device);
1213
1214            println!("virtio_sleep");
1215            let queues = vmm_device
1216                .virtio_sleep()
1217                .unwrap()
1218                .expect("virtio_sleep unexpectedly returned None");
1219
1220            println!("virtio_snapshot");
1221            let snapshot = vmm_device
1222                .virtio_snapshot()
1223                .expect("virtio_snapshot failed");
1224            println!("virtio_restore");
1225            vmm_device
1226                .virtio_restore(snapshot)
1227                .expect("virtio_restore failed");
1228
1229            println!("virtio_wake");
1230            let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1231            let interrupt = Interrupt::new_for_test_with_msix();
1232            vmm_device
1233                .virtio_wake(Some((mem, interrupt, queues)))
1234                .unwrap();
1235
1236            println!("wait for shutdown signal");
1237            shutdown_rx.recv().unwrap();
1238
1239            // The VMM side is supposed to stop before the device side.
1240            println!("drop");
1241            drop(vmm_device);
1242
1243            vmm_bar.wait();
1244        });
1245
1246        // Device side
1247        let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1248        handler.as_mut().allow_backend_req = allow_backend_req;
1249
1250        // Notify listener is ready.
1251        ready_tx.send(()).unwrap();
1252
1253        let mut req_handler = BackendServer::new(server_connection, handler);
1254
1255        // VhostUserFrontend::new()
1256        handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1257        handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1258        handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1259        handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1260        if allow_backend_req {
1261            handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1262        }
1263
1264        // VhostUserFrontend::read_config()
1265        handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1266
1267        // VhostUserFrontend::activate()
1268        handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1269        handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1270        for _ in 0..QUEUES_NUM {
1271            handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1272            handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1273            handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1274            handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1275            handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1276            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1277        }
1278
1279        // VhostUserFrontend::reset()
1280        for _ in 0..QUEUES_NUM {
1281            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1282            handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1283        }
1284
1285        // VhostUserFrontend::activate()
1286        handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1287        handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1288        for _ in 0..QUEUES_NUM {
1289            handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1290            handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1291            handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1292            handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1293            handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1294            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1295        }
1296
1297        if allow_backend_req {
1298            // Make sure the connection still works even after reset/reactivate.
1299            req_handler
1300                .as_ref()
1301                .as_ref()
1302                .backend_conn
1303                .as_ref()
1304                .expect("backend_conn missing")
1305                .send_config_changed()
1306                .expect("send_config_changed failed");
1307        }
1308
1309        // VhostUserFrontend::virtio_sleep()
1310        for _ in 0..QUEUES_NUM {
1311            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1312            handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1313        }
1314
1315        // VhostUserFrontend::virtio_snapshot()
1316        handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1317        handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1318        // VhostUserFrontend::virtio_restore()
1319        handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1320        handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1321
1322        // VhostUserFrontend::virtio_wake()
1323        handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1324        for _ in 0..QUEUES_NUM {
1325            handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1326            handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1327            handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1328            handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1329            handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1330            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1331        }
1332
1333        if allow_backend_req {
1334            // Make sure the connection still works even after sleep/wake.
1335            req_handler
1336                .as_ref()
1337                .as_ref()
1338                .backend_conn
1339                .as_ref()
1340                .expect("backend_conn missing")
1341                .send_config_changed()
1342                .expect("send_config_changed failed");
1343        }
1344
1345        // Ask the client to shutdown, then wait to it to finish.
1346        shutdown_tx.send(()).unwrap();
1347        dev_bar.wait();
1348
1349        // Verify recv_header fails with `ClientExit` after the client has disconnected.
1350        match req_handler.recv_header() {
1351            Err(VhostError::ClientExit) => (),
1352            r => panic!("expected Err(ClientExit) but got {r:?}"),
1353        }
1354    }
1355
1356    fn handle_request<S: vmm_vhost::Backend>(
1357        handler: &mut BackendServer<S>,
1358        expected_message_type: FrontendReq,
1359    ) -> Result<(), VhostError> {
1360        let (hdr, files) = handler.recv_header()?;
1361        assert_eq!(hdr.get_code(), Ok(expected_message_type));
1362        handler.process_message(hdr, files)
1363    }
1364}