1use std::cmp::Ordering;
6use std::convert::TryFrom;
7use std::ffi::OsString;
8use std::fs::remove_file;
9use std::io;
10use std::mem;
11use std::mem::size_of;
12use std::net::Ipv4Addr;
13use std::net::Ipv6Addr;
14use std::net::SocketAddr;
15use std::net::SocketAddrV4;
16use std::net::SocketAddrV6;
17use std::net::TcpListener;
18use std::net::TcpStream;
19use std::net::ToSocketAddrs;
20use std::ops::Deref;
21use std::os::fd::OwnedFd;
22use std::os::unix::ffi::OsStringExt;
23use std::path::Path;
24use std::path::PathBuf;
25use std::ptr::null_mut;
26use std::time::Duration;
27use std::time::Instant;
28
29use libc::c_int;
30use libc::recvfrom;
31use libc::sa_family_t;
32use libc::sockaddr;
33use libc::sockaddr_in;
34use libc::sockaddr_in6;
35use libc::socklen_t;
36use libc::AF_INET;
37use libc::AF_INET6;
38use libc::MSG_PEEK;
39use libc::MSG_TRUNC;
40use log::warn;
41use serde::Deserialize;
42use serde::Serialize;
43
44use crate::descriptor::AsRawDescriptor;
45use crate::descriptor::FromRawDescriptor;
46use crate::descriptor::IntoRawDescriptor;
47use crate::handle_eintr_errno;
48use crate::sys::sockaddr_un;
49use crate::sys::sockaddrv4_to_lib_c;
50use crate::sys::sockaddrv6_to_lib_c;
51use crate::Error;
52use crate::RawDescriptor;
53use crate::SafeDescriptor;
54
55#[derive(Debug, Copy, Clone, PartialEq, Eq)]
57pub enum InetVersion {
58 V4,
59 V6,
60}
61
62impl InetVersion {
63 pub fn from_sockaddr(s: &SocketAddr) -> Self {
64 match s {
65 SocketAddr::V4(_) => InetVersion::V4,
66 SocketAddr::V6(_) => InetVersion::V6,
67 }
68 }
69}
70
71impl From<InetVersion> for sa_family_t {
72 fn from(v: InetVersion) -> sa_family_t {
73 match v {
74 InetVersion::V4 => AF_INET as sa_family_t,
75 InetVersion::V6 => AF_INET6 as sa_family_t,
76 }
77 }
78}
79
80pub(in crate::sys) fn socket(
81 domain: c_int,
82 sock_type: c_int,
83 protocol: c_int,
84) -> io::Result<SafeDescriptor> {
85 match unsafe { libc::socket(domain, sock_type, protocol) } {
88 -1 => Err(io::Error::last_os_error()),
89 fd => Ok(unsafe { SafeDescriptor::from_raw_descriptor(fd) }),
92 }
93}
94
95pub(in crate::sys) fn socketpair(
96 domain: c_int,
97 sock_type: c_int,
98 protocol: c_int,
99) -> io::Result<(SafeDescriptor, SafeDescriptor)> {
100 let mut fds = [0, 0];
101 match unsafe { libc::socketpair(domain, sock_type, protocol, fds.as_mut_ptr()) } {
104 -1 => Err(io::Error::last_os_error()),
105 _ => Ok(
106 unsafe {
109 (
110 SafeDescriptor::from_raw_descriptor(fds[0]),
111 SafeDescriptor::from_raw_descriptor(fds[1]),
112 )
113 },
114 ),
115 }
116}
117
118#[derive(Debug)]
124pub struct TcpSocket {
125 pub(in crate::sys) inet_version: InetVersion,
126 pub(in crate::sys) descriptor: SafeDescriptor,
127}
128
129impl TcpSocket {
130 pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
131 let sockaddr = addr
132 .to_socket_addrs()
133 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
134 .next()
135 .unwrap();
136
137 let ret = match sockaddr {
138 SocketAddr::V4(a) => {
139 let sin = sockaddrv4_to_lib_c(&a);
140 unsafe {
143 libc::bind(
144 self.as_raw_descriptor(),
145 &sin as *const sockaddr_in as *const sockaddr,
146 size_of::<sockaddr_in>() as socklen_t,
147 )
148 }
149 }
150 SocketAddr::V6(a) => {
151 let sin6 = sockaddrv6_to_lib_c(&a);
152 unsafe {
155 libc::bind(
156 self.as_raw_descriptor(),
157 &sin6 as *const sockaddr_in6 as *const sockaddr,
158 size_of::<sockaddr_in6>() as socklen_t,
159 )
160 }
161 }
162 };
163 if ret < 0 {
164 let bind_err = io::Error::last_os_error();
165 Err(bind_err)
166 } else {
167 Ok(())
168 }
169 }
170
171 pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
172 let sockaddr = addr
173 .to_socket_addrs()
174 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
175 .next()
176 .unwrap();
177
178 let ret = match sockaddr {
179 SocketAddr::V4(a) => {
180 let sin = sockaddrv4_to_lib_c(&a);
181 unsafe {
184 libc::connect(
185 self.as_raw_descriptor(),
186 &sin as *const sockaddr_in as *const sockaddr,
187 size_of::<sockaddr_in>() as socklen_t,
188 )
189 }
190 }
191 SocketAddr::V6(a) => {
192 let sin6 = sockaddrv6_to_lib_c(&a);
193 unsafe {
196 libc::connect(
197 self.as_raw_descriptor(),
198 &sin6 as *const sockaddr_in6 as *const sockaddr,
199 size_of::<sockaddr_in>() as socklen_t,
200 )
201 }
202 }
203 };
204
205 if ret < 0 {
206 let connect_err = io::Error::last_os_error();
207 Err(connect_err)
208 } else {
209 Ok(TcpStream::from(self.descriptor))
210 }
211 }
212
213 pub fn listen(self) -> io::Result<TcpListener> {
214 let ret = unsafe { libc::listen(self.as_raw_descriptor(), 1) };
217 if ret < 0 {
218 let listen_err = io::Error::last_os_error();
219 Err(listen_err)
220 } else {
221 Ok(TcpListener::from(self.descriptor))
222 }
223 }
224
225 pub fn local_port(&self) -> io::Result<u16> {
227 match self.inet_version {
228 InetVersion::V4 => {
229 let mut sin = sockaddrv4_to_lib_c(&SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0));
230
231 let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
232 let ret = unsafe {
235 libc::getsockname(
237 self.as_raw_descriptor(),
238 &mut sin as *mut sockaddr_in as *mut sockaddr,
239 &mut addrlen as *mut socklen_t,
240 )
241 };
242 if ret < 0 {
243 let getsockname_err = io::Error::last_os_error();
244 Err(getsockname_err)
245 } else {
246 assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
248
249 Ok(u16::from_be(sin.sin_port))
250 }
251 }
252 InetVersion::V6 => {
253 let mut sin6 = sockaddrv6_to_lib_c(&SocketAddrV6::new(
254 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
255 0,
256 0,
257 0,
258 ));
259
260 let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
261 let ret = unsafe {
264 libc::getsockname(
266 self.as_raw_descriptor(),
267 &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
268 &mut addrlen as *mut socklen_t,
269 )
270 };
271 if ret < 0 {
272 let getsockname_err = io::Error::last_os_error();
273 Err(getsockname_err)
274 } else {
275 assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
277
278 Ok(u16::from_be(sin6.sin6_port))
279 }
280 }
281 }
282 }
283}
284
285impl AsRawDescriptor for TcpSocket {
286 fn as_raw_descriptor(&self) -> RawDescriptor {
287 self.descriptor.as_raw_descriptor()
288 }
289}
290
291pub(in crate::sys) fn sun_path_offset() -> usize {
293 std::mem::offset_of!(libc::sockaddr_un, sun_path)
294}
295
296#[derive(Debug, Serialize, Deserialize)]
298pub struct UnixSeqpacket(SafeDescriptor);
299
300impl UnixSeqpacket {
301 pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
312 let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
313 let (addr, len) = sockaddr_un(path.as_ref())?;
314 unsafe {
318 let ret = libc::connect(
319 descriptor.as_raw_descriptor(),
320 &addr as *const _ as *const _,
321 len,
322 );
323 if ret < 0 {
324 return Err(io::Error::last_os_error());
325 }
326 }
327 Ok(UnixSeqpacket(descriptor))
328 }
329
330 pub fn try_clone(&self) -> io::Result<Self> {
332 Ok(Self(self.0.try_clone()?))
333 }
334
335 pub fn get_readable_bytes(&self) -> io::Result<usize> {
337 let mut byte_count = 0i32;
338 let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONREAD, &mut byte_count) };
341 if ret < 0 {
342 Err(io::Error::last_os_error())
343 } else {
344 Ok(byte_count as usize)
345 }
346 }
347
348 pub fn next_packet_size(&self) -> io::Result<usize> {
351 #[cfg(not(debug_assertions))]
352 let buf = null_mut();
353 #[cfg(debug_assertions)]
358 let buf = &mut 0 as *mut _ as *mut _;
359
360 let ret = unsafe {
364 recvfrom(
365 self.as_raw_descriptor(),
366 buf,
367 0,
368 MSG_TRUNC | MSG_PEEK,
369 null_mut(),
370 null_mut(),
371 )
372 };
373 if ret < 0 {
374 Err(io::Error::last_os_error())
375 } else {
376 Ok(ret as usize)
377 }
378 }
379
380 pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
391 unsafe {
394 let ret = libc::write(
395 self.as_raw_descriptor(),
396 buf.as_ptr() as *const _,
397 buf.len(),
398 );
399 if ret < 0 {
400 Err(io::Error::last_os_error())
401 } else {
402 Ok(ret as usize)
403 }
404 }
405 }
406
407 pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
418 unsafe {
421 let ret = libc::read(
422 self.as_raw_descriptor(),
423 buf.as_mut_ptr() as *mut _,
424 buf.len(),
425 );
426 if ret < 0 {
427 Err(io::Error::last_os_error())
428 } else {
429 Ok(ret as usize)
430 }
431 }
432 }
433
434 pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
442 let packet_size = self.next_packet_size()?;
443 buf.resize(packet_size, 0);
444 let read_bytes = self.recv(buf)?;
445 buf.resize(read_bytes, 0);
446 Ok(())
447 }
448
449 pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
457 let mut buf = Vec::new();
458 self.recv_to_vec(&mut buf)?;
459 Ok(buf)
460 }
461
462 #[allow(clippy::useless_conversion)]
463 fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
464 let timeval = match timeout {
465 Some(t) => {
466 if t.as_secs() == 0 && t.subsec_micros() == 0 {
467 return Err(io::Error::new(
468 io::ErrorKind::InvalidInput,
469 "zero timeout duration is invalid",
470 ));
471 }
472 let nsec = t.subsec_micros() as i32;
474 libc::timeval {
475 tv_sec: t.as_secs() as libc::time_t,
476 tv_usec: libc::suseconds_t::from(nsec),
477 }
478 }
479 None => libc::timeval {
480 tv_sec: 0,
481 tv_usec: 0,
482 },
483 };
484 let ret = unsafe {
489 libc::setsockopt(
490 self.as_raw_descriptor(),
491 libc::SOL_SOCKET,
492 kind,
493 &timeval as *const libc::timeval as *const libc::c_void,
494 mem::size_of::<libc::timeval>() as libc::socklen_t,
495 )
496 };
497 if ret < 0 {
498 Err(io::Error::last_os_error())
499 } else {
500 Ok(())
501 }
502 }
503
504 pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
506 self.set_timeout(timeout, libc::SO_RCVTIMEO)
507 }
508
509 pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
511 self.set_timeout(timeout, libc::SO_SNDTIMEO)
512 }
513
514 pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
516 let mut nonblocking = nonblocking as libc::c_int;
517 let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
521 if ret < 0 {
522 Err(io::Error::last_os_error())
523 } else {
524 Ok(())
525 }
526 }
527}
528
529impl From<UnixSeqpacket> for SafeDescriptor {
530 fn from(s: UnixSeqpacket) -> Self {
531 s.0
532 }
533}
534
535impl From<SafeDescriptor> for UnixSeqpacket {
536 fn from(s: SafeDescriptor) -> Self {
537 Self(s)
538 }
539}
540
541impl FromRawDescriptor for UnixSeqpacket {
542 unsafe fn from_raw_descriptor(descriptor: RawDescriptor) -> Self {
543 Self(SafeDescriptor::from_raw_descriptor(descriptor))
544 }
545}
546
547impl AsRawDescriptor for UnixSeqpacket {
548 fn as_raw_descriptor(&self) -> RawDescriptor {
549 self.0.as_raw_descriptor()
550 }
551}
552
553impl IntoRawDescriptor for UnixSeqpacket {
554 fn into_raw_descriptor(self) -> RawDescriptor {
555 self.0.into_raw_descriptor()
556 }
557}
558
559impl io::Read for UnixSeqpacket {
560 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
561 self.recv(buf)
562 }
563}
564
565impl io::Write for UnixSeqpacket {
566 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
567 self.send(buf)
568 }
569
570 fn flush(&mut self) -> io::Result<()> {
571 Ok(())
572 }
573}
574
575pub struct UnixSeqpacketListener {
577 descriptor: SafeDescriptor,
578 no_path: bool,
579}
580
581impl UnixSeqpacketListener {
582 pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Self> {
584 if path.as_ref().starts_with("/proc/self/fd/") {
585 let fd = path
586 .as_ref()
587 .file_name()
588 .expect("Failed to get fd filename")
589 .to_str()
590 .expect("fd filename should be unicode")
591 .parse::<i32>()
592 .expect("fd should be an integer");
593 let mut result: c_int = 0;
594 let mut result_len = size_of::<c_int>() as libc::socklen_t;
595 let ret = unsafe {
597 libc::getsockopt(
598 fd,
599 libc::SOL_SOCKET,
600 libc::SO_ACCEPTCONN,
601 &mut result as *mut _ as *mut libc::c_void,
602 &mut result_len,
603 )
604 };
605 if ret < 0 {
606 return Err(io::Error::last_os_error());
607 }
608 if result != 1 {
609 return Err(io::Error::new(
610 io::ErrorKind::InvalidInput,
611 "specified descriptor is not a listening socket",
612 ));
613 }
614 let descriptor = unsafe { SafeDescriptor::from_raw_descriptor(fd) };
617 return Ok(UnixSeqpacketListener {
618 descriptor,
619 no_path: true,
620 });
621 }
622
623 let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
624 let (addr, len) = sockaddr_un(path.as_ref())?;
625
626 unsafe {
630 let ret = handle_eintr_errno!(libc::bind(
631 descriptor.as_raw_descriptor(),
632 &addr as *const _ as *const _,
633 len
634 ));
635 if ret < 0 {
636 return Err(io::Error::last_os_error());
637 }
638 let ret = handle_eintr_errno!(libc::listen(descriptor.as_raw_descriptor(), 128));
639 if ret < 0 {
640 return Err(io::Error::last_os_error());
641 }
642 }
643 Ok(UnixSeqpacketListener {
644 descriptor,
645 no_path: false,
646 })
647 }
648
649 pub fn accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket> {
650 let start = Instant::now();
651
652 loop {
653 let mut fds = libc::pollfd {
654 fd: self.as_raw_descriptor(),
655 events: libc::POLLIN,
656 revents: 0,
657 };
658 let elapsed = Instant::now().saturating_duration_since(start);
659 let remaining = timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO);
660 let cur_timeout_ms = i32::try_from(remaining.as_millis()).unwrap_or(i32::MAX);
661 match unsafe { libc::poll(&mut fds, 1, cur_timeout_ms) }.cmp(&0) {
665 Ordering::Greater => return self.accept(),
666 Ordering::Equal => return Err(io::Error::from_raw_os_error(libc::ETIMEDOUT)),
667 Ordering::Less => {
668 if Error::last() != Error::new(libc::EINTR) {
669 return Err(io::Error::last_os_error());
670 }
671 }
672 }
673 }
674 }
675
676 pub fn path(&self) -> io::Result<PathBuf> {
678 let mut addr = sockaddr_un(Path::new(""))?.0;
679 if self.no_path {
680 return Err(io::Error::new(
681 io::ErrorKind::InvalidInput,
682 "socket has no path",
683 ));
684 }
685 let sun_path_offset = (&addr.sun_path as *const _ as usize
686 - &addr.sun_family as *const _ as usize)
687 as libc::socklen_t;
688 let mut len = mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
689 let ret = unsafe {
693 handle_eintr_errno!(libc::getsockname(
694 self.as_raw_descriptor(),
695 &mut addr as *mut libc::sockaddr_un as *mut libc::sockaddr,
696 &mut len
697 ))
698 };
699 if ret < 0 {
700 return Err(io::Error::last_os_error());
701 }
702 if addr.sun_family != libc::AF_UNIX as libc::sa_family_t
703 || addr.sun_path[0] == 0
704 || len < 1 + sun_path_offset
705 {
706 return Err(io::Error::new(
707 io::ErrorKind::InvalidInput,
708 "getsockname on socket returned invalid value",
709 ));
710 }
711
712 let path_os_str = OsString::from_vec(
713 addr.sun_path[..(len - sun_path_offset - 1) as usize]
714 .iter()
715 .map(|&c| c as _)
716 .collect(),
717 );
718 Ok(path_os_str.into())
719 }
720
721 pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
723 let mut nonblocking = nonblocking as libc::c_int;
724 let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
728 if ret < 0 {
729 Err(io::Error::last_os_error())
730 } else {
731 Ok(())
732 }
733 }
734}
735
736impl AsRawDescriptor for UnixSeqpacketListener {
737 fn as_raw_descriptor(&self) -> RawDescriptor {
738 self.descriptor.as_raw_descriptor()
739 }
740}
741
742impl From<UnixSeqpacketListener> for OwnedFd {
743 fn from(val: UnixSeqpacketListener) -> Self {
744 val.descriptor.into()
745 }
746}
747
748pub struct UnlinkUnixSeqpacketListener(pub UnixSeqpacketListener);
750
751impl AsRawDescriptor for UnlinkUnixSeqpacketListener {
752 fn as_raw_descriptor(&self) -> RawDescriptor {
753 self.0.as_raw_descriptor()
754 }
755}
756
757impl AsRef<UnixSeqpacketListener> for UnlinkUnixSeqpacketListener {
758 fn as_ref(&self) -> &UnixSeqpacketListener {
759 &self.0
760 }
761}
762
763impl Deref for UnlinkUnixSeqpacketListener {
764 type Target = UnixSeqpacketListener;
765 fn deref(&self) -> &Self::Target {
766 &self.0
767 }
768}
769
770impl Drop for UnlinkUnixSeqpacketListener {
771 fn drop(&mut self) {
772 if let Ok(path) = self.0.path() {
773 if let Err(e) = remove_file(path) {
774 warn!("failed to remove control socket file: {:?}", e);
775 }
776 }
777 }
778}
779
780#[cfg(test)]
781mod tests {
782 use super::*;
783
784 #[test]
785 fn sockaddr_un_zero_length_input() {
786 let _res = sockaddr_un(Path::new("")).expect("sockaddr_un failed");
787 }
788
789 #[test]
790 fn sockaddr_un_long_input_err() {
791 let res = sockaddr_un(Path::new(&"a".repeat(108)));
792 assert!(res.is_err());
793 }
794
795 #[test]
796 fn sockaddr_un_long_input_pass() {
797 let _res = sockaddr_un(Path::new(&"a".repeat(107))).expect("sockaddr_un failed");
798 }
799
800 #[test]
801 fn sockaddr_un_len_check() {
802 let (_addr, len) = sockaddr_un(Path::new(&"a".repeat(50))).expect("sockaddr_un failed");
803 assert_eq!(len, (sun_path_offset() + 50 + 1) as u32);
804 }
805
806 #[test]
807 #[allow(clippy::unnecessary_cast)]
808 #[allow(clippy::char_lit_as_u8)]
811 fn sockaddr_un_pass() {
812 let path_size = 50;
813 let (addr, len) =
814 sockaddr_un(Path::new(&"a".repeat(path_size))).expect("sockaddr_un failed");
815 assert_eq!(len, (sun_path_offset() + path_size + 1) as u32);
816 assert_eq!(addr.sun_family, libc::AF_UNIX as libc::sa_family_t);
817
818 let mut ref_sun_path = [0 as libc::c_char; 108];
820 for path in ref_sun_path.iter_mut().take(path_size) {
821 *path = 'a' as libc::c_char;
822 }
823
824 for (addr_char, ref_char) in addr.sun_path.iter().zip(ref_sun_path.iter()) {
825 assert_eq!(addr_char, ref_char);
826 }
827 }
828}