vmm_vhost/sys/
unix.rs

1// Copyright 2022 The Chromium OS Authors. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Unix specific code that keeps rest of the code in the crate platform independent.
5
6use std::any::Any;
7use std::fs::File;
8use std::io::ErrorKind;
9use std::io::IoSlice;
10use std::io::IoSliceMut;
11use std::os::fd::OwnedFd;
12use std::os::unix::net::UnixListener;
13use std::os::unix::net::UnixStream;
14use std::path::Path;
15use std::path::PathBuf;
16
17use base::AsRawDescriptor;
18use base::RawDescriptor;
19use base::ReadNotifier;
20use base::SafeDescriptor;
21use base::ScmSocket;
22
23use crate::connection::Listener;
24use crate::frontend_server::FrontendServer;
25use crate::message::MAX_ATTACHED_FD_ENTRIES;
26use crate::Connection;
27use crate::Error;
28use crate::Frontend;
29use crate::Result;
30
31/// Alias to enable platform independent code.
32pub type SystemListener = UnixListener;
33
34pub use SocketPlatformConnection as PlatformConnection;
35
36/// Unix domain socket listener for accepting incoming connections.
37pub struct SocketListener {
38    fd: SystemListener,
39    drop_path: Option<Box<dyn Any>>,
40}
41
42impl SocketListener {
43    /// Create a unix domain socket listener.
44    ///
45    /// # Return:
46    /// * - the new SocketListener object on success.
47    /// * - SocketError: failed to create listener socket.
48    pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
49        if unlink {
50            let _ = std::fs::remove_file(&path);
51        }
52        let fd = SystemListener::bind(&path).map_err(Error::SocketError)?;
53
54        struct DropPath {
55            path: PathBuf,
56        }
57
58        impl Drop for DropPath {
59            fn drop(&mut self) {
60                let _ = std::fs::remove_file(&self.path);
61            }
62        }
63
64        Ok(SocketListener {
65            fd,
66            drop_path: Some(Box::new(DropPath {
67                path: path.as_ref().to_owned(),
68            })),
69        })
70    }
71
72    /// Take and return the resources that the parent process needs to keep alive as long as the
73    /// child process lives, in case of incoming fork.
74    pub fn take_resources_for_parent(&mut self) -> Option<Box<dyn Any>> {
75        self.drop_path.take()
76    }
77}
78
79impl Listener for SocketListener {
80    /// Accept an incoming connection.
81    ///
82    /// # Return:
83    /// * - Some(SystemListener): new SystemListener object if new incoming connection is available.
84    /// * - None: no incoming connection available.
85    /// * - SocketError: errors from accept().
86    fn accept(&mut self) -> Result<Option<Connection>> {
87        loop {
88            match self.fd.accept() {
89                Ok((stream, _addr)) => {
90                    return Ok(Some(Connection::try_from(stream)?));
91                }
92                Err(e) => {
93                    match e.kind() {
94                        // No incoming connection available.
95                        ErrorKind::WouldBlock => return Ok(None),
96                        // New connection closed by peer.
97                        ErrorKind::ConnectionAborted => return Ok(None),
98                        // Interrupted by signals, retry
99                        ErrorKind::Interrupted => continue,
100                        _ => return Err(Error::SocketError(e)),
101                    }
102                }
103            }
104        }
105    }
106
107    /// Change blocking status on the listener.
108    ///
109    /// # Return:
110    /// * - () on success.
111    /// * - SocketError: failure from set_nonblocking().
112    fn set_nonblocking(&self, block: bool) -> Result<()> {
113        self.fd.set_nonblocking(block).map_err(Error::SocketError)
114    }
115}
116
117impl AsRawDescriptor for SocketListener {
118    fn as_raw_descriptor(&self) -> RawDescriptor {
119        self.fd.as_raw_descriptor()
120    }
121}
122
123/// Unix domain socket based vhost-user connection.
124pub struct SocketPlatformConnection {
125    sock: ScmSocket<UnixStream>,
126}
127
128// Advance the internal cursor of the slices.
129// This is same with a nightly API `IoSlice::advance_slices` but for `&[u8]`.
130fn advance_slices(bufs: &mut &mut [&[u8]], mut count: usize) {
131    use std::mem::take;
132
133    let mut idx = 0;
134    for b in bufs.iter() {
135        if count < b.len() {
136            break;
137        }
138        count -= b.len();
139        idx += 1;
140    }
141    *bufs = &mut take(bufs)[idx..];
142    if !bufs.is_empty() {
143        bufs[0] = &bufs[0][count..];
144    }
145}
146
147impl SocketPlatformConnection {
148    /// Sends all bytes from scatter-gather vectors with optional attached file descriptors. Will
149    /// loop until all data has been transfered.
150    ///
151    /// # TODO
152    /// This function takes a slice of `&[u8]` instead of `IoSlice` because the internal
153    /// cursor needs to be moved by `advance_slices()`.
154    /// Once `IoSlice::advance_slices()` becomes stable, this should be updated.
155    /// <https://github.com/rust-lang/rust/issues/62726>.
156    fn send_iovec_all(
157        &self,
158        mut iovs: &mut [&[u8]],
159        mut fds: Option<&[RawDescriptor]>,
160    ) -> Result<()> {
161        // Guarantee that `iovs` becomes empty if it doesn't contain any data.
162        advance_slices(&mut iovs, 0);
163
164        while !iovs.is_empty() {
165            let iovec: Vec<_> = iovs.iter_mut().map(|i| IoSlice::new(i)).collect();
166            match self.sock.send_vectored_with_fds(&iovec, fds.unwrap_or(&[])) {
167                Ok(n) => {
168                    fds = None;
169                    advance_slices(&mut iovs, n);
170                }
171                Err(e) => match e.kind() {
172                    ErrorKind::WouldBlock | ErrorKind::Interrupted => {}
173                    _ => return Err(Error::SocketError(e)),
174                },
175            }
176        }
177        Ok(())
178    }
179
180    /// Sends a single message over the socket with optional attached file descriptors.
181    ///
182    /// - `hdr`: vhost message header
183    /// - `body`: vhost message body (may be empty to send a header-only message)
184    /// - `payload`: additional bytes to append to `body` (may be empty)
185    pub fn send_message(
186        &self,
187        hdr: &[u8],
188        body: &[u8],
189        payload: &[u8],
190        fds: Option<&[RawDescriptor]>,
191    ) -> Result<()> {
192        let mut iobufs = [hdr, body, payload];
193        self.send_iovec_all(&mut iobufs, fds)
194    }
195
196    /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
197    /// file.
198    ///
199    /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
200    /// tricky to pass file descriptors through such a communication channel. Let's assume that a
201    /// sender sending a message with some file descriptors attached. To successfully receive those
202    /// attached file descriptors, the receiver must obey following rules:
203    ///   1) file descriptors are attached to a message.
204    ///   2) message(packet) boundaries must be respected on the receive side.
205    ///
206    /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
207    /// attached file descriptors will get lost.
208    /// Note that this function wraps received file descriptors as `File`.
209    ///
210    /// # Return:
211    /// * - (number of bytes received, [received files]) on success
212    /// * - Disconnect: the connection is closed.
213    /// * - SocketRetry: temporary error caused by signals or short of resources.
214    /// * - SocketError: other socket related errors.
215    pub fn recv_into_bufs(
216        &self,
217        bufs: &mut [IoSliceMut],
218        allow_fd: bool,
219    ) -> Result<(usize, Option<Vec<File>>)> {
220        let max_fds = if allow_fd { MAX_ATTACHED_FD_ENTRIES } else { 0 };
221        let (bytes, fds) = match self.sock.recv_vectored_with_fds(bufs, max_fds) {
222            Ok((bytes, fds)) => (bytes, fds),
223            Err(e) => {
224                return Err(match e.kind() {
225                    ErrorKind::WouldBlock | ErrorKind::Interrupted | ErrorKind::OutOfMemory => {
226                        Error::SocketRetry(e)
227                    }
228                    _ => Error::SocketError(e),
229                });
230            }
231        };
232
233        // 0-bytes indicates that the connection is closed.
234        if bytes == 0 {
235            return Err(Error::Disconnect);
236        }
237
238        let files = if fds.is_empty() {
239            None
240        } else {
241            Some(fds.into_iter().map(File::from).collect())
242        };
243
244        Ok((bytes, files))
245    }
246}
247
248impl AsRawDescriptor for SocketPlatformConnection {
249    fn as_raw_descriptor(&self) -> RawDescriptor {
250        self.sock.as_raw_descriptor()
251    }
252}
253
254impl ReadNotifier for SocketPlatformConnection {
255    fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
256        &self.sock
257    }
258}
259
260impl TryFrom<SafeDescriptor> for Connection {
261    type Error = Error;
262
263    fn try_from(fd: SafeDescriptor) -> Result<Self> {
264        UnixStream::from(fd).try_into()
265    }
266}
267
268impl TryFrom<UnixStream> for Connection {
269    type Error = Error;
270
271    fn try_from(sock: UnixStream) -> Result<Self> {
272        Ok(Self(
273            SocketPlatformConnection {
274                sock: sock.try_into().map_err(Error::SocketError)?,
275            },
276            std::marker::PhantomData,
277        ))
278    }
279}
280
281impl Connection {
282    /// Create a pair of unnamed vhost-user connections connected to each other.
283    pub fn pair() -> Result<(Self, Self)> {
284        let (client, server) = UnixStream::pair().map_err(Error::SocketError)?;
285        let client_connection = Connection::try_from(client)?;
286        let server_connection = Connection::try_from(server)?;
287        Ok((client_connection, server_connection))
288    }
289}
290
291impl<S: Frontend> AsRawDescriptor for FrontendServer<S> {
292    fn as_raw_descriptor(&self) -> RawDescriptor {
293        self.sub_sock.as_raw_descriptor()
294    }
295}
296
297impl<S: Frontend> ReadNotifier for FrontendServer<S> {
298    fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
299        self.sub_sock.0.get_read_notifier()
300    }
301}
302
303impl<S: Frontend> FrontendServer<S> {
304    /// Create a `FrontendServer` that uses a Unix stream internally.
305    ///
306    /// The returned `SafeDescriptor` is the client side of the stream and should be sent to the
307    /// backend using [BackendClient::set_slave_request_fd()].
308    ///
309    /// [BackendClient::set_slave_request_fd()]: struct.BackendClient.html#method.set_slave_request_fd
310    pub fn with_stream(backend: S) -> Result<(Self, SafeDescriptor)> {
311        let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?;
312        let rx_connection = Connection::try_from(rx)?;
313        Ok((
314            Self::new(backend, rx_connection)?,
315            SafeDescriptor::from(OwnedFd::from(tx)),
316        ))
317    }
318}
319
320#[cfg(test)]
321pub(crate) mod tests {
322    use tempfile::Builder;
323    use tempfile::TempDir;
324
325    use super::*;
326    use crate::backend_client::BackendClient;
327    use crate::connection::Listener;
328    use crate::Connection;
329
330    pub(crate) fn temp_dir() -> TempDir {
331        Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
332    }
333
334    #[test]
335    fn create_listener() {
336        let dir = temp_dir();
337        let mut path = dir.path().to_owned();
338        path.push("sock");
339        let listener = SocketListener::new(&path, true).unwrap();
340
341        assert!(listener.as_raw_descriptor() > 0);
342    }
343
344    #[test]
345    fn accept_connection() {
346        let dir = temp_dir();
347        let mut path = dir.path().to_owned();
348        path.push("sock");
349        let mut listener = SocketListener::new(&path, true).unwrap();
350        listener.set_nonblocking(true).unwrap();
351
352        // accept on a fd without incoming connection
353        let conn = listener.accept().unwrap();
354        assert!(conn.is_none());
355    }
356
357    #[test]
358    fn test_create_failure() {
359        let dir = temp_dir();
360        let mut path = dir.path().to_owned();
361        path.push("sock");
362        let _ = SocketListener::new(&path, true).unwrap();
363        let _ = SocketListener::new(&path, false).is_err();
364        assert!(UnixStream::connect(&path).is_err());
365
366        let mut listener = SocketListener::new(&path, true).unwrap();
367        assert!(SocketListener::new(&path, false).is_err());
368        listener.set_nonblocking(true).unwrap();
369
370        let sock = UnixStream::connect(&path).unwrap();
371        let backend_connection = Connection::try_from(sock).unwrap();
372        let _backend_client = BackendClient::new(backend_connection);
373        let _server_connection = listener.accept().unwrap().unwrap();
374    }
375
376    #[test]
377    fn test_advance_slices() {
378        // Test case from https://doc.rust-lang.org/std/io/struct.IoSlice.html#method.advance_slices
379        let buf1 = [1; 8];
380        let buf2 = [2; 16];
381        let buf3 = [3; 8];
382        let mut bufs = &mut [&buf1[..], &buf2[..], &buf3[..]][..];
383        advance_slices(&mut bufs, 10);
384        assert_eq!(bufs[0], [2; 14].as_ref());
385        assert_eq!(bufs[1], [3; 8].as_ref());
386    }
387}