1use std::any::Any;
7use std::fs::File;
8use std::io::ErrorKind;
9use std::io::IoSlice;
10use std::io::IoSliceMut;
11use std::os::fd::OwnedFd;
12use std::os::unix::net::UnixListener;
13use std::os::unix::net::UnixStream;
14use std::path::Path;
15use std::path::PathBuf;
16
17use base::AsRawDescriptor;
18use base::RawDescriptor;
19use base::ReadNotifier;
20use base::SafeDescriptor;
21use base::ScmSocket;
22
23use crate::connection::Listener;
24use crate::frontend_server::FrontendServer;
25use crate::message::FrontendReq;
26use crate::message::Req;
27use crate::message::MAX_ATTACHED_FD_ENTRIES;
28use crate::Connection;
29use crate::Error;
30use crate::Frontend;
31use crate::Result;
32
33pub type SystemListener = UnixListener;
35
36pub use SocketPlatformConnection as PlatformConnection;
37
38pub struct SocketListener {
40 fd: SystemListener,
41 drop_path: Option<Box<dyn Any>>,
42}
43
44impl SocketListener {
45 pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
51 if unlink {
52 let _ = std::fs::remove_file(&path);
53 }
54 let fd = SystemListener::bind(&path).map_err(Error::SocketError)?;
55
56 struct DropPath {
57 path: PathBuf,
58 }
59
60 impl Drop for DropPath {
61 fn drop(&mut self) {
62 let _ = std::fs::remove_file(&self.path);
63 }
64 }
65
66 Ok(SocketListener {
67 fd,
68 drop_path: Some(Box::new(DropPath {
69 path: path.as_ref().to_owned(),
70 })),
71 })
72 }
73
74 pub fn take_resources_for_parent(&mut self) -> Option<Box<dyn Any>> {
77 self.drop_path.take()
78 }
79}
80
81impl Listener for SocketListener {
82 fn accept(&mut self) -> Result<Option<Connection<FrontendReq>>> {
89 loop {
90 match self.fd.accept() {
91 Ok((stream, _addr)) => {
92 return Ok(Some(Connection::try_from(stream)?));
93 }
94 Err(e) => {
95 match e.kind() {
96 ErrorKind::WouldBlock => return Ok(None),
98 ErrorKind::ConnectionAborted => return Ok(None),
100 ErrorKind::Interrupted => continue,
102 _ => return Err(Error::SocketError(e)),
103 }
104 }
105 }
106 }
107 }
108
109 fn set_nonblocking(&self, block: bool) -> Result<()> {
115 self.fd.set_nonblocking(block).map_err(Error::SocketError)
116 }
117}
118
119impl AsRawDescriptor for SocketListener {
120 fn as_raw_descriptor(&self) -> RawDescriptor {
121 self.fd.as_raw_descriptor()
122 }
123}
124
125pub struct SocketPlatformConnection {
127 sock: ScmSocket<UnixStream>,
128}
129
130fn advance_slices(bufs: &mut &mut [&[u8]], mut count: usize) {
133 use std::mem::take;
134
135 let mut idx = 0;
136 for b in bufs.iter() {
137 if count < b.len() {
138 break;
139 }
140 count -= b.len();
141 idx += 1;
142 }
143 *bufs = &mut take(bufs)[idx..];
144 if !bufs.is_empty() {
145 bufs[0] = &bufs[0][count..];
146 }
147}
148
149impl SocketPlatformConnection {
150 fn send_iovec_all(
159 &self,
160 mut iovs: &mut [&[u8]],
161 mut fds: Option<&[RawDescriptor]>,
162 ) -> Result<()> {
163 advance_slices(&mut iovs, 0);
165
166 while !iovs.is_empty() {
167 let iovec: Vec<_> = iovs.iter_mut().map(|i| IoSlice::new(i)).collect();
168 match self.sock.send_vectored_with_fds(&iovec, fds.unwrap_or(&[])) {
169 Ok(n) => {
170 fds = None;
171 advance_slices(&mut iovs, n);
172 }
173 Err(e) => match e.kind() {
174 ErrorKind::WouldBlock | ErrorKind::Interrupted => {}
175 _ => return Err(Error::SocketError(e)),
176 },
177 }
178 }
179 Ok(())
180 }
181
182 pub fn send_message(
188 &self,
189 hdr: &[u8],
190 body: &[u8],
191 payload: &[u8],
192 fds: Option<&[RawDescriptor]>,
193 ) -> Result<()> {
194 let mut iobufs = [hdr, body, payload];
195 self.send_iovec_all(&mut iobufs, fds)
196 }
197
198 pub fn recv_into_bufs(
218 &self,
219 bufs: &mut [IoSliceMut],
220 allow_fd: bool,
221 ) -> Result<(usize, Option<Vec<File>>)> {
222 let max_fds = if allow_fd { MAX_ATTACHED_FD_ENTRIES } else { 0 };
223 let (bytes, fds) = match self.sock.recv_vectored_with_fds(bufs, max_fds) {
224 Ok((bytes, fds)) => (bytes, fds),
225 Err(e) => {
226 return Err(match e.kind() {
227 ErrorKind::WouldBlock | ErrorKind::Interrupted | ErrorKind::OutOfMemory => {
228 Error::SocketRetry(e)
229 }
230 _ => Error::SocketError(e),
231 });
232 }
233 };
234
235 if bytes == 0 {
237 return Err(Error::Disconnect);
238 }
239
240 let files = if fds.is_empty() {
241 None
242 } else {
243 Some(fds.into_iter().map(File::from).collect())
244 };
245
246 Ok((bytes, files))
247 }
248}
249
250impl AsRawDescriptor for SocketPlatformConnection {
251 fn as_raw_descriptor(&self) -> RawDescriptor {
252 self.sock.as_raw_descriptor()
253 }
254}
255
256impl ReadNotifier for SocketPlatformConnection {
257 fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
258 &self.sock
259 }
260}
261
262impl<R: Req> TryFrom<SafeDescriptor> for Connection<R> {
263 type Error = Error;
264
265 fn try_from(fd: SafeDescriptor) -> Result<Self> {
266 UnixStream::from(fd).try_into()
267 }
268}
269
270impl<R: Req> TryFrom<UnixStream> for Connection<R> {
271 type Error = Error;
272
273 fn try_from(sock: UnixStream) -> Result<Self> {
274 Ok(Self(
275 SocketPlatformConnection {
276 sock: sock.try_into().map_err(Error::SocketError)?,
277 },
278 std::marker::PhantomData,
279 std::marker::PhantomData,
280 ))
281 }
282}
283
284impl<R: Req> Connection<R> {
285 pub fn pair() -> Result<(Self, Self)> {
287 let (client, server) = UnixStream::pair().map_err(Error::SocketError)?;
288 let client_connection = Connection::try_from(client)?;
289 let server_connection = Connection::try_from(server)?;
290 Ok((client_connection, server_connection))
291 }
292}
293
294impl<S: Frontend> AsRawDescriptor for FrontendServer<S> {
295 fn as_raw_descriptor(&self) -> RawDescriptor {
296 self.sub_sock.as_raw_descriptor()
297 }
298}
299
300impl<S: Frontend> ReadNotifier for FrontendServer<S> {
301 fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
302 self.sub_sock.0.get_read_notifier()
303 }
304}
305
306impl<S: Frontend> FrontendServer<S> {
307 pub fn with_stream(backend: S) -> Result<(Self, SafeDescriptor)> {
314 let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?;
315 let rx_connection = Connection::try_from(rx)?;
316 Ok((
317 Self::new(backend, rx_connection)?,
318 SafeDescriptor::from(OwnedFd::from(tx)),
319 ))
320 }
321}
322
323#[cfg(test)]
324pub(crate) mod tests {
325 use tempfile::Builder;
326 use tempfile::TempDir;
327
328 use super::*;
329 use crate::backend_client::BackendClient;
330 use crate::connection::Listener;
331 use crate::Connection;
332
333 pub(crate) fn temp_dir() -> TempDir {
334 Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
335 }
336
337 #[test]
338 fn create_listener() {
339 let dir = temp_dir();
340 let mut path = dir.path().to_owned();
341 path.push("sock");
342 let listener = SocketListener::new(&path, true).unwrap();
343
344 assert!(listener.as_raw_descriptor() > 0);
345 }
346
347 #[test]
348 fn accept_connection() {
349 let dir = temp_dir();
350 let mut path = dir.path().to_owned();
351 path.push("sock");
352 let mut listener = SocketListener::new(&path, true).unwrap();
353 listener.set_nonblocking(true).unwrap();
354
355 let conn = listener.accept().unwrap();
357 assert!(conn.is_none());
358 }
359
360 #[test]
361 fn test_create_failure() {
362 let dir = temp_dir();
363 let mut path = dir.path().to_owned();
364 path.push("sock");
365 let _ = SocketListener::new(&path, true).unwrap();
366 let _ = SocketListener::new(&path, false).is_err();
367 assert!(UnixStream::connect(&path).is_err());
368
369 let mut listener = SocketListener::new(&path, true).unwrap();
370 assert!(SocketListener::new(&path, false).is_err());
371 listener.set_nonblocking(true).unwrap();
372
373 let sock = UnixStream::connect(&path).unwrap();
374 let backend_connection = Connection::try_from(sock).unwrap();
375 let _backend_client = BackendClient::new(backend_connection);
376 let _server_connection = listener.accept().unwrap().unwrap();
377 }
378
379 #[test]
380 fn test_advance_slices() {
381 let buf1 = [1; 8];
383 let buf2 = [2; 16];
384 let buf3 = [3; 8];
385 let mut bufs = &mut [&buf1[..], &buf2[..], &buf3[..]][..];
386 advance_slices(&mut bufs, 10);
387 assert_eq!(bufs[0], [2; 14].as_ref());
388 assert_eq!(bufs[1], [3; 8].as_ref());
389 }
390}