devices/virtio/vhost_user_frontend/
mod.rs1mod error;
8mod handler;
9mod sys;
10mod worker;
11
12use std::cell::RefCell;
13use std::collections::BTreeMap;
14use std::io::Read;
15use std::io::Write;
16
17use anyhow::bail;
18use anyhow::Context;
19use base::error;
20use base::trace;
21use base::AsRawDescriptor;
22#[cfg(windows)]
23use base::CloseNotifier;
24use base::Event;
25use base::RawDescriptor;
26use base::ReadNotifier;
27use base::SafeDescriptor;
28use base::SendTube;
29use base::WorkerThread;
30use snapshot::AnySnapshot;
31use vm_memory::GuestMemory;
32use vmm_vhost::message::VhostUserConfigFlags;
33use vmm_vhost::message::VhostUserMigrationPhase;
34use vmm_vhost::message::VhostUserProtocolFeatures;
35use vmm_vhost::message::VhostUserTransferDirection;
36use vmm_vhost::BackendClient;
37use vmm_vhost::VhostUserMemoryRegionInfo;
38use vmm_vhost::VringConfigData;
39use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
40
41use crate::virtio::device_constants::VIRTIO_DEVICE_TYPE_SPECIFIC_FEATURES_MASK;
42use crate::virtio::vhost_user_frontend::error::Error;
43use crate::virtio::vhost_user_frontend::error::Result;
44use crate::virtio::vhost_user_frontend::handler::BackendReqHandler;
45use crate::virtio::vhost_user_frontend::handler::BackendReqHandlerImpl;
46use crate::virtio::vhost_user_frontend::sys::create_backend_req_handler;
47use crate::virtio::vhost_user_frontend::worker::Worker;
48use crate::virtio::DeviceType;
49use crate::virtio::Interrupt;
50use crate::virtio::Queue;
51use crate::virtio::SharedMemoryMapper;
52use crate::virtio::SharedMemoryRegion;
53use crate::virtio::VirtioDevice;
54use crate::PciAddress;
55
56pub struct VhostUserFrontend {
57 device_type: DeviceType,
58 worker_thread: Option<WorkerThread<(Option<BackendReqHandler>, SendTube)>>,
59
60 backend_client: BackendClient,
61 avail_features: u64,
62 acked_features: u64,
63 sent_set_features: bool,
64 protocol_features: VhostUserProtocolFeatures,
65 backend_req_handler: Option<BackendReqHandler>,
69 shmem_region: RefCell<Option<Option<SharedMemoryRegion>>>,
71
72 queue_sizes: Vec<u16>,
73 expose_shmem_descriptors_with_viommu: bool,
74 pci_address: Option<PciAddress>,
75 vm_evt_wrtube: SendTube,
76
77 sent_queues: Option<BTreeMap<usize, Queue>>,
81}
82
83fn power_of_two_le(val: u16) -> Option<u16> {
85 if val == 0 {
86 None
87 } else if val.is_power_of_two() {
88 Some(val)
89 } else {
90 val.checked_next_power_of_two()
91 .map(|next_pow_two| next_pow_two / 2)
92 }
93}
94
95impl VhostUserFrontend {
96 pub fn new(
105 device_type: DeviceType,
106 mut base_features: u64,
107 connection: vmm_vhost::Connection,
108 vm_evt_wrtube: SendTube,
109 max_queue_size: Option<u16>,
110 pci_address: Option<PciAddress>,
111 ) -> Result<VhostUserFrontend> {
112 if base_features & (1 << virtio_sys::virtio_config::VIRTIO_F_RING_PACKED) != 0 {
116 base_features &= !(1 << virtio_sys::virtio_config::VIRTIO_F_RING_PACKED);
117 base::warn!(
118 "VIRTIO_F_RING_PACKED requested, but not yet supported by vhost-user frontend. \
119 Automatically disabled."
120 );
121 }
122
123 #[cfg(windows)]
124 let backend_pid = connection.target_pid();
125
126 let mut backend_client = BackendClient::new(connection);
127
128 backend_client.set_owner().map_err(Error::SetOwner)?;
129
130 let allow_features = VIRTIO_DEVICE_TYPE_SPECIFIC_FEATURES_MASK
131 | base_features
132 | 1 << VHOST_USER_F_PROTOCOL_FEATURES;
133 let avail_features =
134 allow_features & backend_client.get_features().map_err(Error::GetFeatures)?;
135 let mut acked_features = 0;
136
137 let allow_protocol_features = VhostUserProtocolFeatures::CONFIG
138 | VhostUserProtocolFeatures::MQ
139 | VhostUserProtocolFeatures::BACKEND_REQ
140 | VhostUserProtocolFeatures::DEVICE_STATE
141 | VhostUserProtocolFeatures::SHMEM_MAP
142 | VhostUserProtocolFeatures::REPLY_ACK;
148
149 let mut protocol_features = VhostUserProtocolFeatures::empty();
150 if avail_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
151 acked_features |= 1 << VHOST_USER_F_PROTOCOL_FEATURES;
159
160 let avail_protocol_features = backend_client
161 .get_protocol_features()
162 .map_err(Error::GetProtocolFeatures)?;
163 protocol_features = allow_protocol_features & avail_protocol_features;
164 backend_client
165 .set_protocol_features(protocol_features)
166 .map_err(Error::SetProtocolFeatures)?;
167 }
168
169 let backend_req_handler =
171 if protocol_features.contains(VhostUserProtocolFeatures::BACKEND_REQ) {
172 let (mut handler, tx_fd) = create_backend_req_handler(
173 BackendReqHandlerImpl::new(),
174 #[cfg(windows)]
175 backend_pid,
176 )?;
177 handler.set_reply_ack_flag(
178 protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK),
179 );
180 backend_client
181 .set_backend_req_fd(&tx_fd)
182 .map_err(Error::SetDeviceRequestChannel)?;
183 Some(handler)
184 } else {
185 None
186 };
187
188 let num_queues = if protocol_features.contains(VhostUserProtocolFeatures::MQ) {
192 trace!("backend supports VHOST_USER_PROTOCOL_F_MQ");
193 let num_queues = backend_client.get_queue_num().map_err(Error::GetQueueNum)?;
194 trace!("VHOST_USER_GET_QUEUE_NUM returned {num_queues}");
195 num_queues as usize
196 } else {
197 trace!("backend does not support VHOST_USER_PROTOCOL_F_MQ");
198 device_type.min_queues()
199 };
200
201 let max_queue_size = max_queue_size
203 .and_then(power_of_two_le)
204 .unwrap_or(Queue::MAX_SIZE);
205
206 trace!(
207 "vhost-user {device_type} frontend with {num_queues} queues x {max_queue_size} entries\
208 {}",
209 if let Some(pci_address) = pci_address {
210 format!(" pci-address {pci_address}")
211 } else {
212 "".to_string()
213 }
214 );
215
216 let queue_sizes = vec![max_queue_size; num_queues];
217
218 Ok(VhostUserFrontend {
219 device_type,
220 worker_thread: None,
221 backend_client,
222 avail_features,
223 acked_features,
224 sent_set_features: false,
225 protocol_features,
226 backend_req_handler,
227 shmem_region: RefCell::new(None),
228 queue_sizes,
229 expose_shmem_descriptors_with_viommu: device_type == DeviceType::Gpu,
230 pci_address,
231 vm_evt_wrtube,
232 sent_queues: None,
233 })
234 }
235
236 fn set_mem_table(&mut self, mem: &GuestMemory) -> Result<()> {
237 let regions: Vec<_> = mem
238 .regions()
239 .map(|region| VhostUserMemoryRegionInfo {
240 guest_phys_addr: region.guest_addr.0,
241 memory_size: region.size as u64,
242 userspace_addr: region.host_addr as u64,
243 mmap_offset: region.shm_offset,
244 mmap_handle: region.shm.as_raw_descriptor(),
245 })
246 .collect();
247
248 self.backend_client
249 .set_mem_table(regions.as_slice())
250 .map_err(Error::SetMemTable)?;
251
252 Ok(())
253 }
254
255 fn activate_vring(
257 &mut self,
258 mem: &GuestMemory,
259 queue_index: usize,
260 queue: &Queue,
261 irqfd: &Event,
262 ) -> Result<()> {
263 self.backend_client
264 .set_vring_num(queue_index, queue.size())
265 .map_err(Error::SetVringNum)?;
266
267 let config_data = VringConfigData {
268 queue_size: queue.size(),
269 flags: 0u32,
270 desc_table_addr: mem
271 .get_host_address(queue.desc_table())
272 .map_err(Error::GetHostAddress)? as u64,
273 used_ring_addr: mem
274 .get_host_address(queue.used_ring())
275 .map_err(Error::GetHostAddress)? as u64,
276 avail_ring_addr: mem
277 .get_host_address(queue.avail_ring())
278 .map_err(Error::GetHostAddress)? as u64,
279 log_addr: None,
280 };
281 self.backend_client
282 .set_vring_addr(queue_index, &config_data)
283 .map_err(Error::SetVringAddr)?;
284
285 self.backend_client
286 .set_vring_base(queue_index, queue.next_avail_to_process())
287 .map_err(Error::SetVringBase)?;
288
289 self.backend_client
290 .set_vring_call(queue_index, irqfd)
291 .map_err(Error::SetVringCall)?;
292 self.backend_client
293 .set_vring_kick(queue_index, queue.event())
294 .map_err(Error::SetVringKick)?;
295
296 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
299 self.backend_client
300 .set_vring_enable(queue_index, true)
301 .map_err(Error::SetVringEnable)?;
302 }
303
304 Ok(())
305 }
306
307 fn deactivate_vring(&self, queue_index: usize) -> Result<u16> {
309 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
310 self.backend_client
311 .set_vring_enable(queue_index, false)
312 .map_err(Error::SetVringEnable)?;
313 }
314
315 let vring_base = self
316 .backend_client
317 .get_vring_base(queue_index)
318 .map_err(Error::GetVringBase)?;
319
320 vring_base
321 .try_into()
322 .map_err(|_| Error::VringBaseTooBig(vring_base))
323 }
324
325 fn start_worker(&mut self, interrupt: Interrupt, non_msix_evt: Event) {
328 assert!(
329 self.worker_thread.is_none(),
330 "BUG: attempted to start worker twice"
331 );
332
333 let label = self.debug_label();
334
335 let mut backend_req_handler = self.backend_req_handler.take();
336 if let Some(handler) = &mut backend_req_handler {
337 handler.frontend_mut().set_interrupt(interrupt.clone());
339 }
340
341 let backend_client_read_notifier =
342 SafeDescriptor::try_from(self.backend_client.get_read_notifier())
343 .expect("failed to get backend read notifier");
344 #[cfg(windows)]
345 let backend_client_close_notifier =
346 SafeDescriptor::try_from(self.backend_client.get_close_notifier())
347 .expect("failed to get backend close notifier");
348
349 let vm_evt_wrtube = self
350 .vm_evt_wrtube
351 .try_clone()
352 .expect("failed to clone vm_evt_wrtube");
353
354 self.worker_thread = Some(WorkerThread::start(label.clone(), move |kill_evt| {
355 let mut worker = Worker {
356 kill_evt,
357 non_msix_evt,
358 backend_req_handler,
359 backend_client_read_notifier,
360 #[cfg(windows)]
361 backend_client_close_notifier,
362 };
363 if let Err(e) = worker
364 .run(interrupt)
365 .with_context(|| format!("{label}: vhost_user_frontend worker failed"))
366 {
367 error!("vhost-user worker thread exited with an error: {:#}", e);
368
369 if let Err(e) = vm_evt_wrtube.send(&base::VmEventType::DeviceCrashed) {
370 error!("failed to send crash event: {}", e);
371 }
372 }
373 (worker.backend_req_handler, vm_evt_wrtube)
374 }));
375 }
376}
377
378impl VirtioDevice for VhostUserFrontend {
379 fn debug_label(&self) -> String {
381 format!("vu-{}", self.device_type())
382 }
383
384 fn keep_rds(&self) -> Vec<RawDescriptor> {
385 Vec::new()
386 }
387
388 fn device_type(&self) -> DeviceType {
389 self.device_type
390 }
391
392 fn queue_max_sizes(&self) -> &[u16] {
393 &self.queue_sizes
394 }
395
396 fn features(&self) -> u64 {
397 self.avail_features
398 }
399
400 fn ack_features(&mut self, features: u64) {
401 self.acked_features |= features & self.avail_features;
402 }
403
404 fn read_config(&self, offset: u64, data: &mut [u8]) {
405 let Ok(offset) = offset.try_into() else {
406 error!("failed to read config: invalid config offset is given: {offset}");
407 return;
408 };
409 let Ok(data_len) = data.len().try_into() else {
410 error!(
411 "failed to read config: invalid config length is given: {}",
412 data.len()
413 );
414 return;
415 };
416 let (_, config) = match self.backend_client.get_config(
417 offset,
418 data_len,
419 VhostUserConfigFlags::WRITABLE,
420 data,
421 ) {
422 Ok(x) => x,
423 Err(e) => {
424 error!("failed to read config: {}", Error::GetConfig(e));
425 return;
426 }
427 };
428 data.copy_from_slice(&config);
429 }
430
431 fn write_config(&mut self, offset: u64, data: &[u8]) {
432 let Ok(offset) = offset.try_into() else {
433 error!("failed to write config: invalid config offset is given: {offset}");
434 return;
435 };
436 if let Err(e) = self
437 .backend_client
438 .set_config(offset, VhostUserConfigFlags::empty(), data)
439 .map_err(Error::SetConfig)
440 {
441 error!("failed to write config: {}", e);
442 }
443 }
444
445 fn activate(
446 &mut self,
447 mem: GuestMemory,
448 interrupt: Interrupt,
449 queues: BTreeMap<usize, Queue>,
450 ) -> anyhow::Result<()> {
451 if !self.sent_set_features {
452 self.backend_client
453 .set_features(self.acked_features)
454 .map_err(Error::SetFeatures)?;
455 self.sent_set_features = true;
456 }
457
458 self.set_mem_table(&mem)?;
459
460 let msix_config_opt = interrupt
461 .get_msix_config()
462 .as_ref()
463 .ok_or(Error::MsixConfigUnavailable)?;
464 let msix_config = msix_config_opt.lock();
465
466 let non_msix_evt = Event::new().map_err(Error::CreateEvent)?;
467 for (&queue_index, queue) in queues.iter() {
468 let irqfd = msix_config
469 .get_irqfd(queue.vector() as usize)
470 .unwrap_or(&non_msix_evt);
471 self.activate_vring(&mem, queue_index, queue, irqfd)?;
472 }
473
474 self.sent_queues = Some(queues);
475
476 drop(msix_config);
477
478 self.start_worker(interrupt, non_msix_evt);
479 Ok(())
480 }
481
482 fn reset(&mut self) -> anyhow::Result<()> {
483 if let Some(sent_queues) = self.sent_queues.take() {
487 for queue_index in sent_queues.into_keys() {
488 let _vring_base = self
489 .deactivate_vring(queue_index)
490 .context("deactivate_vring failed during reset")?;
491 }
492 }
493
494 if let Some(w) = self.worker_thread.take() {
495 let (backend_req_handler, vm_evt_wrtube) = w.stop();
496 self.backend_req_handler = backend_req_handler;
497 self.vm_evt_wrtube = vm_evt_wrtube;
498 }
499
500 self.sent_set_features = false;
501
502 Ok(())
503 }
504
505 fn pci_address(&self) -> Option<PciAddress> {
506 self.pci_address
507 }
508
509 fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
510 if !self
511 .protocol_features
512 .contains(VhostUserProtocolFeatures::SHMEM_MAP)
513 {
514 return None;
515 }
516 if let Some(r) = self.shmem_region.borrow().as_ref() {
517 return *r;
518 }
519 let regions = match self
520 .backend_client
521 .get_shmem_config()
522 .map_err(Error::ShmemRegions)
523 {
524 Ok(x) => x,
525 Err(e) => {
526 error!("Failed to get shared memory config {}", e);
527 return None;
528 }
529 };
530 let region = match regions.len() {
531 0 => None,
532 1 => Some(regions[0]),
533 n => {
534 error!(
535 "Failed to get shared memory region {}",
536 Error::TooManyShmemRegions(n)
537 );
538 return None;
539 }
540 };
541 *self.shmem_region.borrow_mut() = Some(region);
542 region
543 }
544
545 fn set_shared_memory_mapper(&mut self, mapper: Box<dyn SharedMemoryMapper>) {
546 let Some(backend_req_handler) = self.backend_req_handler.as_mut() else {
549 error!(
550 "Error setting shared memory mapper {}",
551 Error::ProtocolFeatureNotNegoiated(VhostUserProtocolFeatures::BACKEND_REQ)
552 );
553 return;
554 };
555
556 let shmid = self
558 .shmem_region
559 .borrow()
560 .flatten()
561 .expect("missing shmid")
562 .id;
563
564 backend_req_handler
565 .frontend_mut()
566 .set_shared_mapper_state(mapper, shmid);
567 }
568
569 fn expose_shmem_descriptors_with_viommu(&self) -> bool {
570 self.expose_shmem_descriptors_with_viommu
571 }
572
573 fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
574 let Some(mut queues) = self.sent_queues.take() else {
575 return Ok(None);
576 };
577
578 for (&queue_index, queue) in queues.iter_mut() {
579 let vring_base = self
580 .deactivate_vring(queue_index)
581 .context("deactivate_vring failed during sleep")?;
582 queue.vhost_user_reclaim(vring_base);
583 }
584
585 if let Some(w) = self.worker_thread.take() {
586 let (backend_req_handler, vm_evt_wrtube) = w.stop();
587 self.backend_req_handler = backend_req_handler;
588 self.vm_evt_wrtube = vm_evt_wrtube;
589 }
590
591 self.sent_set_features = false;
592
593 Ok(Some(queues))
594 }
595
596 fn virtio_wake(
597 &mut self,
598 queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
601 ) -> anyhow::Result<()> {
602 if let Some((mem, interrupt, queues)) = queues_state {
603 self.activate(mem, interrupt, queues)?;
604 }
605 Ok(())
606 }
607
608 fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
609 if !self
610 .protocol_features
611 .contains(VhostUserProtocolFeatures::DEVICE_STATE)
612 {
613 bail!("snapshot requires VHOST_USER_PROTOCOL_F_DEVICE_STATE");
614 }
615 let (mut r, w) = new_pipe_pair()?;
618 let backend_r = self
619 .backend_client
620 .set_device_state_fd(
621 VhostUserTransferDirection::Save,
622 VhostUserMigrationPhase::Stopped,
623 &w,
624 )
625 .context("failed to negotiate device state fd")?;
626 std::mem::drop(w);
629 let mut snapshot_bytes = Vec::new();
631 if let Some(mut backend_r) = backend_r {
632 backend_r.read_to_end(&mut snapshot_bytes)
633 } else {
634 r.read_to_end(&mut snapshot_bytes)
635 }
636 .context("failed to read device state")?;
637 self.backend_client
639 .check_device_state()
640 .context("failed to transfer device state")?;
641 Ok(AnySnapshot::to_any(VhostUserDeviceState {
642 acked_features: self.acked_features,
643 backend_state: snapshot_bytes,
644 })
645 .map_err(Error::SliceToSerdeValue)?)
646 }
647
648 fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
649 if !self.sent_set_features {
651 self.backend_client
652 .set_features(self.acked_features)
653 .map_err(Error::SetFeatures)?;
654 self.sent_set_features = true;
655 }
656
657 if !self
658 .protocol_features
659 .contains(VhostUserProtocolFeatures::DEVICE_STATE)
660 {
661 bail!("restore requires VHOST_USER_PROTOCOL_F_DEVICE_STATE");
662 }
663
664 let device_state: VhostUserDeviceState =
665 AnySnapshot::from_any(data).map_err(Error::SerdeValueToSlice)?;
666 let missing_features = !self.avail_features & device_state.acked_features;
667 if missing_features != 0 {
668 bail!("The destination backend doesn't support all features acknowledged by the source, missing: {}", missing_features);
669 }
670 self.ack_features(device_state.acked_features);
672 let (r, w) = new_pipe_pair()?;
675 let backend_w = self
676 .backend_client
677 .set_device_state_fd(
678 VhostUserTransferDirection::Load,
679 VhostUserMigrationPhase::Stopped,
680 &r,
681 )
682 .context("failed to negotiate device state fd")?;
683 {
685 let backend_w = backend_w;
689 let mut w = w;
690 if let Some(mut backend_w) = backend_w {
691 backend_w.write_all(device_state.backend_state.as_slice())
692 } else {
693 w.write_all(device_state.backend_state.as_slice())
694 }
695 .context("failed to write device state")?;
696 }
697 self.backend_client
699 .check_device_state()
700 .context("failed to transfer device state")?;
701 Ok(())
702 }
703}
704
705#[derive(serde::Serialize, serde::Deserialize, Debug)]
706struct VhostUserDeviceState {
707 acked_features: u64,
708 backend_state: Vec<u8>,
709}
710
711#[cfg(unix)]
712fn new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)> {
713 base::pipe().context("failed to create pipe")
714}
715
716#[cfg(windows)]
717fn new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)> {
718 base::named_pipes::pair(
719 &base::named_pipes::FramingMode::Byte,
720 &base::named_pipes::BlockingMode::Wait,
721 0,
722 )
723 .context("failed to create named pipes")
724}