1use std::any::Any;
7use std::fs::File;
8use std::io::ErrorKind;
9use std::io::IoSlice;
10use std::io::IoSliceMut;
11use std::os::fd::OwnedFd;
12use std::os::unix::net::UnixListener;
13use std::os::unix::net::UnixStream;
14use std::path::Path;
15use std::path::PathBuf;
16
17use base::AsRawDescriptor;
18use base::RawDescriptor;
19use base::ReadNotifier;
20use base::SafeDescriptor;
21use base::ScmSocket;
22
23use crate::connection::Listener;
24use crate::frontend_server::FrontendServer;
25use crate::message::MAX_ATTACHED_FD_ENTRIES;
26use crate::Connection;
27use crate::Error;
28use crate::Frontend;
29use crate::Result;
30
31pub type SystemListener = UnixListener;
33
34pub use SocketPlatformConnection as PlatformConnection;
35
36pub struct SocketListener {
38 fd: SystemListener,
39 drop_path: Option<Box<dyn Any>>,
40}
41
42impl SocketListener {
43 pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
49 if unlink {
50 let _ = std::fs::remove_file(&path);
51 }
52 let fd = SystemListener::bind(&path).map_err(Error::SocketError)?;
53
54 struct DropPath {
55 path: PathBuf,
56 }
57
58 impl Drop for DropPath {
59 fn drop(&mut self) {
60 let _ = std::fs::remove_file(&self.path);
61 }
62 }
63
64 Ok(SocketListener {
65 fd,
66 drop_path: Some(Box::new(DropPath {
67 path: path.as_ref().to_owned(),
68 })),
69 })
70 }
71
72 pub fn take_resources_for_parent(&mut self) -> Option<Box<dyn Any>> {
75 self.drop_path.take()
76 }
77}
78
79impl Listener for SocketListener {
80 fn accept(&mut self) -> Result<Option<Connection>> {
87 loop {
88 match self.fd.accept() {
89 Ok((stream, _addr)) => {
90 return Ok(Some(Connection::try_from(stream)?));
91 }
92 Err(e) => {
93 match e.kind() {
94 ErrorKind::WouldBlock => return Ok(None),
96 ErrorKind::ConnectionAborted => return Ok(None),
98 ErrorKind::Interrupted => continue,
100 _ => return Err(Error::SocketError(e)),
101 }
102 }
103 }
104 }
105 }
106
107 fn set_nonblocking(&self, block: bool) -> Result<()> {
113 self.fd.set_nonblocking(block).map_err(Error::SocketError)
114 }
115}
116
117impl AsRawDescriptor for SocketListener {
118 fn as_raw_descriptor(&self) -> RawDescriptor {
119 self.fd.as_raw_descriptor()
120 }
121}
122
123pub struct SocketPlatformConnection {
125 sock: ScmSocket<UnixStream>,
126}
127
128fn advance_slices(bufs: &mut &mut [&[u8]], mut count: usize) {
131 use std::mem::take;
132
133 let mut idx = 0;
134 for b in bufs.iter() {
135 if count < b.len() {
136 break;
137 }
138 count -= b.len();
139 idx += 1;
140 }
141 *bufs = &mut take(bufs)[idx..];
142 if !bufs.is_empty() {
143 bufs[0] = &bufs[0][count..];
144 }
145}
146
147impl SocketPlatformConnection {
148 fn send_iovec_all(
157 &self,
158 mut iovs: &mut [&[u8]],
159 mut fds: Option<&[RawDescriptor]>,
160 ) -> Result<()> {
161 advance_slices(&mut iovs, 0);
163
164 while !iovs.is_empty() {
165 let iovec: Vec<_> = iovs.iter_mut().map(|i| IoSlice::new(i)).collect();
166 match self.sock.send_vectored_with_fds(&iovec, fds.unwrap_or(&[])) {
167 Ok(n) => {
168 fds = None;
169 advance_slices(&mut iovs, n);
170 }
171 Err(e) => match e.kind() {
172 ErrorKind::WouldBlock | ErrorKind::Interrupted => {}
173 _ => return Err(Error::SocketError(e)),
174 },
175 }
176 }
177 Ok(())
178 }
179
180 pub fn send_message(
186 &self,
187 hdr: &[u8],
188 body: &[u8],
189 payload: &[u8],
190 fds: Option<&[RawDescriptor]>,
191 ) -> Result<()> {
192 let mut iobufs = [hdr, body, payload];
193 self.send_iovec_all(&mut iobufs, fds)
194 }
195
196 pub fn recv_into_bufs(
216 &self,
217 bufs: &mut [IoSliceMut],
218 allow_fd: bool,
219 ) -> Result<(usize, Option<Vec<File>>)> {
220 let max_fds = if allow_fd { MAX_ATTACHED_FD_ENTRIES } else { 0 };
221 let (bytes, fds) = match self.sock.recv_vectored_with_fds(bufs, max_fds) {
222 Ok((bytes, fds)) => (bytes, fds),
223 Err(e) => {
224 return Err(match e.kind() {
225 ErrorKind::WouldBlock | ErrorKind::Interrupted | ErrorKind::OutOfMemory => {
226 Error::SocketRetry(e)
227 }
228 _ => Error::SocketError(e),
229 });
230 }
231 };
232
233 if bytes == 0 {
235 return Err(Error::Disconnect);
236 }
237
238 let files = if fds.is_empty() {
239 None
240 } else {
241 Some(fds.into_iter().map(File::from).collect())
242 };
243
244 Ok((bytes, files))
245 }
246}
247
248impl AsRawDescriptor for SocketPlatformConnection {
249 fn as_raw_descriptor(&self) -> RawDescriptor {
250 self.sock.as_raw_descriptor()
251 }
252}
253
254impl ReadNotifier for SocketPlatformConnection {
255 fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
256 &self.sock
257 }
258}
259
260impl TryFrom<SafeDescriptor> for Connection {
261 type Error = Error;
262
263 fn try_from(fd: SafeDescriptor) -> Result<Self> {
264 UnixStream::from(fd).try_into()
265 }
266}
267
268impl TryFrom<UnixStream> for Connection {
269 type Error = Error;
270
271 fn try_from(sock: UnixStream) -> Result<Self> {
272 Ok(Self(
273 SocketPlatformConnection {
274 sock: sock.try_into().map_err(Error::SocketError)?,
275 },
276 std::marker::PhantomData,
277 ))
278 }
279}
280
281impl Connection {
282 pub fn pair() -> Result<(Self, Self)> {
284 let (client, server) = UnixStream::pair().map_err(Error::SocketError)?;
285 let client_connection = Connection::try_from(client)?;
286 let server_connection = Connection::try_from(server)?;
287 Ok((client_connection, server_connection))
288 }
289}
290
291impl<S: Frontend> AsRawDescriptor for FrontendServer<S> {
292 fn as_raw_descriptor(&self) -> RawDescriptor {
293 self.sub_sock.as_raw_descriptor()
294 }
295}
296
297impl<S: Frontend> ReadNotifier for FrontendServer<S> {
298 fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
299 self.sub_sock.0.get_read_notifier()
300 }
301}
302
303impl<S: Frontend> FrontendServer<S> {
304 pub fn with_stream(backend: S) -> Result<(Self, SafeDescriptor)> {
311 let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?;
312 let rx_connection = Connection::try_from(rx)?;
313 Ok((
314 Self::new(backend, rx_connection)?,
315 SafeDescriptor::from(OwnedFd::from(tx)),
316 ))
317 }
318}
319
320#[cfg(test)]
321pub(crate) mod tests {
322 use tempfile::Builder;
323 use tempfile::TempDir;
324
325 use super::*;
326 use crate::backend_client::BackendClient;
327 use crate::connection::Listener;
328 use crate::Connection;
329
330 pub(crate) fn temp_dir() -> TempDir {
331 Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
332 }
333
334 #[test]
335 fn create_listener() {
336 let dir = temp_dir();
337 let mut path = dir.path().to_owned();
338 path.push("sock");
339 let listener = SocketListener::new(&path, true).unwrap();
340
341 assert!(listener.as_raw_descriptor() > 0);
342 }
343
344 #[test]
345 fn accept_connection() {
346 let dir = temp_dir();
347 let mut path = dir.path().to_owned();
348 path.push("sock");
349 let mut listener = SocketListener::new(&path, true).unwrap();
350 listener.set_nonblocking(true).unwrap();
351
352 let conn = listener.accept().unwrap();
354 assert!(conn.is_none());
355 }
356
357 #[test]
358 fn test_create_failure() {
359 let dir = temp_dir();
360 let mut path = dir.path().to_owned();
361 path.push("sock");
362 let _ = SocketListener::new(&path, true).unwrap();
363 let _ = SocketListener::new(&path, false).is_err();
364 assert!(UnixStream::connect(&path).is_err());
365
366 let mut listener = SocketListener::new(&path, true).unwrap();
367 assert!(SocketListener::new(&path, false).is_err());
368 listener.set_nonblocking(true).unwrap();
369
370 let sock = UnixStream::connect(&path).unwrap();
371 let backend_connection = Connection::try_from(sock).unwrap();
372 let _backend_client = BackendClient::new(backend_connection);
373 let _server_connection = listener.accept().unwrap().unwrap();
374 }
375
376 #[test]
377 fn test_advance_slices() {
378 let buf1 = [1; 8];
380 let buf2 = [2; 16];
381 let buf3 = [3; 8];
382 let mut bufs = &mut [&buf1[..], &buf2[..], &buf3[..]][..];
383 advance_slices(&mut bufs, 10);
384 assert_eq!(bufs[0], [2; 14].as_ref());
385 assert_eq!(bufs[1], [3; 8].as_ref());
386 }
387}