1use std::fs::File;
9use std::io;
10use std::io::IoSlice;
11use std::io::IoSliceMut;
12use std::mem::size_of;
13use std::mem::size_of_val;
14use std::mem::MaybeUninit;
15use std::os::unix::io::RawFd;
16use std::ptr::copy_nonoverlapping;
17use std::ptr::null_mut;
18use std::ptr::write_unaligned;
19use std::slice;
20
21use libc::c_long;
22use libc::c_void;
23use libc::cmsghdr;
24use libc::iovec;
25use libc::msghdr;
26use libc::recvmsg;
27use libc::SCM_RIGHTS;
28use libc::SOL_SOCKET;
29use serde::Deserialize;
30use serde::Serialize;
31
32use crate::error;
33use crate::sys::sendmsg;
34use crate::AsRawDescriptor;
35use crate::FromRawDescriptor;
36use crate::IoBufMut;
37use crate::RawDescriptor;
38use crate::SafeDescriptor;
39use crate::VolatileSlice;
40
41pub const SCM_MAX_FD: usize = 253;
43
44#[allow(non_snake_case)]
48const fn CMSG_ALIGN(len: usize) -> usize {
49 (len + size_of::<c_long>() - 1) & !(size_of::<c_long>() - 1)
50}
51
52#[allow(non_snake_case)]
53const fn CMSG_SPACE(len: usize) -> usize {
54 size_of::<cmsghdr>() + CMSG_ALIGN(len)
55}
56
57#[allow(non_snake_case)]
58const fn CMSG_LEN(len: usize) -> usize {
59 size_of::<cmsghdr>() + len
60}
61
62#[allow(non_snake_case)]
66#[inline(always)]
67fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd {
68 cmsg_buffer.wrapping_offset(1) as *mut RawFd
70}
71
72#[allow(clippy::cast_ptr_alignment, clippy::unnecessary_cast)]
75fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr {
76 let next_cmsg =
79 (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN(cmsg.cmsg_len as usize)) as *mut cmsghdr;
80 if next_cmsg
81 .wrapping_offset(1)
82 .wrapping_sub(msghdr.msg_control as usize) as usize
83 > msghdr.msg_controllen as usize
84 {
85 null_mut()
86 } else {
87 next_cmsg
88 }
89}
90
91const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE(size_of::<RawFd>() * 32);
92
93enum CmsgBuffer {
94 Inline([u64; CMSG_BUFFER_INLINE_CAPACITY.div_ceil(8)]),
95 Heap(Box<[cmsghdr]>),
96}
97
98impl CmsgBuffer {
99 fn with_capacity(capacity: usize) -> CmsgBuffer {
100 let cap_in_cmsghdr_units =
101 (capacity.checked_add(size_of::<cmsghdr>()).unwrap() - 1) / size_of::<cmsghdr>();
102 if capacity <= CMSG_BUFFER_INLINE_CAPACITY {
103 CmsgBuffer::Inline([0u64; CMSG_BUFFER_INLINE_CAPACITY.div_ceil(8)])
104 } else {
105 CmsgBuffer::Heap(
106 vec![
107 unsafe { MaybeUninit::<cmsghdr>::zeroed().assume_init() };
111 cap_in_cmsghdr_units
112 ]
113 .into_boxed_slice(),
114 )
115 }
116 }
117
118 fn as_mut_ptr(&mut self) -> *mut cmsghdr {
119 match self {
120 CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr,
121 CmsgBuffer::Heap(a) => a.as_mut_ptr(),
122 }
123 }
124}
125
126#[allow(clippy::useless_conversion)]
129fn raw_sendmsg(fd: RawFd, iovec: &[iovec], out_fds: &[RawFd]) -> io::Result<usize> {
130 if out_fds.len() > SCM_MAX_FD {
131 error!(
132 "too many fds to send: {} > SCM_MAX_FD (SCM_MAX_FD)",
133 out_fds.len()
134 );
135 return Err(io::Error::from(io::ErrorKind::InvalidInput));
136 }
137
138 let cmsg_capacity = CMSG_SPACE(size_of_val(out_fds));
139 let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
140
141 let mut msg: msghdr = unsafe { MaybeUninit::zeroed().assume_init() };
146 msg.msg_iov = iovec.as_ptr() as *mut iovec;
147 msg.msg_iovlen = iovec.len().try_into().unwrap();
148
149 if !out_fds.is_empty() {
150 let mut cmsg: cmsghdr = unsafe { MaybeUninit::zeroed().assume_init() };
155 cmsg.cmsg_len = CMSG_LEN(size_of_val(out_fds)).try_into().unwrap();
156 cmsg.cmsg_level = SOL_SOCKET;
157 cmsg.cmsg_type = SCM_RIGHTS;
158 unsafe {
160 write_unaligned(cmsg_buffer.as_mut_ptr(), cmsg);
163 copy_nonoverlapping(
167 out_fds.as_ptr(),
168 CMSG_DATA(cmsg_buffer.as_mut_ptr()),
169 out_fds.len(),
170 );
171 }
172
173 msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
174 msg.msg_controllen = cmsg_capacity.try_into().unwrap();
175 }
176
177 let write_count = unsafe { sendmsg(fd, &msg, 0) };
181
182 if write_count == -1 {
183 Err(io::Error::last_os_error())
184 } else {
185 Ok(write_count as usize)
186 }
187}
188
189#[allow(clippy::useless_conversion, clippy::unnecessary_cast)]
192fn raw_recvmsg(
193 fd: RawFd,
194 iovs: &mut [iovec],
195 max_fds: usize,
196) -> io::Result<(usize, Vec<SafeDescriptor>)> {
197 if max_fds > SCM_MAX_FD {
198 error!("too many fds to recieve: {max_fds} > SCM_MAX_FD (SCM_MAX_FD)");
199 return Err(io::Error::from(io::ErrorKind::InvalidInput));
200 }
201
202 let cmsg_capacity = CMSG_SPACE(max_fds * size_of::<RawFd>());
203 let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
204
205 let mut msg: msghdr = unsafe { MaybeUninit::zeroed().assume_init() };
210 msg.msg_iov = iovs.as_mut_ptr() as *mut iovec;
211 msg.msg_iovlen = iovs.len().try_into().unwrap();
212
213 if max_fds > 0 {
214 msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
215 msg.msg_controllen = cmsg_capacity.try_into().unwrap();
216 }
217
218 let total_read = unsafe { recvmsg(fd, &mut msg, 0) };
222
223 if total_read == -1 {
224 return Err(io::Error::last_os_error());
225 }
226
227 if total_read == 0 && (msg.msg_controllen as usize) < size_of::<cmsghdr>() {
228 return Ok((0, Vec::new()));
229 }
230
231 let mut cmsg_ptr = msg.msg_control as *mut cmsghdr;
232 let mut in_fds: Vec<SafeDescriptor> = Vec::with_capacity(max_fds);
233 while !cmsg_ptr.is_null() {
234 let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() };
238
239 if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS {
240 let fd_count = (cmsg.cmsg_len as usize - CMSG_LEN(0)) / size_of::<RawFd>();
241 let fd_ptr: *const RawFd = CMSG_DATA(cmsg_ptr);
242 for i in 0..fd_count {
243 let fd: RawFd = unsafe { fd_ptr.add(i).read_unaligned() };
245 let sd = unsafe { SafeDescriptor::from_raw_descriptor(fd) };
247 in_fds.push(sd);
248 }
249 }
250
251 cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr);
252 }
253
254 Ok((total_read as usize, in_fds))
255}
256
257pub const SCM_SOCKET_MAX_FD_COUNT: usize = 253;
259
260#[derive(Serialize, Deserialize)]
266pub struct ScmSocket<T: AsRawDescriptor> {
267 pub(in crate::sys) socket: T,
268}
269
270impl<T: AsRawDescriptor> ScmSocket<T> {
271 pub fn send_with_fds(&self, buf: &[u8], fds: &[RawFd]) -> io::Result<usize> {
282 self.send_vectored_with_fds(&[IoSlice::new(buf)], fds)
283 }
284
285 pub fn send_vectored_with_fds(
296 &self,
297 bufs: &[impl AsIobuf],
298 fds: &[RawFd],
299 ) -> io::Result<usize> {
300 raw_sendmsg(
301 self.socket.as_raw_descriptor(),
302 AsIobuf::as_iobuf_slice(bufs),
303 fds,
304 )
305 }
306
307 pub fn recv_with_fds(
319 &self,
320 buf: &mut [u8],
321 max_descriptors: usize,
322 ) -> io::Result<(usize, Vec<SafeDescriptor>)> {
323 self.recv_vectored_with_fds(&mut [IoSliceMut::new(buf)], max_descriptors)
324 }
325
326 pub fn recv_vectored_with_fds(
338 &self,
339 bufs: &mut [IoSliceMut],
340 max_descriptors: usize,
341 ) -> io::Result<(usize, Vec<SafeDescriptor>)> {
342 raw_recvmsg(
343 self.socket.as_raw_descriptor(),
344 IoSliceMut::as_iobuf_mut_slice(bufs),
345 max_descriptors,
346 )
347 }
348
349 pub fn recv_with_file(&self, buf: &mut [u8]) -> io::Result<(usize, Option<File>)> {
359 let (read_count, mut descriptors) = self.recv_with_fds(buf, 1)?;
360 let file = if descriptors.len() == 1 {
361 Some(File::from(descriptors.swap_remove(0)))
362 } else {
363 None
364 };
365 Ok((read_count, file))
366 }
367
368 pub fn inner(&self) -> &T {
370 &self.socket
371 }
372
373 pub fn inner_mut(&mut self) -> &mut T {
375 &mut self.socket
376 }
377
378 pub fn into_inner(self) -> T {
380 self.socket
381 }
382}
383
384impl<T: AsRawDescriptor> AsRawDescriptor for ScmSocket<T> {
385 fn as_raw_descriptor(&self) -> RawDescriptor {
386 self.socket.as_raw_descriptor()
387 }
388}
389
390pub unsafe trait AsIobuf: Sized {
397 fn as_iobuf(&self) -> iovec;
399
400 #[allow(clippy::wrong_self_convention)]
402 fn as_iobuf_slice(bufs: &[Self]) -> &[iovec];
403
404 fn as_iobuf_mut_slice(bufs: &mut [Self]) -> &mut [iovec];
406}
407
408unsafe impl AsIobuf for IoSlice<'_> {
412 fn as_iobuf(&self) -> iovec {
413 iovec {
414 iov_base: self.as_ptr() as *mut c_void,
415 iov_len: self.len(),
416 }
417 }
418
419 fn as_iobuf_slice(bufs: &[Self]) -> &[iovec] {
420 unsafe { slice::from_raw_parts(bufs.as_ptr() as *const iovec, bufs.len()) }
423 }
424
425 fn as_iobuf_mut_slice(bufs: &mut [Self]) -> &mut [iovec] {
426 unsafe { slice::from_raw_parts_mut(bufs.as_mut_ptr() as *mut iovec, bufs.len()) }
429 }
430}
431
432unsafe impl AsIobuf for IoSliceMut<'_> {
436 fn as_iobuf(&self) -> iovec {
437 iovec {
438 iov_base: self.as_ptr() as *mut c_void,
439 iov_len: self.len(),
440 }
441 }
442
443 fn as_iobuf_slice(bufs: &[Self]) -> &[iovec] {
444 unsafe { slice::from_raw_parts(bufs.as_ptr() as *const iovec, bufs.len()) }
447 }
448
449 fn as_iobuf_mut_slice(bufs: &mut [Self]) -> &mut [iovec] {
450 unsafe { slice::from_raw_parts_mut(bufs.as_mut_ptr() as *mut iovec, bufs.len()) }
453 }
454}
455
456unsafe impl AsIobuf for VolatileSlice<'_> {
460 fn as_iobuf(&self) -> iovec {
461 *self.as_iobuf().as_ref()
462 }
463
464 fn as_iobuf_slice(bufs: &[Self]) -> &[iovec] {
465 IoBufMut::as_iobufs(VolatileSlice::as_iobufs(bufs))
466 }
467
468 fn as_iobuf_mut_slice(bufs: &mut [Self]) -> &mut [iovec] {
469 IoBufMut::as_iobufs_mut(VolatileSlice::as_iobufs_mut(bufs))
470 }
471}
472
473#[cfg(test)]
474#[cfg(any(target_os = "android", target_os = "linux"))] mod tests {
476 use std::io::Write;
477 use std::mem::size_of;
478 use std::os::fd::AsRawFd;
479 use std::os::unix::net::UnixDatagram;
480 use std::slice::from_raw_parts;
481
482 use super::*;
483 use crate::AsRawDescriptor;
484 use crate::Event;
485 use crate::EventExt;
486
487 macro_rules! CMSG_SPACE_TEST {
489 ($len:literal) => {
490 assert_eq!(
491 CMSG_SPACE(size_of::<[RawFd; $len]>()) as libc::c_uint,
492 unsafe { libc::CMSG_SPACE(size_of::<[RawFd; $len]>() as libc::c_uint) }
494 );
495 };
496 }
497
498 #[test]
499 #[allow(clippy::erasing_op, clippy::identity_op)]
500 fn buffer_len() {
501 CMSG_SPACE_TEST!(0);
502 CMSG_SPACE_TEST!(1);
503 CMSG_SPACE_TEST!(2);
504 CMSG_SPACE_TEST!(3);
505 CMSG_SPACE_TEST!(4);
506 }
507
508 #[test]
509 fn send_recv_no_fd() {
510 let (u1, u2) = UnixDatagram::pair().expect("failed to create socket pair");
511 let (s1, s2) = (
512 ScmSocket::try_from(u1).unwrap(),
513 ScmSocket::try_from(u2).unwrap(),
514 );
515
516 let send_buf = [1u8, 1, 2, 21, 34, 55];
517 let write_count = s1
518 .send_with_fds(&send_buf, &[])
519 .expect("failed to send data");
520
521 assert_eq!(write_count, 6);
522
523 let mut buf = [0; 6];
524 let (read_count, files) = s2.recv_with_fds(&mut buf, 1).expect("failed to recv data");
525
526 assert_eq!(read_count, 6);
527 assert_eq!(files.len(), 0);
528 assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
529
530 let write_count = s1
531 .send_with_fds(&send_buf, &[])
532 .expect("failed to send data");
533
534 assert_eq!(write_count, 6);
535 let (read_count, files) = s2.recv_with_fds(&mut buf, 1).expect("failed to recv data");
536
537 assert_eq!(read_count, 6);
538 assert_eq!(files.len(), 0);
539 assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
540 }
541
542 #[test]
543 fn send_recv_only_fd() {
544 let (u1, u2) = UnixDatagram::pair().expect("failed to create socket pair");
545 let (s1, s2) = (
546 ScmSocket::try_from(u1).unwrap(),
547 ScmSocket::try_from(u2).unwrap(),
548 );
549
550 let evt = Event::new().expect("failed to create event");
551 let write_count = s1
552 .send_with_fds(&[], &[evt.as_raw_descriptor()])
553 .expect("failed to send fd");
554
555 assert_eq!(write_count, 0);
556
557 let mut buf = [];
558 let (read_count, file_opt) = s2.recv_with_file(&mut buf).expect("failed to recv fd");
559
560 let mut file = file_opt.unwrap();
561
562 assert_eq!(read_count, 0);
563 assert!(file.as_raw_fd() >= 0);
564 assert_ne!(file.as_raw_fd(), s1.as_raw_descriptor());
565 assert_ne!(file.as_raw_fd(), s2.as_raw_descriptor());
566 assert_ne!(file.as_raw_fd(), evt.as_raw_descriptor());
567
568 file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
570 .expect("failed to write to sent fd");
571
572 assert_eq!(evt.read_count().expect("failed to read from event"), 1203);
573 }
574
575 #[test]
576 fn send_recv_with_fd() {
577 let (u1, u2) = UnixDatagram::pair().expect("failed to create socket pair");
578 let (s1, s2) = (
579 ScmSocket::try_from(u1).unwrap(),
580 ScmSocket::try_from(u2).unwrap(),
581 );
582
583 let evt = Event::new().expect("failed to create event");
584 let write_count = s1
585 .send_with_fds(&[237], &[evt.as_raw_descriptor()])
586 .expect("failed to send fd");
587
588 assert_eq!(write_count, 1);
589
590 let mut buf = [0u8];
591 let (read_count, mut files) = s2.recv_with_fds(&mut buf, 2).expect("failed to recv fd");
592
593 assert_eq!(read_count, 1);
594 assert_eq!(buf[0], 237);
595 assert_eq!(files.len(), 1);
596 assert!(files[0].as_raw_descriptor() >= 0);
597 assert_ne!(files[0].as_raw_descriptor(), s1.as_raw_descriptor());
598 assert_ne!(files[0].as_raw_descriptor(), s2.as_raw_descriptor());
599 assert_ne!(files[0].as_raw_descriptor(), evt.as_raw_descriptor());
600
601 let mut file = File::from(files.swap_remove(0));
602
603 file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
605 .expect("failed to write to sent fd");
606
607 assert_eq!(evt.read_count().expect("failed to read from event"), 1203);
608 }
609}