1use std::fmt;
7use std::io;
8use std::mem;
9use std::mem::size_of;
10use std::num::ParseIntError;
11use std::os::raw::c_uchar;
12use std::os::raw::c_uint;
13use std::os::raw::c_ushort;
14use std::os::unix::io::AsRawFd;
15use std::os::unix::io::IntoRawFd;
16use std::os::unix::io::RawFd;
17use std::result;
18use std::str::FromStr;
19
20use libc::c_void;
21use libc::sa_family_t;
22use libc::size_t;
23use libc::sockaddr;
24use libc::socklen_t;
25use libc::F_GETFL;
26use libc::F_SETFL;
27use libc::O_NONBLOCK;
28use libc::VMADDR_CID_ANY;
29use libc::VMADDR_CID_HOST;
30use libc::VMADDR_CID_HYPERVISOR;
31use thiserror::Error;
32
33const AF_VSOCK: sa_family_t = 40;
35
36const VMADDR_CID_LOCAL: c_uint = 1;
38
39pub const VMADDR_PORT_ANY: c_uint = c_uint::MAX;
41
42const PADDING: usize = size_of::<sockaddr>()
45 - size_of::<sa_family_t>()
46 - size_of::<c_ushort>()
47 - (2 * size_of::<c_uint>());
48
49#[repr(C)]
50#[derive(Default)]
51struct sockaddr_vm {
52 svm_family: sa_family_t,
53 svm_reserved1: c_ushort,
54 svm_port: c_uint,
55 svm_cid: c_uint,
56 svm_zero: [c_uchar; PADDING],
57}
58
59#[derive(Error, Debug)]
60#[error("failed to parse vsock address")]
61pub struct AddrParseError;
62
63#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
65pub enum VsockCid {
66 Any,
68 Hypervisor,
70 Local,
72 Host,
74 Cid(c_uint),
76}
77
78impl fmt::Display for VsockCid {
79 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
80 match &self {
81 VsockCid::Any => write!(fmt, "Any"),
82 VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
83 VsockCid::Local => write!(fmt, "Local"),
84 VsockCid::Host => write!(fmt, "Host"),
85 VsockCid::Cid(c) => write!(fmt, "'{c}'"),
86 }
87 }
88}
89
90impl From<c_uint> for VsockCid {
91 fn from(c: c_uint) -> Self {
92 match c {
93 VMADDR_CID_ANY => VsockCid::Any,
94 VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
95 VMADDR_CID_LOCAL => VsockCid::Local,
96 VMADDR_CID_HOST => VsockCid::Host,
97 _ => VsockCid::Cid(c),
98 }
99 }
100}
101
102impl FromStr for VsockCid {
103 type Err = ParseIntError;
104
105 fn from_str(s: &str) -> Result<Self, Self::Err> {
106 let c: c_uint = s.parse()?;
107 Ok(c.into())
108 }
109}
110
111impl From<VsockCid> for c_uint {
112 fn from(cid: VsockCid) -> c_uint {
113 match cid {
114 VsockCid::Any => VMADDR_CID_ANY,
115 VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
116 VsockCid::Local => VMADDR_CID_LOCAL,
117 VsockCid::Host => VMADDR_CID_HOST,
118 VsockCid::Cid(c) => c,
119 }
120 }
121}
122
123#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
125pub struct SocketAddr {
126 pub cid: VsockCid,
127 pub port: c_uint,
128}
129
130pub trait ToSocketAddr {
131 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
132}
133
134impl ToSocketAddr for SocketAddr {
135 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
136 Ok(*self)
137 }
138}
139
140impl ToSocketAddr for str {
141 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
142 self.parse()
143 }
144}
145
146impl ToSocketAddr for (VsockCid, c_uint) {
147 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
148 let (cid, port) = *self;
149 Ok(SocketAddr { cid, port })
150 }
151}
152
153impl<T: ToSocketAddr + ?Sized> ToSocketAddr for &T {
154 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
155 (**self).to_socket_addr()
156 }
157}
158
159impl FromStr for SocketAddr {
160 type Err = AddrParseError;
161
162 fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
165 let components: Vec<&str> = s.split(':').collect();
166 if components.len() != 3 || components[0] != "vsock" {
167 return Err(AddrParseError);
168 }
169
170 Ok(SocketAddr {
171 cid: components[1].parse().map_err(|_| AddrParseError)?,
172 port: components[2].parse().map_err(|_| AddrParseError)?,
173 })
174 }
175}
176
177impl fmt::Display for SocketAddr {
178 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
179 write!(fmt, "{}:{}", self.cid, self.port)
180 }
181}
182
183unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
186 let flags = libc::fcntl(fd, F_GETFL, 0);
187 if flags < 0 {
188 return Err(io::Error::last_os_error());
189 }
190
191 let flags = if nonblocking {
192 flags | O_NONBLOCK
193 } else {
194 flags & !O_NONBLOCK
195 };
196
197 let ret = libc::fcntl(fd, F_SETFL, flags);
198 if ret < 0 {
199 return Err(io::Error::last_os_error());
200 }
201
202 Ok(())
203}
204
205#[derive(Debug)]
211pub struct VsockSocket {
212 fd: RawFd,
213}
214
215impl VsockSocket {
216 pub fn new() -> io::Result<Self> {
217 let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
219 if fd < 0 {
220 Err(io::Error::last_os_error())
221 } else {
222 Ok(VsockSocket { fd })
223 }
224 }
225
226 pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
227 let sockaddr = addr
228 .to_socket_addr()
229 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
230
231 assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
233
234 let svm = sockaddr_vm {
235 svm_family: AF_VSOCK,
236 svm_cid: sockaddr.cid.into(),
237 svm_port: sockaddr.port,
238 ..Default::default()
239 };
240
241 let ret = unsafe {
244 libc::bind(
245 self.fd,
246 &svm as *const sockaddr_vm as *const sockaddr,
247 size_of::<sockaddr_vm>() as socklen_t,
248 )
249 };
250 if ret < 0 {
251 let bind_err = io::Error::last_os_error();
252 Err(bind_err)
253 } else {
254 Ok(())
255 }
256 }
257
258 pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
259 let sockaddr = addr
260 .to_socket_addr()
261 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
262
263 let svm = sockaddr_vm {
264 svm_family: AF_VSOCK,
265 svm_cid: sockaddr.cid.into(),
266 svm_port: sockaddr.port,
267 ..Default::default()
268 };
269
270 let ret = unsafe {
273 libc::connect(
274 self.fd,
275 &svm as *const sockaddr_vm as *const sockaddr,
276 size_of::<sockaddr_vm>() as socklen_t,
277 )
278 };
279 if ret < 0 {
280 let connect_err = io::Error::last_os_error();
281 Err(connect_err)
282 } else {
283 Ok(VsockStream { sock: self })
284 }
285 }
286
287 pub fn listen(self) -> io::Result<VsockListener> {
288 let ret = unsafe { libc::listen(self.fd, 1) };
291 if ret < 0 {
292 let listen_err = io::Error::last_os_error();
293 return Err(listen_err);
294 }
295 Ok(VsockListener { sock: self })
296 }
297
298 pub fn local_port(&self) -> io::Result<u32> {
300 let mut svm: sockaddr_vm = Default::default();
301
302 let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
303 let ret = unsafe {
306 libc::getsockname(
308 self.fd,
309 &mut svm as *mut sockaddr_vm as *mut sockaddr,
310 &mut addrlen as *mut socklen_t,
311 )
312 };
313 if ret < 0 {
314 let getsockname_err = io::Error::last_os_error();
315 Err(getsockname_err)
316 } else {
317 assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
319
320 Ok(svm.svm_port)
321 }
322 }
323
324 pub fn try_clone(&self) -> io::Result<Self> {
325 let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
328 if dup_fd < 0 {
329 Err(io::Error::last_os_error())
330 } else {
331 Ok(Self { fd: dup_fd })
332 }
333 }
334
335 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
336 unsafe { set_nonblocking(self.fd, nonblocking) }
339 }
340}
341
342impl IntoRawFd for VsockSocket {
343 fn into_raw_fd(self) -> RawFd {
344 let fd = self.fd;
345 mem::forget(self);
346 fd
347 }
348}
349
350impl AsRawFd for VsockSocket {
351 fn as_raw_fd(&self) -> RawFd {
352 self.fd
353 }
354}
355
356impl Drop for VsockSocket {
357 fn drop(&mut self) {
358 unsafe { libc::close(self.fd) };
362 }
363}
364
365#[derive(Debug)]
367pub struct VsockStream {
368 sock: VsockSocket,
369}
370
371impl VsockStream {
372 pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
373 let sock = VsockSocket::new()?;
374 sock.connect(addr)
375 }
376
377 pub fn local_port(&self) -> io::Result<u32> {
379 self.sock.local_port()
380 }
381
382 pub fn try_clone(&self) -> io::Result<VsockStream> {
383 self.sock.try_clone().map(|f| VsockStream { sock: f })
384 }
385
386 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
387 self.sock.set_nonblocking(nonblocking)
388 }
389}
390
391impl io::Read for VsockStream {
392 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
393 let ret = unsafe {
396 libc::read(
397 self.sock.as_raw_fd(),
398 buf as *mut [u8] as *mut c_void,
399 buf.len() as size_t,
400 )
401 };
402 if ret < 0 {
403 return Err(io::Error::last_os_error());
404 }
405
406 Ok(ret as usize)
407 }
408}
409
410impl io::Write for VsockStream {
411 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
412 let ret = unsafe {
415 libc::write(
416 self.sock.as_raw_fd(),
417 buf as *const [u8] as *const c_void,
418 buf.len() as size_t,
419 )
420 };
421 if ret < 0 {
422 return Err(io::Error::last_os_error());
423 }
424
425 Ok(ret as usize)
426 }
427
428 fn flush(&mut self) -> io::Result<()> {
429 Ok(())
431 }
432}
433
434impl AsRawFd for VsockStream {
435 fn as_raw_fd(&self) -> RawFd {
436 self.sock.as_raw_fd()
437 }
438}
439
440impl IntoRawFd for VsockStream {
441 fn into_raw_fd(self) -> RawFd {
442 self.sock.into_raw_fd()
443 }
444}
445
446#[derive(Debug)]
448pub struct VsockListener {
449 sock: VsockSocket,
450}
451
452impl VsockListener {
453 pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
456 let mut sock = VsockSocket::new()?;
457 sock.bind(addr)?;
458 sock.listen()
459 }
460
461 pub fn local_port(&self) -> io::Result<u32> {
463 self.sock.local_port()
464 }
465
466 pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
470 let mut svm: sockaddr_vm = Default::default();
471
472 let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
473 let fd = unsafe {
476 libc::accept4(
477 self.sock.as_raw_fd(),
478 &mut svm as *mut sockaddr_vm as *mut sockaddr,
479 &mut socklen as *mut socklen_t,
480 libc::SOCK_CLOEXEC,
481 )
482 };
483 if fd < 0 {
484 return Err(io::Error::last_os_error());
485 }
486
487 if svm.svm_family != AF_VSOCK {
488 return Err(io::Error::new(
489 io::ErrorKind::InvalidData,
490 format!("unexpected address family: {}", svm.svm_family),
491 ));
492 }
493
494 Ok((
495 VsockStream {
496 sock: VsockSocket { fd },
497 },
498 SocketAddr {
499 cid: svm.svm_cid.into(),
500 port: svm.svm_port,
501 },
502 ))
503 }
504
505 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
506 self.sock.set_nonblocking(nonblocking)
507 }
508}
509
510impl AsRawFd for VsockListener {
511 fn as_raw_fd(&self) -> RawFd {
512 self.sock.as_raw_fd()
513 }
514}