1pub(super) mod sys;
48
49use std::collections::BTreeMap;
50use std::convert::From;
51use std::fs::File;
52use std::io::BufReader;
53use std::io::Write;
54use std::num::Wrapping;
55#[cfg(any(target_os = "android", target_os = "linux"))]
56use std::os::unix::io::AsRawFd;
57use std::sync::Arc;
58
59use anyhow::bail;
60use anyhow::Context;
61#[cfg(any(target_os = "android", target_os = "linux"))]
62use base::clear_fd_flags;
63use base::error;
64use base::trace;
65use base::warn;
66use base::Event;
67use base::Protection;
68use base::SafeDescriptor;
69use base::SharedMemory;
70use base::WorkerThread;
71use cros_async::TaskHandle;
72use hypervisor::MemCacheType;
73use serde::Deserialize;
74use serde::Serialize;
75use snapshot::AnySnapshot;
76use sync::Mutex;
77use thiserror::Error as ThisError;
78use vm_control::VmMemorySource;
79use vm_memory::GuestAddress;
80use vm_memory::GuestMemory;
81use vm_memory::MemoryRegion;
82use vmm_vhost::message::VhostUserConfigFlags;
83use vmm_vhost::message::VhostUserExternalMapMsg;
84use vmm_vhost::message::VhostUserGpuMapMsg;
85use vmm_vhost::message::VhostUserInflight;
86use vmm_vhost::message::VhostUserMMap;
87use vmm_vhost::message::VhostUserMMapFlags;
88use vmm_vhost::message::VhostUserMemoryRegion;
89use vmm_vhost::message::VhostUserMigrationPhase;
90use vmm_vhost::message::VhostUserProtocolFeatures;
91use vmm_vhost::message::VhostUserShMemConfigHeader;
92use vmm_vhost::message::VhostUserSingleMemoryRegion;
93use vmm_vhost::message::VhostUserTransferDirection;
94use vmm_vhost::message::VhostUserVringAddrFlags;
95use vmm_vhost::message::VhostUserVringState;
96use vmm_vhost::BackendReq;
97use vmm_vhost::Connection;
98use vmm_vhost::Error as VhostError;
99use vmm_vhost::Frontend;
100use vmm_vhost::FrontendClient;
101use vmm_vhost::Result as VhostResult;
102use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
103
104use crate::virtio::Interrupt;
105use crate::virtio::Queue;
106use crate::virtio::QueueConfig;
107use crate::virtio::SharedMemoryMapper;
108use crate::virtio::SharedMemoryRegion;
109
110#[derive(Default)]
113pub struct MappingInfo {
114 pub vmm_addr: u64,
115 pub guest_phys: u64,
116 pub size: u64,
117}
118
119pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
120 for map in maps {
121 if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
122 return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
123 }
124 }
125 Err(VhostError::InvalidMessage)
126}
127
128pub trait VhostUserDevice {
133 fn max_queue_num(&self) -> usize;
135
136 fn features(&self) -> u64;
138
139 fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {
147 Ok(())
148 }
149
150 fn protocol_features(&self) -> VhostUserProtocolFeatures;
152
153 fn read_config(&self, offset: u64, dst: &mut [u8]);
155
156 fn write_config(&self, _offset: u64, _data: &[u8]) {}
158
159 fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;
163
164 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
168
169 fn reset(&mut self);
171
172 fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
174 None
175 }
176
177 fn set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection) {}
183
184 fn enter_suspended_state(&mut self) -> anyhow::Result<()>;
197
198 fn snapshot(&mut self) -> anyhow::Result<AnySnapshot>;
200
201 fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>;
203
204 fn unmap_guest_memory_on_fork(&self) -> bool {
212 false
213 }
214}
215
216struct Vring {
218 queue: QueueConfig,
220 doorbell: Option<Interrupt>,
221 enabled: bool,
222}
223
224impl Vring {
225 fn new(max_size: u16, features: u64) -> Self {
226 Self {
227 queue: QueueConfig::new(max_size, features),
228 doorbell: None,
229 enabled: false,
230 }
231 }
232
233 fn reset(&mut self) {
234 self.queue.reset();
235 self.doorbell = None;
236 self.enabled = false;
237 }
238}
239
240pub(super) struct VhostUserRegularOps;
242
243impl VhostUserRegularOps {
244 pub fn set_mem_table(
245 contexts: &[VhostUserMemoryRegion],
246 files: Vec<File>,
247 ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
248 if files.len() != contexts.len() {
249 return Err(VhostError::InvalidParam(
250 "number of files & contexts was not equal",
251 ));
252 }
253
254 let mut regions = Vec::with_capacity(files.len());
255 for (region, file) in contexts.iter().zip(files.into_iter()) {
256 let region = MemoryRegion::new_from_shm(
257 region.memory_size,
258 GuestAddress(region.guest_phys_addr),
259 region.mmap_offset,
260 Arc::new(
261 SharedMemory::from_safe_descriptor(
262 SafeDescriptor::from(file),
263 region.memory_size,
264 )
265 .unwrap(),
266 ),
267 )
268 .map_err(|e| {
269 error!("failed to create a memory region: {}", e);
270 VhostError::InvalidOperation
271 })?;
272 regions.push(region);
273 }
274 let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
275 error!("failed to create guest memory: {}", e);
276 VhostError::InvalidOperation
277 })?;
278
279 let vmm_maps = contexts
280 .iter()
281 .map(|region| MappingInfo {
282 vmm_addr: region.user_addr,
283 guest_phys: region.guest_phys_addr,
284 size: region.memory_size,
285 })
286 .collect();
287 Ok((guest_mem, vmm_maps))
288 }
289}
290
291pub struct DeviceRequestHandler<T: VhostUserDevice> {
293 vrings: Vec<Vring>,
294 owned: bool,
295 vmm_maps: Option<Vec<MappingInfo>>,
296 mem: Option<GuestMemory>,
297 acked_features: u64,
298 acked_protocol_features: VhostUserProtocolFeatures,
299 backend: T,
300 backend_req_connection: Option<VhostBackendReqConnection>,
301 device_state_thread: Option<DeviceStateThread>,
303}
304
305enum DeviceStateThread {
306 Save(WorkerThread<Result<(), ciborium::ser::Error<std::io::Error>>>),
307 Load(WorkerThread<Result<DeviceRequestHandlerSnapshot, ciborium::de::Error<std::io::Error>>>),
308}
309
310#[derive(Serialize, Deserialize)]
311pub struct DeviceRequestHandlerSnapshot {
312 acked_features: u64,
313 acked_protocol_features: u64,
314 backend: AnySnapshot,
315}
316
317impl<T: VhostUserDevice> DeviceRequestHandler<T> {
318 pub(crate) fn new(mut backend: T) -> Self {
320 let mut vrings = Vec::with_capacity(backend.max_queue_num());
321 for _ in 0..backend.max_queue_num() {
322 vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
323 }
324
325 backend
328 .enter_suspended_state()
329 .expect("enter_suspended_state failed on device init");
330
331 DeviceRequestHandler {
332 vrings,
333 owned: false,
334 vmm_maps: None,
335 mem: None,
336 acked_features: 0,
337 acked_protocol_features: VhostUserProtocolFeatures::empty(),
338 backend,
339 backend_req_connection: None,
340 device_state_thread: None,
341 }
342 }
343
344 fn all_queues_stopped(&self) -> bool {
348 self.vrings.iter().all(|vring| !vring.queue.ready())
349 }
350}
351
352impl<T: VhostUserDevice> Drop for DeviceRequestHandler<T> {
353 fn drop(&mut self) {
354 for (index, vring) in self.vrings.iter().enumerate() {
355 if vring.queue.ready() {
356 if let Err(e) = self.backend.stop_queue(index) {
357 error!("Failed to stop queue {} during drop: {:#}", index, e);
358 }
359 }
360 }
361 }
362}
363
364impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
365 fn as_ref(&self) -> &T {
366 &self.backend
367 }
368}
369
370impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
371 fn as_mut(&mut self) -> &mut T {
372 &mut self.backend
373 }
374}
375
376impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
377 fn set_owner(&mut self) -> VhostResult<()> {
378 if self.owned {
379 return Err(VhostError::InvalidOperation);
380 }
381 self.owned = true;
382 Ok(())
383 }
384
385 fn reset_owner(&mut self) -> VhostResult<()> {
386 self.owned = false;
387 self.acked_features = 0;
388 self.backend.reset();
389 Ok(())
390 }
391
392 fn get_features(&mut self) -> VhostResult<u64> {
393 let features = self.backend.features();
394 Ok(features)
395 }
396
397 fn set_features(&mut self, features: u64) -> VhostResult<()> {
398 if !self.owned {
399 return Err(VhostError::InvalidOperation);
400 }
401
402 let unexpected_features = features & !self.backend.features();
403 if unexpected_features != 0 {
404 error!("unexpected set_features {:#x}", unexpected_features);
405 return Err(VhostError::InvalidParam("unexpected set_features"));
406 }
407
408 if let Err(e) = self.backend.ack_features(features) {
409 error!("failed to acknowledge features 0x{:x}: {}", features, e);
410 return Err(VhostError::InvalidOperation);
411 }
412
413 self.acked_features |= features;
414
415 let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
423 for v in &mut self.vrings {
424 v.enabled = vring_enabled;
425 }
426
427 Ok(())
428 }
429
430 fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
431 Ok(self.backend.protocol_features())
432 }
433
434 fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
435 let features = match VhostUserProtocolFeatures::from_bits(features) {
436 Some(proto_features) => proto_features,
437 None => {
438 error!(
439 "unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",
440 features
441 );
442 return Err(VhostError::InvalidOperation);
443 }
444 };
445 let supported = self.backend.protocol_features();
446 self.acked_protocol_features = features & supported;
447 Ok(())
448 }
449
450 fn set_mem_table(
451 &mut self,
452 contexts: &[VhostUserMemoryRegion],
453 files: Vec<File>,
454 ) -> VhostResult<()> {
455 let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
456 if self.backend.unmap_guest_memory_on_fork() {
457 #[cfg(any(target_os = "android", target_os = "linux"))]
458 if let Err(e) = guest_mem.use_dontfork() {
459 error!("failed to set MADV_DONTFORK on guest memory: {e:#}");
460 }
461 #[cfg(not(any(target_os = "android", target_os = "linux")))]
462 error!("unmap_guest_memory_on_fork unsupported; skipping");
463 }
464 self.mem = Some(guest_mem);
465 self.vmm_maps = Some(vmm_maps);
466 Ok(())
467 }
468
469 fn get_queue_num(&mut self) -> VhostResult<u64> {
470 Ok(self.vrings.len() as u64)
471 }
472
473 fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
474 if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
475 return Err(VhostError::InvalidParam(
476 "set_vring_num: invalid index or num",
477 ));
478 }
479 self.vrings[index as usize].queue.set_size(num as u16);
480
481 Ok(())
482 }
483
484 fn set_vring_addr(
485 &mut self,
486 index: u32,
487 _flags: VhostUserVringAddrFlags,
488 descriptor: u64,
489 used: u64,
490 available: u64,
491 _log: u64,
492 ) -> VhostResult<()> {
493 if index as usize >= self.vrings.len() {
494 return Err(VhostError::InvalidParam(
495 "set_vring_addr: index out of range",
496 ));
497 }
498
499 let vmm_maps = self
500 .vmm_maps
501 .as_ref()
502 .ok_or(VhostError::InvalidParam("set_vring_addr: missing vmm_maps"))?;
503 let vring = &mut self.vrings[index as usize];
504 vring
505 .queue
506 .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
507 vring
508 .queue
509 .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
510 vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
511
512 Ok(())
513 }
514
515 fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
516 if index as usize >= self.vrings.len() {
517 return Err(VhostError::InvalidParam(
518 "set_vring_base: index out of range",
519 ));
520 }
521
522 let vring = &mut self.vrings[index as usize];
523 vring.queue.set_next_avail(Wrapping(base as u16));
524 vring.queue.set_next_used(Wrapping(base as u16));
525
526 Ok(())
527 }
528
529 fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
530 let vring = self
531 .vrings
532 .get_mut(index as usize)
533 .ok_or(VhostError::InvalidParam(
534 "get_vring_base: index out of range",
535 ))?;
536
537 let vring_base = if vring.queue.ready() {
542 let queue = match self.backend.stop_queue(index as usize) {
543 Ok(q) => q,
544 Err(e) => {
545 error!("Failed to stop queue in get_vring_base: {:#}", e);
546 return Err(VhostError::BackendInternalError);
547 }
548 };
549
550 trace!("stopped queue {index}");
551 vring.reset();
552
553 if self.all_queues_stopped() {
554 trace!("all queues stopped; entering suspended state");
555 self.backend
556 .enter_suspended_state()
557 .map_err(VhostError::EnterSuspendedState)?;
558 }
559
560 queue.next_avail_to_process()
561 } else {
562 0
563 };
564
565 Ok(VhostUserVringState::new(index, vring_base.into()))
566 }
567
568 fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
569 if index as usize >= self.vrings.len() {
570 return Err(VhostError::InvalidParam(
571 "set_vring_kick: index out of range",
572 ));
573 }
574
575 let vring = &mut self.vrings[index as usize];
576 if vring.queue.ready() {
577 error!("kick fd cannot replaced after queue is started");
578 return Err(VhostError::InvalidOperation);
579 }
580
581 let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_kick"))?;
582
583 #[cfg(any(target_os = "android", target_os = "linux"))]
587 if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
588 error!("failed to remove O_NONBLOCK for kick fd: {}", e);
589 return Err(VhostError::InvalidParam(
590 "could not remove O_NONBLOCK from vring_kick",
591 ));
592 }
593
594 let kick_evt = Event::from(SafeDescriptor::from(file));
595
596 vring.queue.ack_features(self.acked_features);
598 vring.queue.set_ready(true);
599
600 let mem = self
601 .mem
602 .as_ref()
603 .cloned()
604 .ok_or(VhostError::InvalidOperation)?;
605
606 let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
607
608 let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {
609 Ok(queue) => queue,
610 Err(e) => {
611 error!("failed to activate vring: {:#}", e);
612 return Err(VhostError::BackendInternalError);
613 }
614 };
615
616 if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {
617 error!("Failed to start queue {}: {}", index, e);
618 return Err(VhostError::BackendInternalError);
619 }
620 trace!("started queue {index}");
621
622 Ok(())
623 }
624
625 fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
626 if index as usize >= self.vrings.len() {
627 return Err(VhostError::InvalidParam(
628 "set_vring_call: index out of range",
629 ));
630 }
631
632 let backend_req_conn = self.backend_req_connection.clone();
633 let signal_config_change_fn = Box::new(move || {
634 if let Some(frontend) = backend_req_conn.as_ref() {
635 if let Err(e) = frontend.send_config_changed() {
636 error!("Failed to notify config change: {:#}", e);
637 }
638 } else {
639 error!("No Backend request connection found");
640 }
641 });
642
643 let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_call"))?;
644 self.vrings[index as usize].doorbell = Some(Interrupt::new_vhost_user(
645 Event::from(SafeDescriptor::from(file)),
646 signal_config_change_fn,
647 ));
648 Ok(())
649 }
650
651 fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
652 Ok(())
654 }
655
656 fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
657 if index as usize >= self.vrings.len() {
658 return Err(VhostError::InvalidParam(
659 "set_vring_enable: index out of range",
660 ));
661 }
662
663 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
666 return Err(VhostError::InvalidOperation);
667 }
668
669 self.vrings[index as usize].enabled = enable;
673
674 Ok(())
675 }
676
677 fn get_config(
678 &mut self,
679 offset: u32,
680 size: u32,
681 _flags: VhostUserConfigFlags,
682 ) -> VhostResult<Vec<u8>> {
683 let mut data = vec![0; size as usize];
684 self.backend.read_config(u64::from(offset), &mut data);
685 Ok(data)
686 }
687
688 fn set_config(
689 &mut self,
690 offset: u32,
691 buf: &[u8],
692 _flags: VhostUserConfigFlags,
693 ) -> VhostResult<()> {
694 self.backend.write_config(u64::from(offset), buf);
695 Ok(())
696 }
697
698 fn set_backend_req_fd(&mut self, ep: Connection<BackendReq>) {
699 let conn = VhostBackendReqConnection::new(
700 FrontendClient::new(ep),
701 self.backend.get_shared_memory_region().map(|r| r.id),
702 );
703
704 if self.backend_req_connection.is_some() {
705 warn!("Backend Request Connection already established. Overwriting");
706 }
707 self.backend_req_connection = Some(conn.clone());
708
709 self.backend.set_backend_req_connection(conn);
710 }
711
712 fn get_inflight_fd(
713 &mut self,
714 _inflight: &VhostUserInflight,
715 ) -> VhostResult<(VhostUserInflight, File)> {
716 unimplemented!("get_inflight_fd");
717 }
718
719 fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
720 unimplemented!("set_inflight_fd");
721 }
722
723 fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
724 Ok(0)
726 }
727
728 fn add_mem_region(
729 &mut self,
730 _region: &VhostUserSingleMemoryRegion,
731 _fd: File,
732 ) -> VhostResult<()> {
733 Ok(())
735 }
736
737 fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
738 Ok(())
740 }
741
742 fn set_device_state_fd(
743 &mut self,
744 transfer_direction: VhostUserTransferDirection,
745 migration_phase: VhostUserMigrationPhase,
746 fd: File,
747 ) -> VhostResult<Option<File>> {
748 if migration_phase != VhostUserMigrationPhase::Stopped {
749 return Err(VhostError::InvalidOperation);
750 }
751 if !self.all_queues_stopped() {
752 return Err(VhostError::InvalidOperation);
753 }
754 if self.device_state_thread.is_some() {
755 error!("must call check_device_state before starting new state transfer");
756 return Err(VhostError::InvalidOperation);
757 }
758 match transfer_direction {
763 VhostUserTransferDirection::Save => {
764 let snapshot = DeviceRequestHandlerSnapshot {
766 acked_features: self.acked_features,
767 acked_protocol_features: self.acked_protocol_features.bits(),
768 backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
769 };
770 self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(
772 "device_state_save",
773 move |_kill_event| -> Result<(), ciborium::ser::Error<std::io::Error>> {
774 let mut w = std::io::BufWriter::new(fd);
775 ciborium::into_writer(&snapshot, &mut w)?;
776 w.flush()?;
777 Ok(())
778 },
779 )));
780 Ok(None)
781 }
782 VhostUserTransferDirection::Load => {
783 self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(
786 "device_state_load",
787 move |_kill_event| ciborium::from_reader(&mut BufReader::new(fd)),
788 )));
789 Ok(None)
790 }
791 }
792 }
793
794 fn check_device_state(&mut self) -> VhostResult<()> {
795 let Some(thread) = self.device_state_thread.take() else {
796 error!("check_device_state: no active state transfer");
797 return Err(VhostError::InvalidOperation);
798 };
799 match thread {
800 DeviceStateThread::Save(worker) => {
801 worker.stop().map_err(|e| {
802 error!("device state save thread failed: {:#}", e);
803 VhostError::BackendInternalError
804 })?;
805 Ok(())
806 }
807 DeviceStateThread::Load(worker) => {
808 let snapshot = worker.stop().map_err(|e| {
809 error!("device state load thread failed: {:#}", e);
810 VhostError::BackendInternalError
811 })?;
812 self.acked_features = snapshot.acked_features;
813 self.acked_protocol_features =
814 VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)
815 .with_context(|| {
816 format!(
817 "unsupported bits in acked_protocol_features: {:#x}",
818 snapshot.acked_protocol_features
819 )
820 })
821 .map_err(VhostError::RestoreError)?;
822 self.backend
823 .restore(snapshot.backend)
824 .map_err(VhostError::RestoreError)?;
825 Ok(())
826 }
827 }
828 }
829
830 fn get_shmem_config(&mut self) -> VhostResult<(VhostUserShMemConfigHeader, Vec<u64>)> {
831 Ok(if let Some(r) = self.backend.get_shared_memory_region() {
832 (VhostUserShMemConfigHeader::new(1), vec![r.length])
833 } else {
834 (VhostUserShMemConfigHeader::new(0), Vec::new())
835 })
836 }
837}
838
839#[derive(Clone)]
841pub struct VhostBackendReqConnection {
842 shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
843 shmid: Option<u8>,
844}
845
846struct VhostBackendReqConnectionShared {
847 conn: FrontendClient,
848 mapped_regions: BTreeMap<u64 , u64 >,
849}
850
851impl VhostBackendReqConnection {
852 fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
853 Self {
854 shared: Arc::new(Mutex::new(VhostBackendReqConnectionShared {
855 conn,
856 mapped_regions: BTreeMap::new(),
857 })),
858 shmid,
859 }
860 }
861
862 fn send_config_changed(&self) -> anyhow::Result<()> {
864 let mut shared = self.shared.lock();
865 shared
866 .conn
867 .handle_config_change()
868 .context("Could not send config change message")?;
869 Ok(())
870 }
871
872 pub fn shmem_mapper(&self) -> Option<Box<dyn SharedMemoryMapper>> {
874 if let Some(shmid) = self.shmid {
875 Some(Box::new(VhostShmemMapper {
876 shared: self.shared.clone(),
877 shmid,
878 }))
879 } else {
880 None
881 }
882 }
883}
884
885#[derive(Clone)]
886struct VhostShmemMapper {
887 shared: Arc<Mutex<VhostBackendReqConnectionShared>>,
888 shmid: u8,
889}
890
891impl SharedMemoryMapper for VhostShmemMapper {
892 fn add_mapping(
893 &mut self,
894 source: VmMemorySource,
895 offset: u64,
896 prot: Protection,
897 _cache: MemCacheType,
898 ) -> anyhow::Result<()> {
899 let mut shared = self.shared.lock();
900 let size = match source {
901 VmMemorySource::Vulkan {
902 descriptor,
903 handle_type,
904 memory_idx,
905 device_uuid,
906 driver_uuid,
907 size,
908 } => {
909 let msg = VhostUserGpuMapMsg::new(
910 self.shmid,
911 offset,
912 size,
913 memory_idx,
914 handle_type,
915 device_uuid,
916 driver_uuid,
917 );
918 shared
919 .conn
920 .gpu_map(&msg, &descriptor)
921 .context("map GPU memory")?;
922 size
923 }
924 VmMemorySource::ExternalMapping { ptr, size } => {
925 let msg = VhostUserExternalMapMsg::new(self.shmid, offset, size, ptr);
926 shared
927 .conn
928 .external_map(&msg)
929 .context("create external mapping")?;
930 size
931 }
932 source => {
933 let (descriptor, fd_offset, size) = match source {
936 VmMemorySource::Descriptor {
937 descriptor,
938 offset,
939 size,
940 } => (descriptor, offset, size),
941 VmMemorySource::SharedMemory(shmem) => {
942 let size = shmem.size();
943 let descriptor = SafeDescriptor::from(shmem);
944 (descriptor, 0, size)
945 }
946 _ => bail!("unsupported source"),
947 };
948 let mut flags = VhostUserMMapFlags::empty();
949 anyhow::ensure!(prot.allows(&Protection::read()), "mapping must be readable");
950 if prot.allows(&Protection::write()) {
951 flags |= VhostUserMMapFlags::MAP_RW;
952 }
953 let msg = VhostUserMMap {
954 shmid: self.shmid,
955 padding: Default::default(),
956 fd_offset,
957 shm_offset: offset,
958 len: size,
959 flags,
960 };
961 shared
962 .conn
963 .shmem_map(&msg, &descriptor)
964 .context("map shmem")?;
965 size
966 }
967 };
968
969 shared.mapped_regions.insert(offset, size);
970 Ok(())
971 }
972
973 fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
974 let mut shared = self.shared.lock();
975 let size = shared
976 .mapped_regions
977 .remove(&offset)
978 .context("unknown offset")?;
979 let msg = VhostUserMMap {
980 shmid: self.shmid,
981 padding: Default::default(),
982 fd_offset: 0,
983 shm_offset: offset,
984 len: size,
985 flags: VhostUserMMapFlags::empty(),
986 };
987 shared
988 .conn
989 .shmem_unmap(&msg)
990 .context("unmap shmem")
991 .map(|_| ())
992 }
993}
994
995pub(crate) struct WorkerState<T, U> {
996 pub(crate) queue_task: TaskHandle<U>,
997 pub(crate) queue: T,
998}
999
1000#[derive(Debug, ThisError)]
1002pub enum Error {
1003 #[error("worker not found when stopping queue")]
1004 WorkerNotFound,
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009 use std::sync::mpsc::channel;
1010 use std::sync::Barrier;
1011
1012 use anyhow::bail;
1013 use base::Event;
1014 use vmm_vhost::BackendServer;
1015 use vmm_vhost::FrontendReq;
1016 use zerocopy::FromBytes;
1017 use zerocopy::FromZeros;
1018 use zerocopy::Immutable;
1019 use zerocopy::IntoBytes;
1020 use zerocopy::KnownLayout;
1021
1022 use super::*;
1023 use crate::virtio::vhost_user_frontend::VhostUserFrontend;
1024 use crate::virtio::DeviceType;
1025 use crate::virtio::VirtioDevice;
1026
1027 #[derive(Clone, Copy, Debug, PartialEq, Eq, FromBytes, Immutable, IntoBytes, KnownLayout)]
1028 #[repr(C, packed(4))]
1029 struct FakeConfig {
1030 x: u32,
1031 y: u64,
1032 }
1033
1034 const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
1035
1036 pub(super) struct FakeBackend {
1037 avail_features: u64,
1038 acked_features: u64,
1039 active_queues: Vec<Option<Queue>>,
1040 allow_backend_req: bool,
1041 backend_conn: Option<VhostBackendReqConnection>,
1042 }
1043
1044 #[derive(Deserialize, Serialize)]
1045 struct FakeBackendSnapshot {
1046 data: Vec<u8>,
1047 }
1048
1049 impl FakeBackend {
1050 const MAX_QUEUE_NUM: usize = 16;
1051
1052 pub(super) fn new() -> Self {
1053 let mut active_queues = Vec::new();
1054 active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1055 Self {
1056 avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
1057 acked_features: 0,
1058 active_queues,
1059 allow_backend_req: false,
1060 backend_conn: None,
1061 }
1062 }
1063 }
1064
1065 impl VhostUserDevice for FakeBackend {
1066 fn max_queue_num(&self) -> usize {
1067 Self::MAX_QUEUE_NUM
1068 }
1069
1070 fn features(&self) -> u64 {
1071 self.avail_features
1072 }
1073
1074 fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1075 let unrequested_features = value & !self.avail_features;
1076 if unrequested_features != 0 {
1077 bail!(
1078 "invalid protocol features are given: 0x{:x}",
1079 unrequested_features
1080 );
1081 }
1082 self.acked_features |= value;
1083 Ok(())
1084 }
1085
1086 fn protocol_features(&self) -> VhostUserProtocolFeatures {
1087 let mut features =
1088 VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;
1089 if self.allow_backend_req {
1090 features |= VhostUserProtocolFeatures::BACKEND_REQ;
1091 }
1092 features
1093 }
1094
1095 fn read_config(&self, offset: u64, dst: &mut [u8]) {
1096 dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1097 }
1098
1099 fn reset(&mut self) {}
1100
1101 fn start_queue(
1102 &mut self,
1103 idx: usize,
1104 queue: Queue,
1105 _mem: GuestMemory,
1106 ) -> anyhow::Result<()> {
1107 self.active_queues[idx] = Some(queue);
1108 Ok(())
1109 }
1110
1111 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1112 Ok(self.active_queues[idx]
1113 .take()
1114 .ok_or(Error::WorkerNotFound)?)
1115 }
1116
1117 fn set_backend_req_connection(&mut self, conn: VhostBackendReqConnection) {
1118 self.backend_conn = Some(conn);
1119 }
1120
1121 fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
1122 Ok(())
1123 }
1124
1125 fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1126 AnySnapshot::to_any(FakeBackendSnapshot {
1127 data: vec![1, 2, 3],
1128 })
1129 .context("failed to serialize snapshot")
1130 }
1131
1132 fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1133 let snapshot: FakeBackendSnapshot =
1134 AnySnapshot::from_any(data).context("failed to deserialize snapshot")?;
1135 assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");
1136 Ok(())
1137 }
1138 }
1139
1140 #[test]
1141 fn test_vhost_user_lifecycle() {
1142 test_vhost_user_lifecycle_parameterized(false);
1143 }
1144
1145 #[test]
1146 #[cfg(not(windows))] fn test_vhost_user_lifecycle_with_backend_req() {
1148 test_vhost_user_lifecycle_parameterized(true);
1149 }
1150
1151 fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {
1152 const QUEUES_NUM: usize = 2;
1153
1154 let (client_connection, server_connection) =
1155 vmm_vhost::Connection::<FrontendReq>::pair().unwrap();
1156
1157 let vmm_bar = Arc::new(Barrier::new(2));
1158 let dev_bar = vmm_bar.clone();
1159
1160 let (ready_tx, ready_rx) = channel();
1161 let (shutdown_tx, shutdown_rx) = channel();
1162 let (vm_evt_wrtube, _vm_evt_rdtube) = base::Tube::directional_pair().unwrap();
1163
1164 std::thread::spawn(move || {
1165 ready_rx.recv().unwrap(); let mut vmm_device = VhostUserFrontend::new(
1169 DeviceType::Console,
1170 0,
1171 client_connection,
1172 vm_evt_wrtube,
1173 None,
1174 None,
1175 )
1176 .unwrap();
1177
1178 println!("read_config");
1179 let mut config = FakeConfig::new_zeroed();
1180 vmm_device.read_config(0, config.as_mut_bytes());
1181 assert_eq!(config, FAKE_CONFIG_DATA);
1183
1184 let activate = |vmm_device: &mut VhostUserFrontend| {
1185 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1186 let interrupt = Interrupt::new_for_test_with_msix();
1187
1188 let mut queues = BTreeMap::new();
1189 for idx in 0..QUEUES_NUM {
1190 let mut queue = QueueConfig::new(0x10, 0);
1191 queue.set_ready(true);
1192 let queue = queue
1193 .activate(&mem, Event::new().unwrap(), interrupt.clone())
1194 .expect("QueueConfig::activate");
1195 queues.insert(idx, queue);
1196 }
1197
1198 println!("activate");
1199 vmm_device.activate(mem, interrupt, queues).unwrap();
1200 };
1201
1202 activate(&mut vmm_device);
1203
1204 println!("reset");
1205 let reset_result = vmm_device.reset();
1206 assert!(
1207 reset_result.is_ok(),
1208 "reset failed: {:#}",
1209 reset_result.unwrap_err()
1210 );
1211
1212 activate(&mut vmm_device);
1213
1214 println!("virtio_sleep");
1215 let queues = vmm_device
1216 .virtio_sleep()
1217 .unwrap()
1218 .expect("virtio_sleep unexpectedly returned None");
1219
1220 println!("virtio_snapshot");
1221 let snapshot = vmm_device
1222 .virtio_snapshot()
1223 .expect("virtio_snapshot failed");
1224 println!("virtio_restore");
1225 vmm_device
1226 .virtio_restore(snapshot)
1227 .expect("virtio_restore failed");
1228
1229 println!("virtio_wake");
1230 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1231 let interrupt = Interrupt::new_for_test_with_msix();
1232 vmm_device
1233 .virtio_wake(Some((mem, interrupt, queues)))
1234 .unwrap();
1235
1236 println!("wait for shutdown signal");
1237 shutdown_rx.recv().unwrap();
1238
1239 println!("drop");
1241 drop(vmm_device);
1242
1243 vmm_bar.wait();
1244 });
1245
1246 let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1248 handler.as_mut().allow_backend_req = allow_backend_req;
1249
1250 ready_tx.send(()).unwrap();
1252
1253 let mut req_handler = BackendServer::new(server_connection, handler);
1254
1255 handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1257 handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1258 handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1259 handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1260 if allow_backend_req {
1261 handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1262 }
1263
1264 handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1266
1267 handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1269 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1270 for _ in 0..QUEUES_NUM {
1271 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1272 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1273 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1274 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1275 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1276 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1277 }
1278
1279 for _ in 0..QUEUES_NUM {
1281 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1282 handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1283 }
1284
1285 handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1287 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1288 for _ in 0..QUEUES_NUM {
1289 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1290 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1291 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1292 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1293 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1294 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1295 }
1296
1297 if allow_backend_req {
1298 req_handler
1300 .as_ref()
1301 .as_ref()
1302 .backend_conn
1303 .as_ref()
1304 .expect("backend_conn missing")
1305 .send_config_changed()
1306 .expect("send_config_changed failed");
1307 }
1308
1309 for _ in 0..QUEUES_NUM {
1311 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1312 handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1313 }
1314
1315 handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1317 handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1318 handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1320 handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1321
1322 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1324 for _ in 0..QUEUES_NUM {
1325 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1326 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1327 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1328 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1329 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1330 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1331 }
1332
1333 if allow_backend_req {
1334 req_handler
1336 .as_ref()
1337 .as_ref()
1338 .backend_conn
1339 .as_ref()
1340 .expect("backend_conn missing")
1341 .send_config_changed()
1342 .expect("send_config_changed failed");
1343 }
1344
1345 shutdown_tx.send(()).unwrap();
1347 dev_bar.wait();
1348
1349 match req_handler.recv_header() {
1351 Err(VhostError::ClientExit) => (),
1352 r => panic!("expected Err(ClientExit) but got {r:?}"),
1353 }
1354 }
1355
1356 fn handle_request<S: vmm_vhost::Backend>(
1357 handler: &mut BackendServer<S>,
1358 expected_message_type: FrontendReq,
1359 ) -> Result<(), VhostError> {
1360 let (hdr, files) = handler.recv_header()?;
1361 assert_eq!(hdr.get_code(), Ok(expected_message_type));
1362 handler.process_message(hdr, files)
1363 }
1364}