1pub(super) mod sys;
48
49use std::collections::BTreeMap;
50use std::convert::From;
51use std::fs::File;
52use std::io::BufReader;
53use std::io::Write;
54use std::num::Wrapping;
55#[cfg(any(target_os = "android", target_os = "linux"))]
56use std::os::unix::io::AsRawFd;
57use std::sync::Arc;
58
59use anyhow::bail;
60use anyhow::Context;
61#[cfg(any(target_os = "android", target_os = "linux"))]
62use base::clear_fd_flags;
63use base::error;
64use base::trace;
65use base::warn;
66use base::Event;
67use base::Protection;
68use base::SafeDescriptor;
69use base::SharedMemory;
70use base::WorkerThread;
71use cros_async::TaskHandle;
72use hypervisor::MemCacheType;
73use serde::Deserialize;
74use serde::Serialize;
75use snapshot::AnySnapshot;
76use sync::Mutex;
77use thiserror::Error as ThisError;
78use vm_control::VmMemorySource;
79use vm_memory::GuestAddress;
80use vm_memory::GuestMemory;
81use vm_memory::MemoryRegion;
82use vmm_vhost::message::VhostSharedMemoryRegion;
83use vmm_vhost::message::VhostUserConfigFlags;
84use vmm_vhost::message::VhostUserExternalMapMsg;
85use vmm_vhost::message::VhostUserGpuMapMsg;
86use vmm_vhost::message::VhostUserInflight;
87use vmm_vhost::message::VhostUserMemoryRegion;
88use vmm_vhost::message::VhostUserMigrationPhase;
89use vmm_vhost::message::VhostUserProtocolFeatures;
90use vmm_vhost::message::VhostUserShmemMapMsg;
91use vmm_vhost::message::VhostUserShmemMapMsgFlags;
92use vmm_vhost::message::VhostUserShmemUnmapMsg;
93use vmm_vhost::message::VhostUserSingleMemoryRegion;
94use vmm_vhost::message::VhostUserTransferDirection;
95use vmm_vhost::message::VhostUserVringAddrFlags;
96use vmm_vhost::message::VhostUserVringState;
97use vmm_vhost::BackendReq;
98use vmm_vhost::Connection;
99use vmm_vhost::Error as VhostError;
100use vmm_vhost::Frontend;
101use vmm_vhost::FrontendClient;
102use vmm_vhost::Result as VhostResult;
103use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
104
105use crate::virtio::Interrupt;
106use crate::virtio::Queue;
107use crate::virtio::QueueConfig;
108use crate::virtio::SharedMemoryMapper;
109use crate::virtio::SharedMemoryRegion;
110
111#[derive(Default)]
114pub struct MappingInfo {
115 pub vmm_addr: u64,
116 pub guest_phys: u64,
117 pub size: u64,
118}
119
120pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
121 for map in maps {
122 if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
123 return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
124 }
125 }
126 Err(VhostError::InvalidMessage)
127}
128
129pub trait VhostUserDevice {
134 fn max_queue_num(&self) -> usize;
136
137 fn features(&self) -> u64;
139
140 fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {
148 Ok(())
149 }
150
151 fn protocol_features(&self) -> VhostUserProtocolFeatures;
153
154 fn read_config(&self, offset: u64, dst: &mut [u8]);
156
157 fn write_config(&self, _offset: u64, _data: &[u8]) {}
159
160 fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;
164
165 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
169
170 fn reset(&mut self);
172
173 fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
175 None
176 }
177
178 fn set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection) {}
184
185 fn enter_suspended_state(&mut self) -> anyhow::Result<()>;
198
199 fn snapshot(&mut self) -> anyhow::Result<AnySnapshot>;
201
202 fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>;
204
205 fn unmap_guest_memory_on_fork(&self) -> bool {
213 false
214 }
215}
216
217struct Vring {
219 queue: QueueConfig,
221 doorbell: Option<Interrupt>,
222 enabled: bool,
223}
224
225impl Vring {
226 fn new(max_size: u16, features: u64) -> Self {
227 Self {
228 queue: QueueConfig::new(max_size, features),
229 doorbell: None,
230 enabled: false,
231 }
232 }
233
234 fn reset(&mut self) {
235 self.queue.reset();
236 self.doorbell = None;
237 self.enabled = false;
238 }
239}
240
241pub(super) struct VhostUserRegularOps;
243
244impl VhostUserRegularOps {
245 pub fn set_mem_table(
246 contexts: &[VhostUserMemoryRegion],
247 files: Vec<File>,
248 ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
249 if files.len() != contexts.len() {
250 return Err(VhostError::InvalidParam(
251 "number of files & contexts was not equal",
252 ));
253 }
254
255 let mut regions = Vec::with_capacity(files.len());
256 for (region, file) in contexts.iter().zip(files.into_iter()) {
257 let region = MemoryRegion::new_from_shm(
258 region.memory_size,
259 GuestAddress(region.guest_phys_addr),
260 region.mmap_offset,
261 Arc::new(
262 SharedMemory::from_safe_descriptor(
263 SafeDescriptor::from(file),
264 region.memory_size,
265 )
266 .unwrap(),
267 ),
268 )
269 .map_err(|e| {
270 error!("failed to create a memory region: {}", e);
271 VhostError::InvalidOperation
272 })?;
273 regions.push(region);
274 }
275 let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
276 error!("failed to create guest memory: {}", e);
277 VhostError::InvalidOperation
278 })?;
279
280 let vmm_maps = contexts
281 .iter()
282 .map(|region| MappingInfo {
283 vmm_addr: region.user_addr,
284 guest_phys: region.guest_phys_addr,
285 size: region.memory_size,
286 })
287 .collect();
288 Ok((guest_mem, vmm_maps))
289 }
290}
291
292pub struct DeviceRequestHandler<T: VhostUserDevice> {
294 vrings: Vec<Vring>,
295 owned: bool,
296 vmm_maps: Option<Vec<MappingInfo>>,
297 mem: Option<GuestMemory>,
298 acked_features: u64,
299 acked_protocol_features: VhostUserProtocolFeatures,
300 backend: T,
301 backend_req_connection: Option<VhostBackendReqConnection>,
302 device_state_thread: Option<DeviceStateThread>,
304}
305
306enum DeviceStateThread {
307 Save(WorkerThread<Result<(), ciborium::ser::Error<std::io::Error>>>),
308 Load(WorkerThread<Result<DeviceRequestHandlerSnapshot, ciborium::de::Error<std::io::Error>>>),
309}
310
311#[derive(Serialize, Deserialize)]
312pub struct DeviceRequestHandlerSnapshot {
313 acked_features: u64,
314 acked_protocol_features: u64,
315 backend: AnySnapshot,
316}
317
318impl<T: VhostUserDevice> DeviceRequestHandler<T> {
319 pub(crate) fn new(mut backend: T) -> Self {
321 let mut vrings = Vec::with_capacity(backend.max_queue_num());
322 for _ in 0..backend.max_queue_num() {
323 vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
324 }
325
326 backend
329 .enter_suspended_state()
330 .expect("enter_suspended_state failed on device init");
331
332 DeviceRequestHandler {
333 vrings,
334 owned: false,
335 vmm_maps: None,
336 mem: None,
337 acked_features: 0,
338 acked_protocol_features: VhostUserProtocolFeatures::empty(),
339 backend,
340 backend_req_connection: None,
341 device_state_thread: None,
342 }
343 }
344
345 fn all_queues_stopped(&self) -> bool {
349 self.vrings.iter().all(|vring| !vring.queue.ready())
350 }
351}
352
353impl<T: VhostUserDevice> Drop for DeviceRequestHandler<T> {
354 fn drop(&mut self) {
355 for (index, vring) in self.vrings.iter().enumerate() {
356 if vring.queue.ready() {
357 if let Err(e) = self.backend.stop_queue(index) {
358 error!("Failed to stop queue {} during drop: {:#}", index, e);
359 }
360 }
361 }
362 }
363}
364
365impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
366 fn as_ref(&self) -> &T {
367 &self.backend
368 }
369}
370
371impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
372 fn as_mut(&mut self) -> &mut T {
373 &mut self.backend
374 }
375}
376
377impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
378 fn set_owner(&mut self) -> VhostResult<()> {
379 if self.owned {
380 return Err(VhostError::InvalidOperation);
381 }
382 self.owned = true;
383 Ok(())
384 }
385
386 fn reset_owner(&mut self) -> VhostResult<()> {
387 self.owned = false;
388 self.acked_features = 0;
389 self.backend.reset();
390 Ok(())
391 }
392
393 fn get_features(&mut self) -> VhostResult<u64> {
394 let features = self.backend.features();
395 Ok(features)
396 }
397
398 fn set_features(&mut self, features: u64) -> VhostResult<()> {
399 if !self.owned {
400 return Err(VhostError::InvalidOperation);
401 }
402
403 let unexpected_features = features & !self.backend.features();
404 if unexpected_features != 0 {
405 error!("unexpected set_features {:#x}", unexpected_features);
406 return Err(VhostError::InvalidParam("unexpected set_features"));
407 }
408
409 if let Err(e) = self.backend.ack_features(features) {
410 error!("failed to acknowledge features 0x{:x}: {}", features, e);
411 return Err(VhostError::InvalidOperation);
412 }
413
414 self.acked_features |= features;
415
416 let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
424 for v in &mut self.vrings {
425 v.enabled = vring_enabled;
426 }
427
428 Ok(())
429 }
430
431 fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
432 Ok(self.backend.protocol_features())
433 }
434
435 fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
436 let features = match VhostUserProtocolFeatures::from_bits(features) {
437 Some(proto_features) => proto_features,
438 None => {
439 error!(
440 "unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",
441 features
442 );
443 return Err(VhostError::InvalidOperation);
444 }
445 };
446 let supported = self.backend.protocol_features();
447 self.acked_protocol_features = features & supported;
448 Ok(())
449 }
450
451 fn set_mem_table(
452 &mut self,
453 contexts: &[VhostUserMemoryRegion],
454 files: Vec<File>,
455 ) -> VhostResult<()> {
456 let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
457 if self.backend.unmap_guest_memory_on_fork() {
458 #[cfg(any(target_os = "android", target_os = "linux"))]
459 if let Err(e) = guest_mem.use_dontfork() {
460 error!("failed to set MADV_DONTFORK on guest memory: {e:#}");
461 }
462 #[cfg(not(any(target_os = "android", target_os = "linux")))]
463 error!("unmap_guest_memory_on_fork unsupported; skipping");
464 }
465 self.mem = Some(guest_mem);
466 self.vmm_maps = Some(vmm_maps);
467 Ok(())
468 }
469
470 fn get_queue_num(&mut self) -> VhostResult<u64> {
471 Ok(self.vrings.len() as u64)
472 }
473
474 fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
475 if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
476 return Err(VhostError::InvalidParam(
477 "set_vring_num: invalid index or num",
478 ));
479 }
480 self.vrings[index as usize].queue.set_size(num as u16);
481
482 Ok(())
483 }
484
485 fn set_vring_addr(
486 &mut self,
487 index: u32,
488 _flags: VhostUserVringAddrFlags,
489 descriptor: u64,
490 used: u64,
491 available: u64,
492 _log: u64,
493 ) -> VhostResult<()> {
494 if index as usize >= self.vrings.len() {
495 return Err(VhostError::InvalidParam(
496 "set_vring_addr: index out of range",
497 ));
498 }
499
500 let vmm_maps = self
501 .vmm_maps
502 .as_ref()
503 .ok_or(VhostError::InvalidParam("set_vring_addr: missing vmm_maps"))?;
504 let vring = &mut self.vrings[index as usize];
505 vring
506 .queue
507 .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
508 vring
509 .queue
510 .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
511 vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
512
513 Ok(())
514 }
515
516 fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
517 if index as usize >= self.vrings.len() {
518 return Err(VhostError::InvalidParam(
519 "set_vring_base: index out of range",
520 ));
521 }
522
523 let vring = &mut self.vrings[index as usize];
524 vring.queue.set_next_avail(Wrapping(base as u16));
525 vring.queue.set_next_used(Wrapping(base as u16));
526
527 Ok(())
528 }
529
530 fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
531 let vring = self
532 .vrings
533 .get_mut(index as usize)
534 .ok_or(VhostError::InvalidParam(
535 "get_vring_base: index out of range",
536 ))?;
537
538 let vring_base = if vring.queue.ready() {
543 let queue = match self.backend.stop_queue(index as usize) {
544 Ok(q) => q,
545 Err(e) => {
546 error!("Failed to stop queue in get_vring_base: {:#}", e);
547 return Err(VhostError::BackendInternalError);
548 }
549 };
550
551 trace!("stopped queue {index}");
552 vring.reset();
553
554 if self.all_queues_stopped() {
555 trace!("all queues stopped; entering suspended state");
556 self.backend
557 .enter_suspended_state()
558 .map_err(VhostError::EnterSuspendedState)?;
559 }
560
561 queue.next_avail_to_process()
562 } else {
563 0
564 };
565
566 Ok(VhostUserVringState::new(index, vring_base.into()))
567 }
568
569 fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
570 if index as usize >= self.vrings.len() {
571 return Err(VhostError::InvalidParam(
572 "set_vring_kick: index out of range",
573 ));
574 }
575
576 let vring = &mut self.vrings[index as usize];
577 if vring.queue.ready() {
578 error!("kick fd cannot replaced after queue is started");
579 return Err(VhostError::InvalidOperation);
580 }
581
582 let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_kick"))?;
583
584 #[cfg(any(target_os = "android", target_os = "linux"))]
588 if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
589 error!("failed to remove O_NONBLOCK for kick fd: {}", e);
590 return Err(VhostError::InvalidParam(
591 "could not remove O_NONBLOCK from vring_kick",
592 ));
593 }
594
595 let kick_evt = Event::from(SafeDescriptor::from(file));
596
597 vring.queue.ack_features(self.acked_features);
599 vring.queue.set_ready(true);
600
601 let mem = self
602 .mem
603 .as_ref()
604 .cloned()
605 .ok_or(VhostError::InvalidOperation)?;
606
607 let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
608
609 let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {
610 Ok(queue) => queue,
611 Err(e) => {
612 error!("failed to activate vring: {:#}", e);
613 return Err(VhostError::BackendInternalError);
614 }
615 };
616
617 if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {
618 error!("Failed to start queue {}: {}", index, e);
619 return Err(VhostError::BackendInternalError);
620 }
621 trace!("started queue {index}");
622
623 Ok(())
624 }
625
626 fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
627 if index as usize >= self.vrings.len() {
628 return Err(VhostError::InvalidParam(
629 "set_vring_call: index out of range",
630 ));
631 }
632
633 let backend_req_conn = self.backend_req_connection.clone();
634 let signal_config_change_fn = Box::new(move || {
635 if let Some(frontend) = backend_req_conn.as_ref() {
636 if let Err(e) = frontend.send_config_changed() {
637 error!("Failed to notify config change: {:#}", e);
638 }
639 } else {
640 error!("No Backend request connection found");
641 }
642 });
643
644 let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_call"))?;
645 self.vrings[index as usize].doorbell = Some(Interrupt::new_vhost_user(
646 Event::from(SafeDescriptor::from(file)),
647 signal_config_change_fn,
648 ));
649 Ok(())
650 }
651
652 fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
653 Ok(())
655 }
656
657 fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
658 if index as usize >= self.vrings.len() {
659 return Err(VhostError::InvalidParam(
660 "set_vring_enable: index out of range",
661 ));
662 }
663
664 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
667 return Err(VhostError::InvalidOperation);
668 }
669
670 self.vrings[index as usize].enabled = enable;
674
675 Ok(())
676 }
677
678 fn get_config(
679 &mut self,
680 offset: u32,
681 size: u32,
682 _flags: VhostUserConfigFlags,
683 ) -> VhostResult<Vec<u8>> {
684 let mut data = vec![0; size as usize];
685 self.backend.read_config(u64::from(offset), &mut data);
686 Ok(data)
687 }
688
689 fn set_config(
690 &mut self,
691 offset: u32,
692 buf: &[u8],
693 _flags: VhostUserConfigFlags,
694 ) -> VhostResult<()> {
695 self.backend.write_config(u64::from(offset), buf);
696 Ok(())
697 }
698
699 fn set_backend_req_fd(&mut self, ep: Connection<BackendReq>) {
700 let conn = VhostBackendReqConnection::new(
701 FrontendClient::new(ep),
702 self.backend.get_shared_memory_region().map(|r| r.id),
703 );
704
705 if self.backend_req_connection.is_some() {
706 warn!("Backend Request Connection already established. Overwriting");
707 }
708 self.backend_req_connection = Some(conn.clone());
709
710 self.backend.set_backend_req_connection(conn);
711 }
712
713 fn get_inflight_fd(
714 &mut self,
715 _inflight: &VhostUserInflight,
716 ) -> VhostResult<(VhostUserInflight, File)> {
717 unimplemented!("get_inflight_fd");
718 }
719
720 fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
721 unimplemented!("set_inflight_fd");
722 }
723
724 fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
725 Ok(0)
727 }
728
729 fn add_mem_region(
730 &mut self,
731 _region: &VhostUserSingleMemoryRegion,
732 _fd: File,
733 ) -> VhostResult<()> {
734 Ok(())
736 }
737
738 fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
739 Ok(())
741 }
742
743 fn set_device_state_fd(
744 &mut self,
745 transfer_direction: VhostUserTransferDirection,
746 migration_phase: VhostUserMigrationPhase,
747 fd: File,
748 ) -> VhostResult<Option<File>> {
749 if migration_phase != VhostUserMigrationPhase::Stopped {
750 return Err(VhostError::InvalidOperation);
751 }
752 if !self.all_queues_stopped() {
753 return Err(VhostError::InvalidOperation);
754 }
755 if self.device_state_thread.is_some() {
756 error!("must call check_device_state before starting new state transfer");
757 return Err(VhostError::InvalidOperation);
758 }
759 match transfer_direction {
764 VhostUserTransferDirection::Save => {
765 let snapshot = DeviceRequestHandlerSnapshot {
767 acked_features: self.acked_features,
768 acked_protocol_features: self.acked_protocol_features.bits(),
769 backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
770 };
771 self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(
773 "device_state_save",
774 move |_kill_event| -> Result<(), ciborium::ser::Error<std::io::Error>> {
775 let mut w = std::io::BufWriter::new(fd);
776 ciborium::into_writer(&snapshot, &mut w)?;
777 w.flush()?;
778 Ok(())
779 },
780 )));
781 Ok(None)
782 }
783 VhostUserTransferDirection::Load => {
784 self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(
787 "device_state_load",
788 move |_kill_event| ciborium::from_reader(&mut BufReader::new(fd)),
789 )));
790 Ok(None)
791 }
792 }
793 }
794
795 fn check_device_state(&mut self) -> VhostResult<()> {
796 let Some(thread) = self.device_state_thread.take() else {
797 error!("check_device_state: no active state transfer");
798 return Err(VhostError::InvalidOperation);
799 };
800 match thread {
801 DeviceStateThread::Save(worker) => {
802 worker.stop().map_err(|e| {
803 error!("device state save thread failed: {:#}", e);
804 VhostError::BackendInternalError
805 })?;
806 Ok(())
807 }
808 DeviceStateThread::Load(worker) => {
809 let snapshot = worker.stop().map_err(|e| {
810 error!("device state load thread failed: {:#}", e);
811 VhostError::BackendInternalError
812 })?;
813 self.acked_features = snapshot.acked_features;
814 self.acked_protocol_features =
815 VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)
816 .with_context(|| {
817 format!(
818 "unsupported bits in acked_protocol_features: {:#x}",
819 snapshot.acked_protocol_features
820 )
821 })
822 .map_err(VhostError::RestoreError)?;
823 self.backend
824 .restore(snapshot.backend)
825 .map_err(VhostError::RestoreError)?;
826 Ok(())
827 }
828 }
829 }
830
831 fn get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>> {
832 Ok(if let Some(r) = self.backend.get_shared_memory_region() {
833 vec![VhostSharedMemoryRegion::new(r.id, r.length)]
834 } else {
835 Vec::new()
836 })
837 }
838}
839
840#[derive(Clone)]
842pub struct VhostBackendReqConnection {
843 shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
844 shmid: Option<u8>,
845}
846
847struct VhostBackendReqConnectionShared {
848 conn: FrontendClient,
849 mapped_regions: BTreeMap<u64 , u64 >,
850}
851
852impl VhostBackendReqConnection {
853 fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
854 Self {
855 shared: Arc::new(Mutex::new(VhostBackendReqConnectionShared {
856 conn,
857 mapped_regions: BTreeMap::new(),
858 })),
859 shmid,
860 }
861 }
862
863 fn send_config_changed(&self) -> anyhow::Result<()> {
865 let mut shared = self.shared.lock();
866 shared
867 .conn
868 .handle_config_change()
869 .context("Could not send config change message")?;
870 Ok(())
871 }
872
873 pub fn shmem_mapper(&self) -> Option<Box<dyn SharedMemoryMapper>> {
875 if let Some(shmid) = self.shmid {
876 Some(Box::new(VhostShmemMapper {
877 shared: self.shared.clone(),
878 shmid,
879 }))
880 } else {
881 None
882 }
883 }
884}
885
886#[derive(Clone)]
887struct VhostShmemMapper {
888 shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
889 shmid: u8,
890}
891
892impl SharedMemoryMapper for VhostShmemMapper {
893 fn add_mapping(
894 &mut self,
895 source: VmMemorySource,
896 offset: u64,
897 prot: Protection,
898 _cache: MemCacheType,
899 ) -> anyhow::Result<()> {
900 let mut shared = self.shared.lock();
901 let size = match source {
902 VmMemorySource::Vulkan {
903 descriptor,
904 handle_type,
905 memory_idx,
906 device_uuid,
907 driver_uuid,
908 size,
909 } => {
910 let msg = VhostUserGpuMapMsg::new(
911 self.shmid,
912 offset,
913 size,
914 memory_idx,
915 handle_type,
916 device_uuid,
917 driver_uuid,
918 );
919 shared
920 .conn
921 .gpu_map(&msg, &descriptor)
922 .context("map GPU memory")?;
923 size
924 }
925 VmMemorySource::ExternalMapping { ptr, size } => {
926 let msg = VhostUserExternalMapMsg::new(self.shmid, offset, size, ptr);
927 shared
928 .conn
929 .external_map(&msg)
930 .context("create external mapping")?;
931 size
932 }
933 source => {
934 let (descriptor, fd_offset, size) = match source {
937 VmMemorySource::Descriptor {
938 descriptor,
939 offset,
940 size,
941 } => (descriptor, offset, size),
942 VmMemorySource::SharedMemory(shmem) => {
943 let size = shmem.size();
944 let descriptor = SafeDescriptor::from(shmem);
945 (descriptor, 0, size)
946 }
947 _ => bail!("unsupported source"),
948 };
949 let flags = VhostUserShmemMapMsgFlags::from(prot);
950 let msg = VhostUserShmemMapMsg::new(self.shmid, offset, fd_offset, size, flags);
951 shared
952 .conn
953 .shmem_map(&msg, &descriptor)
954 .context("map shmem")?;
955 size
956 }
957 };
958
959 shared.mapped_regions.insert(offset, size);
960 Ok(())
961 }
962
963 fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
964 let mut shared = self.shared.lock();
965 let size = shared
966 .mapped_regions
967 .remove(&offset)
968 .context("unknown offset")?;
969 let msg = VhostUserShmemUnmapMsg::new(self.shmid, offset, size);
970 shared
971 .conn
972 .shmem_unmap(&msg)
973 .context("unmap shmem")
974 .map(|_| ())
975 }
976}
977
978pub(crate) struct WorkerState<T, U> {
979 pub(crate) queue_task: TaskHandle<U>,
980 pub(crate) queue: T,
981}
982
983#[derive(Debug, ThisError)]
985pub enum Error {
986 #[error("worker not found when stopping queue")]
987 WorkerNotFound,
988}
989
990#[cfg(test)]
991mod tests {
992 use std::sync::mpsc::channel;
993 use std::sync::Barrier;
994
995 use anyhow::bail;
996 use base::Event;
997 use vmm_vhost::BackendServer;
998 use vmm_vhost::FrontendReq;
999 use zerocopy::FromBytes;
1000 use zerocopy::FromZeros;
1001 use zerocopy::Immutable;
1002 use zerocopy::IntoBytes;
1003 use zerocopy::KnownLayout;
1004
1005 use super::*;
1006 use crate::virtio::vhost_user_frontend::VhostUserFrontend;
1007 use crate::virtio::DeviceType;
1008 use crate::virtio::VirtioDevice;
1009
1010 #[derive(Clone, Copy, Debug, PartialEq, Eq, FromBytes, Immutable, IntoBytes, KnownLayout)]
1011 #[repr(C, packed(4))]
1012 struct FakeConfig {
1013 x: u32,
1014 y: u64,
1015 }
1016
1017 const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
1018
1019 pub(super) struct FakeBackend {
1020 avail_features: u64,
1021 acked_features: u64,
1022 active_queues: Vec<Option<Queue>>,
1023 allow_backend_req: bool,
1024 backend_conn: Option<VhostBackendReqConnection>,
1025 }
1026
1027 #[derive(Deserialize, Serialize)]
1028 struct FakeBackendSnapshot {
1029 data: Vec<u8>,
1030 }
1031
1032 impl FakeBackend {
1033 const MAX_QUEUE_NUM: usize = 16;
1034
1035 pub(super) fn new() -> Self {
1036 let mut active_queues = Vec::new();
1037 active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1038 Self {
1039 avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
1040 acked_features: 0,
1041 active_queues,
1042 allow_backend_req: false,
1043 backend_conn: None,
1044 }
1045 }
1046 }
1047
1048 impl VhostUserDevice for FakeBackend {
1049 fn max_queue_num(&self) -> usize {
1050 Self::MAX_QUEUE_NUM
1051 }
1052
1053 fn features(&self) -> u64 {
1054 self.avail_features
1055 }
1056
1057 fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1058 let unrequested_features = value & !self.avail_features;
1059 if unrequested_features != 0 {
1060 bail!(
1061 "invalid protocol features are given: 0x{:x}",
1062 unrequested_features
1063 );
1064 }
1065 self.acked_features |= value;
1066 Ok(())
1067 }
1068
1069 fn protocol_features(&self) -> VhostUserProtocolFeatures {
1070 let mut features =
1071 VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;
1072 if self.allow_backend_req {
1073 features |= VhostUserProtocolFeatures::BACKEND_REQ;
1074 }
1075 features
1076 }
1077
1078 fn read_config(&self, offset: u64, dst: &mut [u8]) {
1079 dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1080 }
1081
1082 fn reset(&mut self) {}
1083
1084 fn start_queue(
1085 &mut self,
1086 idx: usize,
1087 queue: Queue,
1088 _mem: GuestMemory,
1089 ) -> anyhow::Result<()> {
1090 self.active_queues[idx] = Some(queue);
1091 Ok(())
1092 }
1093
1094 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1095 Ok(self.active_queues[idx]
1096 .take()
1097 .ok_or(Error::WorkerNotFound)?)
1098 }
1099
1100 fn set_backend_req_connection(&mut self, conn: VhostBackendReqConnection) {
1101 self.backend_conn = Some(conn);
1102 }
1103
1104 fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
1105 Ok(())
1106 }
1107
1108 fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1109 AnySnapshot::to_any(FakeBackendSnapshot {
1110 data: vec![1, 2, 3],
1111 })
1112 .context("failed to serialize snapshot")
1113 }
1114
1115 fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1116 let snapshot: FakeBackendSnapshot =
1117 AnySnapshot::from_any(data).context("failed to deserialize snapshot")?;
1118 assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");
1119 Ok(())
1120 }
1121 }
1122
1123 #[test]
1124 fn test_vhost_user_lifecycle() {
1125 test_vhost_user_lifecycle_parameterized(false);
1126 }
1127
1128 #[test]
1129 #[cfg(not(windows))] fn test_vhost_user_lifecycle_with_backend_req() {
1131 test_vhost_user_lifecycle_parameterized(true);
1132 }
1133
1134 fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {
1135 const QUEUES_NUM: usize = 2;
1136
1137 let (client_connection, server_connection) =
1138 vmm_vhost::Connection::<FrontendReq>::pair().unwrap();
1139
1140 let vmm_bar = Arc::new(Barrier::new(2));
1141 let dev_bar = vmm_bar.clone();
1142
1143 let (ready_tx, ready_rx) = channel();
1144 let (shutdown_tx, shutdown_rx) = channel();
1145 let (vm_evt_wrtube, _vm_evt_rdtube) = base::Tube::directional_pair().unwrap();
1146
1147 std::thread::spawn(move || {
1148 ready_rx.recv().unwrap(); let mut vmm_device = VhostUserFrontend::new(
1152 DeviceType::Console,
1153 0,
1154 client_connection,
1155 vm_evt_wrtube,
1156 None,
1157 None,
1158 )
1159 .unwrap();
1160
1161 println!("read_config");
1162 let mut config = FakeConfig::new_zeroed();
1163 vmm_device.read_config(0, config.as_mut_bytes());
1164 assert_eq!(config, FAKE_CONFIG_DATA);
1166
1167 let activate = |vmm_device: &mut VhostUserFrontend| {
1168 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1169 let interrupt = Interrupt::new_for_test_with_msix();
1170
1171 let mut queues = BTreeMap::new();
1172 for idx in 0..QUEUES_NUM {
1173 let mut queue = QueueConfig::new(0x10, 0);
1174 queue.set_ready(true);
1175 let queue = queue
1176 .activate(&mem, Event::new().unwrap(), interrupt.clone())
1177 .expect("QueueConfig::activate");
1178 queues.insert(idx, queue);
1179 }
1180
1181 println!("activate");
1182 vmm_device.activate(mem, interrupt, queues).unwrap();
1183 };
1184
1185 activate(&mut vmm_device);
1186
1187 println!("reset");
1188 let reset_result = vmm_device.reset();
1189 assert!(
1190 reset_result.is_ok(),
1191 "reset failed: {:#}",
1192 reset_result.unwrap_err()
1193 );
1194
1195 activate(&mut vmm_device);
1196
1197 println!("virtio_sleep");
1198 let queues = vmm_device
1199 .virtio_sleep()
1200 .unwrap()
1201 .expect("virtio_sleep unexpectedly returned None");
1202
1203 println!("virtio_snapshot");
1204 let snapshot = vmm_device
1205 .virtio_snapshot()
1206 .expect("virtio_snapshot failed");
1207 println!("virtio_restore");
1208 vmm_device
1209 .virtio_restore(snapshot)
1210 .expect("virtio_restore failed");
1211
1212 println!("virtio_wake");
1213 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1214 let interrupt = Interrupt::new_for_test_with_msix();
1215 vmm_device
1216 .virtio_wake(Some((mem, interrupt, queues)))
1217 .unwrap();
1218
1219 println!("wait for shutdown signal");
1220 shutdown_rx.recv().unwrap();
1221
1222 println!("drop");
1224 drop(vmm_device);
1225
1226 vmm_bar.wait();
1227 });
1228
1229 let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1231 handler.as_mut().allow_backend_req = allow_backend_req;
1232
1233 ready_tx.send(()).unwrap();
1235
1236 let mut req_handler = BackendServer::new(server_connection, handler);
1237
1238 handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1240 handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1241 handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1242 handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1243 if allow_backend_req {
1244 handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1245 }
1246
1247 handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1249
1250 handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1252 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1253 for _ in 0..QUEUES_NUM {
1254 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1255 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1256 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1257 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1258 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1259 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1260 }
1261
1262 for _ in 0..QUEUES_NUM {
1264 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1265 handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1266 }
1267
1268 handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1270 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1271 for _ in 0..QUEUES_NUM {
1272 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1273 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1274 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1275 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1276 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1277 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1278 }
1279
1280 if allow_backend_req {
1281 req_handler
1283 .as_ref()
1284 .as_ref()
1285 .backend_conn
1286 .as_ref()
1287 .expect("backend_conn missing")
1288 .send_config_changed()
1289 .expect("send_config_changed failed");
1290 }
1291
1292 for _ in 0..QUEUES_NUM {
1294 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1295 handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1296 }
1297
1298 handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1300 handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1301 handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1303 handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1304
1305 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1307 for _ in 0..QUEUES_NUM {
1308 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1309 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1310 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1311 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1312 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1313 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1314 }
1315
1316 if allow_backend_req {
1317 req_handler
1319 .as_ref()
1320 .as_ref()
1321 .backend_conn
1322 .as_ref()
1323 .expect("backend_conn missing")
1324 .send_config_changed()
1325 .expect("send_config_changed failed");
1326 }
1327
1328 shutdown_tx.send(()).unwrap();
1330 dev_bar.wait();
1331
1332 match req_handler.recv_header() {
1334 Err(VhostError::ClientExit) => (),
1335 r => panic!("expected Err(ClientExit) but got {r:?}"),
1336 }
1337 }
1338
1339 fn handle_request<S: vmm_vhost::Backend>(
1340 handler: &mut BackendServer<S>,
1341 expected_message_type: FrontendReq,
1342 ) -> Result<(), VhostError> {
1343 let (hdr, files) = handler.recv_header()?;
1344 assert_eq!(hdr.get_code(), Ok(expected_message_type));
1345 handler.process_message(hdr, files)
1346 }
1347}