pub(super) mod sys;
use std::collections::BTreeMap;
use std::convert::From;
use std::fs::File;
use std::num::Wrapping;
#[cfg(any(target_os = "android", target_os = "linux"))]
use std::os::unix::io::AsRawFd;
use std::sync::Arc;
use anyhow::bail;
use anyhow::Context;
#[cfg(any(target_os = "android", target_os = "linux"))]
use base::clear_fd_flags;
use base::error;
use base::warn;
use base::Event;
use base::FromRawDescriptor;
use base::IntoRawDescriptor;
use base::Protection;
use base::SafeDescriptor;
use base::SharedMemory;
use cros_async::TaskHandle;
use hypervisor::MemCacheType;
use serde::Deserialize;
use serde::Serialize;
use sync::Mutex;
use thiserror::Error as ThisError;
use vm_control::VmMemorySource;
use vm_memory::GuestAddress;
use vm_memory::GuestMemory;
use vm_memory::MemoryRegion;
use vmm_vhost::message::VhostSharedMemoryRegion;
use vmm_vhost::message::VhostUserConfigFlags;
use vmm_vhost::message::VhostUserExternalMapMsg;
use vmm_vhost::message::VhostUserGpuMapMsg;
use vmm_vhost::message::VhostUserInflight;
use vmm_vhost::message::VhostUserMemoryRegion;
use vmm_vhost::message::VhostUserProtocolFeatures;
use vmm_vhost::message::VhostUserShmemMapMsg;
use vmm_vhost::message::VhostUserShmemMapMsgFlags;
use vmm_vhost::message::VhostUserShmemUnmapMsg;
use vmm_vhost::message::VhostUserSingleMemoryRegion;
use vmm_vhost::message::VhostUserVringAddrFlags;
use vmm_vhost::message::VhostUserVringState;
use vmm_vhost::BackendReq;
use vmm_vhost::Connection;
use vmm_vhost::Error as VhostError;
use vmm_vhost::Frontend;
use vmm_vhost::FrontendClient;
use vmm_vhost::Result as VhostResult;
use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
use crate::virtio::Interrupt;
use crate::virtio::Queue;
use crate::virtio::QueueConfig;
use crate::virtio::SharedMemoryMapper;
use crate::virtio::SharedMemoryRegion;
#[derive(Default)]
pub struct MappingInfo {
pub vmm_addr: u64,
pub guest_phys: u64,
pub size: u64,
}
pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
for map in maps {
if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
}
}
Err(VhostError::InvalidMessage)
}
pub trait VhostUserDevice {
fn max_queue_num(&self) -> usize;
fn features(&self) -> u64;
fn ack_features(&mut self, value: u64) -> anyhow::Result<()>;
fn acked_features(&self) -> u64;
fn protocol_features(&self) -> VhostUserProtocolFeatures;
fn ack_protocol_features(&mut self, _value: u64) -> anyhow::Result<()>;
fn acked_protocol_features(&self) -> u64;
fn read_config(&self, offset: u64, dst: &mut [u8]);
fn write_config(&self, _offset: u64, _data: &[u8]) {}
fn start_queue(
&mut self,
idx: usize,
queue: Queue,
mem: GuestMemory,
doorbell: Interrupt,
) -> anyhow::Result<()>;
fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
fn reset(&mut self);
fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
None
}
fn set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>) {
error!("set_backend_req_connection is not implemented");
}
fn stop_non_queue_workers(&mut self) -> anyhow::Result<()> {
error!("sleep not implemented for vhost user device");
Ok(())
}
fn snapshot(&self) -> anyhow::Result<Vec<u8>> {
error!("snapshot not implemented for vhost user device");
Ok(Vec::new())
}
fn restore(&mut self, _data: Vec<u8>) -> anyhow::Result<()> {
error!("restore not implemented for vhost user device");
Ok(())
}
}
struct Vring {
queue: QueueConfig,
doorbell: Option<Interrupt>,
enabled: bool,
paused_queue: Option<Queue>,
}
#[derive(Serialize, Deserialize)]
struct VringSnapshot {
queue: serde_json::Value,
paused_queue: Option<serde_json::Value>,
enabled: bool,
}
impl Vring {
fn new(max_size: u16, features: u64) -> Self {
Self {
queue: QueueConfig::new(max_size, features),
doorbell: None,
enabled: false,
paused_queue: None,
}
}
fn reset(&mut self) {
self.queue.reset();
self.doorbell = None;
self.enabled = false;
self.paused_queue = None;
}
fn snapshot(&self) -> anyhow::Result<VringSnapshot> {
Ok(VringSnapshot {
queue: self.queue.snapshot()?,
enabled: self.enabled,
paused_queue: self
.paused_queue
.as_ref()
.map(Queue::snapshot)
.transpose()?,
})
}
fn restore(
&mut self,
vring_snapshot: VringSnapshot,
mem: &GuestMemory,
event: Option<Event>,
) -> anyhow::Result<()> {
self.queue.restore(vring_snapshot.queue)?;
self.enabled = vring_snapshot.enabled;
self.paused_queue = vring_snapshot
.paused_queue
.map(|value| {
Queue::restore(
&self.queue,
value,
mem,
event.context("missing queue event")?,
)
})
.transpose()?;
Ok(())
}
}
pub(super) struct VhostUserRegularOps;
impl VhostUserRegularOps {
pub fn set_mem_table(
contexts: &[VhostUserMemoryRegion],
files: Vec<File>,
) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
if files.len() != contexts.len() {
return Err(VhostError::InvalidParam);
}
let mut regions = Vec::with_capacity(files.len());
for (region, file) in contexts.iter().zip(files.into_iter()) {
let region = MemoryRegion::new_from_shm(
region.memory_size,
GuestAddress(region.guest_phys_addr),
region.mmap_offset,
Arc::new(
SharedMemory::from_safe_descriptor(
SafeDescriptor::from(file),
region.memory_size,
)
.unwrap(),
),
)
.map_err(|e| {
error!("failed to create a memory region: {}", e);
VhostError::InvalidOperation
})?;
regions.push(region);
}
let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
error!("failed to create guest memory: {}", e);
VhostError::InvalidOperation
})?;
let vmm_maps = contexts
.iter()
.map(|region| MappingInfo {
vmm_addr: region.user_addr,
guest_phys: region.guest_phys_addr,
size: region.memory_size,
})
.collect();
Ok((guest_mem, vmm_maps))
}
pub fn set_vring_kick(_index: u8, file: Option<File>) -> VhostResult<Event> {
let file = file.ok_or(VhostError::InvalidParam)?;
#[cfg(any(target_os = "android", target_os = "linux"))]
if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
error!("failed to remove O_NONBLOCK for kick fd: {}", e);
return Err(VhostError::InvalidParam);
}
Ok(Event::from(SafeDescriptor::from(file)))
}
pub fn set_vring_call(
_index: u8,
file: Option<File>,
signal_config_change_fn: Box<dyn Fn() + Send + Sync>,
) -> VhostResult<Interrupt> {
let file = file.ok_or(VhostError::InvalidParam)?;
Ok(Interrupt::new_vhost_user(
Event::from(SafeDescriptor::from(file)),
signal_config_change_fn,
))
}
}
pub struct DeviceRequestHandler<T: VhostUserDevice> {
vrings: Vec<Vring>,
owned: bool,
vmm_maps: Option<Vec<MappingInfo>>,
mem: Option<GuestMemory>,
backend: T,
backend_req_connection: Arc<Mutex<VhostBackendReqConnectionState>>,
}
#[derive(Serialize, Deserialize)]
pub struct DeviceRequestHandlerSnapshot {
vrings: Vec<VringSnapshot>,
backend: Vec<u8>,
}
impl<T: VhostUserDevice> DeviceRequestHandler<T> {
pub(crate) fn new(backend: T) -> Self {
let mut vrings = Vec::with_capacity(backend.max_queue_num());
for _ in 0..backend.max_queue_num() {
vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
}
DeviceRequestHandler {
vrings,
owned: false,
vmm_maps: None,
mem: None,
backend,
backend_req_connection: Arc::new(Mutex::new(
VhostBackendReqConnectionState::NoConnection,
)),
}
}
}
impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
fn as_ref(&self) -> &T {
&self.backend
}
}
impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
fn as_mut(&mut self) -> &mut T {
&mut self.backend
}
}
impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
fn set_owner(&mut self) -> VhostResult<()> {
if self.owned {
return Err(VhostError::InvalidOperation);
}
self.owned = true;
Ok(())
}
fn reset_owner(&mut self) -> VhostResult<()> {
self.owned = false;
self.backend.reset();
Ok(())
}
fn get_features(&mut self) -> VhostResult<u64> {
let features = self.backend.features();
Ok(features)
}
fn set_features(&mut self, features: u64) -> VhostResult<()> {
if !self.owned {
return Err(VhostError::InvalidOperation);
}
if (features & !(self.backend.features())) != 0 {
return Err(VhostError::InvalidParam);
}
if let Err(e) = self.backend.ack_features(features) {
error!("failed to acknowledge features 0x{:x}: {}", features, e);
return Err(VhostError::InvalidOperation);
}
let acked_features = self.backend.acked_features();
let vring_enabled = acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
for v in &mut self.vrings {
v.enabled = vring_enabled;
}
Ok(())
}
fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
Ok(self.backend.protocol_features())
}
fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
if let Err(e) = self.backend.ack_protocol_features(features) {
error!("failed to set protocol features 0x{:x}: {}", features, e);
return Err(VhostError::InvalidOperation);
}
Ok(())
}
fn set_mem_table(
&mut self,
contexts: &[VhostUserMemoryRegion],
files: Vec<File>,
) -> VhostResult<()> {
let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
self.mem = Some(guest_mem);
self.vmm_maps = Some(vmm_maps);
Ok(())
}
fn get_queue_num(&mut self) -> VhostResult<u64> {
Ok(self.vrings.len() as u64)
}
fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
return Err(VhostError::InvalidParam);
}
self.vrings[index as usize].queue.set_size(num as u16);
Ok(())
}
fn set_vring_addr(
&mut self,
index: u32,
_flags: VhostUserVringAddrFlags,
descriptor: u64,
used: u64,
available: u64,
_log: u64,
) -> VhostResult<()> {
if index as usize >= self.vrings.len() {
return Err(VhostError::InvalidParam);
}
let vmm_maps = self.vmm_maps.as_ref().ok_or(VhostError::InvalidParam)?;
let vring = &mut self.vrings[index as usize];
vring
.queue
.set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
vring
.queue
.set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
Ok(())
}
fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
if index as usize >= self.vrings.len() || base >= Queue::MAX_SIZE.into() {
return Err(VhostError::InvalidParam);
}
let vring = &mut self.vrings[index as usize];
vring.queue.set_next_avail(Wrapping(base as u16));
vring.queue.set_next_used(Wrapping(base as u16));
Ok(())
}
fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
let vring = self
.vrings
.get_mut(index as usize)
.ok_or(VhostError::InvalidParam)?;
if vring.queue.ready() {
if let Err(e) = self.backend.stop_queue(index as usize) {
error!("Failed to stop queue in get_vring_base: {:#}", e);
}
vring.reset();
}
Ok(VhostUserVringState::new(
index,
vring.queue.next_avail().0 as u32,
))
}
fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
if index as usize >= self.vrings.len() {
return Err(VhostError::InvalidParam);
}
let vring = &mut self.vrings[index as usize];
if vring.queue.ready() {
error!("kick fd cannot replaced after queue is started");
return Err(VhostError::InvalidOperation);
}
let kick_evt = VhostUserRegularOps::set_vring_kick(index, file)?;
vring.queue.ack_features(self.backend.acked_features());
vring.queue.set_ready(true);
let mem = self
.mem
.as_ref()
.cloned()
.ok_or(VhostError::InvalidOperation)?;
let queue = match vring.queue.activate(&mem, kick_evt) {
Ok(queue) => queue,
Err(e) => {
error!("failed to activate vring: {:#}", e);
return Err(VhostError::BackendInternalError);
}
};
let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
if let Err(e) = self
.backend
.start_queue(index as usize, queue, mem, doorbell)
{
error!("Failed to start queue {}: {}", index, e);
return Err(VhostError::BackendInternalError);
}
Ok(())
}
fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
if index as usize >= self.vrings.len() {
return Err(VhostError::InvalidParam);
}
let backend_req_conn = self.backend_req_connection.clone();
let signal_config_change_fn = Box::new(move || match &*backend_req_conn.lock() {
VhostBackendReqConnectionState::Connected(frontend) => {
if let Err(e) = frontend.send_config_changed() {
error!("Failed to notify config change: {:#}", e);
}
}
VhostBackendReqConnectionState::NoConnection => {
error!("No Backend request connection found");
}
});
let doorbell = VhostUserRegularOps::set_vring_call(index, file, signal_config_change_fn)?;
self.vrings[index as usize].doorbell = Some(doorbell);
Ok(())
}
fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
Ok(())
}
fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
if index as usize >= self.vrings.len() {
return Err(VhostError::InvalidParam);
}
if self.backend.acked_features() & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
return Err(VhostError::InvalidOperation);
}
self.vrings[index as usize].enabled = enable;
Ok(())
}
fn get_config(
&mut self,
offset: u32,
size: u32,
_flags: VhostUserConfigFlags,
) -> VhostResult<Vec<u8>> {
let mut data = vec![0; size as usize];
self.backend.read_config(u64::from(offset), &mut data);
Ok(data)
}
fn set_config(
&mut self,
offset: u32,
buf: &[u8],
_flags: VhostUserConfigFlags,
) -> VhostResult<()> {
self.backend.write_config(u64::from(offset), buf);
Ok(())
}
fn set_backend_req_fd(&mut self, ep: Connection<BackendReq>) {
let conn = Arc::new(VhostBackendReqConnection::new(
FrontendClient::new(ep),
self.backend.get_shared_memory_region().map(|r| r.id),
));
{
let mut backend_req_conn = self.backend_req_connection.lock();
if let VhostBackendReqConnectionState::Connected(_) = &*backend_req_conn {
warn!("Backend Request Connection already established. Overwriting");
}
*backend_req_conn = VhostBackendReqConnectionState::Connected(conn.clone());
}
self.backend.set_backend_req_connection(conn);
}
fn get_inflight_fd(
&mut self,
_inflight: &VhostUserInflight,
) -> VhostResult<(VhostUserInflight, File)> {
unimplemented!("get_inflight_fd");
}
fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
unimplemented!("set_inflight_fd");
}
fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
Ok(0)
}
fn add_mem_region(
&mut self,
_region: &VhostUserSingleMemoryRegion,
_fd: File,
) -> VhostResult<()> {
Ok(())
}
fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
Ok(())
}
fn get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>> {
Ok(if let Some(r) = self.backend.get_shared_memory_region() {
vec![VhostSharedMemoryRegion::new(r.id, r.length)]
} else {
Vec::new()
})
}
fn sleep(&mut self) -> VhostResult<()> {
for (index, vring) in self
.vrings
.iter_mut()
.enumerate()
.filter(|(_index, vring)| vring.queue.ready())
{
match self.backend.stop_queue(index) {
Ok(queue) => vring.paused_queue = Some(queue),
Err(e) => {
error!("failed to stop queue index {}: {:#}", index, e);
return Err(VhostError::StopQueueError(e));
}
}
}
self.backend
.stop_non_queue_workers()
.map_err(VhostError::SleepError)
}
fn wake(&mut self) -> VhostResult<()> {
for (index, vring) in self.vrings.iter_mut().enumerate() {
if let Some(queue) = vring.paused_queue.take() {
let mem = self.mem.clone().ok_or(VhostError::BackendInternalError)?;
let doorbell = vring.doorbell.clone().expect("Failed to clone doorbell");
if let Err(e) = self.backend.start_queue(index, queue, mem, doorbell) {
error!("Failed to start queue {}: {}", index, e);
return Err(VhostError::BackendInternalError);
}
}
}
Ok(())
}
fn snapshot(&mut self) -> VhostResult<Vec<u8>> {
match serde_json::to_vec(&DeviceRequestHandlerSnapshot {
vrings: self
.vrings
.iter()
.map(|vring| vring.snapshot())
.collect::<anyhow::Result<Vec<VringSnapshot>>>()
.map_err(VhostError::SnapshotError)?,
backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
}) {
Ok(serialized_json) => Ok(serialized_json),
Err(e) => {
error!("Failed to serialize DeviceRequestHandlerSnapshot: {}", e);
Err(VhostError::SerializationFailed)
}
}
}
fn restore(&mut self, data_bytes: &[u8], queue_evts: Vec<File>) -> VhostResult<()> {
let device_request_handler_snapshot: DeviceRequestHandlerSnapshot =
serde_json::from_slice(data_bytes).map_err(|e| {
error!("Failed to deserialize DeviceRequestHandlerSnapshot: {}", e);
VhostError::DeserializationFailed
})?;
let mem = self.mem.as_ref().ok_or(VhostError::InvalidOperation)?;
let snapshotted_vrings = device_request_handler_snapshot.vrings;
assert_eq!(snapshotted_vrings.len(), self.vrings.len());
let mut queue_evts_iter = if queue_evts.is_empty() {
None
} else {
Some(queue_evts.into_iter())
};
for (index, (vring, snapshotted_vring)) in self
.vrings
.iter_mut()
.zip(snapshotted_vrings.into_iter())
.enumerate()
{
let queue_evt = if let Some(queue_evts_iter) = &mut queue_evts_iter {
let queue_evt_file = queue_evts_iter
.next()
.ok_or(VhostError::VringIndexNotFound(index))?;
Some(VhostUserRegularOps::set_vring_kick(
index as u8,
Some(queue_evt_file),
)?)
} else {
None
};
vring
.restore(snapshotted_vring, mem, queue_evt)
.map_err(VhostError::RestoreError)?;
}
self.backend
.restore(device_request_handler_snapshot.backend)
.map_err(VhostError::RestoreError)?;
Ok(())
}
}
pub enum VhostBackendReqConnectionState {
Connected(Arc<VhostBackendReqConnection>),
NoConnection,
}
pub struct VhostBackendReqConnection {
conn: Arc<Mutex<FrontendClient>>,
shmem_info: Mutex<Option<ShmemInfo>>,
}
#[derive(Clone)]
struct ShmemInfo {
shmid: u8,
mapped_regions: BTreeMap<u64 , u64 >,
}
impl VhostBackendReqConnection {
pub fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
let shmem_info = Mutex::new(shmid.map(|shmid| ShmemInfo {
shmid,
mapped_regions: BTreeMap::new(),
}));
Self {
conn: Arc::new(Mutex::new(conn)),
shmem_info,
}
}
pub fn send_config_changed(&self) -> anyhow::Result<()> {
self.conn
.lock()
.handle_config_change()
.context("Could not send config change message")?;
Ok(())
}
pub fn take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>> {
let shmem_info = self
.shmem_info
.lock()
.take()
.context("could not take shared memory mapper information")?;
Ok(Box::new(VhostShmemMapper {
conn: self.conn.clone(),
shmem_info,
}))
}
}
struct VhostShmemMapper {
conn: Arc<Mutex<FrontendClient>>,
shmem_info: ShmemInfo,
}
impl SharedMemoryMapper for VhostShmemMapper {
fn add_mapping(
&mut self,
source: VmMemorySource,
offset: u64,
prot: Protection,
_cache: MemCacheType,
) -> anyhow::Result<()> {
let size = match source {
VmMemorySource::Vulkan {
descriptor,
handle_type,
memory_idx,
device_uuid,
driver_uuid,
size,
} => {
let msg = VhostUserGpuMapMsg::new(
self.shmem_info.shmid,
offset,
size,
memory_idx,
handle_type,
device_uuid,
driver_uuid,
);
self.conn
.lock()
.gpu_map(&msg, &descriptor)
.context("failed to map memory")?;
size
}
VmMemorySource::ExternalMapping { ptr, size } => {
let msg = VhostUserExternalMapMsg::new(self.shmem_info.shmid, offset, size, ptr);
self.conn
.lock()
.external_map(&msg)
.context("failed to map memory")?;
size
}
source => {
let (descriptor, fd_offset, size) = match source {
VmMemorySource::Descriptor {
descriptor,
offset,
size,
} => (descriptor, offset, size),
VmMemorySource::SharedMemory(shmem) => {
let size = shmem.size();
let descriptor =
unsafe {
SafeDescriptor::from_raw_descriptor(shmem.into_raw_descriptor())
};
(descriptor, 0, size)
}
_ => bail!("unsupported source"),
};
let flags = VhostUserShmemMapMsgFlags::from(prot);
let msg = VhostUserShmemMapMsg::new(
self.shmem_info.shmid,
offset,
fd_offset,
size,
flags,
);
self.conn
.lock()
.shmem_map(&msg, &descriptor)
.context("failed to map memory")?;
size
}
};
self.shmem_info.mapped_regions.insert(offset, size);
Ok(())
}
fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
let size = self
.shmem_info
.mapped_regions
.remove(&offset)
.context("unknown offset")?;
let msg = VhostUserShmemUnmapMsg::new(self.shmem_info.shmid, offset, size);
self.conn
.lock()
.shmem_unmap(&msg)
.context("failed to map memory")
.map(|_| ())
}
}
pub(crate) struct WorkerState<T, U> {
pub(crate) queue_task: TaskHandle<U>,
pub(crate) queue: T,
}
#[derive(Debug, ThisError)]
pub enum Error {
#[error("worker not found when stopping queue")]
WorkerNotFound,
}
#[cfg(test)]
mod tests {
use std::sync::mpsc::channel;
use std::sync::Barrier;
use anyhow::anyhow;
use anyhow::bail;
use base::Event;
use vmm_vhost::BackendServer;
use vmm_vhost::FrontendReq;
use zerocopy::AsBytes;
use zerocopy::FromBytes;
use zerocopy::FromZeroes;
use super::sys::test_helpers;
use super::*;
use crate::virtio::vhost_user_frontend::VhostUserFrontend;
use crate::virtio::DeviceType;
use crate::virtio::VirtioDevice;
#[derive(Clone, Copy, Debug, PartialEq, Eq, AsBytes, FromZeroes, FromBytes)]
#[repr(C, packed(4))]
struct FakeConfig {
x: u32,
y: u64,
}
const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
pub(super) struct FakeBackend {
avail_features: u64,
acked_features: u64,
acked_protocol_features: VhostUserProtocolFeatures,
active_queues: Vec<Option<Queue>>,
allow_backend_req: bool,
backend_conn: Option<Arc<VhostBackendReqConnection>>,
}
impl FakeBackend {
const MAX_QUEUE_NUM: usize = 16;
pub(super) fn new() -> Self {
let mut active_queues = Vec::new();
active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
Self {
avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
acked_features: 0,
acked_protocol_features: VhostUserProtocolFeatures::empty(),
active_queues,
allow_backend_req: false,
backend_conn: None,
}
}
}
impl VhostUserDevice for FakeBackend {
fn max_queue_num(&self) -> usize {
Self::MAX_QUEUE_NUM
}
fn features(&self) -> u64 {
self.avail_features
}
fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
let unrequested_features = value & !self.avail_features;
if unrequested_features != 0 {
bail!(
"invalid protocol features are given: 0x{:x}",
unrequested_features
);
}
self.acked_features |= value;
Ok(())
}
fn acked_features(&self) -> u64 {
self.acked_features
}
fn protocol_features(&self) -> VhostUserProtocolFeatures {
let mut features = VhostUserProtocolFeatures::CONFIG;
if self.allow_backend_req {
features |= VhostUserProtocolFeatures::BACKEND_REQ;
}
features
}
fn ack_protocol_features(&mut self, features: u64) -> anyhow::Result<()> {
let features = VhostUserProtocolFeatures::from_bits(features).ok_or(anyhow!(
"invalid protocol features are given: 0x{:x}",
features
))?;
let supported = self.protocol_features();
self.acked_protocol_features = features & supported;
Ok(())
}
fn acked_protocol_features(&self) -> u64 {
self.acked_protocol_features.bits()
}
fn read_config(&self, offset: u64, dst: &mut [u8]) {
dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
}
fn reset(&mut self) {}
fn start_queue(
&mut self,
idx: usize,
queue: Queue,
_mem: GuestMemory,
_doorbell: Interrupt,
) -> anyhow::Result<()> {
self.active_queues[idx] = Some(queue);
Ok(())
}
fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
Ok(self.active_queues[idx]
.take()
.ok_or(Error::WorkerNotFound)?)
}
fn set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>) {
self.backend_conn = Some(conn);
}
}
#[test]
fn test_vhost_user_activate() {
test_vhost_user_activate_parameterized(false);
}
#[test]
#[cfg(not(windows))] fn test_vhost_user_activate_with_backend_req() {
test_vhost_user_activate_parameterized(true);
}
fn test_vhost_user_activate_parameterized(allow_backend_req: bool) {
const QUEUES_NUM: usize = 2;
let (dev, vmm) = test_helpers::setup();
let vmm_bar = Arc::new(Barrier::new(2));
let dev_bar = vmm_bar.clone();
let (ready_tx, ready_rx) = channel();
let (shutdown_tx, shutdown_rx) = channel();
std::thread::spawn(move || {
ready_rx.recv().unwrap(); let connection = test_helpers::connect(vmm);
let mut vmm_device =
VhostUserFrontend::new(DeviceType::Console, 0, connection, None, None).unwrap();
println!("read_config");
let mut buf = vec![0; std::mem::size_of::<FakeConfig>()];
vmm_device.read_config(0, &mut buf);
let config = FakeConfig::read_from(buf.as_bytes()).unwrap();
assert_eq!(config, FAKE_CONFIG_DATA);
let activate = |vmm_device: &mut VhostUserFrontend| {
let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
let interrupt = Interrupt::new_for_test_with_msix();
let mut queues = BTreeMap::new();
for idx in 0..QUEUES_NUM {
let mut queue = QueueConfig::new(0x10, 0);
queue.set_ready(true);
let queue = queue
.activate(&mem, Event::new().unwrap())
.expect("QueueConfig::activate");
queues.insert(idx, queue);
}
println!("activate");
vmm_device
.activate(mem.clone(), interrupt.clone(), queues)
.unwrap();
};
activate(&mut vmm_device);
println!("reset");
let reset_result = vmm_device.reset();
assert!(
reset_result.is_ok(),
"reset failed: {:#}",
reset_result.unwrap_err()
);
activate(&mut vmm_device);
println!("virtio_sleep");
vmm_device.virtio_sleep().unwrap();
println!("virtio_wake");
vmm_device.virtio_wake(None).unwrap();
println!("wait for shutdown signal");
shutdown_rx.recv().unwrap();
println!("drop");
drop(vmm_device);
vmm_bar.wait();
});
let mut handler = DeviceRequestHandler::new(FakeBackend::new());
handler.as_mut().allow_backend_req = allow_backend_req;
ready_tx.send(()).unwrap();
let mut req_handler = test_helpers::listen(dev, handler);
handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
if allow_backend_req {
handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
}
handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
for _ in 0..QUEUES_NUM {
handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
}
for _ in 0..QUEUES_NUM {
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
}
handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
for _ in 0..QUEUES_NUM {
handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
}
if allow_backend_req {
req_handler
.as_ref()
.as_ref()
.backend_conn
.as_ref()
.expect("backend_conn missing")
.send_config_changed()
.expect("send_config_changed failed");
}
handle_request(&mut req_handler, FrontendReq::SLEEP).unwrap();
handle_request(&mut req_handler, FrontendReq::WAKE).unwrap();
if allow_backend_req {
req_handler
.as_ref()
.as_ref()
.backend_conn
.as_ref()
.expect("backend_conn missing")
.send_config_changed()
.expect("send_config_changed failed");
}
shutdown_tx.send(()).unwrap();
dev_bar.wait();
match req_handler.recv_header() {
Err(VhostError::ClientExit) => (),
r => panic!("expected Err(ClientExit) but got {:?}", r),
}
}
fn handle_request<S: vmm_vhost::Backend>(
handler: &mut BackendServer<S>,
expected_message_type: FrontendReq,
) -> Result<(), VhostError> {
let (hdr, files) = handler.recv_header()?;
assert_eq!(hdr.get_code(), Ok(expected_message_type));
handler.process_message(hdr, files)
}
}