1pub(super) mod sys;
48
49use std::collections::BTreeMap;
50use std::convert::From;
51use std::fs::File;
52use std::io::BufReader;
53use std::io::Write;
54use std::num::Wrapping;
55#[cfg(any(target_os = "android", target_os = "linux"))]
56use std::os::unix::io::AsRawFd;
57use std::sync::Arc;
58
59use anyhow::bail;
60use anyhow::Context;
61#[cfg(any(target_os = "android", target_os = "linux"))]
62use base::clear_fd_flags;
63use base::error;
64use base::trace;
65use base::warn;
66use base::Event;
67use base::Protection;
68use base::SafeDescriptor;
69use base::SharedMemory;
70use base::WorkerThread;
71use cros_async::TaskHandle;
72use hypervisor::MemCacheType;
73use serde::Deserialize;
74use serde::Serialize;
75use snapshot::AnySnapshot;
76use sync::Mutex;
77use thiserror::Error as ThisError;
78use vm_control::VmMemorySource;
79use vm_memory::GuestAddress;
80use vm_memory::GuestMemory;
81use vm_memory::MemoryRegion;
82use vmm_vhost::message::VhostUserConfigFlags;
83use vmm_vhost::message::VhostUserExternalMapMsg;
84use vmm_vhost::message::VhostUserGpuMapMsg;
85use vmm_vhost::message::VhostUserInflight;
86use vmm_vhost::message::VhostUserMMap;
87use vmm_vhost::message::VhostUserMMapFlags;
88use vmm_vhost::message::VhostUserMemoryRegion;
89use vmm_vhost::message::VhostUserMigrationPhase;
90use vmm_vhost::message::VhostUserProtocolFeatures;
91use vmm_vhost::message::VhostUserSingleMemoryRegion;
92use vmm_vhost::message::VhostUserTransferDirection;
93use vmm_vhost::message::VhostUserVringAddrFlags;
94use vmm_vhost::message::VhostUserVringState;
95use vmm_vhost::Connection;
96use vmm_vhost::Error as VhostError;
97use vmm_vhost::Frontend;
98use vmm_vhost::FrontendClient;
99use vmm_vhost::Result as VhostResult;
100use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
101
102use crate::virtio::Interrupt;
103use crate::virtio::Queue;
104use crate::virtio::QueueConfig;
105use crate::virtio::SharedMemoryMapper;
106use crate::virtio::SharedMemoryRegion;
107
108#[derive(Default)]
111pub struct MappingInfo {
112 pub vmm_addr: u64,
113 pub guest_phys: u64,
114 pub size: u64,
115}
116
117pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
118 for map in maps {
119 if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
120 return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
121 }
122 }
123 Err(VhostError::InvalidMessage)
124}
125
126pub trait VhostUserDevice {
131 fn max_queue_num(&self) -> usize;
133
134 fn features(&self) -> u64;
136
137 fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {
145 Ok(())
146 }
147
148 fn protocol_features(&self) -> VhostUserProtocolFeatures;
150
151 fn read_config(&self, offset: u64, dst: &mut [u8]);
153
154 fn write_config(&self, _offset: u64, _data: &[u8]) {}
156
157 fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;
161
162 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
166
167 fn reset(&mut self);
169
170 fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
172 None
173 }
174
175 fn set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection) {}
181
182 fn enter_suspended_state(&mut self) -> anyhow::Result<()>;
195
196 fn snapshot(&mut self) -> anyhow::Result<AnySnapshot>;
198
199 fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>;
201
202 fn unmap_guest_memory_on_fork(&self) -> bool {
210 false
211 }
212}
213
214struct Vring {
216 queue: QueueConfig,
218 doorbell: Option<Interrupt>,
219 enabled: bool,
220}
221
222impl Vring {
223 fn new(max_size: u16, features: u64) -> Self {
224 Self {
225 queue: QueueConfig::new(max_size, features),
226 doorbell: None,
227 enabled: false,
228 }
229 }
230
231 fn reset(&mut self) {
232 self.queue.reset();
233 self.doorbell = None;
234 self.enabled = false;
235 }
236}
237
238pub(super) struct VhostUserRegularOps;
240
241impl VhostUserRegularOps {
242 pub fn set_mem_table(
243 contexts: &[VhostUserMemoryRegion],
244 files: Vec<File>,
245 ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
246 if files.len() != contexts.len() {
247 return Err(VhostError::InvalidParam(
248 "number of files & contexts was not equal",
249 ));
250 }
251
252 let mut regions = Vec::with_capacity(files.len());
253 for (region, file) in contexts.iter().zip(files.into_iter()) {
254 let region = MemoryRegion::new_from_shm(
255 region.memory_size,
256 GuestAddress(region.guest_phys_addr),
257 region.mmap_offset,
258 Arc::new(
259 SharedMemory::from_safe_descriptor(
260 SafeDescriptor::from(file),
261 region.memory_size,
262 )
263 .unwrap(),
264 ),
265 )
266 .map_err(|e| {
267 error!("failed to create a memory region: {}", e);
268 VhostError::InvalidOperation
269 })?;
270 regions.push(region);
271 }
272 let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
273 error!("failed to create guest memory: {}", e);
274 VhostError::InvalidOperation
275 })?;
276
277 let vmm_maps = contexts
278 .iter()
279 .map(|region| MappingInfo {
280 vmm_addr: region.user_addr,
281 guest_phys: region.guest_phys_addr,
282 size: region.memory_size,
283 })
284 .collect();
285 Ok((guest_mem, vmm_maps))
286 }
287}
288
289pub struct DeviceRequestHandler<T: VhostUserDevice> {
291 vrings: Vec<Vring>,
292 owned: bool,
293 vmm_maps: Option<Vec<MappingInfo>>,
294 mem: Option<GuestMemory>,
295 acked_features: u64,
296 acked_protocol_features: VhostUserProtocolFeatures,
297 backend: T,
298 backend_req_connection: Option<VhostBackendReqConnection>,
299 device_state_thread: Option<DeviceStateThread>,
301}
302
303enum DeviceStateThread {
304 Save(WorkerThread<Result<(), ciborium::ser::Error<std::io::Error>>>),
305 Load(WorkerThread<Result<DeviceRequestHandlerSnapshot, ciborium::de::Error<std::io::Error>>>),
306}
307
308#[derive(Serialize, Deserialize)]
309pub struct DeviceRequestHandlerSnapshot {
310 acked_features: u64,
311 acked_protocol_features: u64,
312 backend: AnySnapshot,
313}
314
315impl<T: VhostUserDevice> DeviceRequestHandler<T> {
316 pub(crate) fn new(mut backend: T) -> Self {
318 let mut vrings = Vec::with_capacity(backend.max_queue_num());
319 for _ in 0..backend.max_queue_num() {
320 vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
321 }
322
323 backend
326 .enter_suspended_state()
327 .expect("enter_suspended_state failed on device init");
328
329 DeviceRequestHandler {
330 vrings,
331 owned: false,
332 vmm_maps: None,
333 mem: None,
334 acked_features: 0,
335 acked_protocol_features: VhostUserProtocolFeatures::empty(),
336 backend,
337 backend_req_connection: None,
338 device_state_thread: None,
339 }
340 }
341
342 fn all_queues_stopped(&self) -> bool {
346 self.vrings.iter().all(|vring| !vring.queue.ready())
347 }
348}
349
350impl<T: VhostUserDevice> Drop for DeviceRequestHandler<T> {
351 fn drop(&mut self) {
352 for (index, vring) in self.vrings.iter().enumerate() {
353 if vring.queue.ready() {
354 if let Err(e) = self.backend.stop_queue(index) {
355 error!("Failed to stop queue {} during drop: {:#}", index, e);
356 }
357 }
358 }
359 }
360}
361
362impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
363 fn as_ref(&self) -> &T {
364 &self.backend
365 }
366}
367
368impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
369 fn as_mut(&mut self) -> &mut T {
370 &mut self.backend
371 }
372}
373
374impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
375 fn set_owner(&mut self) -> VhostResult<()> {
376 if self.owned {
377 return Err(VhostError::InvalidOperation);
378 }
379 self.owned = true;
380 Ok(())
381 }
382
383 fn reset_owner(&mut self) -> VhostResult<()> {
384 self.owned = false;
385 self.acked_features = 0;
386 self.backend.reset();
387 Ok(())
388 }
389
390 fn get_features(&mut self) -> VhostResult<u64> {
391 let features = self.backend.features();
392 Ok(features)
393 }
394
395 fn set_features(&mut self, features: u64) -> VhostResult<()> {
396 if !self.owned {
397 return Err(VhostError::InvalidOperation);
398 }
399
400 let unexpected_features = features & !self.backend.features();
401 if unexpected_features != 0 {
402 error!("unexpected set_features {:#x}", unexpected_features);
403 return Err(VhostError::InvalidParam("unexpected set_features"));
404 }
405
406 if let Err(e) = self.backend.ack_features(features) {
407 error!("failed to acknowledge features 0x{:x}: {}", features, e);
408 return Err(VhostError::InvalidOperation);
409 }
410
411 self.acked_features |= features;
412
413 let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
421 for v in &mut self.vrings {
422 v.enabled = vring_enabled;
423 }
424
425 Ok(())
426 }
427
428 fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
429 Ok(self.backend.protocol_features() | VhostUserProtocolFeatures::REPLY_ACK)
430 }
431
432 fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
433 let features = match VhostUserProtocolFeatures::from_bits(features) {
434 Some(proto_features) => proto_features,
435 None => {
436 error!(
437 "unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",
438 features
439 );
440 return Err(VhostError::InvalidOperation);
441 }
442 };
443 let supported = self.get_protocol_features()?;
444 self.acked_protocol_features = features & supported;
445 Ok(())
446 }
447
448 fn set_mem_table(
449 &mut self,
450 contexts: &[VhostUserMemoryRegion],
451 files: Vec<File>,
452 ) -> VhostResult<()> {
453 let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
454 if self.backend.unmap_guest_memory_on_fork() {
455 #[cfg(any(target_os = "android", target_os = "linux"))]
456 if let Err(e) = guest_mem.use_dontfork() {
457 error!("failed to set MADV_DONTFORK on guest memory: {e:#}");
458 }
459 #[cfg(not(any(target_os = "android", target_os = "linux")))]
460 error!("unmap_guest_memory_on_fork unsupported; skipping");
461 }
462 self.mem = Some(guest_mem);
463 self.vmm_maps = Some(vmm_maps);
464 Ok(())
465 }
466
467 fn get_queue_num(&mut self) -> VhostResult<u64> {
468 Ok(self.vrings.len() as u64)
469 }
470
471 fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
472 if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
473 return Err(VhostError::InvalidParam(
474 "set_vring_num: invalid index or num",
475 ));
476 }
477 self.vrings[index as usize].queue.set_size(num as u16);
478
479 Ok(())
480 }
481
482 fn set_vring_addr(
483 &mut self,
484 index: u32,
485 _flags: VhostUserVringAddrFlags,
486 descriptor: u64,
487 used: u64,
488 available: u64,
489 _log: u64,
490 ) -> VhostResult<()> {
491 if index as usize >= self.vrings.len() {
492 return Err(VhostError::InvalidParam(
493 "set_vring_addr: index out of range",
494 ));
495 }
496
497 let vmm_maps = self
498 .vmm_maps
499 .as_ref()
500 .ok_or(VhostError::InvalidParam("set_vring_addr: missing vmm_maps"))?;
501 let vring = &mut self.vrings[index as usize];
502 vring
503 .queue
504 .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
505 vring
506 .queue
507 .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
508 vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
509
510 Ok(())
511 }
512
513 fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
514 if index as usize >= self.vrings.len() {
515 return Err(VhostError::InvalidParam(
516 "set_vring_base: index out of range",
517 ));
518 }
519
520 let vring = &mut self.vrings[index as usize];
521 vring.queue.set_next_avail(Wrapping(base as u16));
522 vring.queue.set_next_used(Wrapping(base as u16));
523
524 Ok(())
525 }
526
527 fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
528 let vring = self
529 .vrings
530 .get_mut(index as usize)
531 .ok_or(VhostError::InvalidParam(
532 "get_vring_base: index out of range",
533 ))?;
534
535 let vring_base = if vring.queue.ready() {
540 let queue = match self.backend.stop_queue(index as usize) {
541 Ok(q) => q,
542 Err(e) => {
543 error!("Failed to stop queue in get_vring_base: {:#}", e);
544 return Err(VhostError::BackendInternalError);
545 }
546 };
547
548 trace!("stopped queue {index}");
549 vring.reset();
550
551 if self.all_queues_stopped() {
552 trace!("all queues stopped; entering suspended state");
553 self.backend
554 .enter_suspended_state()
555 .map_err(VhostError::EnterSuspendedState)?;
556 }
557
558 queue.next_avail_to_process()
559 } else {
560 0
561 };
562
563 Ok(VhostUserVringState::new(index, vring_base.into()))
564 }
565
566 fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
567 if index as usize >= self.vrings.len() {
568 return Err(VhostError::InvalidParam(
569 "set_vring_kick: index out of range",
570 ));
571 }
572
573 let vring = &mut self.vrings[index as usize];
574 if vring.queue.ready() {
575 error!("kick fd cannot replaced after queue is started");
576 return Err(VhostError::InvalidOperation);
577 }
578
579 let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_kick"))?;
580
581 #[cfg(any(target_os = "android", target_os = "linux"))]
585 if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
586 error!("failed to remove O_NONBLOCK for kick fd: {}", e);
587 return Err(VhostError::InvalidParam(
588 "could not remove O_NONBLOCK from vring_kick",
589 ));
590 }
591
592 let kick_evt = Event::from(SafeDescriptor::from(file));
593
594 vring.queue.ack_features(self.acked_features);
596 vring.queue.set_ready(true);
597
598 let mem = self
599 .mem
600 .as_ref()
601 .cloned()
602 .ok_or(VhostError::InvalidOperation)?;
603
604 let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
605
606 let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {
607 Ok(queue) => queue,
608 Err(e) => {
609 error!("failed to activate vring: {:#}", e);
610 return Err(VhostError::BackendInternalError);
611 }
612 };
613
614 if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {
615 error!("Failed to start queue {}: {}", index, e);
616 return Err(VhostError::BackendInternalError);
617 }
618 trace!("started queue {index}");
619
620 Ok(())
621 }
622
623 fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
624 if index as usize >= self.vrings.len() {
625 return Err(VhostError::InvalidParam(
626 "set_vring_call: index out of range",
627 ));
628 }
629
630 let backend_req_conn = self.backend_req_connection.clone();
631 let signal_config_change_fn = Box::new(move || {
632 if let Some(frontend) = backend_req_conn.as_ref() {
633 if let Err(e) = frontend.send_config_changed() {
634 error!("Failed to notify config change: {:#}", e);
635 }
636 } else {
637 error!("No Backend request connection found");
638 }
639 });
640
641 let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_call"))?;
642 self.vrings[index as usize].doorbell = Some(Interrupt::new_vhost_user(
643 Event::from(SafeDescriptor::from(file)),
644 signal_config_change_fn,
645 ));
646 Ok(())
647 }
648
649 fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
650 Ok(())
652 }
653
654 fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
655 if index as usize >= self.vrings.len() {
656 return Err(VhostError::InvalidParam(
657 "set_vring_enable: index out of range",
658 ));
659 }
660
661 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
664 return Err(VhostError::InvalidOperation);
665 }
666
667 self.vrings[index as usize].enabled = enable;
671
672 Ok(())
673 }
674
675 fn get_config(
676 &mut self,
677 offset: u32,
678 size: u32,
679 _flags: VhostUserConfigFlags,
680 ) -> VhostResult<Vec<u8>> {
681 let mut data = vec![0; size as usize];
682 self.backend.read_config(u64::from(offset), &mut data);
683 Ok(data)
684 }
685
686 fn set_config(
687 &mut self,
688 offset: u32,
689 buf: &[u8],
690 _flags: VhostUserConfigFlags,
691 ) -> VhostResult<()> {
692 self.backend.write_config(u64::from(offset), buf);
693 Ok(())
694 }
695
696 fn set_backend_req_fd(&mut self, ep: Connection) {
697 let conn = VhostBackendReqConnection::new(
698 FrontendClient::new(
699 ep,
700 self.acked_protocol_features
701 .contains(VhostUserProtocolFeatures::REPLY_ACK),
702 ),
703 self.backend.get_shared_memory_region().map(|r| r.id),
704 );
705
706 if self.backend_req_connection.is_some() {
707 warn!("Backend Request Connection already established. Overwriting");
708 }
709 self.backend_req_connection = Some(conn.clone());
710
711 self.backend.set_backend_req_connection(conn);
712 }
713
714 fn get_inflight_fd(
715 &mut self,
716 _inflight: &VhostUserInflight,
717 ) -> VhostResult<(VhostUserInflight, File)> {
718 unimplemented!("get_inflight_fd");
719 }
720
721 fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
722 unimplemented!("set_inflight_fd");
723 }
724
725 fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
726 Ok(0)
728 }
729
730 fn add_mem_region(
731 &mut self,
732 _region: &VhostUserSingleMemoryRegion,
733 _fd: File,
734 ) -> VhostResult<()> {
735 Ok(())
737 }
738
739 fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
740 Ok(())
742 }
743
744 fn set_device_state_fd(
745 &mut self,
746 transfer_direction: VhostUserTransferDirection,
747 migration_phase: VhostUserMigrationPhase,
748 fd: File,
749 ) -> VhostResult<Option<File>> {
750 if migration_phase != VhostUserMigrationPhase::Stopped {
751 return Err(VhostError::InvalidOperation);
752 }
753 if !self.all_queues_stopped() {
754 return Err(VhostError::InvalidOperation);
755 }
756 if self.device_state_thread.is_some() {
757 error!("must call check_device_state before starting new state transfer");
758 return Err(VhostError::InvalidOperation);
759 }
760 match transfer_direction {
765 VhostUserTransferDirection::Save => {
766 let snapshot = DeviceRequestHandlerSnapshot {
768 acked_features: self.acked_features,
769 acked_protocol_features: self.acked_protocol_features.bits(),
770 backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
771 };
772 self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(
774 "device_state_save",
775 move |_kill_event| -> Result<(), ciborium::ser::Error<std::io::Error>> {
776 let mut w = std::io::BufWriter::new(fd);
777 ciborium::into_writer(&snapshot, &mut w)?;
778 w.flush()?;
779 Ok(())
780 },
781 )));
782 Ok(None)
783 }
784 VhostUserTransferDirection::Load => {
785 self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(
788 "device_state_load",
789 move |_kill_event| ciborium::from_reader(&mut BufReader::new(fd)),
790 )));
791 Ok(None)
792 }
793 }
794 }
795
796 fn check_device_state(&mut self) -> VhostResult<()> {
797 let Some(thread) = self.device_state_thread.take() else {
798 error!("check_device_state: no active state transfer");
799 return Err(VhostError::InvalidOperation);
800 };
801 match thread {
802 DeviceStateThread::Save(worker) => {
803 worker.stop().map_err(|e| {
804 error!("device state save thread failed: {:#}", e);
805 VhostError::BackendInternalError
806 })?;
807 Ok(())
808 }
809 DeviceStateThread::Load(worker) => {
810 let snapshot = worker.stop().map_err(|e| {
811 error!("device state load thread failed: {:#}", e);
812 VhostError::BackendInternalError
813 })?;
814 self.acked_features = snapshot.acked_features;
815 self.acked_protocol_features =
816 VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)
817 .with_context(|| {
818 format!(
819 "unsupported bits in acked_protocol_features: {:#x}",
820 snapshot.acked_protocol_features
821 )
822 })
823 .map_err(VhostError::RestoreError)?;
824 self.backend
825 .restore(snapshot.backend)
826 .map_err(VhostError::RestoreError)?;
827 Ok(())
828 }
829 }
830 }
831
832 fn get_shmem_config(&mut self) -> VhostResult<Vec<SharedMemoryRegion>> {
833 Ok(self
834 .backend
835 .get_shared_memory_region()
836 .into_iter()
837 .collect())
838 }
839}
840
841#[derive(Clone)]
843pub struct VhostBackendReqConnection {
844 shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
845 shmid: Option<u8>,
846}
847
848struct VhostBackendReqConnectionShared {
849 conn: FrontendClient,
850 mapped_regions: BTreeMap<u64 , u64 >,
851}
852
853impl VhostBackendReqConnection {
854 fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
855 Self {
856 shared: Arc::new(Mutex::new(VhostBackendReqConnectionShared {
857 conn,
858 mapped_regions: BTreeMap::new(),
859 })),
860 shmid,
861 }
862 }
863
864 fn send_config_changed(&self) -> anyhow::Result<()> {
866 let mut shared = self.shared.lock();
867 shared
868 .conn
869 .handle_config_change()
870 .context("Could not send config change message")?;
871 Ok(())
872 }
873
874 pub fn shmem_mapper(&self) -> Option<Box<dyn SharedMemoryMapper>> {
876 if let Some(shmid) = self.shmid {
877 Some(Box::new(VhostShmemMapper {
878 shared: self.shared.clone(),
879 shmid,
880 }))
881 } else {
882 None
883 }
884 }
885}
886
887#[derive(Clone)]
888struct VhostShmemMapper {
889 shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
890 shmid: u8,
891}
892
893impl SharedMemoryMapper for VhostShmemMapper {
894 fn add_mapping(
895 &mut self,
896 source: VmMemorySource,
897 offset: u64,
898 prot: Protection,
899 _cache: MemCacheType,
900 ) -> anyhow::Result<()> {
901 let mut shared = self.shared.lock();
902 let size = match source {
903 VmMemorySource::Vulkan {
904 descriptor,
905 handle_type,
906 memory_idx,
907 device_uuid,
908 driver_uuid,
909 size,
910 } => {
911 let msg = VhostUserGpuMapMsg::new(
912 self.shmid,
913 offset,
914 size,
915 memory_idx,
916 handle_type,
917 device_uuid,
918 driver_uuid,
919 );
920 shared
921 .conn
922 .gpu_map(&msg, &descriptor)
923 .context("map GPU memory")?;
924 size
925 }
926 VmMemorySource::ExternalMapping { ptr, size } => {
927 let msg = VhostUserExternalMapMsg::new(self.shmid, offset, size, ptr);
928 shared
929 .conn
930 .external_map(&msg)
931 .context("create external mapping")?;
932 size
933 }
934 source => {
935 let (descriptor, fd_offset, size) = match source {
938 VmMemorySource::Descriptor {
939 descriptor,
940 offset,
941 size,
942 } => (descriptor, offset, size),
943 VmMemorySource::SharedMemory(shmem) => {
944 let size = shmem.size();
945 let descriptor = SafeDescriptor::from(shmem);
946 (descriptor, 0, size)
947 }
948 _ => bail!("unsupported source"),
949 };
950 let mut flags = VhostUserMMapFlags::empty();
951 anyhow::ensure!(prot.allows(&Protection::read()), "mapping must be readable");
952 if prot.allows(&Protection::write()) {
953 flags |= VhostUserMMapFlags::MAP_RW;
954 }
955 let msg = VhostUserMMap {
956 shmid: self.shmid,
957 padding: Default::default(),
958 fd_offset,
959 shm_offset: offset,
960 len: size,
961 flags,
962 };
963 shared
964 .conn
965 .shmem_map(&msg, &descriptor)
966 .context("map shmem")?;
967 size
968 }
969 };
970
971 shared.mapped_regions.insert(offset, size);
972 Ok(())
973 }
974
975 fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
976 let mut shared = self.shared.lock();
977 let size = shared
978 .mapped_regions
979 .remove(&offset)
980 .context("unknown offset")?;
981 let msg = VhostUserMMap {
982 shmid: self.shmid,
983 padding: Default::default(),
984 fd_offset: 0,
985 shm_offset: offset,
986 len: size,
987 flags: VhostUserMMapFlags::empty(),
988 };
989 shared
990 .conn
991 .shmem_unmap(&msg)
992 .context("unmap shmem")
993 .map(|_| ())
994 }
995}
996
997pub(crate) struct WorkerState<T, U> {
998 pub(crate) queue_task: TaskHandle<U>,
999 pub(crate) queue: T,
1000}
1001
1002#[derive(Debug, ThisError)]
1004pub enum Error {
1005 #[error("worker not found when stopping queue")]
1006 WorkerNotFound,
1007}
1008
1009#[cfg(test)]
1010mod tests {
1011 use std::sync::mpsc::channel;
1012 use std::sync::Barrier;
1013
1014 use anyhow::bail;
1015 use base::Event;
1016 use vmm_vhost::BackendServer;
1017 use vmm_vhost::FrontendReq;
1018 use zerocopy::FromBytes;
1019 use zerocopy::FromZeros;
1020 use zerocopy::Immutable;
1021 use zerocopy::IntoBytes;
1022 use zerocopy::KnownLayout;
1023
1024 use super::*;
1025 use crate::virtio::vhost_user_frontend::VhostUserFrontend;
1026 use crate::virtio::DeviceType;
1027 use crate::virtio::VirtioDevice;
1028
1029 #[derive(Clone, Copy, Debug, PartialEq, Eq, FromBytes, Immutable, IntoBytes, KnownLayout)]
1030 #[repr(C, packed(4))]
1031 struct FakeConfig {
1032 x: u32,
1033 y: u64,
1034 }
1035
1036 const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
1037
1038 pub(super) struct FakeBackend {
1039 avail_features: u64,
1040 acked_features: u64,
1041 active_queues: Vec<Option<Queue>>,
1042 allow_backend_req: bool,
1043 backend_conn: Option<VhostBackendReqConnection>,
1044 }
1045
1046 #[derive(Deserialize, Serialize)]
1047 struct FakeBackendSnapshot {
1048 data: Vec<u8>,
1049 }
1050
1051 impl FakeBackend {
1052 const MAX_QUEUE_NUM: usize = 16;
1053
1054 pub(super) fn new() -> Self {
1055 let mut active_queues = Vec::new();
1056 active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1057 Self {
1058 avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
1059 acked_features: 0,
1060 active_queues,
1061 allow_backend_req: false,
1062 backend_conn: None,
1063 }
1064 }
1065 }
1066
1067 impl VhostUserDevice for FakeBackend {
1068 fn max_queue_num(&self) -> usize {
1069 Self::MAX_QUEUE_NUM
1070 }
1071
1072 fn features(&self) -> u64 {
1073 self.avail_features
1074 }
1075
1076 fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1077 let unrequested_features = value & !self.avail_features;
1078 if unrequested_features != 0 {
1079 bail!(
1080 "invalid protocol features are given: 0x{:x}",
1081 unrequested_features
1082 );
1083 }
1084 self.acked_features |= value;
1085 Ok(())
1086 }
1087
1088 fn protocol_features(&self) -> VhostUserProtocolFeatures {
1089 let mut features =
1090 VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;
1091 if self.allow_backend_req {
1092 features |= VhostUserProtocolFeatures::BACKEND_REQ;
1093 }
1094 features
1095 }
1096
1097 fn read_config(&self, offset: u64, dst: &mut [u8]) {
1098 dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1099 }
1100
1101 fn reset(&mut self) {}
1102
1103 fn start_queue(
1104 &mut self,
1105 idx: usize,
1106 queue: Queue,
1107 _mem: GuestMemory,
1108 ) -> anyhow::Result<()> {
1109 self.active_queues[idx] = Some(queue);
1110 Ok(())
1111 }
1112
1113 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1114 Ok(self.active_queues[idx]
1115 .take()
1116 .ok_or(Error::WorkerNotFound)?)
1117 }
1118
1119 fn set_backend_req_connection(&mut self, conn: VhostBackendReqConnection) {
1120 self.backend_conn = Some(conn);
1121 }
1122
1123 fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
1124 Ok(())
1125 }
1126
1127 fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1128 AnySnapshot::to_any(FakeBackendSnapshot {
1129 data: vec![1, 2, 3],
1130 })
1131 .context("failed to serialize snapshot")
1132 }
1133
1134 fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1135 let snapshot: FakeBackendSnapshot =
1136 AnySnapshot::from_any(data).context("failed to deserialize snapshot")?;
1137 assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");
1138 Ok(())
1139 }
1140 }
1141
1142 #[test]
1143 fn test_vhost_user_lifecycle() {
1144 test_vhost_user_lifecycle_parameterized(false);
1145 }
1146
1147 #[test]
1148 #[cfg(not(windows))] fn test_vhost_user_lifecycle_with_backend_req() {
1150 test_vhost_user_lifecycle_parameterized(true);
1151 }
1152
1153 fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {
1154 const QUEUES_NUM: usize = 2;
1155
1156 let (client_connection, server_connection) = vmm_vhost::Connection::pair().unwrap();
1157
1158 let vmm_bar = Arc::new(Barrier::new(2));
1159 let dev_bar = vmm_bar.clone();
1160
1161 let (ready_tx, ready_rx) = channel();
1162 let (shutdown_tx, shutdown_rx) = channel();
1163 let (vm_evt_wrtube, _vm_evt_rdtube) = base::Tube::directional_pair().unwrap();
1164
1165 std::thread::spawn(move || {
1166 ready_rx.recv().unwrap(); let mut vmm_device = VhostUserFrontend::new(
1170 DeviceType::Console,
1171 0,
1172 client_connection,
1173 vm_evt_wrtube,
1174 None,
1175 None,
1176 )
1177 .unwrap();
1178
1179 println!("read_config");
1180 let mut config = FakeConfig::new_zeroed();
1181 vmm_device.read_config(0, config.as_mut_bytes());
1182 assert_eq!(config, FAKE_CONFIG_DATA);
1184
1185 let activate = |vmm_device: &mut VhostUserFrontend| {
1186 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1187 let interrupt = Interrupt::new_for_test_with_msix();
1188
1189 let mut queues = BTreeMap::new();
1190 for idx in 0..QUEUES_NUM {
1191 let mut queue = QueueConfig::new(0x10, 0);
1192 queue.set_ready(true);
1193 let queue = queue
1194 .activate(&mem, Event::new().unwrap(), interrupt.clone())
1195 .expect("QueueConfig::activate");
1196 queues.insert(idx, queue);
1197 }
1198
1199 println!("activate");
1200 vmm_device.activate(mem, interrupt, queues).unwrap();
1201 };
1202
1203 activate(&mut vmm_device);
1204
1205 println!("reset");
1206 let reset_result = vmm_device.reset();
1207 assert!(
1208 reset_result.is_ok(),
1209 "reset failed: {:#}",
1210 reset_result.unwrap_err()
1211 );
1212
1213 activate(&mut vmm_device);
1214
1215 println!("virtio_sleep");
1216 let queues = vmm_device
1217 .virtio_sleep()
1218 .unwrap()
1219 .expect("virtio_sleep unexpectedly returned None");
1220
1221 println!("virtio_snapshot");
1222 let snapshot = vmm_device
1223 .virtio_snapshot()
1224 .expect("virtio_snapshot failed");
1225 println!("virtio_restore");
1226 vmm_device
1227 .virtio_restore(snapshot)
1228 .expect("virtio_restore failed");
1229
1230 println!("virtio_wake");
1231 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1232 let interrupt = Interrupt::new_for_test_with_msix();
1233 vmm_device
1234 .virtio_wake(Some((mem, interrupt, queues)))
1235 .unwrap();
1236
1237 println!("wait for shutdown signal");
1238 shutdown_rx.recv().unwrap();
1239
1240 println!("drop");
1242 drop(vmm_device);
1243
1244 vmm_bar.wait();
1245 });
1246
1247 let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1249 handler.as_mut().allow_backend_req = allow_backend_req;
1250
1251 ready_tx.send(()).unwrap();
1253
1254 let mut req_handler = BackendServer::new(server_connection, handler);
1255
1256 handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1258 handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1259 handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1260 handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1261 if allow_backend_req {
1262 handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1263 }
1264
1265 handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1267
1268 handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1270 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1271 for _ in 0..QUEUES_NUM {
1272 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1273 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1274 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1275 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1276 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1277 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1278 }
1279
1280 for _ in 0..QUEUES_NUM {
1282 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1283 handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1284 }
1285
1286 handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1288 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1289 for _ in 0..QUEUES_NUM {
1290 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1291 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1292 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1293 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1294 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1295 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1296 }
1297
1298 if allow_backend_req {
1299 req_handler
1301 .as_ref()
1302 .as_ref()
1303 .backend_conn
1304 .as_ref()
1305 .expect("backend_conn missing")
1306 .send_config_changed()
1307 .expect("send_config_changed failed");
1308 }
1309
1310 for _ in 0..QUEUES_NUM {
1312 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1313 handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1314 }
1315
1316 handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1318 handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1319 handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1321 handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1322 handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1323
1324 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1326 for _ in 0..QUEUES_NUM {
1327 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1328 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1329 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1330 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1331 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1332 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1333 }
1334
1335 if allow_backend_req {
1336 req_handler
1338 .as_ref()
1339 .as_ref()
1340 .backend_conn
1341 .as_ref()
1342 .expect("backend_conn missing")
1343 .send_config_changed()
1344 .expect("send_config_changed failed");
1345 }
1346
1347 shutdown_tx.send(()).unwrap();
1349 dev_bar.wait();
1350
1351 match req_handler.recv_header() {
1353 Err(VhostError::ClientExit) => (),
1354 r => panic!("expected Err(ClientExit) but got {r:?}"),
1355 }
1356 }
1357
1358 fn handle_request<S: vmm_vhost::Backend>(
1359 handler: &mut BackendServer<S>,
1360 expected_message_type: FrontendReq,
1361 ) -> Result<(), VhostError> {
1362 let (hdr, files) = handler.recv_header()?;
1363 assert_eq!(hdr.get_code(), Ok(expected_message_type));
1364 handler.process_message(hdr, files)
1365 }
1366}