devices/virtio/vhost_user_backend/
handler.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Library for implementing vhost-user device executables.
6//!
7//! This crate provides
8//! * `VhostUserDevice` trait, which is a collection of methods to handle vhost-user requests, and
9//! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10//!
11//! They are expected to be used as follows:
12//!
13//! 1. Define a struct and implement `VhostUserDevice` for it.
14//! 2. Create a `DeviceRequestHandler` with the backend struct.
15//! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16//!
17//! ```ignore
18//! struct MyBackend {
19//!   /* fields */
20//! }
21//!
22//! impl VhostUserDevice for MyBackend {
23//!   /* implement methods */
24//! }
25//!
26//! fn main() -> Result<(), Box<dyn Error>> {
27//!   let backend = MyBackend { /* initialize fields */ };
28//!   let handler = DeviceRequestHandler::new(backend);
29//!   let socket = std::path::Path("/path/to/socket");
30//!   let ex = cros_async::Executor::new()?;
31//!
32//!   if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33//!     eprintln!("error happened: {}", e);
34//!   }
35//!   Ok(())
36//! }
37//! ```
38// Implementation note:
39// This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
40// protocol. DeviceRequestHandler implements the Backend trait from vmm_vhost, and includes some
41// common code for setting up guest memory and managing partially configured vrings.
42// DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request() when it
43// becomes readable. handle_request() reads and parses the message and then calls one of the
44// Backend trait methods. These dispatch back to the supplied VhostUserDevice implementation (this
45// is what our devices implement).
46
47pub(super) mod sys;
48
49use std::collections::BTreeMap;
50use std::convert::From;
51use std::fs::File;
52use std::io::BufReader;
53use std::io::Write;
54use std::num::Wrapping;
55#[cfg(any(target_os = "android", target_os = "linux"))]
56use std::os::unix::io::AsRawFd;
57use std::sync::Arc;
58
59use anyhow::bail;
60use anyhow::Context;
61#[cfg(any(target_os = "android", target_os = "linux"))]
62use base::clear_fd_flags;
63use base::error;
64use base::trace;
65use base::warn;
66use base::Event;
67use base::Protection;
68use base::SafeDescriptor;
69use base::SharedMemory;
70use base::WorkerThread;
71use cros_async::TaskHandle;
72use hypervisor::MemCacheType;
73use serde::Deserialize;
74use serde::Serialize;
75use snapshot::AnySnapshot;
76use sync::Mutex;
77use thiserror::Error as ThisError;
78use vm_control::VmMemorySource;
79use vm_memory::GuestAddress;
80use vm_memory::GuestMemory;
81use vm_memory::MemoryRegion;
82use vmm_vhost::message::VhostSharedMemoryRegion;
83use vmm_vhost::message::VhostUserConfigFlags;
84use vmm_vhost::message::VhostUserExternalMapMsg;
85use vmm_vhost::message::VhostUserGpuMapMsg;
86use vmm_vhost::message::VhostUserInflight;
87use vmm_vhost::message::VhostUserMemoryRegion;
88use vmm_vhost::message::VhostUserMigrationPhase;
89use vmm_vhost::message::VhostUserProtocolFeatures;
90use vmm_vhost::message::VhostUserShmemMapMsg;
91use vmm_vhost::message::VhostUserShmemMapMsgFlags;
92use vmm_vhost::message::VhostUserShmemUnmapMsg;
93use vmm_vhost::message::VhostUserSingleMemoryRegion;
94use vmm_vhost::message::VhostUserTransferDirection;
95use vmm_vhost::message::VhostUserVringAddrFlags;
96use vmm_vhost::message::VhostUserVringState;
97use vmm_vhost::BackendReq;
98use vmm_vhost::Connection;
99use vmm_vhost::Error as VhostError;
100use vmm_vhost::Frontend;
101use vmm_vhost::FrontendClient;
102use vmm_vhost::Result as VhostResult;
103use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
104
105use crate::virtio::Interrupt;
106use crate::virtio::Queue;
107use crate::virtio::QueueConfig;
108use crate::virtio::SharedMemoryMapper;
109use crate::virtio::SharedMemoryRegion;
110
111/// Keeps a mapping from the vmm's virtual addresses to guest addresses.
112/// used to translate messages from the vmm to guest offsets.
113#[derive(Default)]
114pub struct MappingInfo {
115    pub vmm_addr: u64,
116    pub guest_phys: u64,
117    pub size: u64,
118}
119
120pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
121    for map in maps {
122        if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
123            return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
124        }
125    }
126    Err(VhostError::InvalidMessage)
127}
128
129/// Trait for vhost-user devices. Analogous to the `VirtioDevice` trait.
130///
131/// In contrast with [[vmm_vhost::Backend]], which closely matches the vhost-user spec, this trait
132/// is designed to follow crosvm conventions for implementing devices.
133pub trait VhostUserDevice {
134    /// The maximum number of queues that this backend can manage.
135    fn max_queue_num(&self) -> usize;
136
137    /// The set of feature bits that this backend supports.
138    fn features(&self) -> u64;
139
140    /// Acknowledges that this set of features should be enabled.
141    ///
142    /// Implementations only need to handle device-specific feature bits; the `DeviceRequestHandler`
143    /// framework will manage generic vhost and vring features.
144    ///
145    /// `DeviceRequestHandler` checks for valid features before calling this function, so the
146    /// features in `value` will always be a subset of those advertised by `features()`.
147    fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {
148        Ok(())
149    }
150
151    /// The set of protocol feature bits that this backend supports.
152    fn protocol_features(&self) -> VhostUserProtocolFeatures;
153
154    /// Reads this device configuration space at `offset`.
155    fn read_config(&self, offset: u64, dst: &mut [u8]);
156
157    /// writes `data` to this device's configuration space at `offset`.
158    fn write_config(&self, _offset: u64, _data: &[u8]) {}
159
160    /// Indicates that the backend should start processing requests for virtio queue number `idx`.
161    /// This method must not block the current thread so device backends should either spawn an
162    /// async task or another thread to handle messages from the Queue.
163    fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;
164
165    /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
166    /// This method should return the queue passed to `start_queue` for the corresponding `idx`.
167    /// This method will only be called for queues that were previously started by `start_queue`.
168    fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
169
170    /// Resets the vhost-user backend.
171    fn reset(&mut self);
172
173    /// Returns the device's shared memory region if present.
174    fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
175        None
176    }
177
178    /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message
179    /// handling.
180    ///
181    /// This method will be called when `VhostUserProtocolFeatures::BACKEND_REQ` is
182    /// negotiated.
183    fn set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection) {}
184
185    /// Enter the "suspended device state" described in the vhost-user spec. See the spec for
186    /// requirements.
187    ///
188    /// One reasonably foolproof way to satisfy the requirements is to stop all worker threads.
189    ///
190    /// Called after a `stop_queue` call if there are no running queues left. Also called soon
191    /// after device creation to ensure the device is acting suspended immediately on construction.
192    ///
193    /// The next `start_queue` call implicitly exits the "suspend device state".
194    ///
195    /// * Ok(())    => device successfully suspended
196    /// * Err(_)    => unrecoverable error
197    fn enter_suspended_state(&mut self) -> anyhow::Result<()>;
198
199    /// Snapshot device and return serialized state.
200    fn snapshot(&mut self) -> anyhow::Result<AnySnapshot>;
201
202    /// Restore device state from a snapshot.
203    fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>;
204
205    /// Whether guest memory should be unmapped in forked processes.
206    ///
207    /// This is intended for use in combination with --protected-vm, where the guest memory can be
208    /// dangerous to access. Some systems, e.g. Android, have tools that fork processes and examine
209    /// their memory. This flag effectively hides the guest memory from those tools.
210    ///
211    /// Not compatible with sandboxing.
212    fn unmap_guest_memory_on_fork(&self) -> bool {
213        false
214    }
215}
216
217/// A virtio ring entry.
218struct Vring {
219    // The queue config. This doesn't get mutated by the queue workers.
220    queue: QueueConfig,
221    doorbell: Option<Interrupt>,
222    enabled: bool,
223}
224
225impl Vring {
226    fn new(max_size: u16, features: u64) -> Self {
227        Self {
228            queue: QueueConfig::new(max_size, features),
229            doorbell: None,
230            enabled: false,
231        }
232    }
233
234    fn reset(&mut self) {
235        self.queue.reset();
236        self.doorbell = None;
237        self.enabled = false;
238    }
239}
240
241/// Ops for running vhost-user over a stream (i.e. regular protocol).
242pub(super) struct VhostUserRegularOps;
243
244impl VhostUserRegularOps {
245    pub fn set_mem_table(
246        contexts: &[VhostUserMemoryRegion],
247        files: Vec<File>,
248    ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
249        if files.len() != contexts.len() {
250            return Err(VhostError::InvalidParam(
251                "number of files & contexts was not equal",
252            ));
253        }
254
255        let mut regions = Vec::with_capacity(files.len());
256        for (region, file) in contexts.iter().zip(files.into_iter()) {
257            let region = MemoryRegion::new_from_shm(
258                region.memory_size,
259                GuestAddress(region.guest_phys_addr),
260                region.mmap_offset,
261                Arc::new(
262                    SharedMemory::from_safe_descriptor(
263                        SafeDescriptor::from(file),
264                        region.memory_size,
265                    )
266                    .unwrap(),
267                ),
268            )
269            .map_err(|e| {
270                error!("failed to create a memory region: {}", e);
271                VhostError::InvalidOperation
272            })?;
273            regions.push(region);
274        }
275        let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
276            error!("failed to create guest memory: {}", e);
277            VhostError::InvalidOperation
278        })?;
279
280        let vmm_maps = contexts
281            .iter()
282            .map(|region| MappingInfo {
283                vmm_addr: region.user_addr,
284                guest_phys: region.guest_phys_addr,
285                size: region.memory_size,
286            })
287            .collect();
288        Ok((guest_mem, vmm_maps))
289    }
290}
291
292/// An adapter that implements `vmm_vhost::Backend` for any type implementing `VhostUserDevice`.
293pub struct DeviceRequestHandler<T: VhostUserDevice> {
294    vrings: Vec<Vring>,
295    owned: bool,
296    vmm_maps: Option<Vec<MappingInfo>>,
297    mem: Option<GuestMemory>,
298    acked_features: u64,
299    acked_protocol_features: VhostUserProtocolFeatures,
300    backend: T,
301    backend_req_connection: Option<VhostBackendReqConnection>,
302    // Thread processing active device state FD.
303    device_state_thread: Option<DeviceStateThread>,
304}
305
306enum DeviceStateThread {
307    Save(WorkerThread<Result<(), ciborium::ser::Error<std::io::Error>>>),
308    Load(WorkerThread<Result<DeviceRequestHandlerSnapshot, ciborium::de::Error<std::io::Error>>>),
309}
310
311#[derive(Serialize, Deserialize)]
312pub struct DeviceRequestHandlerSnapshot {
313    acked_features: u64,
314    acked_protocol_features: u64,
315    backend: AnySnapshot,
316}
317
318impl<T: VhostUserDevice> DeviceRequestHandler<T> {
319    /// Creates a vhost-user handler instance for `backend`.
320    pub(crate) fn new(mut backend: T) -> Self {
321        let mut vrings = Vec::with_capacity(backend.max_queue_num());
322        for _ in 0..backend.max_queue_num() {
323            vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
324        }
325
326        // VhostUserDevice implementations must support `enter_suspended_state()`.
327        // Call it on startup to ensure it works and to initialize the device in a suspended state.
328        backend
329            .enter_suspended_state()
330            .expect("enter_suspended_state failed on device init");
331
332        DeviceRequestHandler {
333            vrings,
334            owned: false,
335            vmm_maps: None,
336            mem: None,
337            acked_features: 0,
338            acked_protocol_features: VhostUserProtocolFeatures::empty(),
339            backend,
340            backend_req_connection: None,
341            device_state_thread: None,
342        }
343    }
344
345    /// Check if all queues are stopped.
346    ///
347    /// The device can be suspended with `enter_suspended_state()` only when all queues are stopped.
348    fn all_queues_stopped(&self) -> bool {
349        self.vrings.iter().all(|vring| !vring.queue.ready())
350    }
351}
352
353impl<T: VhostUserDevice> Drop for DeviceRequestHandler<T> {
354    fn drop(&mut self) {
355        for (index, vring) in self.vrings.iter().enumerate() {
356            if vring.queue.ready() {
357                if let Err(e) = self.backend.stop_queue(index) {
358                    error!("Failed to stop queue {} during drop: {:#}", index, e);
359                }
360            }
361        }
362    }
363}
364
365impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
366    fn as_ref(&self) -> &T {
367        &self.backend
368    }
369}
370
371impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
372    fn as_mut(&mut self) -> &mut T {
373        &mut self.backend
374    }
375}
376
377impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
378    fn set_owner(&mut self) -> VhostResult<()> {
379        if self.owned {
380            return Err(VhostError::InvalidOperation);
381        }
382        self.owned = true;
383        Ok(())
384    }
385
386    fn reset_owner(&mut self) -> VhostResult<()> {
387        self.owned = false;
388        self.acked_features = 0;
389        self.backend.reset();
390        Ok(())
391    }
392
393    fn get_features(&mut self) -> VhostResult<u64> {
394        let features = self.backend.features();
395        Ok(features)
396    }
397
398    fn set_features(&mut self, features: u64) -> VhostResult<()> {
399        if !self.owned {
400            return Err(VhostError::InvalidOperation);
401        }
402
403        let unexpected_features = features & !self.backend.features();
404        if unexpected_features != 0 {
405            error!("unexpected set_features {:#x}", unexpected_features);
406            return Err(VhostError::InvalidParam("unexpected set_features"));
407        }
408
409        if let Err(e) = self.backend.ack_features(features) {
410            error!("failed to acknowledge features 0x{:x}: {}", features, e);
411            return Err(VhostError::InvalidOperation);
412        }
413
414        self.acked_features |= features;
415
416        // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
417        // enabled state.
418        // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
419        // disabled state.
420        // Client must not pass data to/from the backend until ring is enabled by
421        // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
422        // VHOST_USER_SET_VRING_ENABLE with parameter 0.
423        let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
424        for v in &mut self.vrings {
425            v.enabled = vring_enabled;
426        }
427
428        Ok(())
429    }
430
431    fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
432        Ok(self.backend.protocol_features())
433    }
434
435    fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
436        let features = match VhostUserProtocolFeatures::from_bits(features) {
437            Some(proto_features) => proto_features,
438            None => {
439                error!(
440                    "unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",
441                    features
442                );
443                return Err(VhostError::InvalidOperation);
444            }
445        };
446        let supported = self.backend.protocol_features();
447        self.acked_protocol_features = features & supported;
448        Ok(())
449    }
450
451    fn set_mem_table(
452        &mut self,
453        contexts: &[VhostUserMemoryRegion],
454        files: Vec<File>,
455    ) -> VhostResult<()> {
456        let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
457        if self.backend.unmap_guest_memory_on_fork() {
458            #[cfg(any(target_os = "android", target_os = "linux"))]
459            if let Err(e) = guest_mem.use_dontfork() {
460                error!("failed to set MADV_DONTFORK on guest memory: {e:#}");
461            }
462            #[cfg(not(any(target_os = "android", target_os = "linux")))]
463            error!("unmap_guest_memory_on_fork unsupported; skipping");
464        }
465        self.mem = Some(guest_mem);
466        self.vmm_maps = Some(vmm_maps);
467        Ok(())
468    }
469
470    fn get_queue_num(&mut self) -> VhostResult<u64> {
471        Ok(self.vrings.len() as u64)
472    }
473
474    fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
475        if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
476            return Err(VhostError::InvalidParam(
477                "set_vring_num: invalid index or num",
478            ));
479        }
480        self.vrings[index as usize].queue.set_size(num as u16);
481
482        Ok(())
483    }
484
485    fn set_vring_addr(
486        &mut self,
487        index: u32,
488        _flags: VhostUserVringAddrFlags,
489        descriptor: u64,
490        used: u64,
491        available: u64,
492        _log: u64,
493    ) -> VhostResult<()> {
494        if index as usize >= self.vrings.len() {
495            return Err(VhostError::InvalidParam(
496                "set_vring_addr: index out of range",
497            ));
498        }
499
500        let vmm_maps = self
501            .vmm_maps
502            .as_ref()
503            .ok_or(VhostError::InvalidParam("set_vring_addr: missing vmm_maps"))?;
504        let vring = &mut self.vrings[index as usize];
505        vring
506            .queue
507            .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
508        vring
509            .queue
510            .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
511        vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
512
513        Ok(())
514    }
515
516    fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
517        if index as usize >= self.vrings.len() {
518            return Err(VhostError::InvalidParam(
519                "set_vring_base: index out of range",
520            ));
521        }
522
523        let vring = &mut self.vrings[index as usize];
524        vring.queue.set_next_avail(Wrapping(base as u16));
525        vring.queue.set_next_used(Wrapping(base as u16));
526
527        Ok(())
528    }
529
530    fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
531        let vring = self
532            .vrings
533            .get_mut(index as usize)
534            .ok_or(VhostError::InvalidParam(
535                "get_vring_base: index out of range",
536            ))?;
537
538        // Quotation from vhost-user spec:
539        // "The back-end must [...] stop ring upon receiving VHOST_USER_GET_VRING_BASE."
540        // We only call `queue.set_ready()` when starting the queue, so if the queue is ready, that
541        // means it is started and should be stopped.
542        let vring_base = if vring.queue.ready() {
543            let queue = match self.backend.stop_queue(index as usize) {
544                Ok(q) => q,
545                Err(e) => {
546                    error!("Failed to stop queue in get_vring_base: {:#}", e);
547                    return Err(VhostError::BackendInternalError);
548                }
549            };
550
551            trace!("stopped queue {index}");
552            vring.reset();
553
554            if self.all_queues_stopped() {
555                trace!("all queues stopped; entering suspended state");
556                self.backend
557                    .enter_suspended_state()
558                    .map_err(VhostError::EnterSuspendedState)?;
559            }
560
561            queue.next_avail_to_process()
562        } else {
563            0
564        };
565
566        Ok(VhostUserVringState::new(index, vring_base.into()))
567    }
568
569    fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
570        if index as usize >= self.vrings.len() {
571            return Err(VhostError::InvalidParam(
572                "set_vring_kick: index out of range",
573            ));
574        }
575
576        let vring = &mut self.vrings[index as usize];
577        if vring.queue.ready() {
578            error!("kick fd cannot replaced after queue is started");
579            return Err(VhostError::InvalidOperation);
580        }
581
582        let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_kick"))?;
583
584        // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
585        // values via `next_val()` later.
586        // This is only required (and can only be done) on Unix platforms.
587        #[cfg(any(target_os = "android", target_os = "linux"))]
588        if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
589            error!("failed to remove O_NONBLOCK for kick fd: {}", e);
590            return Err(VhostError::InvalidParam(
591                "could not remove O_NONBLOCK from vring_kick",
592            ));
593        }
594
595        let kick_evt = Event::from(SafeDescriptor::from(file));
596
597        // Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).
598        vring.queue.ack_features(self.acked_features);
599        vring.queue.set_ready(true);
600
601        let mem = self
602            .mem
603            .as_ref()
604            .cloned()
605            .ok_or(VhostError::InvalidOperation)?;
606
607        let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
608
609        let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {
610            Ok(queue) => queue,
611            Err(e) => {
612                error!("failed to activate vring: {:#}", e);
613                return Err(VhostError::BackendInternalError);
614            }
615        };
616
617        if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {
618            error!("Failed to start queue {}: {}", index, e);
619            return Err(VhostError::BackendInternalError);
620        }
621        trace!("started queue {index}");
622
623        Ok(())
624    }
625
626    fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
627        if index as usize >= self.vrings.len() {
628            return Err(VhostError::InvalidParam(
629                "set_vring_call: index out of range",
630            ));
631        }
632
633        let backend_req_conn = self.backend_req_connection.clone();
634        let signal_config_change_fn = Box::new(move || {
635            if let Some(frontend) = backend_req_conn.as_ref() {
636                if let Err(e) = frontend.send_config_changed() {
637                    error!("Failed to notify config change: {:#}", e);
638                }
639            } else {
640                error!("No Backend request connection found");
641            }
642        });
643
644        let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_call"))?;
645        self.vrings[index as usize].doorbell = Some(Interrupt::new_vhost_user(
646            Event::from(SafeDescriptor::from(file)),
647            signal_config_change_fn,
648        ));
649        Ok(())
650    }
651
652    fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
653        // TODO
654        Ok(())
655    }
656
657    fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
658        if index as usize >= self.vrings.len() {
659            return Err(VhostError::InvalidParam(
660                "set_vring_enable: index out of range",
661            ));
662        }
663
664        // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
665        // has been negotiated.
666        if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
667            return Err(VhostError::InvalidOperation);
668        }
669
670        // Backend must not pass data to/from the ring until ring is enabled by
671        // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
672        // VHOST_USER_SET_VRING_ENABLE with parameter 0.
673        self.vrings[index as usize].enabled = enable;
674
675        Ok(())
676    }
677
678    fn get_config(
679        &mut self,
680        offset: u32,
681        size: u32,
682        _flags: VhostUserConfigFlags,
683    ) -> VhostResult<Vec<u8>> {
684        let mut data = vec![0; size as usize];
685        self.backend.read_config(u64::from(offset), &mut data);
686        Ok(data)
687    }
688
689    fn set_config(
690        &mut self,
691        offset: u32,
692        buf: &[u8],
693        _flags: VhostUserConfigFlags,
694    ) -> VhostResult<()> {
695        self.backend.write_config(u64::from(offset), buf);
696        Ok(())
697    }
698
699    fn set_backend_req_fd(&mut self, ep: Connection<BackendReq>) {
700        let conn = VhostBackendReqConnection::new(
701            FrontendClient::new(ep),
702            self.backend.get_shared_memory_region().map(|r| r.id),
703        );
704
705        if self.backend_req_connection.is_some() {
706            warn!("Backend Request Connection already established. Overwriting");
707        }
708        self.backend_req_connection = Some(conn.clone());
709
710        self.backend.set_backend_req_connection(conn);
711    }
712
713    fn get_inflight_fd(
714        &mut self,
715        _inflight: &VhostUserInflight,
716    ) -> VhostResult<(VhostUserInflight, File)> {
717        unimplemented!("get_inflight_fd");
718    }
719
720    fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
721        unimplemented!("set_inflight_fd");
722    }
723
724    fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
725        //TODO
726        Ok(0)
727    }
728
729    fn add_mem_region(
730        &mut self,
731        _region: &VhostUserSingleMemoryRegion,
732        _fd: File,
733    ) -> VhostResult<()> {
734        //TODO
735        Ok(())
736    }
737
738    fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
739        //TODO
740        Ok(())
741    }
742
743    fn set_device_state_fd(
744        &mut self,
745        transfer_direction: VhostUserTransferDirection,
746        migration_phase: VhostUserMigrationPhase,
747        fd: File,
748    ) -> VhostResult<Option<File>> {
749        if migration_phase != VhostUserMigrationPhase::Stopped {
750            return Err(VhostError::InvalidOperation);
751        }
752        if !self.all_queues_stopped() {
753            return Err(VhostError::InvalidOperation);
754        }
755        if self.device_state_thread.is_some() {
756            error!("must call check_device_state before starting new state transfer");
757            return Err(VhostError::InvalidOperation);
758        }
759        // `set_device_state_fd` is designed to allow snapshot/restore concurrently with other
760        // methods, but, for simplicitly, we do those operations inline and only spawn a thread to
761        // handle the serialization and data transfer (the latter which seems necessary to
762        // implement the API correctly without, e.g., deadlocking because a pipe is full).
763        match transfer_direction {
764            VhostUserTransferDirection::Save => {
765                // Snapshot the state.
766                let snapshot = DeviceRequestHandlerSnapshot {
767                    acked_features: self.acked_features,
768                    acked_protocol_features: self.acked_protocol_features.bits(),
769                    backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
770                };
771                // Spawn thread to write the serialized bytes.
772                self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(
773                    "device_state_save",
774                    move |_kill_event| -> Result<(), ciborium::ser::Error<std::io::Error>> {
775                        let mut w = std::io::BufWriter::new(fd);
776                        ciborium::into_writer(&snapshot, &mut w)?;
777                        w.flush()?;
778                        Ok(())
779                    },
780                )));
781                Ok(None)
782            }
783            VhostUserTransferDirection::Load => {
784                // Spawn a thread to read the bytes and deserialize. Restore will happen in
785                // `check_device_state`.
786                self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(
787                    "device_state_load",
788                    move |_kill_event| ciborium::from_reader(&mut BufReader::new(fd)),
789                )));
790                Ok(None)
791            }
792        }
793    }
794
795    fn check_device_state(&mut self) -> VhostResult<()> {
796        let Some(thread) = self.device_state_thread.take() else {
797            error!("check_device_state: no active state transfer");
798            return Err(VhostError::InvalidOperation);
799        };
800        match thread {
801            DeviceStateThread::Save(worker) => {
802                worker.stop().map_err(|e| {
803                    error!("device state save thread failed: {:#}", e);
804                    VhostError::BackendInternalError
805                })?;
806                Ok(())
807            }
808            DeviceStateThread::Load(worker) => {
809                let snapshot = worker.stop().map_err(|e| {
810                    error!("device state load thread failed: {:#}", e);
811                    VhostError::BackendInternalError
812                })?;
813                self.acked_features = snapshot.acked_features;
814                self.acked_protocol_features =
815                    VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)
816                        .with_context(|| {
817                            format!(
818                                "unsupported bits in acked_protocol_features: {:#x}",
819                                snapshot.acked_protocol_features
820                            )
821                        })
822                        .map_err(VhostError::RestoreError)?;
823                self.backend
824                    .restore(snapshot.backend)
825                    .map_err(VhostError::RestoreError)?;
826                Ok(())
827            }
828        }
829    }
830
831    fn get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>> {
832        Ok(if let Some(r) = self.backend.get_shared_memory_region() {
833            vec![VhostSharedMemoryRegion::new(r.id, r.length)]
834        } else {
835            Vec::new()
836        })
837    }
838}
839
840/// Keeps track of Vhost user backend request connection.
841#[derive(Clone)]
842pub struct VhostBackendReqConnection {
843    shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
844    shmid: Option<u8>,
845}
846
847struct VhostBackendReqConnectionShared {
848    conn: FrontendClient,
849    mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,
850}
851
852impl VhostBackendReqConnection {
853    fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
854        Self {
855            shared: Arc::new(Mutex::new(VhostBackendReqConnectionShared {
856                conn,
857                mapped_regions: BTreeMap::new(),
858            })),
859            shmid,
860        }
861    }
862
863    /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend
864    fn send_config_changed(&self) -> anyhow::Result<()> {
865        let mut shared = self.shared.lock();
866        shared
867            .conn
868            .handle_config_change()
869            .context("Could not send config change message")?;
870        Ok(())
871    }
872
873    /// Create a SharedMemoryMapper trait object using this backend request connection.
874    pub fn shmem_mapper(&self) -> Option<Box<dyn SharedMemoryMapper>> {
875        if let Some(shmid) = self.shmid {
876            Some(Box::new(VhostShmemMapper {
877                shared: self.shared.clone(),
878                shmid,
879            }))
880        } else {
881            None
882        }
883    }
884}
885
886#[derive(Clone)]
887struct VhostShmemMapper {
888    shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
889    shmid: u8,
890}
891
892impl SharedMemoryMapper for VhostShmemMapper {
893    fn add_mapping(
894        &mut self,
895        source: VmMemorySource,
896        offset: u64,
897        prot: Protection,
898        _cache: MemCacheType,
899    ) -> anyhow::Result<()> {
900        let mut shared = self.shared.lock();
901        let size = match source {
902            VmMemorySource::Vulkan {
903                descriptor,
904                handle_type,
905                memory_idx,
906                device_uuid,
907                driver_uuid,
908                size,
909            } => {
910                let msg = VhostUserGpuMapMsg::new(
911                    self.shmid,
912                    offset,
913                    size,
914                    memory_idx,
915                    handle_type,
916                    device_uuid,
917                    driver_uuid,
918                );
919                shared
920                    .conn
921                    .gpu_map(&msg, &descriptor)
922                    .context("map GPU memory")?;
923                size
924            }
925            VmMemorySource::ExternalMapping { ptr, size } => {
926                let msg = VhostUserExternalMapMsg::new(self.shmid, offset, size, ptr);
927                shared
928                    .conn
929                    .external_map(&msg)
930                    .context("create external mapping")?;
931                size
932            }
933            source => {
934                // The last two sources use the same VhostUserShmemMapMsg, continue matching here
935                // on the aliased `source` above.
936                let (descriptor, fd_offset, size) = match source {
937                    VmMemorySource::Descriptor {
938                        descriptor,
939                        offset,
940                        size,
941                    } => (descriptor, offset, size),
942                    VmMemorySource::SharedMemory(shmem) => {
943                        let size = shmem.size();
944                        let descriptor = SafeDescriptor::from(shmem);
945                        (descriptor, 0, size)
946                    }
947                    _ => bail!("unsupported source"),
948                };
949                let flags = VhostUserShmemMapMsgFlags::from(prot);
950                let msg = VhostUserShmemMapMsg::new(self.shmid, offset, fd_offset, size, flags);
951                shared
952                    .conn
953                    .shmem_map(&msg, &descriptor)
954                    .context("map shmem")?;
955                size
956            }
957        };
958
959        shared.mapped_regions.insert(offset, size);
960        Ok(())
961    }
962
963    fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
964        let mut shared = self.shared.lock();
965        let size = shared
966            .mapped_regions
967            .remove(&offset)
968            .context("unknown offset")?;
969        let msg = VhostUserShmemUnmapMsg::new(self.shmid, offset, size);
970        shared
971            .conn
972            .shmem_unmap(&msg)
973            .context("unmap shmem")
974            .map(|_| ())
975    }
976}
977
978pub(crate) struct WorkerState<T, U> {
979    pub(crate) queue_task: TaskHandle<U>,
980    pub(crate) queue: T,
981}
982
983/// Errors for device operations
984#[derive(Debug, ThisError)]
985pub enum Error {
986    #[error("worker not found when stopping queue")]
987    WorkerNotFound,
988}
989
990#[cfg(test)]
991mod tests {
992    use std::sync::mpsc::channel;
993    use std::sync::Barrier;
994
995    use anyhow::bail;
996    use base::Event;
997    use vmm_vhost::BackendServer;
998    use vmm_vhost::FrontendReq;
999    use zerocopy::FromBytes;
1000    use zerocopy::FromZeros;
1001    use zerocopy::Immutable;
1002    use zerocopy::IntoBytes;
1003    use zerocopy::KnownLayout;
1004
1005    use super::*;
1006    use crate::virtio::vhost_user_frontend::VhostUserFrontend;
1007    use crate::virtio::DeviceType;
1008    use crate::virtio::VirtioDevice;
1009
1010    #[derive(Clone, Copy, Debug, PartialEq, Eq, FromBytes, Immutable, IntoBytes, KnownLayout)]
1011    #[repr(C, packed(4))]
1012    struct FakeConfig {
1013        x: u32,
1014        y: u64,
1015    }
1016
1017    const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
1018
1019    pub(super) struct FakeBackend {
1020        avail_features: u64,
1021        acked_features: u64,
1022        active_queues: Vec<Option<Queue>>,
1023        allow_backend_req: bool,
1024        backend_conn: Option<VhostBackendReqConnection>,
1025    }
1026
1027    #[derive(Deserialize, Serialize)]
1028    struct FakeBackendSnapshot {
1029        data: Vec<u8>,
1030    }
1031
1032    impl FakeBackend {
1033        const MAX_QUEUE_NUM: usize = 16;
1034
1035        pub(super) fn new() -> Self {
1036            let mut active_queues = Vec::new();
1037            active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1038            Self {
1039                avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
1040                acked_features: 0,
1041                active_queues,
1042                allow_backend_req: false,
1043                backend_conn: None,
1044            }
1045        }
1046    }
1047
1048    impl VhostUserDevice for FakeBackend {
1049        fn max_queue_num(&self) -> usize {
1050            Self::MAX_QUEUE_NUM
1051        }
1052
1053        fn features(&self) -> u64 {
1054            self.avail_features
1055        }
1056
1057        fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1058            let unrequested_features = value & !self.avail_features;
1059            if unrequested_features != 0 {
1060                bail!(
1061                    "invalid protocol features are given: 0x{:x}",
1062                    unrequested_features
1063                );
1064            }
1065            self.acked_features |= value;
1066            Ok(())
1067        }
1068
1069        fn protocol_features(&self) -> VhostUserProtocolFeatures {
1070            let mut features =
1071                VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;
1072            if self.allow_backend_req {
1073                features |= VhostUserProtocolFeatures::BACKEND_REQ;
1074            }
1075            features
1076        }
1077
1078        fn read_config(&self, offset: u64, dst: &mut [u8]) {
1079            dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1080        }
1081
1082        fn reset(&mut self) {}
1083
1084        fn start_queue(
1085            &mut self,
1086            idx: usize,
1087            queue: Queue,
1088            _mem: GuestMemory,
1089        ) -> anyhow::Result<()> {
1090            self.active_queues[idx] = Some(queue);
1091            Ok(())
1092        }
1093
1094        fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1095            Ok(self.active_queues[idx]
1096                .take()
1097                .ok_or(Error::WorkerNotFound)?)
1098        }
1099
1100        fn set_backend_req_connection(&mut self, conn: VhostBackendReqConnection) {
1101            self.backend_conn = Some(conn);
1102        }
1103
1104        fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
1105            Ok(())
1106        }
1107
1108        fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1109            AnySnapshot::to_any(FakeBackendSnapshot {
1110                data: vec![1, 2, 3],
1111            })
1112            .context("failed to serialize snapshot")
1113        }
1114
1115        fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1116            let snapshot: FakeBackendSnapshot =
1117                AnySnapshot::from_any(data).context("failed to deserialize snapshot")?;
1118            assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");
1119            Ok(())
1120        }
1121    }
1122
1123    #[test]
1124    fn test_vhost_user_lifecycle() {
1125        test_vhost_user_lifecycle_parameterized(false);
1126    }
1127
1128    #[test]
1129    #[cfg(not(windows))] // Windows requries more complex connection setup.
1130    fn test_vhost_user_lifecycle_with_backend_req() {
1131        test_vhost_user_lifecycle_parameterized(true);
1132    }
1133
1134    fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {
1135        const QUEUES_NUM: usize = 2;
1136
1137        let (client_connection, server_connection) =
1138            vmm_vhost::Connection::<FrontendReq>::pair().unwrap();
1139
1140        let vmm_bar = Arc::new(Barrier::new(2));
1141        let dev_bar = vmm_bar.clone();
1142
1143        let (ready_tx, ready_rx) = channel();
1144        let (shutdown_tx, shutdown_rx) = channel();
1145        let (vm_evt_wrtube, _vm_evt_rdtube) = base::Tube::directional_pair().unwrap();
1146
1147        std::thread::spawn(move || {
1148            // VMM side
1149            ready_rx.recv().unwrap(); // Ensure the device is ready.
1150
1151            let mut vmm_device = VhostUserFrontend::new(
1152                DeviceType::Console,
1153                0,
1154                client_connection,
1155                vm_evt_wrtube,
1156                None,
1157                None,
1158            )
1159            .unwrap();
1160
1161            println!("read_config");
1162            let mut config = FakeConfig::new_zeroed();
1163            vmm_device.read_config(0, config.as_mut_bytes());
1164            // Check if the obtained config data is correct.
1165            assert_eq!(config, FAKE_CONFIG_DATA);
1166
1167            let activate = |vmm_device: &mut VhostUserFrontend| {
1168                let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1169                let interrupt = Interrupt::new_for_test_with_msix();
1170
1171                let mut queues = BTreeMap::new();
1172                for idx in 0..QUEUES_NUM {
1173                    let mut queue = QueueConfig::new(0x10, 0);
1174                    queue.set_ready(true);
1175                    let queue = queue
1176                        .activate(&mem, Event::new().unwrap(), interrupt.clone())
1177                        .expect("QueueConfig::activate");
1178                    queues.insert(idx, queue);
1179                }
1180
1181                println!("activate");
1182                vmm_device.activate(mem, interrupt, queues).unwrap();
1183            };
1184
1185            activate(&mut vmm_device);
1186
1187            println!("reset");
1188            let reset_result = vmm_device.reset();
1189            assert!(
1190                reset_result.is_ok(),
1191                "reset failed: {:#}",
1192                reset_result.unwrap_err()
1193            );
1194
1195            activate(&mut vmm_device);
1196
1197            println!("virtio_sleep");
1198            let queues = vmm_device
1199                .virtio_sleep()
1200                .unwrap()
1201                .expect("virtio_sleep unexpectedly returned None");
1202
1203            println!("virtio_snapshot");
1204            let snapshot = vmm_device
1205                .virtio_snapshot()
1206                .expect("virtio_snapshot failed");
1207            println!("virtio_restore");
1208            vmm_device
1209                .virtio_restore(snapshot)
1210                .expect("virtio_restore failed");
1211
1212            println!("virtio_wake");
1213            let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1214            let interrupt = Interrupt::new_for_test_with_msix();
1215            vmm_device
1216                .virtio_wake(Some((mem, interrupt, queues)))
1217                .unwrap();
1218
1219            println!("wait for shutdown signal");
1220            shutdown_rx.recv().unwrap();
1221
1222            // The VMM side is supposed to stop before the device side.
1223            println!("drop");
1224            drop(vmm_device);
1225
1226            vmm_bar.wait();
1227        });
1228
1229        // Device side
1230        let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1231        handler.as_mut().allow_backend_req = allow_backend_req;
1232
1233        // Notify listener is ready.
1234        ready_tx.send(()).unwrap();
1235
1236        let mut req_handler = BackendServer::new(server_connection, handler);
1237
1238        // VhostUserFrontend::new()
1239        handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1240        handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1241        handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1242        handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1243        if allow_backend_req {
1244            handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1245        }
1246
1247        // VhostUserFrontend::read_config()
1248        handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1249
1250        // VhostUserFrontend::activate()
1251        handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1252        handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1253        for _ in 0..QUEUES_NUM {
1254            handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1255            handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1256            handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1257            handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1258            handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1259            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1260        }
1261
1262        // VhostUserFrontend::reset()
1263        for _ in 0..QUEUES_NUM {
1264            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1265            handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1266        }
1267
1268        // VhostUserFrontend::activate()
1269        handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1270        handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1271        for _ in 0..QUEUES_NUM {
1272            handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1273            handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1274            handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1275            handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1276            handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1277            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1278        }
1279
1280        if allow_backend_req {
1281            // Make sure the connection still works even after reset/reactivate.
1282            req_handler
1283                .as_ref()
1284                .as_ref()
1285                .backend_conn
1286                .as_ref()
1287                .expect("backend_conn missing")
1288                .send_config_changed()
1289                .expect("send_config_changed failed");
1290        }
1291
1292        // VhostUserFrontend::virtio_sleep()
1293        for _ in 0..QUEUES_NUM {
1294            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1295            handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1296        }
1297
1298        // VhostUserFrontend::virtio_snapshot()
1299        handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1300        handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1301        // VhostUserFrontend::virtio_restore()
1302        handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1303        handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1304
1305        // VhostUserFrontend::virtio_wake()
1306        handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1307        for _ in 0..QUEUES_NUM {
1308            handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1309            handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1310            handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1311            handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1312            handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1313            handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1314        }
1315
1316        if allow_backend_req {
1317            // Make sure the connection still works even after sleep/wake.
1318            req_handler
1319                .as_ref()
1320                .as_ref()
1321                .backend_conn
1322                .as_ref()
1323                .expect("backend_conn missing")
1324                .send_config_changed()
1325                .expect("send_config_changed failed");
1326        }
1327
1328        // Ask the client to shutdown, then wait to it to finish.
1329        shutdown_tx.send(()).unwrap();
1330        dev_bar.wait();
1331
1332        // Verify recv_header fails with `ClientExit` after the client has disconnected.
1333        match req_handler.recv_header() {
1334            Err(VhostError::ClientExit) => (),
1335            r => panic!("expected Err(ClientExit) but got {r:?}"),
1336        }
1337    }
1338
1339    fn handle_request<S: vmm_vhost::Backend>(
1340        handler: &mut BackendServer<S>,
1341        expected_message_type: FrontendReq,
1342    ) -> Result<(), VhostError> {
1343        let (hdr, files) = handler.recv_header()?;
1344        assert_eq!(hdr.get_code(), Ok(expected_message_type));
1345        handler.process_message(hdr, files)
1346    }
1347}