base/sys/unix/
net.rs

1// Copyright 2018 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::cmp::Ordering;
6use std::convert::TryFrom;
7use std::ffi::OsString;
8use std::fs::remove_file;
9use std::io;
10use std::mem;
11use std::mem::size_of;
12use std::net::Ipv4Addr;
13use std::net::Ipv6Addr;
14use std::net::SocketAddr;
15use std::net::SocketAddrV4;
16use std::net::SocketAddrV6;
17use std::net::TcpListener;
18use std::net::TcpStream;
19use std::net::ToSocketAddrs;
20use std::ops::Deref;
21use std::os::fd::OwnedFd;
22use std::os::unix::ffi::OsStringExt;
23use std::path::Path;
24use std::path::PathBuf;
25use std::ptr::null_mut;
26use std::time::Duration;
27use std::time::Instant;
28
29use libc::c_int;
30use libc::recvfrom;
31use libc::sa_family_t;
32use libc::sockaddr;
33use libc::sockaddr_in;
34use libc::sockaddr_in6;
35use libc::socklen_t;
36use libc::AF_INET;
37use libc::AF_INET6;
38use libc::MSG_PEEK;
39use libc::MSG_TRUNC;
40use log::warn;
41use serde::Deserialize;
42use serde::Serialize;
43
44use crate::descriptor::AsRawDescriptor;
45use crate::descriptor::FromRawDescriptor;
46use crate::descriptor::IntoRawDescriptor;
47use crate::handle_eintr_errno;
48use crate::sys::sockaddr_un;
49use crate::sys::sockaddrv4_to_lib_c;
50use crate::sys::sockaddrv6_to_lib_c;
51use crate::Error;
52use crate::RawDescriptor;
53use crate::SafeDescriptor;
54
55/// Assist in handling both IP version 4 and IP version 6.
56#[derive(Debug, Copy, Clone, PartialEq, Eq)]
57pub enum InetVersion {
58    V4,
59    V6,
60}
61
62impl InetVersion {
63    pub fn from_sockaddr(s: &SocketAddr) -> Self {
64        match s {
65            SocketAddr::V4(_) => InetVersion::V4,
66            SocketAddr::V6(_) => InetVersion::V6,
67        }
68    }
69}
70
71impl From<InetVersion> for sa_family_t {
72    fn from(v: InetVersion) -> sa_family_t {
73        match v {
74            InetVersion::V4 => AF_INET as sa_family_t,
75            InetVersion::V6 => AF_INET6 as sa_family_t,
76        }
77    }
78}
79
80pub(in crate::sys) fn socket(
81    domain: c_int,
82    sock_type: c_int,
83    protocol: c_int,
84) -> io::Result<SafeDescriptor> {
85    // SAFETY:
86    // Safe socket initialization since we handle the returned error.
87    match unsafe { libc::socket(domain, sock_type, protocol) } {
88        -1 => Err(io::Error::last_os_error()),
89        // SAFETY:
90        // Safe because we own the file descriptor.
91        fd => Ok(unsafe { SafeDescriptor::from_raw_descriptor(fd) }),
92    }
93}
94
95pub(in crate::sys) fn socketpair(
96    domain: c_int,
97    sock_type: c_int,
98    protocol: c_int,
99) -> io::Result<(SafeDescriptor, SafeDescriptor)> {
100    let mut fds = [0, 0];
101    // SAFETY:
102    // Safe because we give enough space to store all the fds and we check the return value.
103    match unsafe { libc::socketpair(domain, sock_type, protocol, fds.as_mut_ptr()) } {
104        -1 => Err(io::Error::last_os_error()),
105        _ => Ok(
106            // SAFETY:
107            // Safe because we own the file descriptors.
108            unsafe {
109                (
110                    SafeDescriptor::from_raw_descriptor(fds[0]),
111                    SafeDescriptor::from_raw_descriptor(fds[1]),
112                )
113            },
114        ),
115    }
116}
117
118/// A TCP socket.
119///
120/// Do not use this class unless you need to change socket options or query the
121/// state of the socket prior to calling listen or connect. Instead use either TcpStream or
122/// TcpListener.
123#[derive(Debug)]
124pub struct TcpSocket {
125    pub(in crate::sys) inet_version: InetVersion,
126    pub(in crate::sys) descriptor: SafeDescriptor,
127}
128
129impl TcpSocket {
130    pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
131        let sockaddr = addr
132            .to_socket_addrs()
133            .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
134            .next()
135            .unwrap();
136
137        let ret = match sockaddr {
138            SocketAddr::V4(a) => {
139                let sin = sockaddrv4_to_lib_c(&a);
140                // SAFETY:
141                // Safe because this doesn't modify any memory and we check the return value.
142                unsafe {
143                    libc::bind(
144                        self.as_raw_descriptor(),
145                        &sin as *const sockaddr_in as *const sockaddr,
146                        size_of::<sockaddr_in>() as socklen_t,
147                    )
148                }
149            }
150            SocketAddr::V6(a) => {
151                let sin6 = sockaddrv6_to_lib_c(&a);
152                // SAFETY:
153                // Safe because this doesn't modify any memory and we check the return value.
154                unsafe {
155                    libc::bind(
156                        self.as_raw_descriptor(),
157                        &sin6 as *const sockaddr_in6 as *const sockaddr,
158                        size_of::<sockaddr_in6>() as socklen_t,
159                    )
160                }
161            }
162        };
163        if ret < 0 {
164            let bind_err = io::Error::last_os_error();
165            Err(bind_err)
166        } else {
167            Ok(())
168        }
169    }
170
171    pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
172        let sockaddr = addr
173            .to_socket_addrs()
174            .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
175            .next()
176            .unwrap();
177
178        let ret = match sockaddr {
179            SocketAddr::V4(a) => {
180                let sin = sockaddrv4_to_lib_c(&a);
181                // SAFETY:
182                // Safe because this doesn't modify any memory and we check the return value.
183                unsafe {
184                    libc::connect(
185                        self.as_raw_descriptor(),
186                        &sin as *const sockaddr_in as *const sockaddr,
187                        size_of::<sockaddr_in>() as socklen_t,
188                    )
189                }
190            }
191            SocketAddr::V6(a) => {
192                let sin6 = sockaddrv6_to_lib_c(&a);
193                // SAFETY:
194                // Safe because this doesn't modify any memory and we check the return value.
195                unsafe {
196                    libc::connect(
197                        self.as_raw_descriptor(),
198                        &sin6 as *const sockaddr_in6 as *const sockaddr,
199                        size_of::<sockaddr_in>() as socklen_t,
200                    )
201                }
202            }
203        };
204
205        if ret < 0 {
206            let connect_err = io::Error::last_os_error();
207            Err(connect_err)
208        } else {
209            Ok(TcpStream::from(self.descriptor))
210        }
211    }
212
213    pub fn listen(self) -> io::Result<TcpListener> {
214        // SAFETY:
215        // Safe because this doesn't modify any memory and we check the return value.
216        let ret = unsafe { libc::listen(self.as_raw_descriptor(), 1) };
217        if ret < 0 {
218            let listen_err = io::Error::last_os_error();
219            Err(listen_err)
220        } else {
221            Ok(TcpListener::from(self.descriptor))
222        }
223    }
224
225    /// Returns the port that this socket is bound to. This can only succeed after bind is called.
226    pub fn local_port(&self) -> io::Result<u16> {
227        match self.inet_version {
228            InetVersion::V4 => {
229                let mut sin = sockaddrv4_to_lib_c(&SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0));
230
231                let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
232                // SAFETY:
233                // Safe because we give a valid pointer for addrlen and check the length.
234                let ret = unsafe {
235                    // Get the socket address that was actually bound.
236                    libc::getsockname(
237                        self.as_raw_descriptor(),
238                        &mut sin as *mut sockaddr_in as *mut sockaddr,
239                        &mut addrlen as *mut socklen_t,
240                    )
241                };
242                if ret < 0 {
243                    let getsockname_err = io::Error::last_os_error();
244                    Err(getsockname_err)
245                } else {
246                    // If this doesn't match, it's not safe to get the port out of the sockaddr.
247                    assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
248
249                    Ok(u16::from_be(sin.sin_port))
250                }
251            }
252            InetVersion::V6 => {
253                let mut sin6 = sockaddrv6_to_lib_c(&SocketAddrV6::new(
254                    Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
255                    0,
256                    0,
257                    0,
258                ));
259
260                let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
261                // SAFETY:
262                // Safe because we give a valid pointer for addrlen and check the length.
263                let ret = unsafe {
264                    // Get the socket address that was actually bound.
265                    libc::getsockname(
266                        self.as_raw_descriptor(),
267                        &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
268                        &mut addrlen as *mut socklen_t,
269                    )
270                };
271                if ret < 0 {
272                    let getsockname_err = io::Error::last_os_error();
273                    Err(getsockname_err)
274                } else {
275                    // If this doesn't match, it's not safe to get the port out of the sockaddr.
276                    assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
277
278                    Ok(u16::from_be(sin6.sin6_port))
279                }
280            }
281        }
282    }
283}
284
285impl AsRawDescriptor for TcpSocket {
286    fn as_raw_descriptor(&self) -> RawDescriptor {
287        self.descriptor.as_raw_descriptor()
288    }
289}
290
291// Offset of sun_path in structure sockaddr_un.
292pub(in crate::sys) fn sun_path_offset() -> usize {
293    std::mem::offset_of!(libc::sockaddr_un, sun_path)
294}
295
296/// A Unix `SOCK_SEQPACKET` socket point to given `path`
297#[derive(Debug, Serialize, Deserialize)]
298pub struct UnixSeqpacket(SafeDescriptor);
299
300impl UnixSeqpacket {
301    /// Open a `SOCK_SEQPACKET` connection to socket named by `path`.
302    ///
303    /// # Arguments
304    /// * `path` - Path to `SOCK_SEQPACKET` socket
305    ///
306    /// # Returns
307    /// A `UnixSeqpacket` structure point to the socket
308    ///
309    /// # Errors
310    /// Return `io::Error` when error occurs.
311    pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
312        let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
313        let (addr, len) = sockaddr_un(path.as_ref())?;
314        // SAFETY:
315        // Safe connect since we handle the error and use the right length generated from
316        // `sockaddr_un`.
317        unsafe {
318            let ret = libc::connect(
319                descriptor.as_raw_descriptor(),
320                &addr as *const _ as *const _,
321                len,
322            );
323            if ret < 0 {
324                return Err(io::Error::last_os_error());
325            }
326        }
327        Ok(UnixSeqpacket(descriptor))
328    }
329
330    /// Clone the underlying FD.
331    pub fn try_clone(&self) -> io::Result<Self> {
332        Ok(Self(self.0.try_clone()?))
333    }
334
335    /// Gets the number of bytes that can be read from this socket without blocking.
336    pub fn get_readable_bytes(&self) -> io::Result<usize> {
337        let mut byte_count = 0i32;
338        // SAFETY:
339        // Safe because self has valid raw descriptor and return value are checked.
340        let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONREAD, &mut byte_count) };
341        if ret < 0 {
342            Err(io::Error::last_os_error())
343        } else {
344            Ok(byte_count as usize)
345        }
346    }
347
348    /// Gets the number of bytes in the next packet. This blocks as if `recv` were called,
349    /// respecting the blocking and timeout settings of the underlying socket.
350    pub fn next_packet_size(&self) -> io::Result<usize> {
351        #[cfg(not(debug_assertions))]
352        let buf = null_mut();
353        // Work around for qemu's syscall translation which will reject null pointers in recvfrom.
354        // This only matters for running the unit tests for a non-native architecture. See the
355        // upstream thread for the qemu fix:
356        // https://lists.nongnu.org/archive/html/qemu-devel/2021-03/msg09027.html
357        #[cfg(debug_assertions)]
358        let buf = &mut 0 as *mut _ as *mut _;
359
360        // SAFETY:
361        // This form of recvfrom doesn't modify any data because all null pointers are used. We only
362        // use the return value and check for errors on an FD owned by this structure.
363        let ret = unsafe {
364            recvfrom(
365                self.as_raw_descriptor(),
366                buf,
367                0,
368                MSG_TRUNC | MSG_PEEK,
369                null_mut(),
370                null_mut(),
371            )
372        };
373        if ret < 0 {
374            Err(io::Error::last_os_error())
375        } else {
376            Ok(ret as usize)
377        }
378    }
379
380    /// Write data from a given buffer to the socket fd
381    ///
382    /// # Arguments
383    /// * `buf` - A reference to the data buffer.
384    ///
385    /// # Returns
386    /// * `usize` - The size of bytes written to the buffer.
387    ///
388    /// # Errors
389    /// Returns error when `libc::write` failed.
390    pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
391        // SAFETY:
392        // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
393        unsafe {
394            let ret = libc::write(
395                self.as_raw_descriptor(),
396                buf.as_ptr() as *const _,
397                buf.len(),
398            );
399            if ret < 0 {
400                Err(io::Error::last_os_error())
401            } else {
402                Ok(ret as usize)
403            }
404        }
405    }
406
407    /// Read data from the socket fd to a given buffer
408    ///
409    /// # Arguments
410    /// * `buf` - A mut reference to the data buffer.
411    ///
412    /// # Returns
413    /// * `usize` - The size of bytes read to the buffer.
414    ///
415    /// # Errors
416    /// Returns error when `libc::read` failed.
417    pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
418        // SAFETY:
419        // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
420        unsafe {
421            let ret = libc::read(
422                self.as_raw_descriptor(),
423                buf.as_mut_ptr() as *mut _,
424                buf.len(),
425            );
426            if ret < 0 {
427                Err(io::Error::last_os_error())
428            } else {
429                Ok(ret as usize)
430            }
431        }
432    }
433
434    /// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size.
435    ///
436    /// # Arguments
437    /// * `buf` - A mut reference to a `Vec` to resize and read into.
438    ///
439    /// # Errors
440    /// Returns error when `libc::read` or `get_readable_bytes` failed.
441    pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
442        let packet_size = self.next_packet_size()?;
443        buf.resize(packet_size, 0);
444        let read_bytes = self.recv(buf)?;
445        buf.resize(read_bytes, 0);
446        Ok(())
447    }
448
449    /// Read data from the socket fd to a new `Vec`.
450    ///
451    /// # Returns
452    /// * `vec` - A new `Vec` with the entire received packet.
453    ///
454    /// # Errors
455    /// Returns error when `libc::read` or `get_readable_bytes` failed.
456    pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
457        let mut buf = Vec::new();
458        self.recv_to_vec(&mut buf)?;
459        Ok(buf)
460    }
461
462    #[allow(clippy::useless_conversion)]
463    fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
464        let timeval = match timeout {
465            Some(t) => {
466                if t.as_secs() == 0 && t.subsec_micros() == 0 {
467                    return Err(io::Error::new(
468                        io::ErrorKind::InvalidInput,
469                        "zero timeout duration is invalid",
470                    ));
471                }
472                // subsec_micros fits in i32 because it is defined to be less than one million.
473                let nsec = t.subsec_micros() as i32;
474                libc::timeval {
475                    tv_sec: t.as_secs() as libc::time_t,
476                    tv_usec: libc::suseconds_t::from(nsec),
477                }
478            }
479            None => libc::timeval {
480                tv_sec: 0,
481                tv_usec: 0,
482            },
483        };
484        // SAFETY:
485        // Safe because we own the fd, and the length of the pointer's data is the same as the
486        // passed in length parameter. The level argument is valid, the kind is assumed to be valid,
487        // and the return value is checked.
488        let ret = unsafe {
489            libc::setsockopt(
490                self.as_raw_descriptor(),
491                libc::SOL_SOCKET,
492                kind,
493                &timeval as *const libc::timeval as *const libc::c_void,
494                mem::size_of::<libc::timeval>() as libc::socklen_t,
495            )
496        };
497        if ret < 0 {
498            Err(io::Error::last_os_error())
499        } else {
500            Ok(())
501        }
502    }
503
504    /// Sets or removes the timeout for read/recv operations on this socket.
505    pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
506        self.set_timeout(timeout, libc::SO_RCVTIMEO)
507    }
508
509    /// Sets or removes the timeout for write/send operations on this socket.
510    pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
511        self.set_timeout(timeout, libc::SO_SNDTIMEO)
512    }
513
514    /// Sets the blocking mode for this socket.
515    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
516        let mut nonblocking = nonblocking as libc::c_int;
517        // SAFETY:
518        // Safe because the return value is checked, and this ioctl call sets the nonblocking mode
519        // and does not continue holding the file descriptor after the call.
520        let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
521        if ret < 0 {
522            Err(io::Error::last_os_error())
523        } else {
524            Ok(())
525        }
526    }
527}
528
529impl From<UnixSeqpacket> for SafeDescriptor {
530    fn from(s: UnixSeqpacket) -> Self {
531        s.0
532    }
533}
534
535impl From<SafeDescriptor> for UnixSeqpacket {
536    fn from(s: SafeDescriptor) -> Self {
537        Self(s)
538    }
539}
540
541impl FromRawDescriptor for UnixSeqpacket {
542    unsafe fn from_raw_descriptor(descriptor: RawDescriptor) -> Self {
543        Self(SafeDescriptor::from_raw_descriptor(descriptor))
544    }
545}
546
547impl AsRawDescriptor for UnixSeqpacket {
548    fn as_raw_descriptor(&self) -> RawDescriptor {
549        self.0.as_raw_descriptor()
550    }
551}
552
553impl IntoRawDescriptor for UnixSeqpacket {
554    fn into_raw_descriptor(self) -> RawDescriptor {
555        self.0.into_raw_descriptor()
556    }
557}
558
559impl io::Read for UnixSeqpacket {
560    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
561        self.recv(buf)
562    }
563}
564
565impl io::Write for UnixSeqpacket {
566    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
567        self.send(buf)
568    }
569
570    fn flush(&mut self) -> io::Result<()> {
571        Ok(())
572    }
573}
574
575/// Like a `UnixListener` but for accepting `UnixSeqpacket` type sockets.
576pub struct UnixSeqpacketListener {
577    descriptor: SafeDescriptor,
578    no_path: bool,
579}
580
581impl UnixSeqpacketListener {
582    /// Creates a new `UnixSeqpacketListener` bound to the given path.
583    pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Self> {
584        if path.as_ref().starts_with("/proc/self/fd/") {
585            let fd = path
586                .as_ref()
587                .file_name()
588                .expect("Failed to get fd filename")
589                .to_str()
590                .expect("fd filename should be unicode")
591                .parse::<i32>()
592                .expect("fd should be an integer");
593            let mut result: c_int = 0;
594            let mut result_len = size_of::<c_int>() as libc::socklen_t;
595            // SAFETY: Safe because fd and other args are valid and the return value is checked.
596            let ret = unsafe {
597                libc::getsockopt(
598                    fd,
599                    libc::SOL_SOCKET,
600                    libc::SO_ACCEPTCONN,
601                    &mut result as *mut _ as *mut libc::c_void,
602                    &mut result_len,
603                )
604            };
605            if ret < 0 {
606                return Err(io::Error::last_os_error());
607            }
608            if result != 1 {
609                return Err(io::Error::new(
610                    io::ErrorKind::InvalidInput,
611                    "specified descriptor is not a listening socket",
612                ));
613            }
614            // SAFETY:
615            // Safe because we validated the socket file descriptor.
616            let descriptor = unsafe { SafeDescriptor::from_raw_descriptor(fd) };
617            return Ok(UnixSeqpacketListener {
618                descriptor,
619                no_path: true,
620            });
621        }
622
623        let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
624        let (addr, len) = sockaddr_un(path.as_ref())?;
625
626        // SAFETY:
627        // Safe connect since we handle the error and use the right length generated from
628        // `sockaddr_un`.
629        unsafe {
630            let ret = handle_eintr_errno!(libc::bind(
631                descriptor.as_raw_descriptor(),
632                &addr as *const _ as *const _,
633                len
634            ));
635            if ret < 0 {
636                return Err(io::Error::last_os_error());
637            }
638            let ret = handle_eintr_errno!(libc::listen(descriptor.as_raw_descriptor(), 128));
639            if ret < 0 {
640                return Err(io::Error::last_os_error());
641            }
642        }
643        Ok(UnixSeqpacketListener {
644            descriptor,
645            no_path: false,
646        })
647    }
648
649    pub fn accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket> {
650        let start = Instant::now();
651
652        loop {
653            let mut fds = libc::pollfd {
654                fd: self.as_raw_descriptor(),
655                events: libc::POLLIN,
656                revents: 0,
657            };
658            let elapsed = Instant::now().saturating_duration_since(start);
659            let remaining = timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO);
660            let cur_timeout_ms = i32::try_from(remaining.as_millis()).unwrap_or(i32::MAX);
661            // SAFETY:
662            // Safe because we give a valid pointer to a list (of 1) FD and we check
663            // the return value.
664            match unsafe { libc::poll(&mut fds, 1, cur_timeout_ms) }.cmp(&0) {
665                Ordering::Greater => return self.accept(),
666                Ordering::Equal => return Err(io::Error::from_raw_os_error(libc::ETIMEDOUT)),
667                Ordering::Less => {
668                    if Error::last() != Error::new(libc::EINTR) {
669                        return Err(io::Error::last_os_error());
670                    }
671                }
672            }
673        }
674    }
675
676    /// Gets the path that this listener is bound to.
677    pub fn path(&self) -> io::Result<PathBuf> {
678        let mut addr = sockaddr_un(Path::new(""))?.0;
679        if self.no_path {
680            return Err(io::Error::new(
681                io::ErrorKind::InvalidInput,
682                "socket has no path",
683            ));
684        }
685        let sun_path_offset = (&addr.sun_path as *const _ as usize
686            - &addr.sun_family as *const _ as usize)
687            as libc::socklen_t;
688        let mut len = mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
689        // SAFETY:
690        // Safe because the length given matches the length of the data of the given pointer, and we
691        // check the return value.
692        let ret = unsafe {
693            handle_eintr_errno!(libc::getsockname(
694                self.as_raw_descriptor(),
695                &mut addr as *mut libc::sockaddr_un as *mut libc::sockaddr,
696                &mut len
697            ))
698        };
699        if ret < 0 {
700            return Err(io::Error::last_os_error());
701        }
702        if addr.sun_family != libc::AF_UNIX as libc::sa_family_t
703            || addr.sun_path[0] == 0
704            || len < 1 + sun_path_offset
705        {
706            return Err(io::Error::new(
707                io::ErrorKind::InvalidInput,
708                "getsockname on socket returned invalid value",
709            ));
710        }
711
712        let path_os_str = OsString::from_vec(
713            addr.sun_path[..(len - sun_path_offset - 1) as usize]
714                .iter()
715                .map(|&c| c as _)
716                .collect(),
717        );
718        Ok(path_os_str.into())
719    }
720
721    /// Sets the blocking mode for this socket.
722    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
723        let mut nonblocking = nonblocking as libc::c_int;
724        // SAFETY:
725        // Safe because the return value is checked, and this ioctl call sets the nonblocking mode
726        // and does not continue holding the file descriptor after the call.
727        let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
728        if ret < 0 {
729            Err(io::Error::last_os_error())
730        } else {
731            Ok(())
732        }
733    }
734}
735
736impl AsRawDescriptor for UnixSeqpacketListener {
737    fn as_raw_descriptor(&self) -> RawDescriptor {
738        self.descriptor.as_raw_descriptor()
739    }
740}
741
742impl From<UnixSeqpacketListener> for OwnedFd {
743    fn from(val: UnixSeqpacketListener) -> Self {
744        val.descriptor.into()
745    }
746}
747
748/// Used to attempt to clean up a `UnixSeqpacketListener` after it is dropped.
749pub struct UnlinkUnixSeqpacketListener(pub UnixSeqpacketListener);
750
751impl AsRawDescriptor for UnlinkUnixSeqpacketListener {
752    fn as_raw_descriptor(&self) -> RawDescriptor {
753        self.0.as_raw_descriptor()
754    }
755}
756
757impl AsRef<UnixSeqpacketListener> for UnlinkUnixSeqpacketListener {
758    fn as_ref(&self) -> &UnixSeqpacketListener {
759        &self.0
760    }
761}
762
763impl Deref for UnlinkUnixSeqpacketListener {
764    type Target = UnixSeqpacketListener;
765    fn deref(&self) -> &Self::Target {
766        &self.0
767    }
768}
769
770impl Drop for UnlinkUnixSeqpacketListener {
771    fn drop(&mut self) {
772        if let Ok(path) = self.0.path() {
773            if let Err(e) = remove_file(path) {
774                warn!("failed to remove control socket file: {:?}", e);
775            }
776        }
777    }
778}
779
780#[cfg(test)]
781mod tests {
782    use super::*;
783
784    #[test]
785    fn sockaddr_un_zero_length_input() {
786        let _res = sockaddr_un(Path::new("")).expect("sockaddr_un failed");
787    }
788
789    #[test]
790    fn sockaddr_un_long_input_err() {
791        let res = sockaddr_un(Path::new(&"a".repeat(108)));
792        assert!(res.is_err());
793    }
794
795    #[test]
796    fn sockaddr_un_long_input_pass() {
797        let _res = sockaddr_un(Path::new(&"a".repeat(107))).expect("sockaddr_un failed");
798    }
799
800    #[test]
801    fn sockaddr_un_len_check() {
802        let (_addr, len) = sockaddr_un(Path::new(&"a".repeat(50))).expect("sockaddr_un failed");
803        assert_eq!(len, (sun_path_offset() + 50 + 1) as u32);
804    }
805
806    #[test]
807    #[allow(clippy::unnecessary_cast)]
808    // c_char is u8 on aarch64 and i8 on x86, so clippy's suggested fix of changing
809    // `'a' as libc::c_char` below to `b'a'` won't work everywhere.
810    #[allow(clippy::char_lit_as_u8)]
811    fn sockaddr_un_pass() {
812        let path_size = 50;
813        let (addr, len) =
814            sockaddr_un(Path::new(&"a".repeat(path_size))).expect("sockaddr_un failed");
815        assert_eq!(len, (sun_path_offset() + path_size + 1) as u32);
816        assert_eq!(addr.sun_family, libc::AF_UNIX as libc::sa_family_t);
817
818        // Check `sun_path` in returned `sockaddr_un`
819        let mut ref_sun_path = [0 as libc::c_char; 108];
820        for path in ref_sun_path.iter_mut().take(path_size) {
821            *path = 'a' as libc::c_char;
822        }
823
824        for (addr_char, ref_char) in addr.sun_path.iter().zip(ref_sun_path.iter()) {
825            assert_eq!(addr_char, ref_char);
826        }
827    }
828}