vmm_vhost/sys/
unix.rs

1// Copyright 2022 The Chromium OS Authors. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Unix specific code that keeps rest of the code in the crate platform independent.
5
6use std::any::Any;
7use std::fs::File;
8use std::io::ErrorKind;
9use std::io::IoSlice;
10use std::io::IoSliceMut;
11use std::os::fd::OwnedFd;
12use std::os::unix::net::UnixListener;
13use std::os::unix::net::UnixStream;
14use std::path::Path;
15use std::path::PathBuf;
16
17use base::AsRawDescriptor;
18use base::RawDescriptor;
19use base::ReadNotifier;
20use base::SafeDescriptor;
21use base::ScmSocket;
22
23use crate::connection::Listener;
24use crate::frontend_server::FrontendServer;
25use crate::message::FrontendReq;
26use crate::message::Req;
27use crate::message::MAX_ATTACHED_FD_ENTRIES;
28use crate::Connection;
29use crate::Error;
30use crate::Frontend;
31use crate::Result;
32
33/// Alias to enable platform independent code.
34pub type SystemListener = UnixListener;
35
36pub use SocketPlatformConnection as PlatformConnection;
37
38/// Unix domain socket listener for accepting incoming connections.
39pub struct SocketListener {
40    fd: SystemListener,
41    drop_path: Option<Box<dyn Any>>,
42}
43
44impl SocketListener {
45    /// Create a unix domain socket listener.
46    ///
47    /// # Return:
48    /// * - the new SocketListener object on success.
49    /// * - SocketError: failed to create listener socket.
50    pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
51        if unlink {
52            let _ = std::fs::remove_file(&path);
53        }
54        let fd = SystemListener::bind(&path).map_err(Error::SocketError)?;
55
56        struct DropPath {
57            path: PathBuf,
58        }
59
60        impl Drop for DropPath {
61            fn drop(&mut self) {
62                let _ = std::fs::remove_file(&self.path);
63            }
64        }
65
66        Ok(SocketListener {
67            fd,
68            drop_path: Some(Box::new(DropPath {
69                path: path.as_ref().to_owned(),
70            })),
71        })
72    }
73
74    /// Take and return the resources that the parent process needs to keep alive as long as the
75    /// child process lives, in case of incoming fork.
76    pub fn take_resources_for_parent(&mut self) -> Option<Box<dyn Any>> {
77        self.drop_path.take()
78    }
79}
80
81impl Listener for SocketListener {
82    /// Accept an incoming connection.
83    ///
84    /// # Return:
85    /// * - Some(SystemListener): new SystemListener object if new incoming connection is available.
86    /// * - None: no incoming connection available.
87    /// * - SocketError: errors from accept().
88    fn accept(&mut self) -> Result<Option<Connection<FrontendReq>>> {
89        loop {
90            match self.fd.accept() {
91                Ok((stream, _addr)) => {
92                    return Ok(Some(Connection::try_from(stream)?));
93                }
94                Err(e) => {
95                    match e.kind() {
96                        // No incoming connection available.
97                        ErrorKind::WouldBlock => return Ok(None),
98                        // New connection closed by peer.
99                        ErrorKind::ConnectionAborted => return Ok(None),
100                        // Interrupted by signals, retry
101                        ErrorKind::Interrupted => continue,
102                        _ => return Err(Error::SocketError(e)),
103                    }
104                }
105            }
106        }
107    }
108
109    /// Change blocking status on the listener.
110    ///
111    /// # Return:
112    /// * - () on success.
113    /// * - SocketError: failure from set_nonblocking().
114    fn set_nonblocking(&self, block: bool) -> Result<()> {
115        self.fd.set_nonblocking(block).map_err(Error::SocketError)
116    }
117}
118
119impl AsRawDescriptor for SocketListener {
120    fn as_raw_descriptor(&self) -> RawDescriptor {
121        self.fd.as_raw_descriptor()
122    }
123}
124
125/// Unix domain socket based vhost-user connection.
126pub struct SocketPlatformConnection {
127    sock: ScmSocket<UnixStream>,
128}
129
130// Advance the internal cursor of the slices.
131// This is same with a nightly API `IoSlice::advance_slices` but for `&[u8]`.
132fn advance_slices(bufs: &mut &mut [&[u8]], mut count: usize) {
133    use std::mem::take;
134
135    let mut idx = 0;
136    for b in bufs.iter() {
137        if count < b.len() {
138            break;
139        }
140        count -= b.len();
141        idx += 1;
142    }
143    *bufs = &mut take(bufs)[idx..];
144    if !bufs.is_empty() {
145        bufs[0] = &bufs[0][count..];
146    }
147}
148
149impl SocketPlatformConnection {
150    /// Sends all bytes from scatter-gather vectors with optional attached file descriptors. Will
151    /// loop until all data has been transfered.
152    ///
153    /// # TODO
154    /// This function takes a slice of `&[u8]` instead of `IoSlice` because the internal
155    /// cursor needs to be moved by `advance_slices()`.
156    /// Once `IoSlice::advance_slices()` becomes stable, this should be updated.
157    /// <https://github.com/rust-lang/rust/issues/62726>.
158    fn send_iovec_all(
159        &self,
160        mut iovs: &mut [&[u8]],
161        mut fds: Option<&[RawDescriptor]>,
162    ) -> Result<()> {
163        // Guarantee that `iovs` becomes empty if it doesn't contain any data.
164        advance_slices(&mut iovs, 0);
165
166        while !iovs.is_empty() {
167            let iovec: Vec<_> = iovs.iter_mut().map(|i| IoSlice::new(i)).collect();
168            match self.sock.send_vectored_with_fds(&iovec, fds.unwrap_or(&[])) {
169                Ok(n) => {
170                    fds = None;
171                    advance_slices(&mut iovs, n);
172                }
173                Err(e) => match e.kind() {
174                    ErrorKind::WouldBlock | ErrorKind::Interrupted => {}
175                    _ => return Err(Error::SocketError(e)),
176                },
177            }
178        }
179        Ok(())
180    }
181
182    /// Sends a single message over the socket with optional attached file descriptors.
183    ///
184    /// - `hdr`: vhost message header
185    /// - `body`: vhost message body (may be empty to send a header-only message)
186    /// - `payload`: additional bytes to append to `body` (may be empty)
187    pub fn send_message(
188        &self,
189        hdr: &[u8],
190        body: &[u8],
191        payload: &[u8],
192        fds: Option<&[RawDescriptor]>,
193    ) -> Result<()> {
194        let mut iobufs = [hdr, body, payload];
195        self.send_iovec_all(&mut iobufs, fds)
196    }
197
198    /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
199    /// file.
200    ///
201    /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
202    /// tricky to pass file descriptors through such a communication channel. Let's assume that a
203    /// sender sending a message with some file descriptors attached. To successfully receive those
204    /// attached file descriptors, the receiver must obey following rules:
205    ///   1) file descriptors are attached to a message.
206    ///   2) message(packet) boundaries must be respected on the receive side.
207    ///
208    /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
209    /// attached file descriptors will get lost.
210    /// Note that this function wraps received file descriptors as `File`.
211    ///
212    /// # Return:
213    /// * - (number of bytes received, [received files]) on success
214    /// * - Disconnect: the connection is closed.
215    /// * - SocketRetry: temporary error caused by signals or short of resources.
216    /// * - SocketError: other socket related errors.
217    pub fn recv_into_bufs(
218        &self,
219        bufs: &mut [IoSliceMut],
220        allow_fd: bool,
221    ) -> Result<(usize, Option<Vec<File>>)> {
222        let max_fds = if allow_fd { MAX_ATTACHED_FD_ENTRIES } else { 0 };
223        let (bytes, fds) = match self.sock.recv_vectored_with_fds(bufs, max_fds) {
224            Ok((bytes, fds)) => (bytes, fds),
225            Err(e) => {
226                return Err(match e.kind() {
227                    ErrorKind::WouldBlock | ErrorKind::Interrupted | ErrorKind::OutOfMemory => {
228                        Error::SocketRetry(e)
229                    }
230                    _ => Error::SocketError(e),
231                });
232            }
233        };
234
235        // 0-bytes indicates that the connection is closed.
236        if bytes == 0 {
237            return Err(Error::Disconnect);
238        }
239
240        let files = if fds.is_empty() {
241            None
242        } else {
243            Some(fds.into_iter().map(File::from).collect())
244        };
245
246        Ok((bytes, files))
247    }
248}
249
250impl AsRawDescriptor for SocketPlatformConnection {
251    fn as_raw_descriptor(&self) -> RawDescriptor {
252        self.sock.as_raw_descriptor()
253    }
254}
255
256impl ReadNotifier for SocketPlatformConnection {
257    fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
258        &self.sock
259    }
260}
261
262impl<R: Req> TryFrom<SafeDescriptor> for Connection<R> {
263    type Error = Error;
264
265    fn try_from(fd: SafeDescriptor) -> Result<Self> {
266        UnixStream::from(fd).try_into()
267    }
268}
269
270impl<R: Req> TryFrom<UnixStream> for Connection<R> {
271    type Error = Error;
272
273    fn try_from(sock: UnixStream) -> Result<Self> {
274        Ok(Self(
275            SocketPlatformConnection {
276                sock: sock.try_into().map_err(Error::SocketError)?,
277            },
278            std::marker::PhantomData,
279            std::marker::PhantomData,
280        ))
281    }
282}
283
284impl<R: Req> Connection<R> {
285    /// Create a pair of unnamed vhost-user connections connected to each other.
286    pub fn pair() -> Result<(Self, Self)> {
287        let (client, server) = UnixStream::pair().map_err(Error::SocketError)?;
288        let client_connection = Connection::try_from(client)?;
289        let server_connection = Connection::try_from(server)?;
290        Ok((client_connection, server_connection))
291    }
292}
293
294impl<S: Frontend> AsRawDescriptor for FrontendServer<S> {
295    fn as_raw_descriptor(&self) -> RawDescriptor {
296        self.sub_sock.as_raw_descriptor()
297    }
298}
299
300impl<S: Frontend> ReadNotifier for FrontendServer<S> {
301    fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
302        self.sub_sock.0.get_read_notifier()
303    }
304}
305
306impl<S: Frontend> FrontendServer<S> {
307    /// Create a `FrontendServer` that uses a Unix stream internally.
308    ///
309    /// The returned `SafeDescriptor` is the client side of the stream and should be sent to the
310    /// backend using [BackendClient::set_slave_request_fd()].
311    ///
312    /// [BackendClient::set_slave_request_fd()]: struct.BackendClient.html#method.set_slave_request_fd
313    pub fn with_stream(backend: S) -> Result<(Self, SafeDescriptor)> {
314        let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?;
315        let rx_connection = Connection::try_from(rx)?;
316        Ok((
317            Self::new(backend, rx_connection)?,
318            SafeDescriptor::from(OwnedFd::from(tx)),
319        ))
320    }
321}
322
323#[cfg(test)]
324pub(crate) mod tests {
325    use tempfile::Builder;
326    use tempfile::TempDir;
327
328    use super::*;
329    use crate::backend_client::BackendClient;
330    use crate::connection::Listener;
331    use crate::Connection;
332
333    pub(crate) fn temp_dir() -> TempDir {
334        Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
335    }
336
337    #[test]
338    fn create_listener() {
339        let dir = temp_dir();
340        let mut path = dir.path().to_owned();
341        path.push("sock");
342        let listener = SocketListener::new(&path, true).unwrap();
343
344        assert!(listener.as_raw_descriptor() > 0);
345    }
346
347    #[test]
348    fn accept_connection() {
349        let dir = temp_dir();
350        let mut path = dir.path().to_owned();
351        path.push("sock");
352        let mut listener = SocketListener::new(&path, true).unwrap();
353        listener.set_nonblocking(true).unwrap();
354
355        // accept on a fd without incoming connection
356        let conn = listener.accept().unwrap();
357        assert!(conn.is_none());
358    }
359
360    #[test]
361    fn test_create_failure() {
362        let dir = temp_dir();
363        let mut path = dir.path().to_owned();
364        path.push("sock");
365        let _ = SocketListener::new(&path, true).unwrap();
366        let _ = SocketListener::new(&path, false).is_err();
367        assert!(UnixStream::connect(&path).is_err());
368
369        let mut listener = SocketListener::new(&path, true).unwrap();
370        assert!(SocketListener::new(&path, false).is_err());
371        listener.set_nonblocking(true).unwrap();
372
373        let sock = UnixStream::connect(&path).unwrap();
374        let backend_connection = Connection::try_from(sock).unwrap();
375        let _backend_client = BackendClient::new(backend_connection);
376        let _server_connection = listener.accept().unwrap().unwrap();
377    }
378
379    #[test]
380    fn test_advance_slices() {
381        // Test case from https://doc.rust-lang.org/std/io/struct.IoSlice.html#method.advance_slices
382        let buf1 = [1; 8];
383        let buf2 = [2; 16];
384        let buf3 = [3; 8];
385        let mut bufs = &mut [&buf1[..], &buf2[..], &buf3[..]][..];
386        advance_slices(&mut bufs, 10);
387        assert_eq!(bufs[0], [2; 14].as_ref());
388        assert_eq!(bufs[1], [3; 8].as_ref());
389    }
390}