use std::fs::File;
use std::io::Error as IOError;
use std::num::TryFromIntError;
use remain::sorted;
use thiserror::Error as ThisError;
mod backend;
pub use backend::*;
pub mod message;
pub use message::VHOST_USER_F_PROTOCOL_FEATURES;
pub mod connection;
mod sys;
pub use connection::Connection;
pub use message::BackendReq;
pub use message::FrontendReq;
#[cfg(unix)]
pub use sys::unix;
pub(crate) mod backend_client;
pub use backend_client::BackendClient;
mod frontend_server;
pub use self::frontend_server::Frontend;
mod backend_server;
mod frontend_client;
pub use self::backend_server::Backend;
pub use self::backend_server::BackendServer;
pub use self::frontend_client::FrontendClient;
pub use self::frontend_server::FrontendServer;
#[sorted]
#[derive(Debug, ThisError)]
pub enum Error {
#[error("backend internal error")]
BackendInternalError,
#[error("client exited properly")]
ClientExit,
#[error("failed to deserialize data")]
DeserializationFailed,
#[error("client closed the connection")]
Disconnect,
#[error("Failed to enter suspended state")]
EnterSuspendedState(anyhow::Error),
#[error("virtio features mismatch")]
FeatureMismatch,
#[error("frontend Internal error")]
FrontendInternalError,
#[error("wrong number of attached fds")]
IncorrectFds,
#[error("invalid cast to int: {0}")]
InvalidCastToInt(TryFromIntError),
#[error("invalid message")]
InvalidMessage,
#[error("invalid operation")]
InvalidOperation,
#[error("invalid parameters")]
InvalidParam,
#[error("oversized message")]
OversizedMsg,
#[error("partial message")]
PartialMessage,
#[error("buffer for recv was too small, data was dropped: got size {got}, needed {want}")]
RecvBufferTooSmall {
got: usize,
want: usize,
},
#[error("handler failed to handle request: {0}")]
ReqHandlerError(IOError),
#[error("Failed to restore")]
RestoreError(anyhow::Error),
#[error("failed to serialize data")]
SerializationFailed,
#[error("Failed to run device specific sleep: {0}")]
SleepError(anyhow::Error),
#[error("Failed to snapshot")]
SnapshotError(anyhow::Error),
#[error("socket is broken: {0}")]
SocketBroken(std::io::Error),
#[error("can't connect to peer: {0}")]
SocketConnect(std::io::Error),
#[error("socket error: {0}")]
SocketError(std::io::Error),
#[error("Failed get socket from the fd: {0}")]
SocketFromFdError(std::path::PathBuf),
#[error("temporary socket error: {0}")]
SocketRetry(std::io::Error),
#[error("failed to stop queue")]
StopQueueError(anyhow::Error),
#[error("failed to read/write on Tube: {0}")]
TubeError(base::TubeError),
#[error("error occurred in VFIO device: {0}")]
VfioDeviceError(anyhow::Error),
#[error("Vring index not found: {0}")]
VringIndexNotFound(usize),
#[error("Failed to run device specific wake: {0}")]
WakeError(anyhow::Error),
}
impl From<base::TubeError> for Error {
fn from(err: base::TubeError) -> Self {
match err {
base::TubeError::Disconnected => Error::Disconnect,
err => Error::TubeError(err),
}
}
}
impl From<std::io::Error> for Error {
fn from(err: std::io::Error) -> Self {
Error::SocketError(err)
}
}
impl From<base::Error> for Error {
#[allow(unreachable_patterns)] fn from(err: base::Error) -> Self {
match err.errno() {
libc::EAGAIN | libc::EWOULDBLOCK | libc::EINTR | libc::ENOBUFS | libc::ENOMEM => {
Error::SocketRetry(err.into())
}
libc::ECONNRESET | libc::EPIPE => Error::SocketBroken(err.into()),
libc::EACCES => Error::SocketConnect(IOError::from_raw_os_error(libc::EACCES)),
e => Error::SocketError(IOError::from_raw_os_error(e)),
}
}
}
pub type Result<T> = std::result::Result<T, Error>;
pub type HandlerResult<T> = std::result::Result<T, IOError>;
pub(crate) fn into_single_file(mut files: Vec<File>) -> Option<File> {
if files.len() != 1 {
return None;
}
Some(files.swap_remove(0))
}
#[cfg(test)]
mod test_backend;
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::Barrier;
use std::thread;
use base::AsRawDescriptor;
use tempfile::tempfile;
use super::*;
use crate::message::*;
use crate::test_backend::TestBackend;
use crate::test_backend::VIRTIO_FEATURES;
use crate::VhostUserMemoryRegionInfo;
use crate::VringConfigData;
fn create_client_server_pair<S>(backend: S) -> (BackendClient, BackendServer<S>)
where
S: Backend,
{
let (client_connection, server_connection) = Connection::pair().unwrap();
let backend_client = BackendClient::new(client_connection);
(
backend_client,
BackendServer::<S>::new(server_connection, backend),
)
}
fn handle_request(h: &mut BackendServer<TestBackend>) -> Result<()> {
let (hdr, files) = h.recv_header()?;
h.process_message(hdr, files)
}
#[test]
fn create_test_backend() {
let mut backend = TestBackend::new();
backend.set_owner().unwrap();
assert!(backend.set_owner().is_err());
}
#[test]
fn test_set_owner() {
let test_backend = TestBackend::new();
let (backend_client, mut backend_server) = create_client_server_pair(test_backend);
assert!(!backend_server.as_ref().owned);
backend_client.set_owner().unwrap();
handle_request(&mut backend_server).unwrap();
assert!(backend_server.as_ref().owned);
backend_client.set_owner().unwrap();
assert!(handle_request(&mut backend_server).is_err());
assert!(backend_server.as_ref().owned);
}
#[test]
fn test_set_features() {
let mbar = Arc::new(Barrier::new(2));
let sbar = mbar.clone();
let test_backend = TestBackend::new();
let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
thread::spawn(move || {
handle_request(&mut backend_server).unwrap();
assert!(backend_server.as_ref().owned);
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
assert_eq!(
backend_server.as_ref().acked_features,
VIRTIO_FEATURES & !0x1
);
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
assert_eq!(
backend_server.as_ref().acked_protocol_features,
VhostUserProtocolFeatures::all().bits()
);
sbar.wait();
});
backend_client.set_owner().unwrap();
let features = backend_client.get_features().unwrap();
assert_eq!(features, VIRTIO_FEATURES);
backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
let features = backend_client.get_protocol_features().unwrap();
assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
backend_client.set_protocol_features(features).unwrap();
mbar.wait();
}
#[test]
fn test_client_server_process() {
let mbar = Arc::new(Barrier::new(2));
let sbar = mbar.clone();
let test_backend = TestBackend::new();
let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
thread::spawn(move || {
handle_request(&mut backend_server).unwrap();
assert!(backend_server.as_ref().owned);
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
assert_eq!(
backend_server.as_ref().acked_features,
VIRTIO_FEATURES & !0x1
);
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
assert_eq!(
backend_server.as_ref().acked_protocol_features,
VhostUserProtocolFeatures::all().bits()
);
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap_err();
handle_request(&mut backend_server).unwrap_err();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
handle_request(&mut backend_server).unwrap();
sbar.wait();
});
backend_client.set_owner().unwrap();
let features = backend_client.get_features().unwrap();
assert_eq!(features, VIRTIO_FEATURES);
backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
let features = backend_client.get_protocol_features().unwrap();
assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
backend_client.set_protocol_features(features).unwrap();
let (inflight_info, inflight_file) = backend_client
.get_inflight_fd(&VhostUserInflight {
num_queues: 2,
queue_size: 256,
..Default::default()
})
.unwrap();
backend_client
.set_inflight_fd(&inflight_info, inflight_file.as_raw_descriptor())
.unwrap();
let num = backend_client.get_queue_num().unwrap();
assert_eq!(num, 2);
let event = base::Event::new().unwrap();
let mem = [VhostUserMemoryRegionInfo {
guest_phys_addr: 0,
memory_size: 0x10_0000,
userspace_addr: 0,
mmap_offset: 0,
mmap_handle: event.as_raw_descriptor(),
}];
backend_client.set_mem_table(&mem).unwrap();
backend_client
.set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8])
.unwrap();
let buf = [0x0u8; 4];
let (reply_body, reply_payload) = backend_client
.get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
.unwrap();
let offset = reply_body.offset;
assert_eq!(offset, 0x100);
assert_eq!(reply_payload[0], 0xa5);
#[cfg(windows)]
let tubes = base::Tube::pair().unwrap();
#[cfg(windows)]
let descriptor =
unsafe { tube_transporter::packed_tube::pack(tubes.0, std::process::id()).unwrap() };
#[cfg(unix)]
let descriptor = base::Event::new().unwrap();
backend_client.set_backend_req_fd(&descriptor).unwrap();
backend_client.set_vring_enable(0, true).unwrap();
backend_client
.set_log_base(0, Some(event.as_raw_descriptor()))
.unwrap();
backend_client
.set_log_fd(event.as_raw_descriptor())
.unwrap();
backend_client.set_vring_num(0, 256).unwrap();
backend_client.set_vring_base(0, 0).unwrap();
let config = VringConfigData {
queue_size: 128,
flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
desc_table_addr: 0x1000,
used_ring_addr: 0x2000,
avail_ring_addr: 0x3000,
log_addr: Some(0x4000),
};
backend_client.set_vring_addr(0, &config).unwrap();
backend_client.set_vring_call(0, &event).unwrap();
backend_client.set_vring_kick(0, &event).unwrap();
backend_client.set_vring_err(0, &event).unwrap();
let max_mem_slots = backend_client.get_max_mem_slots().unwrap();
assert_eq!(max_mem_slots, 32);
let region_file = tempfile().unwrap();
let region = VhostUserMemoryRegionInfo {
guest_phys_addr: 0x10_0000,
memory_size: 0x10_0000,
userspace_addr: 0,
mmap_offset: 0,
mmap_handle: region_file.as_raw_descriptor(),
};
backend_client.add_mem_region(®ion).unwrap();
backend_client.remove_mem_region(®ion).unwrap();
mbar.wait();
}
#[test]
fn test_error_display() {
assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters");
assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation");
}
#[test]
fn test_error_from_base_error() {
let e: Error = base::Error::new(libc::EAGAIN).into();
if let Error::SocketRetry(e1) = e {
assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN);
} else {
panic!("invalid error code conversion!");
}
}
}