devices/virtio/vhost_user_frontend/
mod.rs1mod error;
8mod handler;
9mod sys;
10mod worker;
11
12use std::cell::RefCell;
13use std::collections::BTreeMap;
14use std::io::Read;
15use std::io::Write;
16
17use anyhow::bail;
18use anyhow::Context;
19use base::error;
20use base::trace;
21use base::AsRawDescriptor;
22#[cfg(windows)]
23use base::CloseNotifier;
24use base::Event;
25use base::RawDescriptor;
26use base::ReadNotifier;
27use base::SafeDescriptor;
28use base::SendTube;
29use base::WorkerThread;
30use snapshot::AnySnapshot;
31use vm_memory::GuestMemory;
32use vmm_vhost::message::VhostUserConfigFlags;
33use vmm_vhost::message::VhostUserMigrationPhase;
34use vmm_vhost::message::VhostUserProtocolFeatures;
35use vmm_vhost::message::VhostUserTransferDirection;
36use vmm_vhost::BackendClient;
37use vmm_vhost::VhostUserMemoryRegionInfo;
38use vmm_vhost::VringConfigData;
39use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
40
41use crate::virtio::device_constants::VIRTIO_DEVICE_TYPE_SPECIFIC_FEATURES_MASK;
42use crate::virtio::vhost_user_frontend::error::Error;
43use crate::virtio::vhost_user_frontend::error::Result;
44use crate::virtio::vhost_user_frontend::handler::BackendReqHandler;
45use crate::virtio::vhost_user_frontend::handler::BackendReqHandlerImpl;
46use crate::virtio::vhost_user_frontend::sys::create_backend_req_handler;
47use crate::virtio::vhost_user_frontend::worker::Worker;
48use crate::virtio::DeviceType;
49use crate::virtio::Interrupt;
50use crate::virtio::Queue;
51use crate::virtio::SharedMemoryMapper;
52use crate::virtio::SharedMemoryRegion;
53use crate::virtio::VirtioDevice;
54use crate::PciAddress;
55
56pub struct VhostUserFrontend {
57 device_type: DeviceType,
58 worker_thread: Option<WorkerThread<(Option<BackendReqHandler>, SendTube)>>,
59
60 backend_client: BackendClient,
61 avail_features: u64,
62 acked_features: u64,
63 sent_set_features: bool,
64 protocol_features: VhostUserProtocolFeatures,
65 backend_req_handler: Option<BackendReqHandler>,
69 shmem_region: RefCell<Option<Option<SharedMemoryRegion>>>,
71
72 queue_sizes: Vec<u16>,
73 expose_shmem_descriptors_with_viommu: bool,
74 pci_address: Option<PciAddress>,
75 vm_evt_wrtube: SendTube,
76
77 sent_queues: Option<BTreeMap<usize, Queue>>,
81}
82
83fn power_of_two_le(val: u16) -> Option<u16> {
85 if val == 0 {
86 None
87 } else if val.is_power_of_two() {
88 Some(val)
89 } else {
90 val.checked_next_power_of_two()
91 .map(|next_pow_two| next_pow_two / 2)
92 }
93}
94
95impl VhostUserFrontend {
96 pub fn new(
105 device_type: DeviceType,
106 mut base_features: u64,
107 connection: vmm_vhost::Connection<vmm_vhost::FrontendReq>,
108 vm_evt_wrtube: SendTube,
109 max_queue_size: Option<u16>,
110 pci_address: Option<PciAddress>,
111 ) -> Result<VhostUserFrontend> {
112 if base_features & (1 << virtio_sys::virtio_config::VIRTIO_F_RING_PACKED) != 0 {
116 base_features &= !(1 << virtio_sys::virtio_config::VIRTIO_F_RING_PACKED);
117 base::warn!(
118 "VIRTIO_F_RING_PACKED requested, but not yet supported by vhost-user frontend. \
119 Automatically disabled."
120 );
121 }
122
123 #[cfg(windows)]
124 let backend_pid = connection.target_pid();
125
126 let mut backend_client = BackendClient::new(connection);
127
128 backend_client.set_owner().map_err(Error::SetOwner)?;
129
130 let allow_features = VIRTIO_DEVICE_TYPE_SPECIFIC_FEATURES_MASK
131 | base_features
132 | 1 << VHOST_USER_F_PROTOCOL_FEATURES;
133 let avail_features =
134 allow_features & backend_client.get_features().map_err(Error::GetFeatures)?;
135 let mut acked_features = 0;
136
137 let allow_protocol_features = VhostUserProtocolFeatures::CONFIG
138 | VhostUserProtocolFeatures::MQ
139 | VhostUserProtocolFeatures::BACKEND_REQ
140 | VhostUserProtocolFeatures::DEVICE_STATE
141 | VhostUserProtocolFeatures::SHMEM_MAP;
142
143 let mut protocol_features = VhostUserProtocolFeatures::empty();
144 if avail_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
145 acked_features |= 1 << VHOST_USER_F_PROTOCOL_FEATURES;
153
154 let avail_protocol_features = backend_client
155 .get_protocol_features()
156 .map_err(Error::GetProtocolFeatures)?;
157 protocol_features = allow_protocol_features & avail_protocol_features;
158 backend_client
159 .set_protocol_features(protocol_features)
160 .map_err(Error::SetProtocolFeatures)?;
161 }
162
163 let backend_req_handler =
165 if protocol_features.contains(VhostUserProtocolFeatures::BACKEND_REQ) {
166 let (handler, tx_fd) = create_backend_req_handler(
167 BackendReqHandlerImpl::new(),
168 #[cfg(windows)]
169 backend_pid,
170 )?;
171 backend_client
172 .set_backend_req_fd(&tx_fd)
173 .map_err(Error::SetDeviceRequestChannel)?;
174 Some(handler)
175 } else {
176 None
177 };
178
179 let num_queues = if protocol_features.contains(VhostUserProtocolFeatures::MQ) {
183 trace!("backend supports VHOST_USER_PROTOCOL_F_MQ");
184 let num_queues = backend_client.get_queue_num().map_err(Error::GetQueueNum)?;
185 trace!("VHOST_USER_GET_QUEUE_NUM returned {num_queues}");
186 num_queues as usize
187 } else {
188 trace!("backend does not support VHOST_USER_PROTOCOL_F_MQ");
189 device_type.min_queues()
190 };
191
192 let max_queue_size = max_queue_size
194 .and_then(power_of_two_le)
195 .unwrap_or(Queue::MAX_SIZE);
196
197 trace!(
198 "vhost-user {device_type} frontend with {num_queues} queues x {max_queue_size} entries\
199 {}",
200 if let Some(pci_address) = pci_address {
201 format!(" pci-address {pci_address}")
202 } else {
203 "".to_string()
204 }
205 );
206
207 let queue_sizes = vec![max_queue_size; num_queues];
208
209 Ok(VhostUserFrontend {
210 device_type,
211 worker_thread: None,
212 backend_client,
213 avail_features,
214 acked_features,
215 sent_set_features: false,
216 protocol_features,
217 backend_req_handler,
218 shmem_region: RefCell::new(None),
219 queue_sizes,
220 expose_shmem_descriptors_with_viommu: device_type == DeviceType::Gpu,
221 pci_address,
222 vm_evt_wrtube,
223 sent_queues: None,
224 })
225 }
226
227 fn set_mem_table(&mut self, mem: &GuestMemory) -> Result<()> {
228 let regions: Vec<_> = mem
229 .regions()
230 .map(|region| VhostUserMemoryRegionInfo {
231 guest_phys_addr: region.guest_addr.0,
232 memory_size: region.size as u64,
233 userspace_addr: region.host_addr as u64,
234 mmap_offset: region.shm_offset,
235 mmap_handle: region.shm.as_raw_descriptor(),
236 })
237 .collect();
238
239 self.backend_client
240 .set_mem_table(regions.as_slice())
241 .map_err(Error::SetMemTable)?;
242
243 Ok(())
244 }
245
246 fn activate_vring(
248 &mut self,
249 mem: &GuestMemory,
250 queue_index: usize,
251 queue: &Queue,
252 irqfd: &Event,
253 ) -> Result<()> {
254 self.backend_client
255 .set_vring_num(queue_index, queue.size())
256 .map_err(Error::SetVringNum)?;
257
258 let config_data = VringConfigData {
259 queue_size: queue.size(),
260 flags: 0u32,
261 desc_table_addr: mem
262 .get_host_address(queue.desc_table())
263 .map_err(Error::GetHostAddress)? as u64,
264 used_ring_addr: mem
265 .get_host_address(queue.used_ring())
266 .map_err(Error::GetHostAddress)? as u64,
267 avail_ring_addr: mem
268 .get_host_address(queue.avail_ring())
269 .map_err(Error::GetHostAddress)? as u64,
270 log_addr: None,
271 };
272 self.backend_client
273 .set_vring_addr(queue_index, &config_data)
274 .map_err(Error::SetVringAddr)?;
275
276 self.backend_client
277 .set_vring_base(queue_index, queue.next_avail_to_process())
278 .map_err(Error::SetVringBase)?;
279
280 self.backend_client
281 .set_vring_call(queue_index, irqfd)
282 .map_err(Error::SetVringCall)?;
283 self.backend_client
284 .set_vring_kick(queue_index, queue.event())
285 .map_err(Error::SetVringKick)?;
286
287 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
290 self.backend_client
291 .set_vring_enable(queue_index, true)
292 .map_err(Error::SetVringEnable)?;
293 }
294
295 Ok(())
296 }
297
298 fn deactivate_vring(&self, queue_index: usize) -> Result<u16> {
300 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0 {
301 self.backend_client
302 .set_vring_enable(queue_index, false)
303 .map_err(Error::SetVringEnable)?;
304 }
305
306 let vring_base = self
307 .backend_client
308 .get_vring_base(queue_index)
309 .map_err(Error::GetVringBase)?;
310
311 vring_base
312 .try_into()
313 .map_err(|_| Error::VringBaseTooBig(vring_base))
314 }
315
316 fn start_worker(&mut self, interrupt: Interrupt, non_msix_evt: Event) {
319 assert!(
320 self.worker_thread.is_none(),
321 "BUG: attempted to start worker twice"
322 );
323
324 let label = self.debug_label();
325
326 let mut backend_req_handler = self.backend_req_handler.take();
327 if let Some(handler) = &mut backend_req_handler {
328 handler.frontend_mut().set_interrupt(interrupt.clone());
330 }
331
332 let backend_client_read_notifier =
333 SafeDescriptor::try_from(self.backend_client.get_read_notifier())
334 .expect("failed to get backend read notifier");
335 #[cfg(windows)]
336 let backend_client_close_notifier =
337 SafeDescriptor::try_from(self.backend_client.get_close_notifier())
338 .expect("failed to get backend close notifier");
339
340 let vm_evt_wrtube = self
341 .vm_evt_wrtube
342 .try_clone()
343 .expect("failed to clone vm_evt_wrtube");
344
345 self.worker_thread = Some(WorkerThread::start(label.clone(), move |kill_evt| {
346 let mut worker = Worker {
347 kill_evt,
348 non_msix_evt,
349 backend_req_handler,
350 backend_client_read_notifier,
351 #[cfg(windows)]
352 backend_client_close_notifier,
353 };
354 if let Err(e) = worker
355 .run(interrupt)
356 .with_context(|| format!("{label}: vhost_user_frontend worker failed"))
357 {
358 error!("vhost-user worker thread exited with an error: {:#}", e);
359
360 if let Err(e) = vm_evt_wrtube.send(&base::VmEventType::DeviceCrashed) {
361 error!("failed to send crash event: {}", e);
362 }
363 }
364 (worker.backend_req_handler, vm_evt_wrtube)
365 }));
366 }
367}
368
369impl VirtioDevice for VhostUserFrontend {
370 fn debug_label(&self) -> String {
372 format!("vu-{}", self.device_type())
373 }
374
375 fn keep_rds(&self) -> Vec<RawDescriptor> {
376 Vec::new()
377 }
378
379 fn device_type(&self) -> DeviceType {
380 self.device_type
381 }
382
383 fn queue_max_sizes(&self) -> &[u16] {
384 &self.queue_sizes
385 }
386
387 fn features(&self) -> u64 {
388 self.avail_features
389 }
390
391 fn ack_features(&mut self, features: u64) {
392 let features = (features & self.avail_features) | self.acked_features;
393 if let Err(e) = self
394 .backend_client
395 .set_features(features)
396 .map_err(Error::SetFeatures)
397 {
398 error!("failed to enable features 0x{:x}: {}", features, e);
399 return;
400 }
401 self.acked_features = features;
402 self.sent_set_features = true;
403 }
404
405 fn read_config(&self, offset: u64, data: &mut [u8]) {
406 let Ok(offset) = offset.try_into() else {
407 error!("failed to read config: invalid config offset is given: {offset}");
408 return;
409 };
410 let Ok(data_len) = data.len().try_into() else {
411 error!(
412 "failed to read config: invalid config length is given: {}",
413 data.len()
414 );
415 return;
416 };
417 let (_, config) = match self.backend_client.get_config(
418 offset,
419 data_len,
420 VhostUserConfigFlags::WRITABLE,
421 data,
422 ) {
423 Ok(x) => x,
424 Err(e) => {
425 error!("failed to read config: {}", Error::GetConfig(e));
426 return;
427 }
428 };
429 data.copy_from_slice(&config);
430 }
431
432 fn write_config(&mut self, offset: u64, data: &[u8]) {
433 let Ok(offset) = offset.try_into() else {
434 error!("failed to write config: invalid config offset is given: {offset}");
435 return;
436 };
437 if let Err(e) = self
438 .backend_client
439 .set_config(offset, VhostUserConfigFlags::empty(), data)
440 .map_err(Error::SetConfig)
441 {
442 error!("failed to write config: {}", e);
443 }
444 }
445
446 fn activate(
447 &mut self,
448 mem: GuestMemory,
449 interrupt: Interrupt,
450 queues: BTreeMap<usize, Queue>,
451 ) -> anyhow::Result<()> {
452 if !self.sent_set_features {
454 self.ack_features(self.acked_features);
455 }
456
457 self.set_mem_table(&mem)?;
458
459 let msix_config_opt = interrupt
460 .get_msix_config()
461 .as_ref()
462 .ok_or(Error::MsixConfigUnavailable)?;
463 let msix_config = msix_config_opt.lock();
464
465 let non_msix_evt = Event::new().map_err(Error::CreateEvent)?;
466 for (&queue_index, queue) in queues.iter() {
467 let irqfd = msix_config
468 .get_irqfd(queue.vector() as usize)
469 .unwrap_or(&non_msix_evt);
470 self.activate_vring(&mem, queue_index, queue, irqfd)?;
471 }
472
473 self.sent_queues = Some(queues);
474
475 drop(msix_config);
476
477 self.start_worker(interrupt, non_msix_evt);
478 Ok(())
479 }
480
481 fn reset(&mut self) -> anyhow::Result<()> {
482 if let Some(sent_queues) = self.sent_queues.take() {
486 for queue_index in sent_queues.into_keys() {
487 let _vring_base = self
488 .deactivate_vring(queue_index)
489 .context("deactivate_vring failed during reset")?;
490 }
491 }
492
493 if let Some(w) = self.worker_thread.take() {
494 let (backend_req_handler, vm_evt_wrtube) = w.stop();
495 self.backend_req_handler = backend_req_handler;
496 self.vm_evt_wrtube = vm_evt_wrtube;
497 }
498
499 self.sent_set_features = false;
500
501 Ok(())
502 }
503
504 fn pci_address(&self) -> Option<PciAddress> {
505 self.pci_address
506 }
507
508 fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
509 if !self
510 .protocol_features
511 .contains(VhostUserProtocolFeatures::SHMEM_MAP)
512 {
513 return None;
514 }
515 if let Some(r) = self.shmem_region.borrow().as_ref() {
516 return r.clone();
517 }
518 let (config_hdr, sizes) = match self
519 .backend_client
520 .get_shmem_config()
521 .map_err(Error::ShmemRegions)
522 {
523 Ok(x) => x,
524 Err(e) => {
525 error!("Failed to get shared memory config {}", e);
526 return None;
527 }
528 };
529 let region = match config_hdr.nregions {
530 0 => None,
531 1 => Some(SharedMemoryRegion {
532 id: 0,
533 length: sizes[0],
534 }),
535 n => {
536 error!(
537 "Failed to get shared memory region {}",
538 Error::TooManyShmemRegions(n as usize)
539 );
540 return None;
541 }
542 };
543 *self.shmem_region.borrow_mut() = Some(region.clone());
544 region
545 }
546
547 fn set_shared_memory_mapper(&mut self, mapper: Box<dyn SharedMemoryMapper>) {
548 let Some(backend_req_handler) = self.backend_req_handler.as_mut() else {
551 error!(
552 "Error setting shared memory mapper {}",
553 Error::ProtocolFeatureNotNegoiated(VhostUserProtocolFeatures::BACKEND_REQ)
554 );
555 return;
556 };
557
558 let shmid = self
560 .shmem_region
561 .borrow()
562 .clone()
563 .flatten()
564 .expect("missing shmid")
565 .id;
566
567 backend_req_handler
568 .frontend_mut()
569 .set_shared_mapper_state(mapper, shmid);
570 }
571
572 fn expose_shmem_descriptors_with_viommu(&self) -> bool {
573 self.expose_shmem_descriptors_with_viommu
574 }
575
576 fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
577 let Some(mut queues) = self.sent_queues.take() else {
578 return Ok(None);
579 };
580
581 for (&queue_index, queue) in queues.iter_mut() {
582 let vring_base = self
583 .deactivate_vring(queue_index)
584 .context("deactivate_vring failed during sleep")?;
585 queue.vhost_user_reclaim(vring_base);
586 }
587
588 if let Some(w) = self.worker_thread.take() {
589 let (backend_req_handler, vm_evt_wrtube) = w.stop();
590 self.backend_req_handler = backend_req_handler;
591 self.vm_evt_wrtube = vm_evt_wrtube;
592 }
593
594 Ok(Some(queues))
595 }
596
597 fn virtio_wake(
598 &mut self,
599 queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
602 ) -> anyhow::Result<()> {
603 if let Some((mem, interrupt, queues)) = queues_state {
604 self.activate(mem, interrupt, queues)?;
605 }
606 Ok(())
607 }
608
609 fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
610 if !self
611 .protocol_features
612 .contains(VhostUserProtocolFeatures::DEVICE_STATE)
613 {
614 bail!("snapshot requires VHOST_USER_PROTOCOL_F_DEVICE_STATE");
615 }
616 let (mut r, w) = new_pipe_pair()?;
619 let backend_r = self
620 .backend_client
621 .set_device_state_fd(
622 VhostUserTransferDirection::Save,
623 VhostUserMigrationPhase::Stopped,
624 &w,
625 )
626 .context("failed to negotiate device state fd")?;
627 std::mem::drop(w);
630 let mut snapshot_bytes = Vec::new();
632 if let Some(mut backend_r) = backend_r {
633 backend_r.read_to_end(&mut snapshot_bytes)
634 } else {
635 r.read_to_end(&mut snapshot_bytes)
636 }
637 .context("failed to read device state")?;
638 self.backend_client
640 .check_device_state()
641 .context("failed to transfer device state")?;
642 Ok(AnySnapshot::to_any(snapshot_bytes).map_err(Error::SliceToSerdeValue)?)
643 }
644
645 fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
646 if !self.sent_set_features {
648 self.ack_features(self.acked_features);
649 }
650
651 if !self
652 .protocol_features
653 .contains(VhostUserProtocolFeatures::DEVICE_STATE)
654 {
655 bail!("restore requires VHOST_USER_PROTOCOL_F_DEVICE_STATE");
656 }
657
658 let data_bytes: Vec<u8> = AnySnapshot::from_any(data).map_err(Error::SerdeValueToSlice)?;
659 let (r, w) = new_pipe_pair()?;
662 let backend_w = self
663 .backend_client
664 .set_device_state_fd(
665 VhostUserTransferDirection::Load,
666 VhostUserMigrationPhase::Stopped,
667 &r,
668 )
669 .context("failed to negotiate device state fd")?;
670 {
672 let backend_w = backend_w;
676 let mut w = w;
677 if let Some(mut backend_w) = backend_w {
678 backend_w.write_all(data_bytes.as_slice())
679 } else {
680 w.write_all(data_bytes.as_slice())
681 }
682 .context("failed to write device state")?;
683 }
684 self.backend_client
686 .check_device_state()
687 .context("failed to transfer device state")?;
688 Ok(())
689 }
690}
691
692#[cfg(unix)]
693fn new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)> {
694 base::pipe().context("failed to create pipe")
695}
696
697#[cfg(windows)]
698fn new_pipe_pair() -> anyhow::Result<(impl AsRawDescriptor + Read, impl AsRawDescriptor + Write)> {
699 base::named_pipes::pair(
700 &base::named_pipes::FramingMode::Byte,
701 &base::named_pipes::BlockingMode::Wait,
702 0,
703 )
704 .context("failed to create named pipes")
705}