base/sys/unix/
stream_channel.rs

1// Copyright 2022 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::io;
6use std::io::Read;
7use std::os::unix::io::AsRawFd;
8use std::os::unix::io::RawFd;
9use std::os::unix::net::UnixStream;
10use std::time::Duration;
11
12use libc::c_void;
13use serde::Deserialize;
14use serde::Serialize;
15
16use super::super::net::UnixSeqpacket;
17use crate::descriptor::AsRawDescriptor;
18use crate::IntoRawDescriptor;
19use crate::RawDescriptor;
20use crate::ReadNotifier;
21use crate::Result;
22
23#[derive(Copy, Clone)]
24pub enum FramingMode {
25    Message,
26    Byte,
27}
28
29#[derive(Copy, Clone, PartialEq, Eq)]
30pub enum BlockingMode {
31    Blocking,
32    Nonblocking,
33}
34
35impl io::Read for StreamChannel {
36    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
37        self.inner_read(buf)
38    }
39}
40
41impl io::Read for &StreamChannel {
42    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
43        self.inner_read(buf)
44    }
45}
46
47impl AsRawDescriptor for StreamChannel {
48    fn as_raw_descriptor(&self) -> RawDescriptor {
49        (&self).as_raw_descriptor()
50    }
51}
52
53#[derive(Debug, Deserialize, Serialize)]
54enum SocketType {
55    Message(UnixSeqpacket),
56    #[serde(with = "crate::with_as_descriptor")]
57    Byte(UnixStream),
58}
59
60/// An abstraction over named pipes and unix socketpairs. This abstraction can be used in a blocking
61/// and non blocking mode.
62///
63/// WARNING: partial reads of messages behave differently depending on the platform.
64/// See sys::unix::StreamChannel::inner_read for details.
65#[derive(Debug, Deserialize, Serialize)]
66pub struct StreamChannel {
67    stream: SocketType,
68}
69
70impl StreamChannel {
71    pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
72        match &mut self.stream {
73            SocketType::Byte(sock) => sock.set_nonblocking(nonblocking),
74            SocketType::Message(sock) => sock.set_nonblocking(nonblocking),
75        }
76    }
77
78    pub fn get_framing_mode(&self) -> FramingMode {
79        match &self.stream {
80            SocketType::Message(_) => FramingMode::Message,
81            SocketType::Byte(_) => FramingMode::Byte,
82        }
83    }
84
85    pub(super) fn inner_read(&self, buf: &mut [u8]) -> io::Result<usize> {
86        match &self.stream {
87            SocketType::Byte(sock) => (&mut &*sock).read(buf),
88
89            // On Windows, reading from SOCK_SEQPACKET with a buffer that is too small is an error,
90            // and the extra data will be preserved inside the named pipe.
91            //
92            // Linux though, will silently truncate unless MSG_TRUNC is passed. So we pass it, but
93            // even in that case, Linux will still throw away the extra data. This means there is a
94            // slight behavior difference between platforms from the consumer's perspective.
95            // In practice on Linux, intentional partial reads of messages are usually accomplished
96            // by also passing MSG_PEEK. While we could do this, and hide this rough edge from
97            // consumers, it would add complexity & turn every read into two read syscalls.
98            //
99            // So the compromise is this:
100            // * On Linux: a partial read of a message is an Err and loses data.
101            // * On Windows: a partial read of a message is Ok and does not lose data.
102            SocketType::Message(sock) => {
103                // SAFETY:
104                // Safe because buf is valid, we pass buf's size to recv to bound the return
105                // length, and we check the return code.
106                let retval = unsafe {
107                    // TODO(nkgold|b/152067913): Move this into the UnixSeqpacket struct as a
108                    // recv_with_flags method once that struct's tests are working.
109                    libc::recv(
110                        sock.as_raw_descriptor(),
111                        buf.as_mut_ptr() as *mut c_void,
112                        buf.len(),
113                        libc::MSG_TRUNC,
114                    )
115                };
116                let receive_len = if retval < 0 {
117                    Err(std::io::Error::last_os_error())
118                } else {
119                    Ok(retval)
120                }? as usize;
121
122                if receive_len > buf.len() {
123                    Err(std::io::Error::other(format!(
124                        "packet size {:?} encountered, but buffer was only of size {:?}",
125                        receive_len,
126                        buf.len()
127                    )))
128                } else {
129                    Ok(receive_len)
130                }
131            }
132        }
133    }
134
135    /// Creates a cross platform stream pair.
136    pub fn pair(
137        blocking_mode: BlockingMode,
138        framing_mode: FramingMode,
139    ) -> Result<(StreamChannel, StreamChannel)> {
140        let (pipe_a, pipe_b) = match framing_mode {
141            FramingMode::Byte => {
142                let (pipe_a, pipe_b) = UnixStream::pair()?;
143                (SocketType::Byte(pipe_a), SocketType::Byte(pipe_b))
144            }
145            FramingMode::Message => {
146                let (pipe_a, pipe_b) = UnixSeqpacket::pair()?;
147                (SocketType::Message(pipe_a), SocketType::Message(pipe_b))
148            }
149        };
150        let mut stream_a = StreamChannel { stream: pipe_a };
151        let mut stream_b = StreamChannel { stream: pipe_b };
152        let is_non_blocking = blocking_mode == BlockingMode::Nonblocking;
153        stream_a.set_nonblocking(is_non_blocking)?;
154        stream_b.set_nonblocking(is_non_blocking)?;
155        Ok((stream_a, stream_b))
156    }
157
158    pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
159        match &self.stream {
160            SocketType::Byte(sock) => sock.set_read_timeout(timeout),
161            SocketType::Message(sock) => sock.set_read_timeout(timeout),
162        }
163    }
164
165    pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
166        match &self.stream {
167            SocketType::Byte(sock) => sock.set_write_timeout(timeout),
168            SocketType::Message(sock) => sock.set_write_timeout(timeout),
169        }
170    }
171
172    // WARNING: Generally, multiple StreamChannel ends are not wanted. StreamChannel behavior with
173    // > 1 reader per end is not defined.
174    pub fn try_clone(&self) -> io::Result<Self> {
175        Ok(StreamChannel {
176            stream: match &self.stream {
177                SocketType::Byte(sock) => SocketType::Byte(sock.try_clone()?),
178                SocketType::Message(sock) => SocketType::Message(sock.try_clone()?),
179            },
180        })
181    }
182}
183
184impl io::Write for StreamChannel {
185    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
186        match &mut self.stream {
187            SocketType::Byte(sock) => sock.write(buf),
188            SocketType::Message(sock) => sock.send(buf),
189        }
190    }
191    fn flush(&mut self) -> io::Result<()> {
192        match &mut self.stream {
193            SocketType::Byte(sock) => sock.flush(),
194            SocketType::Message(_) => Ok(()),
195        }
196    }
197}
198
199impl io::Write for &StreamChannel {
200    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
201        match &self.stream {
202            SocketType::Byte(sock) => (&mut &*sock).write(buf),
203            SocketType::Message(sock) => sock.send(buf),
204        }
205    }
206    fn flush(&mut self) -> io::Result<()> {
207        match &self.stream {
208            SocketType::Byte(sock) => (&mut &*sock).flush(),
209            SocketType::Message(_) => Ok(()),
210        }
211    }
212}
213
214impl AsRawFd for StreamChannel {
215    fn as_raw_fd(&self) -> RawFd {
216        match &self.stream {
217            SocketType::Byte(sock) => sock.as_raw_descriptor(),
218            SocketType::Message(sock) => sock.as_raw_descriptor(),
219        }
220    }
221}
222
223impl AsRawFd for &StreamChannel {
224    fn as_raw_fd(&self) -> RawFd {
225        self.as_raw_descriptor()
226    }
227}
228
229impl AsRawDescriptor for &StreamChannel {
230    fn as_raw_descriptor(&self) -> RawDescriptor {
231        match &self.stream {
232            SocketType::Byte(sock) => sock.as_raw_descriptor(),
233            SocketType::Message(sock) => sock.as_raw_descriptor(),
234        }
235    }
236}
237
238impl IntoRawDescriptor for StreamChannel {
239    fn into_raw_descriptor(self) -> RawFd {
240        match self.stream {
241            SocketType::Byte(sock) => sock.into_raw_descriptor(),
242            SocketType::Message(sock) => sock.into_raw_descriptor(),
243        }
244    }
245}
246
247impl ReadNotifier for StreamChannel {
248    /// Returns a RawDescriptor that can be polled for reads using PollContext.
249    fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
250        self
251    }
252}
253
254#[cfg(test)]
255mod test {
256    use std::io::Read;
257    use std::io::Write;
258
259    use super::*;
260    use crate::EventContext;
261    use crate::EventToken;
262    use crate::ReadNotifier;
263
264    #[derive(EventToken, Debug, Eq, PartialEq, Copy, Clone)]
265    enum Token {
266        ReceivedData,
267    }
268
269    #[test]
270    fn test_non_blocking_pair_byte() {
271        let (mut sender, mut receiver) =
272            StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
273
274        sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
275
276        // Wait for the data to arrive.
277        let event_ctx: EventContext<Token> =
278            EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
279                .unwrap();
280        let events = event_ctx.wait().unwrap();
281        let tokens: Vec<Token> = events
282            .iter()
283            .filter(|e| e.is_readable)
284            .map(|e| e.token)
285            .collect();
286        assert_eq!(tokens, vec! {Token::ReceivedData});
287
288        // Smaller than what we sent so we get multiple chunks
289        let mut recv_buffer: [u8; 4] = [0; 4];
290
291        let mut size = receiver.read(&mut recv_buffer).unwrap();
292        assert_eq!(size, 4);
293        assert_eq!(recv_buffer, [75, 77, 54, 82]);
294
295        size = receiver.read(&mut recv_buffer).unwrap();
296        assert_eq!(size, 2);
297        assert_eq!(recv_buffer[0..2], [76, 65]);
298
299        // Now that we've polled for & received all data, polling again should show no events.
300        assert_eq!(
301            event_ctx
302                .wait_timeout(std::time::Duration::new(0, 0))
303                .unwrap()
304                .len(),
305            0
306        );
307    }
308
309    #[test]
310    fn test_non_blocking_pair_message() {
311        let (mut sender, mut receiver) =
312            StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Message).unwrap();
313
314        sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
315
316        // Wait for the data to arrive.
317        let event_ctx: EventContext<Token> =
318            EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
319                .unwrap();
320        let events = event_ctx.wait().unwrap();
321        let tokens: Vec<Token> = events
322            .iter()
323            .filter(|e| e.is_readable)
324            .map(|e| e.token)
325            .collect();
326        assert_eq!(tokens, vec! {Token::ReceivedData});
327
328        // Unlike Byte format, Message mode panics if the buffer is smaller than the packet size;
329        // make the buffer the right size.
330        let mut recv_buffer: [u8; 6] = [0; 6];
331
332        let size = receiver.read(&mut recv_buffer).unwrap();
333        assert_eq!(size, 6);
334        assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]);
335
336        // Now that we've polled for & received all data, polling again should show no events.
337        assert_eq!(
338            event_ctx
339                .wait_timeout(std::time::Duration::new(0, 0))
340                .unwrap()
341                .len(),
342            0
343        );
344    }
345
346    #[test]
347    fn test_non_blocking_pair_error_no_data() {
348        let (mut sender, mut receiver) =
349            StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
350        receiver
351            .set_nonblocking(true)
352            .expect("Failed to set receiver to nonblocking mode.");
353
354        sender.write_all(&[75, 77]).unwrap();
355
356        // Wait for the data to arrive.
357        let event_ctx: EventContext<Token> =
358            EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
359                .unwrap();
360        let events = event_ctx.wait().unwrap();
361        let tokens: Vec<Token> = events
362            .iter()
363            .filter(|e| e.is_readable)
364            .map(|e| e.token)
365            .collect();
366        assert_eq!(tokens, vec! {Token::ReceivedData});
367
368        // We only read 2 bytes, even though we requested 4 bytes.
369        let mut recv_buffer: [u8; 4] = [0; 4];
370        let size = receiver.read(&mut recv_buffer).unwrap();
371        assert_eq!(size, 2);
372        assert_eq!(recv_buffer, [75, 77, 00, 00]);
373
374        // Further reads should encounter an error since there is no available data and this is a
375        // non blocking pipe.
376        assert!(receiver.read(&mut recv_buffer).is_err());
377    }
378}