use std::fmt;
use std::io;
use std::mem;
use std::mem::size_of;
use std::num::ParseIntError;
use std::os::raw::c_uchar;
use std::os::raw::c_uint;
use std::os::raw::c_ushort;
use std::os::unix::io::AsRawFd;
use std::os::unix::io::IntoRawFd;
use std::os::unix::io::RawFd;
use std::result;
use std::str::FromStr;
use libc::c_void;
use libc::sa_family_t;
use libc::size_t;
use libc::sockaddr;
use libc::socklen_t;
use libc::F_GETFL;
use libc::F_SETFL;
use libc::O_NONBLOCK;
use libc::VMADDR_CID_ANY;
use libc::VMADDR_CID_HOST;
use libc::VMADDR_CID_HYPERVISOR;
use thiserror::Error;
const AF_VSOCK: sa_family_t = 40;
const VMADDR_CID_LOCAL: c_uint = 1;
pub const VMADDR_PORT_ANY: c_uint = c_uint::MAX;
const PADDING: usize = size_of::<sockaddr>()
- size_of::<sa_family_t>()
- size_of::<c_ushort>()
- (2 * size_of::<c_uint>());
#[repr(C)]
#[derive(Default)]
struct sockaddr_vm {
svm_family: sa_family_t,
svm_reserved1: c_ushort,
svm_port: c_uint,
svm_cid: c_uint,
svm_zero: [c_uchar; PADDING],
}
#[derive(Error, Debug)]
#[error("failed to parse vsock address")]
pub struct AddrParseError;
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub enum VsockCid {
Any,
Hypervisor,
Local,
Host,
Cid(c_uint),
}
impl fmt::Display for VsockCid {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match &self {
VsockCid::Any => write!(fmt, "Any"),
VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
VsockCid::Local => write!(fmt, "Local"),
VsockCid::Host => write!(fmt, "Host"),
VsockCid::Cid(c) => write!(fmt, "'{}'", c),
}
}
}
impl From<c_uint> for VsockCid {
fn from(c: c_uint) -> Self {
match c {
VMADDR_CID_ANY => VsockCid::Any,
VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
VMADDR_CID_LOCAL => VsockCid::Local,
VMADDR_CID_HOST => VsockCid::Host,
_ => VsockCid::Cid(c),
}
}
}
impl FromStr for VsockCid {
type Err = ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let c: c_uint = s.parse()?;
Ok(c.into())
}
}
impl From<VsockCid> for c_uint {
fn from(cid: VsockCid) -> c_uint {
match cid {
VsockCid::Any => VMADDR_CID_ANY,
VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
VsockCid::Local => VMADDR_CID_LOCAL,
VsockCid::Host => VMADDR_CID_HOST,
VsockCid::Cid(c) => c,
}
}
}
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct SocketAddr {
pub cid: VsockCid,
pub port: c_uint,
}
pub trait ToSocketAddr {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
}
impl ToSocketAddr for SocketAddr {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
Ok(*self)
}
}
impl ToSocketAddr for str {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
self.parse()
}
}
impl ToSocketAddr for (VsockCid, c_uint) {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
let (cid, port) = *self;
Ok(SocketAddr { cid, port })
}
}
impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
(**self).to_socket_addr()
}
}
impl FromStr for SocketAddr {
type Err = AddrParseError;
fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
let components: Vec<&str> = s.split(':').collect();
if components.len() != 3 || components[0] != "vsock" {
return Err(AddrParseError);
}
Ok(SocketAddr {
cid: components[1].parse().map_err(|_| AddrParseError)?,
port: components[2].parse().map_err(|_| AddrParseError)?,
})
}
}
impl fmt::Display for SocketAddr {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}:{}", self.cid, self.port)
}
}
unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
let flags = libc::fcntl(fd, F_GETFL, 0);
if flags < 0 {
return Err(io::Error::last_os_error());
}
let flags = if nonblocking {
flags | O_NONBLOCK
} else {
flags & !O_NONBLOCK
};
let ret = libc::fcntl(fd, F_SETFL, flags);
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
#[derive(Debug)]
pub struct VsockSocket {
fd: RawFd,
}
impl VsockSocket {
pub fn new() -> io::Result<Self> {
let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
if fd < 0 {
Err(io::Error::last_os_error())
} else {
Ok(VsockSocket { fd })
}
}
pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
let sockaddr = addr
.to_socket_addr()
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
let svm = sockaddr_vm {
svm_family: AF_VSOCK,
svm_cid: sockaddr.cid.into(),
svm_port: sockaddr.port,
..Default::default()
};
let ret = unsafe {
libc::bind(
self.fd,
&svm as *const sockaddr_vm as *const sockaddr,
size_of::<sockaddr_vm>() as socklen_t,
)
};
if ret < 0 {
let bind_err = io::Error::last_os_error();
Err(bind_err)
} else {
Ok(())
}
}
pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
let sockaddr = addr
.to_socket_addr()
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
let svm = sockaddr_vm {
svm_family: AF_VSOCK,
svm_cid: sockaddr.cid.into(),
svm_port: sockaddr.port,
..Default::default()
};
let ret = unsafe {
libc::connect(
self.fd,
&svm as *const sockaddr_vm as *const sockaddr,
size_of::<sockaddr_vm>() as socklen_t,
)
};
if ret < 0 {
let connect_err = io::Error::last_os_error();
Err(connect_err)
} else {
Ok(VsockStream { sock: self })
}
}
pub fn listen(self) -> io::Result<VsockListener> {
let ret = unsafe { libc::listen(self.fd, 1) };
if ret < 0 {
let listen_err = io::Error::last_os_error();
return Err(listen_err);
}
Ok(VsockListener { sock: self })
}
pub fn local_port(&self) -> io::Result<u32> {
let mut svm: sockaddr_vm = Default::default();
let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
let ret = unsafe {
libc::getsockname(
self.fd,
&mut svm as *mut sockaddr_vm as *mut sockaddr,
&mut addrlen as *mut socklen_t,
)
};
if ret < 0 {
let getsockname_err = io::Error::last_os_error();
Err(getsockname_err)
} else {
assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
Ok(svm.svm_port)
}
}
pub fn try_clone(&self) -> io::Result<Self> {
let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
if dup_fd < 0 {
Err(io::Error::last_os_error())
} else {
Ok(Self { fd: dup_fd })
}
}
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
unsafe { set_nonblocking(self.fd, nonblocking) }
}
}
impl IntoRawFd for VsockSocket {
fn into_raw_fd(self) -> RawFd {
let fd = self.fd;
mem::forget(self);
fd
}
}
impl AsRawFd for VsockSocket {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl Drop for VsockSocket {
fn drop(&mut self) {
unsafe { libc::close(self.fd) };
}
}
#[derive(Debug)]
pub struct VsockStream {
sock: VsockSocket,
}
impl VsockStream {
pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
let sock = VsockSocket::new()?;
sock.connect(addr)
}
pub fn local_port(&self) -> io::Result<u32> {
self.sock.local_port()
}
pub fn try_clone(&self) -> io::Result<VsockStream> {
self.sock.try_clone().map(|f| VsockStream { sock: f })
}
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
self.sock.set_nonblocking(nonblocking)
}
}
impl io::Read for VsockStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let ret = unsafe {
libc::read(
self.sock.as_raw_fd(),
buf as *mut [u8] as *mut c_void,
buf.len() as size_t,
)
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(ret as usize)
}
}
impl io::Write for VsockStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let ret = unsafe {
libc::write(
self.sock.as_raw_fd(),
buf as *const [u8] as *const c_void,
buf.len() as size_t,
)
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(ret as usize)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl AsRawFd for VsockStream {
fn as_raw_fd(&self) -> RawFd {
self.sock.as_raw_fd()
}
}
impl IntoRawFd for VsockStream {
fn into_raw_fd(self) -> RawFd {
self.sock.into_raw_fd()
}
}
#[derive(Debug)]
pub struct VsockListener {
sock: VsockSocket,
}
impl VsockListener {
pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
let mut sock = VsockSocket::new()?;
sock.bind(addr)?;
sock.listen()
}
pub fn local_port(&self) -> io::Result<u32> {
self.sock.local_port()
}
pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
let mut svm: sockaddr_vm = Default::default();
let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
let fd = unsafe {
libc::accept4(
self.sock.as_raw_fd(),
&mut svm as *mut sockaddr_vm as *mut sockaddr,
&mut socklen as *mut socklen_t,
libc::SOCK_CLOEXEC,
)
};
if fd < 0 {
return Err(io::Error::last_os_error());
}
if svm.svm_family != AF_VSOCK {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unexpected address family: {}", svm.svm_family),
));
}
Ok((
VsockStream {
sock: VsockSocket { fd },
},
SocketAddr {
cid: svm.svm_cid.into(),
port: svm.svm_port,
},
))
}
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
self.sock.set_nonblocking(nonblocking)
}
}
impl AsRawFd for VsockListener {
fn as_raw_fd(&self) -> RawFd {
self.sock.as_raw_fd()
}
}