use std::io;
use std::io::Read;
use std::os::unix::io::AsRawFd;
use std::os::unix::io::RawFd;
use std::os::unix::net::UnixStream;
use std::time::Duration;
use libc::c_void;
use serde::Deserialize;
use serde::Serialize;
use super::super::net::UnixSeqpacket;
use crate::descriptor::AsRawDescriptor;
use crate::IntoRawDescriptor;
use crate::RawDescriptor;
use crate::ReadNotifier;
use crate::Result;
#[derive(Copy, Clone)]
pub enum FramingMode {
Message,
Byte,
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum BlockingMode {
Blocking,
Nonblocking,
}
impl io::Read for StreamChannel {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner_read(buf)
}
}
impl io::Read for &StreamChannel {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner_read(buf)
}
}
impl AsRawDescriptor for StreamChannel {
fn as_raw_descriptor(&self) -> RawDescriptor {
(&self).as_raw_descriptor()
}
}
#[derive(Debug, Deserialize, Serialize)]
enum SocketType {
Message(UnixSeqpacket),
#[serde(with = "crate::with_as_descriptor")]
Byte(UnixStream),
}
#[derive(Debug, Deserialize, Serialize)]
pub struct StreamChannel {
stream: SocketType,
}
impl StreamChannel {
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
match &mut self.stream {
SocketType::Byte(sock) => sock.set_nonblocking(nonblocking),
SocketType::Message(sock) => sock.set_nonblocking(nonblocking),
}
}
pub fn get_framing_mode(&self) -> FramingMode {
match &self.stream {
SocketType::Message(_) => FramingMode::Message,
SocketType::Byte(_) => FramingMode::Byte,
}
}
pub(super) fn inner_read(&self, buf: &mut [u8]) -> io::Result<usize> {
match &self.stream {
SocketType::Byte(sock) => (&mut &*sock).read(buf),
SocketType::Message(sock) => {
let retval = unsafe {
libc::recv(
sock.as_raw_descriptor(),
buf.as_mut_ptr() as *mut c_void,
buf.len(),
libc::MSG_TRUNC,
)
};
let receive_len = if retval < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(retval)
}? as usize;
if receive_len > buf.len() {
Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!(
"packet size {:?} encountered, but buffer was only of size {:?}",
receive_len,
buf.len()
),
))
} else {
Ok(receive_len)
}
}
}
}
pub fn pair(
blocking_mode: BlockingMode,
framing_mode: FramingMode,
) -> Result<(StreamChannel, StreamChannel)> {
let (pipe_a, pipe_b) = match framing_mode {
FramingMode::Byte => {
let (pipe_a, pipe_b) = UnixStream::pair()?;
(SocketType::Byte(pipe_a), SocketType::Byte(pipe_b))
}
FramingMode::Message => {
let (pipe_a, pipe_b) = UnixSeqpacket::pair()?;
(SocketType::Message(pipe_a), SocketType::Message(pipe_b))
}
};
let mut stream_a = StreamChannel { stream: pipe_a };
let mut stream_b = StreamChannel { stream: pipe_b };
let is_non_blocking = blocking_mode == BlockingMode::Nonblocking;
stream_a.set_nonblocking(is_non_blocking)?;
stream_b.set_nonblocking(is_non_blocking)?;
Ok((stream_a, stream_b))
}
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
match &self.stream {
SocketType::Byte(sock) => sock.set_read_timeout(timeout),
SocketType::Message(sock) => sock.set_read_timeout(timeout),
}
}
pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
match &self.stream {
SocketType::Byte(sock) => sock.set_write_timeout(timeout),
SocketType::Message(sock) => sock.set_write_timeout(timeout),
}
}
pub fn try_clone(&self) -> io::Result<Self> {
Ok(StreamChannel {
stream: match &self.stream {
SocketType::Byte(sock) => SocketType::Byte(sock.try_clone()?),
SocketType::Message(sock) => SocketType::Message(sock.try_clone()?),
},
})
}
}
impl io::Write for StreamChannel {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match &mut self.stream {
SocketType::Byte(sock) => sock.write(buf),
SocketType::Message(sock) => sock.send(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match &mut self.stream {
SocketType::Byte(sock) => sock.flush(),
SocketType::Message(_) => Ok(()),
}
}
}
impl io::Write for &StreamChannel {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match &self.stream {
SocketType::Byte(sock) => (&mut &*sock).write(buf),
SocketType::Message(sock) => sock.send(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match &self.stream {
SocketType::Byte(sock) => (&mut &*sock).flush(),
SocketType::Message(_) => Ok(()),
}
}
}
impl AsRawFd for StreamChannel {
fn as_raw_fd(&self) -> RawFd {
match &self.stream {
SocketType::Byte(sock) => sock.as_raw_descriptor(),
SocketType::Message(sock) => sock.as_raw_descriptor(),
}
}
}
impl AsRawFd for &StreamChannel {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_descriptor()
}
}
impl AsRawDescriptor for &StreamChannel {
fn as_raw_descriptor(&self) -> RawDescriptor {
match &self.stream {
SocketType::Byte(sock) => sock.as_raw_descriptor(),
SocketType::Message(sock) => sock.as_raw_descriptor(),
}
}
}
impl IntoRawDescriptor for StreamChannel {
fn into_raw_descriptor(self) -> RawFd {
match self.stream {
SocketType::Byte(sock) => sock.into_raw_descriptor(),
SocketType::Message(sock) => sock.into_raw_descriptor(),
}
}
}
impl ReadNotifier for StreamChannel {
fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
self
}
}
#[cfg(test)]
mod test {
use std::io::Read;
use std::io::Write;
use super::*;
use crate::EventContext;
use crate::EventToken;
use crate::ReadNotifier;
#[derive(EventToken, Debug, Eq, PartialEq, Copy, Clone)]
enum Token {
ReceivedData,
}
#[test]
fn test_non_blocking_pair_byte() {
let (mut sender, mut receiver) =
StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
let event_ctx: EventContext<Token> =
EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
.unwrap();
let events = event_ctx.wait().unwrap();
let tokens: Vec<Token> = events
.iter()
.filter(|e| e.is_readable)
.map(|e| e.token)
.collect();
assert_eq!(tokens, vec! {Token::ReceivedData});
let mut recv_buffer: [u8; 4] = [0; 4];
let mut size = receiver.read(&mut recv_buffer).unwrap();
assert_eq!(size, 4);
assert_eq!(recv_buffer, [75, 77, 54, 82]);
size = receiver.read(&mut recv_buffer).unwrap();
assert_eq!(size, 2);
assert_eq!(recv_buffer[0..2], [76, 65]);
assert_eq!(
event_ctx
.wait_timeout(std::time::Duration::new(0, 0))
.unwrap()
.len(),
0
);
}
#[test]
fn test_non_blocking_pair_message() {
let (mut sender, mut receiver) =
StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Message).unwrap();
sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
let event_ctx: EventContext<Token> =
EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
.unwrap();
let events = event_ctx.wait().unwrap();
let tokens: Vec<Token> = events
.iter()
.filter(|e| e.is_readable)
.map(|e| e.token)
.collect();
assert_eq!(tokens, vec! {Token::ReceivedData});
let mut recv_buffer: [u8; 6] = [0; 6];
let size = receiver.read(&mut recv_buffer).unwrap();
assert_eq!(size, 6);
assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]);
assert_eq!(
event_ctx
.wait_timeout(std::time::Duration::new(0, 0))
.unwrap()
.len(),
0
);
}
#[test]
fn test_non_blocking_pair_error_no_data() {
let (mut sender, mut receiver) =
StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
receiver
.set_nonblocking(true)
.expect("Failed to set receiver to nonblocking mode.");
sender.write_all(&[75, 77]).unwrap();
let event_ctx: EventContext<Token> =
EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
.unwrap();
let events = event_ctx.wait().unwrap();
let tokens: Vec<Token> = events
.iter()
.filter(|e| e.is_readable)
.map(|e| e.token)
.collect();
assert_eq!(tokens, vec! {Token::ReceivedData});
let mut recv_buffer: [u8; 4] = [0; 4];
let size = receiver.read(&mut recv_buffer).unwrap();
assert_eq!(size, 2);
assert_eq!(recv_buffer, [75, 77, 00, 00]);
assert!(receiver.read(&mut recv_buffer).is_err());
}
}