devices/virtio/vhost_user_frontend/
mod.rs1mod error;
8mod handler;
9mod sys;
10mod worker;
11
12use std::cell::RefCell;
13use std::collections::BTreeMap;
14use std::io::Read;
15use std::io::Write;
16
17use anyhow::bail;
18use anyhow::Context;
19use base::error;
20use base::trace;
21use base::AsRawDescriptor;
22#[cfg(windows)]
23use base::CloseNotifier;
24use base::Event;
25use base::RawDescriptor;
26use base::ReadNotifier;
27use base::SafeDescriptor;
28use base::SendTube;
29use base::WorkerThread;
30use snapshot::AnySnapshot;
31use vm_memory::GuestMemory;
32use vmm_vhost::message::VhostUserConfigFlags;
33use vmm_vhost::message::VhostUserMigrationPhase;
34use vmm_vhost::message::VhostUserProtocolFeatures;
35use vmm_vhost::message::VhostUserTransferDirection;
36use vmm_vhost::BackendClient;
37use vmm_vhost::VhostUserMemoryRegionInfo;
38use vmm_vhost::VringConfigData;
39use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
40
41use crate::virtio::device_constants::VIRTIO_DEVICE_TYPE_SPECIFIC_FEATURES_MASK;
42use crate::virtio::vhost_user_frontend::error::Error;
43use crate::virtio::vhost_user_frontend::error::Result;
44use crate::virtio::vhost_user_frontend::handler::BackendReqHandler;
45use crate::virtio::vhost_user_frontend::handler::BackendReqHandlerImpl;
46use crate::virtio::vhost_user_frontend::sys::create_backend_req_handler;
47use crate::virtio::vhost_user_frontend::worker::Worker;
48use crate::virtio::DeviceType;
49use crate::virtio::Interrupt;
50use crate::virtio::Queue;
51use crate::virtio::SharedMemoryMapper;
52use crate::virtio::SharedMemoryRegion;
53use crate::virtio::VirtioDevice;
54use crate::PciAddress;
55
56pub struct VhostUserFrontend {
57 device_type: DeviceType,
58 worker_thread: Option<WorkerThread<(Option<BackendReqHandler>, SendTube)>>,
59
60 backend_client: BackendClient,
61 avail_features: u64,
62 acked_features: u64,
63 sent_set_features: bool,
64 protocol_features: VhostUserProtocolFeatures,
65 backend_req_handler: Option<BackendReqHandler>,
69 shmem_region: RefCell<Option<Option<SharedMemoryRegion>>>,
71
72 queue_sizes: Vec<u16>,
73 expose_shmem_descriptors_with_viommu: bool,
74 pci_address: Option<PciAddress>,
75 vm_evt_wrtube: SendTube,
76
77 sent_queues: Option<BTreeMap<usize, Queue>>,
81}
82
83fn power_of_two_le(val: u16) -> Option<u16> {
85 if val == 0 {
86 None
87 } else if val.is_power_of_two() {
88 Some(val)
89 } else {
90 val.checked_next_power_of_two()
91 .map(|next_pow_two| next_pow_two / 2)
92 }
93}
94
95impl VhostUserFrontend {
96 pub fn new(
105 device_type: DeviceType,
106 mut base_features: u64,
107 connection: vmm_vhost::Connection<vmm_vhost::FrontendReq>,
108 vm_evt_wrtube: SendTube,
109 max_queue_size: Option<u16>,
110 pci_address: Option<PciAddress>,
111 ) -> Result<VhostUserFrontend> {
112 if base_features & (1 << virtio_sys::virtio_config::VIRTIO_F_RING_PACKED) != 0 {
116 base_features &= !(1 << virtio_sys::virtio_config::VIRTIO_F_RING_PACKED);
117 base::warn!(
118 "VIRTIO_F_RING_PACKED requested, but not yet supported by vhost-user frontend. \
119 Automatically disabled."
120 );
121 }
122
123 #[cfg(windows)]
124 let backend_pid = connection.target_pid();
125
126 let mut backend_client = BackendClient::new(connection);
127
128 backend_client.set_owner().map_err(Error::SetOwner)?;
129
130 let allow_features = VIRTIO_DEVICE_TYPE_SPECIFIC_FEATURES_MASK
131 | base_features
132 | 1 << VHOST_USER_F_PROTOCOL_FEATURES;
133 let avail_features =
134 allow_features & backend_client.get_features().map_err(Error::GetFeatures)?;
135 let mut acked_features = 0;
136
137 let mut allow_protocol_features = VhostUserProtocolFeatures::CONFIG
138 | VhostUserProtocolFeatures::MQ
139 | VhostUserProtocolFeatures::BACKEND_REQ
140 | VhostUserProtocolFeatures::DEVICE_STATE;
141
142 let expose_shmem_descriptors_with_viommu = if device_type == DeviceType::Gpu {
146 allow_protocol_features |= VhostUserProtocolFeatures::SHARED_MEMORY_REGIONS;
147 true
148 } else {
149 false
150 };
151
152 let mut protocol_features = VhostUserProtocolFeatures::empty();
153 if avail_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
154 acked_features |= 1 << VHOST_USER_F_PROTOCOL_FEATURES;
162
163 let avail_protocol_features = backend_client
164 .get_protocol_features()
165 .map_err(Error::GetProtocolFeatures)?;
166 protocol_features = allow_protocol_features & avail_protocol_features;
167 backend_client
168 .set_protocol_features(protocol_features)
169 .map_err(Error::SetProtocolFeatures)?;
170 }
171
172 let backend_req_handler =
174 if protocol_features.contains(VhostUserProtocolFeatures::BACKEND_REQ) {
175 let (handler, tx_fd) = create_backend_req_handler(
176 BackendReqHandlerImpl::new(),
177 #[cfg(windows)]
178 backend_pid,
179 )?;
180 backend_client
181 .set_backend_req_fd(&tx_fd)
182 .map_err(Error::SetDeviceRequestChannel)?;
183 Some(handler)
184 } else {
185 None
186 };
187
188 let num_queues = if protocol_features.contains(VhostUserProtocolFeatures::MQ) {
192 trace!("backend supports VHOST_USER_PROTOCOL_F_MQ");
193 let num_queues = backend_client.get_queue_num().map_err(Error::GetQueueNum)?;
194 trace!("VHOST_USER_GET_QUEUE_NUM returned {num_queues}");
195 num_queues as usize
196 } else {
197 trace!("backend does not support VHOST_USER_PROTOCOL_F_MQ");
198 device_type.min_queues()
199 };
200
201 let max_queue_size = max_queue_size
203 .and_then(power_of_two_le)
204 .unwrap_or(Queue::MAX_SIZE);
205
206 trace!(
207 "vhost-user {device_type} frontend with {num_queues} queues x {max_queue_size} entries\
208 {}",
209 if let Some(pci_address) = pci_address {
210 format!(" pci-address {pci_address}")
211 } else {
212 "".to_string()
213 }
214 );
215
216 let queue_sizes = vec![max_queue_size; num_queues];
217
218 Ok(VhostUserFrontend {
219 device_type,
220 worker_thread: None,
221 backend_client,
222 avail_features,
223 acked_features,
224 sent_set_features: false,
225 protocol_features,
226 backend_req_handler,
227 shmem_region: RefCell::new(None),
228 queue_sizes,
229 expose_shmem_descriptors_with_viommu,
230 pci_address,
231 vm_evt_wrtube,
232 sent_queues: None,
233 })
234 }
235
236 fn set_mem_table(&mut self, mem: &GuestMemory) -> Result<()> {
237 let regions: Vec<_> = mem
238 .regions()
239 .map(|region| VhostUserMemoryRegionInfo {
240 guest_phys_addr: region.guest_addr.0,
241 memory_size: region.size as u64,
242 userspace_addr: region.host_addr as u64,
243 mmap_offset: region.shm_offset,
244 mmap_handle: region.shm.as_raw_descriptor(),
245 })
246 .collect();
247
248 self.backend_client
249 .set_mem_table(regions.as_slice())
250 .map_err(Error::SetMemTable)?;
251
252 Ok(())
253 }
254
255 fn activate_vring(
257 &mut self,
258 mem: &GuestMemory,
259 queue_index: usize,
260 queue: &Queue,
261 irqfd: &Event,
262 ) -> Result<()> {
263 self.backend_client
264 .set_vring_num(queue_index, queue.size())
265 .map_err(Error::SetVringNum)?;
266
267 let config_data = VringConfigData {
268 queue_size: queue.size(),
269 flags: 0u32,
270 desc_table_addr: mem
271 .get_host_address(queue.desc_table())
272 .map_err(Error::GetHostAddress)? as u64,
273 used_ring_addr: mem
274 .get_host_address(queue.used_ring())
275 .map_err(Error::GetHostAddress)? as u64,
276 avail_ring_addr: mem
277 .get_host_address(queue.avail_ring())
278 .map_err(Error::GetHostAddress)? as u64,
279 log_addr: None,
280 };
281 self.backend_client
282 .set_vring_addr(queue_index, &config_data)
283 .map_err(Error::SetVringAddr)?;
284
285 self.backend_client
286 .set_vring_base(queue_index, queue.next_avail_to_process())
287 .map_err(Error::SetVringBase)?;
288
289 self.backend_client
290 .set_vring_call(queue_index, irqfd)
291 .map_err(Error::SetVringCall)?;
292 self.backend_client
293 .set_vring_kick(queue_index, queue.event())
294 .map_err(Error::SetVringKick)?;
295
296 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
299 self.backend_client
300 .set_vring_enable(queue_index, true)
301 .map_err(Error::SetVringEnable)?;
302 }
303
304 Ok(())
305 }
306
307 fn deactivate_vring(&self, queue_index: usize) -> Result<u16> {
309 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
310 self.backend_client
311 .set_vring_enable(queue_index, false)
312 .map_err(Error::SetVringEnable)?;
313 }
314
315 let vring_base = self
316 .backend_client
317 .get_vring_base(queue_index)
318 .map_err(Error::GetVringBase)?;
319
320 vring_base
321 .try_into()
322 .map_err(|_| Error::VringBaseTooBig(vring_base))
323 }
324
325 fn start_worker(&mut self, interrupt: Interrupt, non_msix_evt: Event) {
328 assert!(
329 self.worker_thread.is_none(),
330 "BUG: attempted to start worker twice"
331 );
332
333 let label = self.debug_label();
334
335 let mut backend_req_handler = self.backend_req_handler.take();
336 if let Some(handler) = &mut backend_req_handler {
337 handler.frontend_mut().set_interrupt(interrupt.clone());
339 }
340
341 let backend_client_read_notifier =
342 SafeDescriptor::try_from(self.backend_client.get_read_notifier())
343 .expect("failed to get backend read notifier");
344 #[cfg(windows)]
345 let backend_client_close_notifier =
346 SafeDescriptor::try_from(self.backend_client.get_close_notifier())
347 .expect("failed to get backend close notifier");
348
349 let vm_evt_wrtube = self
350 .vm_evt_wrtube
351 .try_clone()
352 .expect("failed to clone vm_evt_wrtube");
353
354 self.worker_thread = Some(WorkerThread::start(label.clone(), move |kill_evt| {
355 let mut worker = Worker {
356 kill_evt,
357 non_msix_evt,
358 backend_req_handler,
359 backend_client_read_notifier,
360 #[cfg(windows)]
361 backend_client_close_notifier,
362 };
363 if let Err(e) = worker
364 .run(interrupt)
365 .with_context(|| format!("{label}: vhost_user_frontend worker failed"))
366 {
367 error!("vhost-user worker thread exited with an error: {:#}", e);
368
369 if let Err(e) = vm_evt_wrtube.send(&base::VmEventType::DeviceCrashed) {
370 error!("failed to send crash event: {}", e);
371 }
372 }
373 (worker.backend_req_handler, vm_evt_wrtube)
374 }));
375 }
376}
377
378impl VirtioDevice for VhostUserFrontend {
379 fn debug_label(&self) -> String {
381 format!("vu-{}", self.device_type())
382 }
383
384 fn keep_rds(&self) -> Vec<RawDescriptor> {
385 Vec::new()
386 }
387
388 fn device_type(&self) -> DeviceType {
389 self.device_type
390 }
391
392 fn queue_max_sizes(&self) -> &[u16] {
393 &self.queue_sizes
394 }
395
396 fn features(&self) -> u64 {
397 self.avail_features
398 }
399
400 fn ack_features(&mut self, features: u64) {
401 let features = (features & self.avail_features) | self.acked_features;
402 if let Err(e) = self
403 .backend_client
404 .set_features(features)
405 .map_err(Error::SetFeatures)
406 {
407 error!("failed to enable features 0x{:x}: {}", features, e);
408 return;
409 }
410 self.acked_features = features;
411 self.sent_set_features = true;
412 }
413
414 fn read_config(&self, offset: u64, data: &mut [u8]) {
415 let Ok(offset) = offset.try_into() else {
416 error!("failed to read config: invalid config offset is given: {offset}");
417 return;
418 };
419 let Ok(data_len) = data.len().try_into() else {
420 error!(
421 "failed to read config: invalid config length is given: {}",
422 data.len()
423 );
424 return;
425 };
426 let (_, config) = match self.backend_client.get_config(
427 offset,
428 data_len,
429 VhostUserConfigFlags::WRITABLE,
430 data,
431 ) {
432 Ok(x) => x,
433 Err(e) => {
434 error!("failed to read config: {}", Error::GetConfig(e));
435 return;
436 }
437 };
438 data.copy_from_slice(&config);
439 }
440
441 fn write_config(&mut self, offset: u64, data: &[u8]) {
442 let Ok(offset) = offset.try_into() else {
443 error!("failed to write config: invalid config offset is given: {offset}");
444 return;
445 };
446 if let Err(e) = self
447 .backend_client
448 .set_config(offset, VhostUserConfigFlags::empty(), data)
449 .map_err(Error::SetConfig)
450 {
451 error!("failed to write config: {}", e);
452 }
453 }
454
455 fn activate(
456 &mut self,
457 mem: GuestMemory,
458 interrupt: Interrupt,
459 queues: BTreeMap<usize, Queue>,
460 ) -> anyhow::Result<()> {
461 if !self.sent_set_features {
463 self.ack_features(self.acked_features);
464 }
465
466 self.set_mem_table(&mem)?;
467
468 let msix_config_opt = interrupt
469 .get_msix_config()
470 .as_ref()
471 .ok_or(Error::MsixConfigUnavailable)?;
472 let msix_config = msix_config_opt.lock();
473
474 let non_msix_evt = Event::new().map_err(Error::CreateEvent)?;
475 for (&queue_index, queue) in queues.iter() {
476 let irqfd = msix_config
477 .get_irqfd(queue.vector() as usize)
478 .unwrap_or(&non_msix_evt);
479 self.activate_vring(&mem, queue_index, queue, irqfd)?;
480 }
481
482 self.sent_queues = Some(queues);
483
484 drop(msix_config);
485
486 self.start_worker(interrupt, non_msix_evt);
487 Ok(())
488 }
489
490 fn reset(&mut self) -> anyhow::Result<()> {
491 if let Some(sent_queues) = self.sent_queues.take() {
492 for queue_index in sent_queues.into_keys() {
493 let _vring_base = self
494 .deactivate_vring(queue_index)
495 .context("deactivate_vring failed during reset")?;
496 }
497 }
498
499 if let Some(w) = self.worker_thread.take() {
500 let (backend_req_handler, vm_evt_wrtube) = w.stop();
501 self.backend_req_handler = backend_req_handler;
502 self.vm_evt_wrtube = vm_evt_wrtube;
503 }
504
505 self.sent_set_features = false;
506
507 Ok(())
508 }
509
510 fn pci_address(&self) -> Option<PciAddress> {
511 self.pci_address
512 }
513
514 fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
515 if !self
516 .protocol_features
517 .contains(VhostUserProtocolFeatures::SHARED_MEMORY_REGIONS)
518 {
519 return None;
520 }
521 if let Some(r) = self.shmem_region.borrow().as_ref() {
522 return r.clone();
523 }
524 let regions = match self
525 .backend_client
526 .get_shared_memory_regions()
527 .map_err(Error::ShmemRegions)
528 {
529 Ok(x) => x,
530 Err(e) => {
531 error!("Failed to get shared memory regions {}", e);
532 return None;
533 }
534 };
535 let region = match regions.len() {
536 0 => None,
537 1 => Some(SharedMemoryRegion {
538 id: regions[0].id,
539 length: regions[0].length,
540 }),
541 n => {
542 error!(
543 "Failed to get shared memory regions {}",
544 Error::TooManyShmemRegions(n)
545 );
546 return None;
547 }
548 };
549
550 *self.shmem_region.borrow_mut() = Some(region.clone());
551 region
552 }
553
554 fn set_shared_memory_mapper(&mut self, mapper: Box<dyn SharedMemoryMapper>) {
555 let Some(backend_req_handler) = self.backend_req_handler.as_mut() else {
558 error!(
559 "Error setting shared memory mapper {}",
560 Error::ProtocolFeatureNotNegoiated(VhostUserProtocolFeatures::BACKEND_REQ)
561 );
562 return;
563 };
564
565 let shmid = self
567 .shmem_region
568 .borrow()
569 .clone()
570 .flatten()
571 .expect("missing shmid")
572 .id;
573
574 backend_req_handler
575 .frontend_mut()
576 .set_shared_mapper_state(mapper, shmid);
577 }
578
579 fn expose_shmem_descriptors_with_viommu(&self) -> bool {
580 self.expose_shmem_descriptors_with_viommu
581 }
582
583 fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
584 let Some(mut queues) = self.sent_queues.take() else {
585 return Ok(None);
586 };
587
588 for (&queue_index, queue) in queues.iter_mut() {
589 let vring_base = self
590 .deactivate_vring(queue_index)
591 .context("deactivate_vring failed during sleep")?;
592 queue.vhost_user_reclaim(vring_base);
593 }
594
595 if let Some(w) = self.worker_thread.take() {
596 let (backend_req_handler, vm_evt_wrtube) = w.stop();
597 self.backend_req_handler = backend_req_handler;
598 self.vm_evt_wrtube = vm_evt_wrtube;
599 }
600
601 Ok(Some(queues))
602 }
603
604 fn virtio_wake(
605 &mut self,
606 queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
609 ) -> anyhow::Result<()> {
610 if let Some((mem, interrupt, queues)) = queues_state {
611 self.activate(mem, interrupt, queues)?;
612 }
613 Ok(())
614 }
615
616 fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
617 if !self
618 .protocol_features
619 .contains(VhostUserProtocolFeatures::DEVICE_STATE)
620 {
621 bail!("snapshot requires VHOST_USER_PROTOCOL_F_DEVICE_STATE");
622 }
623 let (mut r, w) = new_pipe_pair()?;
626 let backend_r = self
627 .backend_client
628 .set_device_state_fd(
629 VhostUserTransferDirection::Save,
630 VhostUserMigrationPhase::Stopped,
631 &w,
632 )
633 .context("failed to negotiate device state fd")?;
634 std::mem::drop(w);
637 let mut snapshot_bytes = Vec::new();
639 if let Some(mut backend_r) = backend_r {
640 backend_r.read_to_end(&mut snapshot_bytes)
641 } else {
642 r.read_to_end(&mut snapshot_bytes)
643 }
644 .context("failed to read device state")?;
645 self.backend_client
647 .check_device_state()
648 .context("failed to transfer device state")?;
649 Ok(AnySnapshot::to_any(snapshot_bytes).map_err(Error::SliceToSerdeValue)?)
650 }
651
652 fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
653 if !self.sent_set_features {
655 self.ack_features(self.acked_features);
656 }
657
658 if !self
659 .protocol_features
660 .contains(VhostUserProtocolFeatures::DEVICE_STATE)
661 {
662 bail!("restore requires VHOST_USER_PROTOCOL_F_DEVICE_STATE");
663 }
664
665 let data_bytes: Vec<u8> = AnySnapshot::from_any(data).map_err(Error::SerdeValueToSlice)?;
666 let (r, w) = new_pipe_pair()?;
669 let backend_w = self
670 .backend_client
671 .set_device_state_fd(
672 VhostUserTransferDirection::Load,
673 VhostUserMigrationPhase::Stopped,
674 &r,
675 )
676 .context("failed to negotiate device state fd")?;
677 {
679 let backend_w = backend_w;
683 let mut w = w;
684 if let Some(mut backend_w) = backend_w {
685 backend_w.write_all(data_bytes.as_slice())
686 } else {
687 w.write_all(data_bytes.as_slice())
688 }
689 .context("failed to write device state")?;
690 }
691 self.backend_client
693 .check_device_state()
694 .context("failed to transfer device state")?;
695 Ok(())
696 }
697}
698
699#[cfg(unix)]
700fn new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)> {
701 base::pipe().context("failed to create pipe")
702}
703
704#[cfg(windows)]
705fn new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)> {
706 base::named_pipes::pair(
707 &base::named_pipes::FramingMode::Byte,
708 &base::named_pipes::BlockingMode::Wait,
709 0,
710 )
711 .context("failed to create named pipes")
712}