devices/virtio/vhost_user_frontend/
mod.rs1mod error;
8mod handler;
9mod sys;
10mod worker;
11
12use std::cell::RefCell;
13use std::collections::BTreeMap;
14use std::io::Read;
15use std::io::Write;
16
17use anyhow::bail;
18use anyhow::Context;
19use base::error;
20use base::trace;
21use base::AsRawDescriptor;
22#[cfg(windows)]
23use base::CloseNotifier;
24use base::Event;
25use base::RawDescriptor;
26use base::ReadNotifier;
27use base::SafeDescriptor;
28use base::SendTube;
29use base::WorkerThread;
30use snapshot::AnySnapshot;
31use vm_memory::GuestMemory;
32use vmm_vhost::message::VhostUserConfigFlags;
33use vmm_vhost::message::VhostUserMigrationPhase;
34use vmm_vhost::message::VhostUserProtocolFeatures;
35use vmm_vhost::message::VhostUserTransferDirection;
36use vmm_vhost::BackendClient;
37use vmm_vhost::VhostUserMemoryRegionInfo;
38use vmm_vhost::VringConfigData;
39use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
40
41use crate::virtio::device_constants::VIRTIO_DEVICE_TYPE_SPECIFIC_FEATURES_MASK;
42use crate::virtio::vhost_user_frontend::error::Error;
43use crate::virtio::vhost_user_frontend::error::Result;
44use crate::virtio::vhost_user_frontend::handler::BackendReqHandler;
45use crate::virtio::vhost_user_frontend::handler::BackendReqHandlerImpl;
46use crate::virtio::vhost_user_frontend::sys::create_backend_req_handler;
47use crate::virtio::vhost_user_frontend::worker::Worker;
48use crate::virtio::DeviceType;
49use crate::virtio::Interrupt;
50use crate::virtio::Queue;
51use crate::virtio::SharedMemoryMapper;
52use crate::virtio::SharedMemoryRegion;
53use crate::virtio::VirtioDevice;
54use crate::PciAddress;
55
56pub struct VhostUserFrontend {
57 device_type: DeviceType,
58 worker_thread: Option<WorkerThread<(Option<BackendReqHandler>, SendTube)>>,
59
60 backend_client: BackendClient,
61 avail_features: u64,
62 acked_features: u64,
63 last_acked_features: u64,
65 protocol_features: VhostUserProtocolFeatures,
66 backend_req_handler: Option<BackendReqHandler>,
70 shmem_region: RefCell<Option<Option<SharedMemoryRegion>>>,
72
73 queue_sizes: Vec<u16>,
74 expose_shmem_descriptors_with_viommu: bool,
75 pci_address: Option<PciAddress>,
76 vm_evt_wrtube: SendTube,
77
78 sent_queues: Option<BTreeMap<usize, Queue>>,
82}
83
84fn power_of_two_le(val: u16) -> Option<u16> {
86 if val == 0 {
87 None
88 } else if val.is_power_of_two() {
89 Some(val)
90 } else {
91 val.checked_next_power_of_two()
92 .map(|next_pow_two| next_pow_two / 2)
93 }
94}
95
96impl VhostUserFrontend {
97 pub fn new(
106 device_type: DeviceType,
107 mut base_features: u64,
108 connection: vmm_vhost::Connection,
109 vm_evt_wrtube: SendTube,
110 max_queue_size: Option<u16>,
111 pci_address: Option<PciAddress>,
112 ) -> Result<VhostUserFrontend> {
113 if base_features & (1 << virtio_sys::virtio_config::VIRTIO_F_RING_PACKED) != 0 {
117 base_features &= !(1 << virtio_sys::virtio_config::VIRTIO_F_RING_PACKED);
118 base::warn!(
119 "VIRTIO_F_RING_PACKED requested, but not yet supported by vhost-user frontend. \
120 Automatically disabled."
121 );
122 }
123
124 #[cfg(windows)]
125 let backend_pid = connection.target_pid();
126
127 let mut backend_client = BackendClient::new(connection);
128
129 backend_client.set_owner().map_err(Error::SetOwner)?;
130
131 let allow_features = VIRTIO_DEVICE_TYPE_SPECIFIC_FEATURES_MASK
132 | base_features
133 | 1 << VHOST_USER_F_PROTOCOL_FEATURES;
134 let avail_features =
135 allow_features & backend_client.get_features().map_err(Error::GetFeatures)?;
136 let mut acked_features = 0;
137
138 let allow_protocol_features = VhostUserProtocolFeatures::CONFIG
139 | VhostUserProtocolFeatures::MQ
140 | VhostUserProtocolFeatures::BACKEND_REQ
141 | VhostUserProtocolFeatures::DEVICE_STATE
142 | VhostUserProtocolFeatures::SHMEM_MAP
143 | VhostUserProtocolFeatures::REPLY_ACK;
149
150 let mut protocol_features = VhostUserProtocolFeatures::empty();
151 if avail_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
152 acked_features |= 1 << VHOST_USER_F_PROTOCOL_FEATURES;
160
161 let avail_protocol_features = backend_client
162 .get_protocol_features()
163 .map_err(Error::GetProtocolFeatures)?;
164 protocol_features = allow_protocol_features & avail_protocol_features;
165 backend_client
166 .set_protocol_features(protocol_features)
167 .map_err(Error::SetProtocolFeatures)?;
168 }
169
170 let backend_req_handler =
172 if protocol_features.contains(VhostUserProtocolFeatures::BACKEND_REQ) {
173 let (mut handler, tx_fd) = create_backend_req_handler(
174 BackendReqHandlerImpl::new(),
175 #[cfg(windows)]
176 backend_pid,
177 )?;
178 handler.set_reply_ack_flag(
179 protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK),
180 );
181 backend_client
182 .set_backend_req_fd(&tx_fd)
183 .map_err(Error::SetDeviceRequestChannel)?;
184 Some(handler)
185 } else {
186 None
187 };
188
189 let num_queues = if protocol_features.contains(VhostUserProtocolFeatures::MQ) {
193 trace!("backend supports VHOST_USER_PROTOCOL_F_MQ");
194 let num_queues = backend_client.get_queue_num().map_err(Error::GetQueueNum)?;
195 trace!("VHOST_USER_GET_QUEUE_NUM returned {num_queues}");
196 num_queues as usize
197 } else {
198 trace!("backend does not support VHOST_USER_PROTOCOL_F_MQ");
199 device_type.min_queues()
200 };
201
202 let max_queue_size = max_queue_size
204 .and_then(power_of_two_le)
205 .unwrap_or(Queue::MAX_SIZE);
206
207 trace!(
208 "vhost-user {device_type} frontend with {num_queues} queues x {max_queue_size} entries\
209 {}",
210 if let Some(pci_address) = pci_address {
211 format!(" pci-address {pci_address}")
212 } else {
213 "".to_string()
214 }
215 );
216
217 let queue_sizes = vec![max_queue_size; num_queues];
218
219 Ok(VhostUserFrontend {
220 device_type,
221 worker_thread: None,
222 backend_client,
223 avail_features,
224 acked_features,
225 last_acked_features: acked_features,
226 protocol_features,
227 backend_req_handler,
228 shmem_region: RefCell::new(None),
229 queue_sizes,
230 expose_shmem_descriptors_with_viommu: device_type == DeviceType::Gpu,
231 pci_address,
232 vm_evt_wrtube,
233 sent_queues: None,
234 })
235 }
236
237 fn set_mem_table(&mut self, mem: &GuestMemory) -> Result<()> {
238 let regions: Vec<_> = mem
239 .regions()
240 .map(|region| VhostUserMemoryRegionInfo {
241 guest_phys_addr: region.guest_addr.0,
242 memory_size: region.size as u64,
243 userspace_addr: region.host_addr as u64,
244 mmap_offset: region.shm_offset,
245 mmap_handle: region.shm.as_raw_descriptor(),
246 })
247 .collect();
248
249 self.backend_client
250 .set_mem_table(regions.as_slice())
251 .map_err(Error::SetMemTable)?;
252
253 Ok(())
254 }
255
256 fn activate_vring(
258 &mut self,
259 mem: &GuestMemory,
260 queue_index: usize,
261 queue: &Queue,
262 irqfd: &Event,
263 ) -> Result<()> {
264 self.backend_client
265 .set_vring_num(queue_index, queue.size())
266 .map_err(Error::SetVringNum)?;
267
268 let config_data = VringConfigData {
269 queue_size: queue.size(),
270 flags: 0u32,
271 desc_table_addr: mem
272 .get_host_address(queue.desc_table())
273 .map_err(Error::GetHostAddress)? as u64,
274 used_ring_addr: mem
275 .get_host_address(queue.used_ring())
276 .map_err(Error::GetHostAddress)? as u64,
277 avail_ring_addr: mem
278 .get_host_address(queue.avail_ring())
279 .map_err(Error::GetHostAddress)? as u64,
280 log_addr: None,
281 };
282 self.backend_client
283 .set_vring_addr(queue_index, &config_data)
284 .map_err(Error::SetVringAddr)?;
285
286 self.backend_client
287 .set_vring_base(queue_index, queue.next_avail_to_process())
288 .map_err(Error::SetVringBase)?;
289
290 self.backend_client
291 .set_vring_call(queue_index, irqfd)
292 .map_err(Error::SetVringCall)?;
293 self.backend_client
294 .set_vring_kick(queue_index, queue.event())
295 .map_err(Error::SetVringKick)?;
296
297 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
300 self.backend_client
301 .set_vring_enable(queue_index, true)
302 .map_err(Error::SetVringEnable)?;
303 }
304
305 Ok(())
306 }
307
308 fn deactivate_vring(&self, queue_index: usize) -> Result<u16> {
310 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
311 self.backend_client
312 .set_vring_enable(queue_index, false)
313 .map_err(Error::SetVringEnable)?;
314 }
315
316 let vring_base = self
317 .backend_client
318 .get_vring_base(queue_index)
319 .map_err(Error::GetVringBase)?;
320
321 vring_base
322 .try_into()
323 .map_err(|_| Error::VringBaseTooBig(vring_base))
324 }
325
326 fn start_worker(&mut self, interrupt: Interrupt, non_msix_evt: Event) {
329 assert!(
330 self.worker_thread.is_none(),
331 "BUG: attempted to start worker twice"
332 );
333
334 let label = self.debug_label();
335
336 let mut backend_req_handler = self.backend_req_handler.take();
337 if let Some(handler) = &mut backend_req_handler {
338 handler.frontend_mut().set_interrupt(interrupt.clone());
340 }
341
342 let backend_client_read_notifier =
343 SafeDescriptor::try_from(self.backend_client.get_read_notifier())
344 .expect("failed to get backend read notifier");
345 #[cfg(windows)]
346 let backend_client_close_notifier =
347 SafeDescriptor::try_from(self.backend_client.get_close_notifier())
348 .expect("failed to get backend close notifier");
349
350 let vm_evt_wrtube = self
351 .vm_evt_wrtube
352 .try_clone()
353 .expect("failed to clone vm_evt_wrtube");
354
355 self.worker_thread = Some(WorkerThread::start(label.clone(), move |kill_evt| {
356 let mut worker = Worker {
357 kill_evt,
358 non_msix_evt,
359 backend_req_handler,
360 backend_client_read_notifier,
361 #[cfg(windows)]
362 backend_client_close_notifier,
363 };
364 if let Err(e) = worker
365 .run(interrupt)
366 .with_context(|| format!("{label}: vhost_user_frontend worker failed"))
367 {
368 error!("vhost-user worker thread exited with an error: {:#}", e);
369
370 if let Err(e) = vm_evt_wrtube.send(&base::VmEventType::DeviceCrashed) {
371 error!("failed to send crash event: {}", e);
372 }
373 }
374 (worker.backend_req_handler, vm_evt_wrtube)
375 }));
376 }
377}
378
379impl VirtioDevice for VhostUserFrontend {
380 fn debug_label(&self) -> String {
382 format!("vu-{}", self.device_type())
383 }
384
385 fn keep_rds(&self) -> Vec<RawDescriptor> {
386 Vec::new()
387 }
388
389 fn device_type(&self) -> DeviceType {
390 self.device_type
391 }
392
393 fn queue_max_sizes(&self) -> &[u16] {
394 &self.queue_sizes
395 }
396
397 fn features(&self) -> u64 {
398 self.avail_features
399 }
400
401 fn ack_features(&mut self, features: u64) {
402 self.acked_features |= features & self.avail_features;
403 }
404
405 fn read_config(&self, offset: u64, data: &mut [u8]) {
406 let Ok(offset) = offset.try_into() else {
407 error!("failed to read config: invalid config offset is given: {offset}");
408 return;
409 };
410 let Ok(data_len) = data.len().try_into() else {
411 error!(
412 "failed to read config: invalid config length is given: {}",
413 data.len()
414 );
415 return;
416 };
417 let (_, config) = match self.backend_client.get_config(
418 offset,
419 data_len,
420 VhostUserConfigFlags::WRITABLE,
421 data,
422 ) {
423 Ok(x) => x,
424 Err(e) => {
425 error!("failed to read config: {}", Error::GetConfig(e));
426 return;
427 }
428 };
429 data.copy_from_slice(&config);
430 }
431
432 fn write_config(&mut self, offset: u64, data: &[u8]) {
433 let Ok(offset) = offset.try_into() else {
434 error!("failed to write config: invalid config offset is given: {offset}");
435 return;
436 };
437 if let Err(e) = self
438 .backend_client
439 .set_config(offset, VhostUserConfigFlags::empty(), data)
440 .map_err(Error::SetConfig)
441 {
442 error!("failed to write config: {}", e);
443 }
444 }
445
446 fn activate(
447 &mut self,
448 mem: GuestMemory,
449 interrupt: Interrupt,
450 queues: BTreeMap<usize, Queue>,
451 ) -> anyhow::Result<()> {
452 if self.last_acked_features != self.acked_features {
453 self.backend_client
454 .set_features(self.acked_features)
455 .map_err(Error::SetFeatures)?;
456 self.last_acked_features = self.acked_features;
457 }
458
459 self.set_mem_table(&mem)?;
460
461 let msix_config_opt = interrupt
462 .get_msix_config()
463 .as_ref()
464 .ok_or(Error::MsixConfigUnavailable)?;
465 let msix_config = msix_config_opt.lock();
466
467 let non_msix_evt = Event::new().map_err(Error::CreateEvent)?;
468 for (&queue_index, queue) in queues.iter() {
469 let irqfd = msix_config
470 .get_irqfd(queue.vector() as usize)
471 .unwrap_or(&non_msix_evt);
472 self.activate_vring(&mem, queue_index, queue, irqfd)?;
473 }
474
475 self.sent_queues = Some(queues);
476
477 drop(msix_config);
478
479 self.start_worker(interrupt, non_msix_evt);
480 Ok(())
481 }
482
483 fn reset(&mut self) -> anyhow::Result<()> {
484 if let Some(sent_queues) = self.sent_queues.take() {
488 for queue_index in sent_queues.into_keys() {
489 let _vring_base = self
490 .deactivate_vring(queue_index)
491 .context("deactivate_vring failed during reset")?;
492 }
493 }
494
495 if let Some(w) = self.worker_thread.take() {
496 let (backend_req_handler, vm_evt_wrtube) = w.stop();
497 self.backend_req_handler = backend_req_handler;
498 self.vm_evt_wrtube = vm_evt_wrtube;
499 }
500
501 Ok(())
502 }
503
504 fn pci_address(&self) -> Option<PciAddress> {
505 self.pci_address
506 }
507
508 fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
509 if !self
510 .protocol_features
511 .contains(VhostUserProtocolFeatures::SHMEM_MAP)
512 {
513 return None;
514 }
515 if let Some(r) = self.shmem_region.borrow().as_ref() {
516 return *r;
517 }
518 let regions = match self
519 .backend_client
520 .get_shmem_config()
521 .map_err(Error::ShmemRegions)
522 {
523 Ok(x) => x,
524 Err(e) => {
525 error!("Failed to get shared memory config {}", e);
526 return None;
527 }
528 };
529 let region = match regions.len() {
530 0 => None,
531 1 => Some(regions[0]),
532 n => {
533 error!(
534 "Failed to get shared memory region {}",
535 Error::TooManyShmemRegions(n)
536 );
537 return None;
538 }
539 };
540 *self.shmem_region.borrow_mut() = Some(region);
541 region
542 }
543
544 fn set_shared_memory_mapper(&mut self, mapper: Box<dyn SharedMemoryMapper>) {
545 let Some(backend_req_handler) = self.backend_req_handler.as_mut() else {
548 error!(
549 "Error setting shared memory mapper {}",
550 Error::ProtocolFeatureNotNegoiated(VhostUserProtocolFeatures::BACKEND_REQ)
551 );
552 return;
553 };
554
555 let shmid = self
557 .shmem_region
558 .borrow()
559 .flatten()
560 .expect("missing shmid")
561 .id;
562
563 backend_req_handler
564 .frontend_mut()
565 .set_shared_mapper_state(mapper, shmid);
566 }
567
568 fn expose_shmem_descriptors_with_viommu(&self) -> bool {
569 self.expose_shmem_descriptors_with_viommu
570 }
571
572 fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
573 let Some(mut queues) = self.sent_queues.take() else {
574 return Ok(None);
575 };
576
577 for (&queue_index, queue) in queues.iter_mut() {
578 let vring_base = self
579 .deactivate_vring(queue_index)
580 .context("deactivate_vring failed during sleep")?;
581 queue.vhost_user_reclaim(vring_base);
582 }
583
584 if let Some(w) = self.worker_thread.take() {
585 let (backend_req_handler, vm_evt_wrtube) = w.stop();
586 self.backend_req_handler = backend_req_handler;
587 self.vm_evt_wrtube = vm_evt_wrtube;
588 }
589
590 Ok(Some(queues))
591 }
592
593 fn virtio_wake(
594 &mut self,
595 queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
596 ) -> anyhow::Result<()> {
597 if let Some((mem, interrupt, queues)) = queues_state {
598 self.activate(mem, interrupt, queues)?;
599 }
600 Ok(())
601 }
602
603 fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
604 if !self
605 .protocol_features
606 .contains(VhostUserProtocolFeatures::DEVICE_STATE)
607 {
608 bail!("snapshot requires VHOST_USER_PROTOCOL_F_DEVICE_STATE");
609 }
610 let (mut r, w) = new_pipe_pair()?;
613 let backend_r = self
614 .backend_client
615 .set_device_state_fd(
616 VhostUserTransferDirection::Save,
617 VhostUserMigrationPhase::Stopped,
618 &w,
619 )
620 .context("failed to negotiate device state fd")?;
621 std::mem::drop(w);
624 let mut snapshot_bytes = Vec::new();
626 if let Some(mut backend_r) = backend_r {
627 backend_r.read_to_end(&mut snapshot_bytes)
628 } else {
629 r.read_to_end(&mut snapshot_bytes)
630 }
631 .context("failed to read device state")?;
632 self.backend_client
634 .check_device_state()
635 .context("failed to transfer device state")?;
636 Ok(AnySnapshot::to_any(VhostUserDeviceState {
637 acked_features: self.acked_features,
638 backend_state: snapshot_bytes,
639 })
640 .map_err(Error::SliceToSerdeValue)?)
641 }
642
643 fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
644 if !self
645 .protocol_features
646 .contains(VhostUserProtocolFeatures::DEVICE_STATE)
647 {
648 bail!("restore requires VHOST_USER_PROTOCOL_F_DEVICE_STATE");
649 }
650
651 let device_state: VhostUserDeviceState =
652 AnySnapshot::from_any(data).map_err(Error::SerdeValueToSlice)?;
653
654 let missing_features = !self.avail_features & device_state.acked_features;
656 if missing_features != 0 {
657 bail!("The destination backend doesn't support all features acknowledged by the source, missing: {}", missing_features);
658 }
659 self.acked_features = device_state.acked_features;
660 if self.last_acked_features != self.acked_features {
661 self.backend_client
662 .set_features(self.acked_features)
663 .map_err(Error::SetFeatures)?;
664 self.last_acked_features = self.acked_features;
665 }
666
667 let (r, w) = new_pipe_pair()?;
670 let backend_w = self
671 .backend_client
672 .set_device_state_fd(
673 VhostUserTransferDirection::Load,
674 VhostUserMigrationPhase::Stopped,
675 &r,
676 )
677 .context("failed to negotiate device state fd")?;
678 {
680 let backend_w = backend_w;
684 let mut w = w;
685 if let Some(mut backend_w) = backend_w {
686 backend_w.write_all(device_state.backend_state.as_slice())
687 } else {
688 w.write_all(device_state.backend_state.as_slice())
689 }
690 .context("failed to write device state")?;
691 }
692 self.backend_client
694 .check_device_state()
695 .context("failed to transfer device state")?;
696 Ok(())
697 }
698}
699
700#[derive(serde::Serialize, serde::Deserialize, Debug)]
701struct VhostUserDeviceState {
702 acked_features: u64,
703 backend_state: Vec<u8>,
704}
705
706#[cfg(unix)]
707fn new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)> {
708 base::pipe().context("failed to create pipe")
709}
710
711#[cfg(windows)]
712fn new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)> {
713 base::named_pipes::pair(
714 &base::named_pipes::FramingMode::Byte,
715 &base::named_pipes::BlockingMode::Wait,
716 0,
717 )
718 .context("failed to create named pipes")
719}