base/sys/linux/
vsock.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5/// Support for virtual sockets.
6use std::fmt;
7use std::io;
8use std::mem;
9use std::mem::size_of;
10use std::num::ParseIntError;
11use std::os::raw::c_uchar;
12use std::os::raw::c_uint;
13use std::os::raw::c_ushort;
14use std::os::unix::io::AsRawFd;
15use std::os::unix::io::IntoRawFd;
16use std::os::unix::io::RawFd;
17use std::result;
18use std::str::FromStr;
19
20use libc::c_void;
21use libc::sa_family_t;
22use libc::size_t;
23use libc::sockaddr;
24use libc::socklen_t;
25use libc::F_GETFL;
26use libc::F_SETFL;
27use libc::O_NONBLOCK;
28use libc::VMADDR_CID_ANY;
29use libc::VMADDR_CID_HOST;
30use libc::VMADDR_CID_HYPERVISOR;
31use thiserror::Error;
32
33// The domain for vsock sockets.
34const AF_VSOCK: sa_family_t = 40;
35
36// Vsock loopback address.
37const VMADDR_CID_LOCAL: c_uint = 1;
38
39/// Vsock equivalent of binding on port 0. Binds to a random port.
40pub const VMADDR_PORT_ANY: c_uint = c_uint::MAX;
41
42// The number of bytes of padding to be added to the sockaddr_vm struct.  Taken directly
43// from linux/vm_sockets.h.
44const PADDING: usize = size_of::<sockaddr>()
45    - size_of::<sa_family_t>()
46    - size_of::<c_ushort>()
47    - (2 * size_of::<c_uint>());
48
49#[repr(C)]
50#[derive(Default)]
51struct sockaddr_vm {
52    svm_family: sa_family_t,
53    svm_reserved1: c_ushort,
54    svm_port: c_uint,
55    svm_cid: c_uint,
56    svm_zero: [c_uchar; PADDING],
57}
58
59#[derive(Error, Debug)]
60#[error("failed to parse vsock address")]
61pub struct AddrParseError;
62
63/// The vsock equivalent of an IP address.
64#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
65pub enum VsockCid {
66    /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint.
67    Any,
68    /// An address that refers to the bare-metal machine that serves as the hypervisor.
69    Hypervisor,
70    /// The loopback address.
71    Local,
72    /// The parent machine. It may not be the hypervisor for nested VMs.
73    Host,
74    /// An assigned CID that serves as the address for VSOCK.
75    Cid(c_uint),
76}
77
78impl fmt::Display for VsockCid {
79    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
80        match &self {
81            VsockCid::Any => write!(fmt, "Any"),
82            VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
83            VsockCid::Local => write!(fmt, "Local"),
84            VsockCid::Host => write!(fmt, "Host"),
85            VsockCid::Cid(c) => write!(fmt, "'{c}'"),
86        }
87    }
88}
89
90impl From<c_uint> for VsockCid {
91    fn from(c: c_uint) -> Self {
92        match c {
93            VMADDR_CID_ANY => VsockCid::Any,
94            VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
95            VMADDR_CID_LOCAL => VsockCid::Local,
96            VMADDR_CID_HOST => VsockCid::Host,
97            _ => VsockCid::Cid(c),
98        }
99    }
100}
101
102impl FromStr for VsockCid {
103    type Err = ParseIntError;
104
105    fn from_str(s: &str) -> Result<Self, Self::Err> {
106        let c: c_uint = s.parse()?;
107        Ok(c.into())
108    }
109}
110
111impl From<VsockCid> for c_uint {
112    fn from(cid: VsockCid) -> c_uint {
113        match cid {
114            VsockCid::Any => VMADDR_CID_ANY,
115            VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
116            VsockCid::Local => VMADDR_CID_LOCAL,
117            VsockCid::Host => VMADDR_CID_HOST,
118            VsockCid::Cid(c) => c,
119        }
120    }
121}
122
123/// An address associated with a virtual socket.
124#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
125pub struct SocketAddr {
126    pub cid: VsockCid,
127    pub port: c_uint,
128}
129
130pub trait ToSocketAddr {
131    fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
132}
133
134impl ToSocketAddr for SocketAddr {
135    fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
136        Ok(*self)
137    }
138}
139
140impl ToSocketAddr for str {
141    fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
142        self.parse()
143    }
144}
145
146impl ToSocketAddr for (VsockCid, c_uint) {
147    fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
148        let (cid, port) = *self;
149        Ok(SocketAddr { cid, port })
150    }
151}
152
153impl<T: ToSocketAddr + ?Sized> ToSocketAddr for &T {
154    fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
155        (**self).to_socket_addr()
156    }
157}
158
159impl FromStr for SocketAddr {
160    type Err = AddrParseError;
161
162    /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
163    /// "vsock:cid:port".
164    fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
165        let components: Vec<&str> = s.split(':').collect();
166        if components.len() != 3 || components[0] != "vsock" {
167            return Err(AddrParseError);
168        }
169
170        Ok(SocketAddr {
171            cid: components[1].parse().map_err(|_| AddrParseError)?,
172            port: components[2].parse().map_err(|_| AddrParseError)?,
173        })
174    }
175}
176
177impl fmt::Display for SocketAddr {
178    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
179        write!(fmt, "{}:{}", self.cid, self.port)
180    }
181}
182
183/// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the
184/// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets.
185unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
186    let flags = libc::fcntl(fd, F_GETFL, 0);
187    if flags < 0 {
188        return Err(io::Error::last_os_error());
189    }
190
191    let flags = if nonblocking {
192        flags | O_NONBLOCK
193    } else {
194        flags & !O_NONBLOCK
195    };
196
197    let ret = libc::fcntl(fd, F_SETFL, flags);
198    if ret < 0 {
199        return Err(io::Error::last_os_error());
200    }
201
202    Ok(())
203}
204
205/// A virtual socket.
206///
207/// Do not use this class unless you need to change socket options or query the
208/// state of the socket prior to calling listen or connect. Instead use either VsockStream or
209/// VsockListener.
210#[derive(Debug)]
211pub struct VsockSocket {
212    fd: RawFd,
213}
214
215impl VsockSocket {
216    pub fn new() -> io::Result<Self> {
217        // SAFETY: trivially safe
218        let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
219        if fd < 0 {
220            Err(io::Error::last_os_error())
221        } else {
222            Ok(VsockSocket { fd })
223        }
224    }
225
226    pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
227        let sockaddr = addr
228            .to_socket_addr()
229            .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
230
231        // The compiler should optimize this out since these are both compile-time constants.
232        assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
233
234        let svm = sockaddr_vm {
235            svm_family: AF_VSOCK,
236            svm_cid: sockaddr.cid.into(),
237            svm_port: sockaddr.port,
238            ..Default::default()
239        };
240
241        // SAFETY:
242        // Safe because this doesn't modify any memory and we check the return value.
243        let ret = unsafe {
244            libc::bind(
245                self.fd,
246                &svm as *const sockaddr_vm as *const sockaddr,
247                size_of::<sockaddr_vm>() as socklen_t,
248            )
249        };
250        if ret < 0 {
251            let bind_err = io::Error::last_os_error();
252            Err(bind_err)
253        } else {
254            Ok(())
255        }
256    }
257
258    pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
259        let sockaddr = addr
260            .to_socket_addr()
261            .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
262
263        let svm = sockaddr_vm {
264            svm_family: AF_VSOCK,
265            svm_cid: sockaddr.cid.into(),
266            svm_port: sockaddr.port,
267            ..Default::default()
268        };
269
270        // SAFETY:
271        // Safe because this just connects a vsock socket, and the return value is checked.
272        let ret = unsafe {
273            libc::connect(
274                self.fd,
275                &svm as *const sockaddr_vm as *const sockaddr,
276                size_of::<sockaddr_vm>() as socklen_t,
277            )
278        };
279        if ret < 0 {
280            let connect_err = io::Error::last_os_error();
281            Err(connect_err)
282        } else {
283            Ok(VsockStream { sock: self })
284        }
285    }
286
287    pub fn listen(self) -> io::Result<VsockListener> {
288        // SAFETY:
289        // Safe because this doesn't modify any memory and we check the return value.
290        let ret = unsafe { libc::listen(self.fd, 1) };
291        if ret < 0 {
292            let listen_err = io::Error::last_os_error();
293            return Err(listen_err);
294        }
295        Ok(VsockListener { sock: self })
296    }
297
298    /// Returns the port that this socket is bound to. This can only succeed after bind is called.
299    pub fn local_port(&self) -> io::Result<u32> {
300        let mut svm: sockaddr_vm = Default::default();
301
302        let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
303        // SAFETY:
304        // Safe because we give a valid pointer for addrlen and check the length.
305        let ret = unsafe {
306            // Get the socket address that was actually bound.
307            libc::getsockname(
308                self.fd,
309                &mut svm as *mut sockaddr_vm as *mut sockaddr,
310                &mut addrlen as *mut socklen_t,
311            )
312        };
313        if ret < 0 {
314            let getsockname_err = io::Error::last_os_error();
315            Err(getsockname_err)
316        } else {
317            // If this doesn't match, it's not safe to get the port out of the sockaddr.
318            assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
319
320            Ok(svm.svm_port)
321        }
322    }
323
324    pub fn try_clone(&self) -> io::Result<Self> {
325        // SAFETY:
326        // Safe because this doesn't modify any memory and we check the return value.
327        let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
328        if dup_fd < 0 {
329            Err(io::Error::last_os_error())
330        } else {
331            Ok(Self { fd: dup_fd })
332        }
333    }
334
335    pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
336        // SAFETY:
337        // Safe because the fd is valid and owned by this stream.
338        unsafe { set_nonblocking(self.fd, nonblocking) }
339    }
340}
341
342impl IntoRawFd for VsockSocket {
343    fn into_raw_fd(self) -> RawFd {
344        let fd = self.fd;
345        mem::forget(self);
346        fd
347    }
348}
349
350impl AsRawFd for VsockSocket {
351    fn as_raw_fd(&self) -> RawFd {
352        self.fd
353    }
354}
355
356impl Drop for VsockSocket {
357    fn drop(&mut self) {
358        // SAFETY:
359        // Safe because this doesn't modify any memory and we are the only
360        // owner of the file descriptor.
361        unsafe { libc::close(self.fd) };
362    }
363}
364
365/// A virtual stream socket.
366#[derive(Debug)]
367pub struct VsockStream {
368    sock: VsockSocket,
369}
370
371impl VsockStream {
372    pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
373        let sock = VsockSocket::new()?;
374        sock.connect(addr)
375    }
376
377    /// Returns the port that this stream is bound to.
378    pub fn local_port(&self) -> io::Result<u32> {
379        self.sock.local_port()
380    }
381
382    pub fn try_clone(&self) -> io::Result<VsockStream> {
383        self.sock.try_clone().map(|f| VsockStream { sock: f })
384    }
385
386    pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
387        self.sock.set_nonblocking(nonblocking)
388    }
389}
390
391impl io::Read for VsockStream {
392    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
393        // SAFETY:
394        // Safe because this will only modify the contents of |buf| and we check the return value.
395        let ret = unsafe {
396            libc::read(
397                self.sock.as_raw_fd(),
398                buf as *mut [u8] as *mut c_void,
399                buf.len() as size_t,
400            )
401        };
402        if ret < 0 {
403            return Err(io::Error::last_os_error());
404        }
405
406        Ok(ret as usize)
407    }
408}
409
410impl io::Write for VsockStream {
411    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
412        // SAFETY:
413        // Safe because this doesn't modify any memory and we check the return value.
414        let ret = unsafe {
415            libc::write(
416                self.sock.as_raw_fd(),
417                buf as *const [u8] as *const c_void,
418                buf.len() as size_t,
419            )
420        };
421        if ret < 0 {
422            return Err(io::Error::last_os_error());
423        }
424
425        Ok(ret as usize)
426    }
427
428    fn flush(&mut self) -> io::Result<()> {
429        // No buffered data so nothing to do.
430        Ok(())
431    }
432}
433
434impl AsRawFd for VsockStream {
435    fn as_raw_fd(&self) -> RawFd {
436        self.sock.as_raw_fd()
437    }
438}
439
440impl IntoRawFd for VsockStream {
441    fn into_raw_fd(self) -> RawFd {
442        self.sock.into_raw_fd()
443    }
444}
445
446/// Represents a virtual socket server.
447#[derive(Debug)]
448pub struct VsockListener {
449    sock: VsockSocket,
450}
451
452impl VsockListener {
453    /// Creates a new `VsockListener` bound to the specified port on the current virtual socket
454    /// endpoint.
455    pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
456        let mut sock = VsockSocket::new()?;
457        sock.bind(addr)?;
458        sock.listen()
459    }
460
461    /// Returns the port that this listener is bound to.
462    pub fn local_port(&self) -> io::Result<u32> {
463        self.sock.local_port()
464    }
465
466    /// Accepts a new incoming connection on this listener.  Blocks the calling thread until a
467    /// new connection is established.  When established, returns the corresponding `VsockStream`
468    /// and the remote peer's address.
469    pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
470        let mut svm: sockaddr_vm = Default::default();
471
472        let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
473        // SAFETY:
474        // Safe because this will only modify |svm| and we check the return value.
475        let fd = unsafe {
476            libc::accept4(
477                self.sock.as_raw_fd(),
478                &mut svm as *mut sockaddr_vm as *mut sockaddr,
479                &mut socklen as *mut socklen_t,
480                libc::SOCK_CLOEXEC,
481            )
482        };
483        if fd < 0 {
484            return Err(io::Error::last_os_error());
485        }
486
487        if svm.svm_family != AF_VSOCK {
488            return Err(io::Error::new(
489                io::ErrorKind::InvalidData,
490                format!("unexpected address family: {}", svm.svm_family),
491            ));
492        }
493
494        Ok((
495            VsockStream {
496                sock: VsockSocket { fd },
497            },
498            SocketAddr {
499                cid: svm.svm_cid.into(),
500                port: svm.svm_port,
501            },
502        ))
503    }
504
505    pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
506        self.sock.set_nonblocking(nonblocking)
507    }
508}
509
510impl AsRawFd for VsockListener {
511    fn as_raw_fd(&self) -> RawFd {
512        self.sock.as_raw_fd()
513    }
514}