pub(super) mod sys;
use std::collections::BTreeMap;
use std::convert::From;
use std::fs::File;
use std::num::Wrapping;
#[cfg(any(target_os = "android", target_os = "linux"))]
use std::os::unix::io::AsRawFd;
use std::sync::Arc;
use anyhow::bail;
use anyhow::Context;
#[cfg(any(target_os = "android", target_os = "linux"))]
use base::clear_fd_flags;
use base::error;
use base::trace;
use base::warn;
use base::Event;
use base::Protection;
use base::SafeDescriptor;
use base::SharedMemory;
use base::WorkerThread;
use cros_async::TaskHandle;
use hypervisor::MemCacheType;
use serde::Deserialize;
use serde::Serialize;
use sync::Mutex;
use thiserror::Error as ThisError;
use vm_control::VmMemorySource;
use vm_memory::GuestAddress;
use vm_memory::GuestMemory;
use vm_memory::MemoryRegion;
use vmm_vhost::message::VhostSharedMemoryRegion;
use vmm_vhost::message::VhostUserConfigFlags;
use vmm_vhost::message::VhostUserExternalMapMsg;
use vmm_vhost::message::VhostUserGpuMapMsg;
use vmm_vhost::message::VhostUserInflight;
use vmm_vhost::message::VhostUserMemoryRegion;
use vmm_vhost::message::VhostUserMigrationPhase;
use vmm_vhost::message::VhostUserProtocolFeatures;
use vmm_vhost::message::VhostUserShmemMapMsg;
use vmm_vhost::message::VhostUserShmemMapMsgFlags;
use vmm_vhost::message::VhostUserShmemUnmapMsg;
use vmm_vhost::message::VhostUserSingleMemoryRegion;
use vmm_vhost::message::VhostUserTransferDirection;
use vmm_vhost::message::VhostUserVringAddrFlags;
use vmm_vhost::message::VhostUserVringState;
use vmm_vhost::BackendReq;
use vmm_vhost::Connection;
use vmm_vhost::Error as VhostError;
use vmm_vhost::Frontend;
use vmm_vhost::FrontendClient;
use vmm_vhost::Result as VhostResult;
use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
use crate::virtio::Interrupt;
use crate::virtio::Queue;
use crate::virtio::QueueConfig;
use crate::virtio::SharedMemoryMapper;
use crate::virtio::SharedMemoryRegion;
#[derive(Default)]
pub struct MappingInfo {
pub vmm_addr: u64,
pub guest_phys: u64,
pub size: u64,
}
pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
for map in maps {
if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
}
}
Err(VhostError::InvalidMessage)
}
pub trait VhostUserDevice {
fn max_queue_num(&self) -> usize;
fn features(&self) -> u64;
fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {
Ok(())
}
fn protocol_features(&self) -> VhostUserProtocolFeatures;
fn read_config(&self, offset: u64, dst: &mut [u8]);
fn write_config(&self, _offset: u64, _data: &[u8]) {}
fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;
fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
fn reset(&mut self);
fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
None
}
fn set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>) {}
fn enter_suspended_state(&mut self) -> anyhow::Result<()>;
fn snapshot(&mut self) -> anyhow::Result<serde_json::Value>;
fn restore(&mut self, data: serde_json::Value) -> anyhow::Result<()>;
}
struct Vring {
queue: QueueConfig,
doorbell: Option<Interrupt>,
enabled: bool,
}
impl Vring {
fn new(max_size: u16, features: u64) -> Self {
Self {
queue: QueueConfig::new(max_size, features),
doorbell: None,
enabled: false,
}
}
fn reset(&mut self) {
self.queue.reset();
self.doorbell = None;
self.enabled = false;
}
}
pub(super) struct VhostUserRegularOps;
impl VhostUserRegularOps {
pub fn set_mem_table(
contexts: &[VhostUserMemoryRegion],
files: Vec<File>,
) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
if files.len() != contexts.len() {
return Err(VhostError::InvalidParam);
}
let mut regions = Vec::with_capacity(files.len());
for (region, file) in contexts.iter().zip(files.into_iter()) {
let region = MemoryRegion::new_from_shm(
region.memory_size,
GuestAddress(region.guest_phys_addr),
region.mmap_offset,
Arc::new(
SharedMemory::from_safe_descriptor(
SafeDescriptor::from(file),
region.memory_size,
)
.unwrap(),
),
)
.map_err(|e| {
error!("failed to create a memory region: {}", e);
VhostError::InvalidOperation
})?;
regions.push(region);
}
let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
error!("failed to create guest memory: {}", e);
VhostError::InvalidOperation
})?;
let vmm_maps = contexts
.iter()
.map(|region| MappingInfo {
vmm_addr: region.user_addr,
guest_phys: region.guest_phys_addr,
size: region.memory_size,
})
.collect();
Ok((guest_mem, vmm_maps))
}
pub fn set_vring_kick(_index: u8, file: Option<File>) -> VhostResult<Event> {
let file = file.ok_or(VhostError::InvalidParam)?;
#[cfg(any(target_os = "android", target_os = "linux"))]
if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
error!("failed to remove O_NONBLOCK for kick fd: {}", e);
return Err(VhostError::InvalidParam);
}
Ok(Event::from(SafeDescriptor::from(file)))
}
pub fn set_vring_call(
_index: u8,
file: Option<File>,
signal_config_change_fn: Box<dyn Fn() + Send + Sync>,
) -> VhostResult<Interrupt> {
let file = file.ok_or(VhostError::InvalidParam)?;
Ok(Interrupt::new_vhost_user(
Event::from(SafeDescriptor::from(file)),
signal_config_change_fn,
))
}
}
pub struct DeviceRequestHandler<T: VhostUserDevice> {
vrings: Vec<Vring>,
owned: bool,
vmm_maps: Option<Vec<MappingInfo>>,
mem: Option<GuestMemory>,
acked_features: u64,
acked_protocol_features: VhostUserProtocolFeatures,
backend: T,
backend_req_connection: Arc<Mutex<VhostBackendReqConnectionState>>,
device_state_thread: Option<DeviceStateThread>,
}
enum DeviceStateThread {
Save(WorkerThread<serde_json::Result<()>>),
Load(WorkerThread<serde_json::Result<DeviceRequestHandlerSnapshot>>),
}
#[derive(Serialize, Deserialize)]
pub struct DeviceRequestHandlerSnapshot {
acked_features: u64,
acked_protocol_features: u64,
backend: serde_json::Value,
}
impl<T: VhostUserDevice> DeviceRequestHandler<T> {
pub(crate) fn new(mut backend: T) -> Self {
let mut vrings = Vec::with_capacity(backend.max_queue_num());
for _ in 0..backend.max_queue_num() {
vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
}
backend
.enter_suspended_state()
.expect("enter_suspended_state failed on device init");
DeviceRequestHandler {
vrings,
owned: false,
vmm_maps: None,
mem: None,
acked_features: 0,
acked_protocol_features: VhostUserProtocolFeatures::empty(),
backend,
backend_req_connection: Arc::new(Mutex::new(
VhostBackendReqConnectionState::NoConnection,
)),
device_state_thread: None,
}
}
fn all_queues_stopped(&self) -> bool {
self.vrings.iter().all(|vring| !vring.queue.ready())
}
}
impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
fn as_ref(&self) -> &T {
&self.backend
}
}
impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
fn as_mut(&mut self) -> &mut T {
&mut self.backend
}
}
impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
fn set_owner(&mut self) -> VhostResult<()> {
if self.owned {
return Err(VhostError::InvalidOperation);
}
self.owned = true;
Ok(())
}
fn reset_owner(&mut self) -> VhostResult<()> {
self.owned = false;
self.acked_features = 0;
self.backend.reset();
Ok(())
}
fn get_features(&mut self) -> VhostResult<u64> {
let features = self.backend.features();
Ok(features)
}
fn set_features(&mut self, features: u64) -> VhostResult<()> {
if !self.owned {
return Err(VhostError::InvalidOperation);
}
let unexpected_features = features & !self.backend.features();
if unexpected_features != 0 {
error!("unexpected set_features {:#x}", unexpected_features);
return Err(VhostError::InvalidParam);
}
if let Err(e) = self.backend.ack_features(features) {
error!("failed to acknowledge features 0x{:x}: {}", features, e);
return Err(VhostError::InvalidOperation);
}
self.acked_features |= features;
let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
for v in &mut self.vrings {
v.enabled = vring_enabled;
}
Ok(())
}
fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
Ok(self.backend.protocol_features())
}
fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
let features = match VhostUserProtocolFeatures::from_bits(features) {
Some(proto_features) => proto_features,
None => {
error!(
"unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",
features
);
return Err(VhostError::InvalidOperation);
}
};
let supported = self.backend.protocol_features();
self.acked_protocol_features = features & supported;
Ok(())
}
fn set_mem_table(
&mut self,
contexts: &[VhostUserMemoryRegion],
files: Vec<File>,
) -> VhostResult<()> {
let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
self.mem = Some(guest_mem);
self.vmm_maps = Some(vmm_maps);
Ok(())
}
fn get_queue_num(&mut self) -> VhostResult<u64> {
Ok(self.vrings.len() as u64)
}
fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
return Err(VhostError::InvalidParam);
}
self.vrings[index as usize].queue.set_size(num as u16);
Ok(())
}
fn set_vring_addr(
&mut self,
index: u32,
_flags: VhostUserVringAddrFlags,
descriptor: u64,
used: u64,
available: u64,
_log: u64,
) -> VhostResult<()> {
if index as usize >= self.vrings.len() {
return Err(VhostError::InvalidParam);
}
let vmm_maps = self.vmm_maps.as_ref().ok_or(VhostError::InvalidParam)?;
let vring = &mut self.vrings[index as usize];
vring
.queue
.set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
vring
.queue
.set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
Ok(())
}
fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
if index as usize >= self.vrings.len() || base >= Queue::MAX_SIZE.into() {
return Err(VhostError::InvalidParam);
}
let vring = &mut self.vrings[index as usize];
vring.queue.set_next_avail(Wrapping(base as u16));
vring.queue.set_next_used(Wrapping(base as u16));
Ok(())
}
fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
let vring = self
.vrings
.get_mut(index as usize)
.ok_or(VhostError::InvalidParam)?;
let vring_base = if vring.queue.ready() {
let queue = match self.backend.stop_queue(index as usize) {
Ok(q) => q,
Err(e) => {
error!("Failed to stop queue in get_vring_base: {:#}", e);
return Err(VhostError::BackendInternalError);
}
};
trace!("stopped queue {index}");
vring.reset();
if self.all_queues_stopped() {
trace!("all queues stopped; entering suspended state");
self.backend
.enter_suspended_state()
.map_err(VhostError::EnterSuspendedState)?;
}
queue.next_avail_to_process()
} else {
0
};
Ok(VhostUserVringState::new(index, vring_base.into()))
}
fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
if index as usize >= self.vrings.len() {
return Err(VhostError::InvalidParam);
}
let vring = &mut self.vrings[index as usize];
if vring.queue.ready() {
error!("kick fd cannot replaced after queue is started");
return Err(VhostError::InvalidOperation);
}
let kick_evt = VhostUserRegularOps::set_vring_kick(index, file)?;
vring.queue.ack_features(self.acked_features);
vring.queue.set_ready(true);
let mem = self
.mem
.as_ref()
.cloned()
.ok_or(VhostError::InvalidOperation)?;
let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {
Ok(queue) => queue,
Err(e) => {
error!("failed to activate vring: {:#}", e);
return Err(VhostError::BackendInternalError);
}
};
if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {
error!("Failed to start queue {}: {}", index, e);
return Err(VhostError::BackendInternalError);
}
trace!("started queue {index}");
Ok(())
}
fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
if index as usize >= self.vrings.len() {
return Err(VhostError::InvalidParam);
}
let backend_req_conn = self.backend_req_connection.clone();
let signal_config_change_fn = Box::new(move || match &*backend_req_conn.lock() {
VhostBackendReqConnectionState::Connected(frontend) => {
if let Err(e) = frontend.send_config_changed() {
error!("Failed to notify config change: {:#}", e);
}
}
VhostBackendReqConnectionState::NoConnection => {
error!("No Backend request connection found");
}
});
let doorbell = VhostUserRegularOps::set_vring_call(index, file, signal_config_change_fn)?;
self.vrings[index as usize].doorbell = Some(doorbell);
Ok(())
}
fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
Ok(())
}
fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
if index as usize >= self.vrings.len() {
return Err(VhostError::InvalidParam);
}
if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
return Err(VhostError::InvalidOperation);
}
self.vrings[index as usize].enabled = enable;
Ok(())
}
fn get_config(
&mut self,
offset: u32,
size: u32,
_flags: VhostUserConfigFlags,
) -> VhostResult<Vec<u8>> {
let mut data = vec![0; size as usize];
self.backend.read_config(u64::from(offset), &mut data);
Ok(data)
}
fn set_config(
&mut self,
offset: u32,
buf: &[u8],
_flags: VhostUserConfigFlags,
) -> VhostResult<()> {
self.backend.write_config(u64::from(offset), buf);
Ok(())
}
fn set_backend_req_fd(&mut self, ep: Connection<BackendReq>) {
let conn = Arc::new(VhostBackendReqConnection::new(
FrontendClient::new(ep),
self.backend.get_shared_memory_region().map(|r| r.id),
));
{
let mut backend_req_conn = self.backend_req_connection.lock();
if let VhostBackendReqConnectionState::Connected(_) = &*backend_req_conn {
warn!("Backend Request Connection already established. Overwriting");
}
*backend_req_conn = VhostBackendReqConnectionState::Connected(conn.clone());
}
self.backend.set_backend_req_connection(conn);
}
fn get_inflight_fd(
&mut self,
_inflight: &VhostUserInflight,
) -> VhostResult<(VhostUserInflight, File)> {
unimplemented!("get_inflight_fd");
}
fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
unimplemented!("set_inflight_fd");
}
fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
Ok(0)
}
fn add_mem_region(
&mut self,
_region: &VhostUserSingleMemoryRegion,
_fd: File,
) -> VhostResult<()> {
Ok(())
}
fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
Ok(())
}
fn set_device_state_fd(
&mut self,
transfer_direction: VhostUserTransferDirection,
migration_phase: VhostUserMigrationPhase,
mut fd: File,
) -> VhostResult<Option<File>> {
if migration_phase != VhostUserMigrationPhase::Stopped {
return Err(VhostError::InvalidOperation);
}
if !self.all_queues_stopped() {
return Err(VhostError::InvalidOperation);
}
if self.device_state_thread.is_some() {
error!("must call check_device_state before starting new state transfer");
return Err(VhostError::InvalidOperation);
}
match transfer_direction {
VhostUserTransferDirection::Save => {
let snapshot = DeviceRequestHandlerSnapshot {
acked_features: self.acked_features,
acked_protocol_features: self.acked_protocol_features.bits(),
backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
};
self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(
"device_state_save",
move |_kill_event| serde_json::to_writer(&mut fd, &snapshot),
)));
Ok(None)
}
VhostUserTransferDirection::Load => {
self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(
"device_state_load",
move |_kill_event| serde_json::from_reader(&mut fd),
)));
Ok(None)
}
}
}
fn check_device_state(&mut self) -> VhostResult<()> {
let Some(thread) = self.device_state_thread.take() else {
error!("check_device_state: no active state transfer");
return Err(VhostError::InvalidOperation);
};
match thread {
DeviceStateThread::Save(worker) => {
worker.stop().map_err(|e| {
error!("device state save thread failed: {:#}", e);
VhostError::BackendInternalError
})?;
Ok(())
}
DeviceStateThread::Load(worker) => {
let snapshot = worker.stop().map_err(|e| {
error!("device state load thread failed: {:#}", e);
VhostError::BackendInternalError
})?;
self.acked_features = snapshot.acked_features;
self.acked_protocol_features =
VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)
.with_context(|| {
format!(
"unsupported bits in acked_protocol_features: {:#x}",
snapshot.acked_protocol_features
)
})
.map_err(VhostError::RestoreError)?;
self.backend
.restore(snapshot.backend)
.map_err(VhostError::RestoreError)?;
Ok(())
}
}
}
fn get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>> {
Ok(if let Some(r) = self.backend.get_shared_memory_region() {
vec![VhostSharedMemoryRegion::new(r.id, r.length)]
} else {
Vec::new()
})
}
}
pub enum VhostBackendReqConnectionState {
Connected(Arc<VhostBackendReqConnection>),
NoConnection,
}
pub struct VhostBackendReqConnection {
conn: Arc<Mutex<FrontendClient>>,
shmem_info: Mutex<Option<ShmemInfo>>,
}
#[derive(Clone)]
struct ShmemInfo {
shmid: u8,
mapped_regions: BTreeMap<u64 , u64 >,
}
impl VhostBackendReqConnection {
pub fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
let shmem_info = Mutex::new(shmid.map(|shmid| ShmemInfo {
shmid,
mapped_regions: BTreeMap::new(),
}));
Self {
conn: Arc::new(Mutex::new(conn)),
shmem_info,
}
}
pub fn send_config_changed(&self) -> anyhow::Result<()> {
self.conn
.lock()
.handle_config_change()
.context("Could not send config change message")?;
Ok(())
}
pub fn take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>> {
let shmem_info = self
.shmem_info
.lock()
.take()
.context("could not take shared memory mapper information")?;
Ok(Box::new(VhostShmemMapper {
conn: self.conn.clone(),
shmem_info,
}))
}
}
struct VhostShmemMapper {
conn: Arc<Mutex<FrontendClient>>,
shmem_info: ShmemInfo,
}
impl SharedMemoryMapper for VhostShmemMapper {
fn add_mapping(
&mut self,
source: VmMemorySource,
offset: u64,
prot: Protection,
_cache: MemCacheType,
) -> anyhow::Result<()> {
let size = match source {
VmMemorySource::Vulkan {
descriptor,
handle_type,
memory_idx,
device_uuid,
driver_uuid,
size,
} => {
let msg = VhostUserGpuMapMsg::new(
self.shmem_info.shmid,
offset,
size,
memory_idx,
handle_type,
device_uuid,
driver_uuid,
);
self.conn
.lock()
.gpu_map(&msg, &descriptor)
.context("failed to map memory")?;
size
}
VmMemorySource::ExternalMapping { ptr, size } => {
let msg = VhostUserExternalMapMsg::new(self.shmem_info.shmid, offset, size, ptr);
self.conn
.lock()
.external_map(&msg)
.context("failed to map memory")?;
size
}
source => {
let (descriptor, fd_offset, size) = match source {
VmMemorySource::Descriptor {
descriptor,
offset,
size,
} => (descriptor, offset, size),
VmMemorySource::SharedMemory(shmem) => {
let size = shmem.size();
let descriptor = SafeDescriptor::from(shmem);
(descriptor, 0, size)
}
_ => bail!("unsupported source"),
};
let flags = VhostUserShmemMapMsgFlags::from(prot);
let msg = VhostUserShmemMapMsg::new(
self.shmem_info.shmid,
offset,
fd_offset,
size,
flags,
);
self.conn
.lock()
.shmem_map(&msg, &descriptor)
.context("failed to map memory")?;
size
}
};
self.shmem_info.mapped_regions.insert(offset, size);
Ok(())
}
fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
let size = self
.shmem_info
.mapped_regions
.remove(&offset)
.context("unknown offset")?;
let msg = VhostUserShmemUnmapMsg::new(self.shmem_info.shmid, offset, size);
self.conn
.lock()
.shmem_unmap(&msg)
.context("failed to map memory")
.map(|_| ())
}
}
pub(crate) struct WorkerState<T, U> {
pub(crate) queue_task: TaskHandle<U>,
pub(crate) queue: T,
}
#[derive(Debug, ThisError)]
pub enum Error {
#[error("worker not found when stopping queue")]
WorkerNotFound,
}
#[cfg(test)]
mod tests {
use std::sync::mpsc::channel;
use std::sync::Barrier;
use anyhow::bail;
use base::Event;
use vmm_vhost::BackendServer;
use vmm_vhost::FrontendReq;
use zerocopy::AsBytes;
use zerocopy::FromBytes;
use zerocopy::FromZeroes;
use super::*;
use crate::virtio::vhost_user_frontend::VhostUserFrontend;
use crate::virtio::DeviceType;
use crate::virtio::VirtioDevice;
#[derive(Clone, Copy, Debug, PartialEq, Eq, AsBytes, FromZeroes, FromBytes)]
#[repr(C, packed(4))]
struct FakeConfig {
x: u32,
y: u64,
}
const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
pub(super) struct FakeBackend {
avail_features: u64,
acked_features: u64,
active_queues: Vec<Option<Queue>>,
allow_backend_req: bool,
backend_conn: Option<Arc<VhostBackendReqConnection>>,
}
#[derive(Deserialize, Serialize)]
struct FakeBackendSnapshot {
data: Vec<u8>,
}
impl FakeBackend {
const MAX_QUEUE_NUM: usize = 16;
pub(super) fn new() -> Self {
let mut active_queues = Vec::new();
active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
Self {
avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
acked_features: 0,
active_queues,
allow_backend_req: false,
backend_conn: None,
}
}
}
impl VhostUserDevice for FakeBackend {
fn max_queue_num(&self) -> usize {
Self::MAX_QUEUE_NUM
}
fn features(&self) -> u64 {
self.avail_features
}
fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
let unrequested_features = value & !self.avail_features;
if unrequested_features != 0 {
bail!(
"invalid protocol features are given: 0x{:x}",
unrequested_features
);
}
self.acked_features |= value;
Ok(())
}
fn protocol_features(&self) -> VhostUserProtocolFeatures {
let mut features =
VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;
if self.allow_backend_req {
features |= VhostUserProtocolFeatures::BACKEND_REQ;
}
features
}
fn read_config(&self, offset: u64, dst: &mut [u8]) {
dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
}
fn reset(&mut self) {}
fn start_queue(
&mut self,
idx: usize,
queue: Queue,
_mem: GuestMemory,
) -> anyhow::Result<()> {
self.active_queues[idx] = Some(queue);
Ok(())
}
fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
Ok(self.active_queues[idx]
.take()
.ok_or(Error::WorkerNotFound)?)
}
fn set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>) {
self.backend_conn = Some(conn);
}
fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
Ok(())
}
fn snapshot(&mut self) -> anyhow::Result<serde_json::Value> {
serde_json::to_value(FakeBackendSnapshot {
data: vec![1, 2, 3],
})
.context("failed to serialize snapshot")
}
fn restore(&mut self, data: serde_json::Value) -> anyhow::Result<()> {
let snapshot: FakeBackendSnapshot =
serde_json::from_value(data).context("failed to deserialize snapshot")?;
assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");
Ok(())
}
}
#[test]
fn test_vhost_user_lifecycle() {
test_vhost_user_lifecycle_parameterized(false);
}
#[test]
#[cfg(not(windows))] fn test_vhost_user_lifecycle_with_backend_req() {
test_vhost_user_lifecycle_parameterized(true);
}
fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {
const QUEUES_NUM: usize = 2;
let (client_connection, server_connection) =
vmm_vhost::Connection::<FrontendReq>::pair().unwrap();
let vmm_bar = Arc::new(Barrier::new(2));
let dev_bar = vmm_bar.clone();
let (ready_tx, ready_rx) = channel();
let (shutdown_tx, shutdown_rx) = channel();
std::thread::spawn(move || {
ready_rx.recv().unwrap(); let mut vmm_device =
VhostUserFrontend::new(DeviceType::Console, 0, client_connection, None, None)
.unwrap();
println!("read_config");
let mut buf = vec![0; std::mem::size_of::<FakeConfig>()];
vmm_device.read_config(0, &mut buf);
let config = FakeConfig::read_from(buf.as_bytes()).unwrap();
assert_eq!(config, FAKE_CONFIG_DATA);
let activate = |vmm_device: &mut VhostUserFrontend| {
let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
let interrupt = Interrupt::new_for_test_with_msix();
let mut queues = BTreeMap::new();
for idx in 0..QUEUES_NUM {
let mut queue = QueueConfig::new(0x10, 0);
queue.set_ready(true);
let queue = queue
.activate(&mem, Event::new().unwrap(), interrupt.clone())
.expect("QueueConfig::activate");
queues.insert(idx, queue);
}
println!("activate");
vmm_device
.activate(mem.clone(), interrupt.clone(), queues)
.unwrap();
};
activate(&mut vmm_device);
println!("reset");
let reset_result = vmm_device.reset();
assert!(
reset_result.is_ok(),
"reset failed: {:#}",
reset_result.unwrap_err()
);
activate(&mut vmm_device);
println!("virtio_sleep");
let queues = vmm_device
.virtio_sleep()
.unwrap()
.expect("virtio_sleep unexpectedly returned None");
println!("virtio_snapshot");
let snapshot = vmm_device
.virtio_snapshot()
.expect("virtio_snapshot failed");
println!("virtio_restore");
vmm_device
.virtio_restore(snapshot)
.expect("virtio_restore failed");
println!("virtio_wake");
let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
let interrupt = Interrupt::new_for_test_with_msix();
vmm_device
.virtio_wake(Some((mem, interrupt, queues)))
.unwrap();
println!("wait for shutdown signal");
shutdown_rx.recv().unwrap();
println!("drop");
drop(vmm_device);
vmm_bar.wait();
});
let mut handler = DeviceRequestHandler::new(FakeBackend::new());
handler.as_mut().allow_backend_req = allow_backend_req;
ready_tx.send(()).unwrap();
let mut req_handler = BackendServer::new(server_connection, handler);
handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
if allow_backend_req {
handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
}
handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
for _ in 0..QUEUES_NUM {
handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
}
for _ in 0..QUEUES_NUM {
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
}
handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
for _ in 0..QUEUES_NUM {
handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
}
if allow_backend_req {
req_handler
.as_ref()
.as_ref()
.backend_conn
.as_ref()
.expect("backend_conn missing")
.send_config_changed()
.expect("send_config_changed failed");
}
for _ in 0..QUEUES_NUM {
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
}
handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
for _ in 0..QUEUES_NUM {
handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
}
if allow_backend_req {
req_handler
.as_ref()
.as_ref()
.backend_conn
.as_ref()
.expect("backend_conn missing")
.send_config_changed()
.expect("send_config_changed failed");
}
shutdown_tx.send(()).unwrap();
dev_bar.wait();
match req_handler.recv_header() {
Err(VhostError::ClientExit) => (),
r => panic!("expected Err(ClientExit) but got {:?}", r),
}
}
fn handle_request<S: vmm_vhost::Backend>(
handler: &mut BackendServer<S>,
expected_message_type: FrontendReq,
) -> Result<(), VhostError> {
let (hdr, files) = handler.recv_header()?;
assert_eq!(hdr.get_code(), Ok(expected_message_type));
handler.process_message(hdr, files)
}
}