1use std::io;
6use std::io::Read;
7use std::os::unix::io::AsRawFd;
8use std::os::unix::io::RawFd;
9use std::os::unix::net::UnixStream;
10use std::time::Duration;
11
12use libc::c_void;
13use serde::Deserialize;
14use serde::Serialize;
15
16use super::super::net::UnixSeqpacket;
17use crate::descriptor::AsRawDescriptor;
18use crate::IntoRawDescriptor;
19use crate::RawDescriptor;
20use crate::ReadNotifier;
21use crate::Result;
22
23#[derive(Copy, Clone)]
24pub enum FramingMode {
25 Message,
26 Byte,
27}
28
29#[derive(Copy, Clone, PartialEq, Eq)]
30pub enum BlockingMode {
31 Blocking,
32 Nonblocking,
33}
34
35impl io::Read for StreamChannel {
36 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
37 self.inner_read(buf)
38 }
39}
40
41impl io::Read for &StreamChannel {
42 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
43 self.inner_read(buf)
44 }
45}
46
47impl AsRawDescriptor for StreamChannel {
48 fn as_raw_descriptor(&self) -> RawDescriptor {
49 (&self).as_raw_descriptor()
50 }
51}
52
53#[derive(Debug, Deserialize, Serialize)]
54enum SocketType {
55 Message(UnixSeqpacket),
56 #[serde(with = "crate::with_as_descriptor")]
57 Byte(UnixStream),
58}
59
60#[derive(Debug, Deserialize, Serialize)]
66pub struct StreamChannel {
67 stream: SocketType,
68}
69
70impl StreamChannel {
71 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
72 match &mut self.stream {
73 SocketType::Byte(sock) => sock.set_nonblocking(nonblocking),
74 SocketType::Message(sock) => sock.set_nonblocking(nonblocking),
75 }
76 }
77
78 pub fn get_framing_mode(&self) -> FramingMode {
79 match &self.stream {
80 SocketType::Message(_) => FramingMode::Message,
81 SocketType::Byte(_) => FramingMode::Byte,
82 }
83 }
84
85 pub(super) fn inner_read(&self, buf: &mut [u8]) -> io::Result<usize> {
86 match &self.stream {
87 SocketType::Byte(sock) => (&mut &*sock).read(buf),
88
89 SocketType::Message(sock) => {
103 let retval = unsafe {
107 libc::recv(
110 sock.as_raw_descriptor(),
111 buf.as_mut_ptr() as *mut c_void,
112 buf.len(),
113 libc::MSG_TRUNC,
114 )
115 };
116 let receive_len = if retval < 0 {
117 Err(std::io::Error::last_os_error())
118 } else {
119 Ok(retval)
120 }? as usize;
121
122 if receive_len > buf.len() {
123 Err(std::io::Error::other(format!(
124 "packet size {:?} encountered, but buffer was only of size {:?}",
125 receive_len,
126 buf.len()
127 )))
128 } else {
129 Ok(receive_len)
130 }
131 }
132 }
133 }
134
135 pub fn pair(
137 blocking_mode: BlockingMode,
138 framing_mode: FramingMode,
139 ) -> Result<(StreamChannel, StreamChannel)> {
140 let (pipe_a, pipe_b) = match framing_mode {
141 FramingMode::Byte => {
142 let (pipe_a, pipe_b) = UnixStream::pair()?;
143 (SocketType::Byte(pipe_a), SocketType::Byte(pipe_b))
144 }
145 FramingMode::Message => {
146 let (pipe_a, pipe_b) = UnixSeqpacket::pair()?;
147 (SocketType::Message(pipe_a), SocketType::Message(pipe_b))
148 }
149 };
150 let mut stream_a = StreamChannel { stream: pipe_a };
151 let mut stream_b = StreamChannel { stream: pipe_b };
152 let is_non_blocking = blocking_mode == BlockingMode::Nonblocking;
153 stream_a.set_nonblocking(is_non_blocking)?;
154 stream_b.set_nonblocking(is_non_blocking)?;
155 Ok((stream_a, stream_b))
156 }
157
158 pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
159 match &self.stream {
160 SocketType::Byte(sock) => sock.set_read_timeout(timeout),
161 SocketType::Message(sock) => sock.set_read_timeout(timeout),
162 }
163 }
164
165 pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
166 match &self.stream {
167 SocketType::Byte(sock) => sock.set_write_timeout(timeout),
168 SocketType::Message(sock) => sock.set_write_timeout(timeout),
169 }
170 }
171
172 pub fn try_clone(&self) -> io::Result<Self> {
175 Ok(StreamChannel {
176 stream: match &self.stream {
177 SocketType::Byte(sock) => SocketType::Byte(sock.try_clone()?),
178 SocketType::Message(sock) => SocketType::Message(sock.try_clone()?),
179 },
180 })
181 }
182}
183
184impl io::Write for StreamChannel {
185 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
186 match &mut self.stream {
187 SocketType::Byte(sock) => sock.write(buf),
188 SocketType::Message(sock) => sock.send(buf),
189 }
190 }
191 fn flush(&mut self) -> io::Result<()> {
192 match &mut self.stream {
193 SocketType::Byte(sock) => sock.flush(),
194 SocketType::Message(_) => Ok(()),
195 }
196 }
197}
198
199impl io::Write for &StreamChannel {
200 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
201 match &self.stream {
202 SocketType::Byte(sock) => (&mut &*sock).write(buf),
203 SocketType::Message(sock) => sock.send(buf),
204 }
205 }
206 fn flush(&mut self) -> io::Result<()> {
207 match &self.stream {
208 SocketType::Byte(sock) => (&mut &*sock).flush(),
209 SocketType::Message(_) => Ok(()),
210 }
211 }
212}
213
214impl AsRawFd for StreamChannel {
215 fn as_raw_fd(&self) -> RawFd {
216 match &self.stream {
217 SocketType::Byte(sock) => sock.as_raw_descriptor(),
218 SocketType::Message(sock) => sock.as_raw_descriptor(),
219 }
220 }
221}
222
223impl AsRawFd for &StreamChannel {
224 fn as_raw_fd(&self) -> RawFd {
225 self.as_raw_descriptor()
226 }
227}
228
229impl AsRawDescriptor for &StreamChannel {
230 fn as_raw_descriptor(&self) -> RawDescriptor {
231 match &self.stream {
232 SocketType::Byte(sock) => sock.as_raw_descriptor(),
233 SocketType::Message(sock) => sock.as_raw_descriptor(),
234 }
235 }
236}
237
238impl IntoRawDescriptor for StreamChannel {
239 fn into_raw_descriptor(self) -> RawFd {
240 match self.stream {
241 SocketType::Byte(sock) => sock.into_raw_descriptor(),
242 SocketType::Message(sock) => sock.into_raw_descriptor(),
243 }
244 }
245}
246
247impl ReadNotifier for StreamChannel {
248 fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
250 self
251 }
252}
253
254#[cfg(test)]
255mod test {
256 use std::io::Read;
257 use std::io::Write;
258
259 use super::*;
260 use crate::EventContext;
261 use crate::EventToken;
262 use crate::ReadNotifier;
263
264 #[derive(EventToken, Debug, Eq, PartialEq, Copy, Clone)]
265 enum Token {
266 ReceivedData,
267 }
268
269 #[test]
270 fn test_non_blocking_pair_byte() {
271 let (mut sender, mut receiver) =
272 StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
273
274 sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
275
276 let event_ctx: EventContext<Token> =
278 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
279 .unwrap();
280 let events = event_ctx.wait().unwrap();
281 let tokens: Vec<Token> = events
282 .iter()
283 .filter(|e| e.is_readable)
284 .map(|e| e.token)
285 .collect();
286 assert_eq!(tokens, vec! {Token::ReceivedData});
287
288 let mut recv_buffer: [u8; 4] = [0; 4];
290
291 let mut size = receiver.read(&mut recv_buffer).unwrap();
292 assert_eq!(size, 4);
293 assert_eq!(recv_buffer, [75, 77, 54, 82]);
294
295 size = receiver.read(&mut recv_buffer).unwrap();
296 assert_eq!(size, 2);
297 assert_eq!(recv_buffer[0..2], [76, 65]);
298
299 assert_eq!(
301 event_ctx
302 .wait_timeout(std::time::Duration::new(0, 0))
303 .unwrap()
304 .len(),
305 0
306 );
307 }
308
309 #[test]
310 fn test_non_blocking_pair_message() {
311 let (mut sender, mut receiver) =
312 StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Message).unwrap();
313
314 sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
315
316 let event_ctx: EventContext<Token> =
318 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
319 .unwrap();
320 let events = event_ctx.wait().unwrap();
321 let tokens: Vec<Token> = events
322 .iter()
323 .filter(|e| e.is_readable)
324 .map(|e| e.token)
325 .collect();
326 assert_eq!(tokens, vec! {Token::ReceivedData});
327
328 let mut recv_buffer: [u8; 6] = [0; 6];
331
332 let size = receiver.read(&mut recv_buffer).unwrap();
333 assert_eq!(size, 6);
334 assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]);
335
336 assert_eq!(
338 event_ctx
339 .wait_timeout(std::time::Duration::new(0, 0))
340 .unwrap()
341 .len(),
342 0
343 );
344 }
345
346 #[test]
347 fn test_non_blocking_pair_error_no_data() {
348 let (mut sender, mut receiver) =
349 StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
350 receiver
351 .set_nonblocking(true)
352 .expect("Failed to set receiver to nonblocking mode.");
353
354 sender.write_all(&[75, 77]).unwrap();
355
356 let event_ctx: EventContext<Token> =
358 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
359 .unwrap();
360 let events = event_ctx.wait().unwrap();
361 let tokens: Vec<Token> = events
362 .iter()
363 .filter(|e| e.is_readable)
364 .map(|e| e.token)
365 .collect();
366 assert_eq!(tokens, vec! {Token::ReceivedData});
367
368 let mut recv_buffer: [u8; 4] = [0; 4];
370 let size = receiver.read(&mut recv_buffer).unwrap();
371 assert_eq!(size, 2);
372 assert_eq!(recv_buffer, [75, 77, 00, 00]);
373
374 assert!(receiver.read(&mut recv_buffer).is_err());
377 }
378}