use std::os::unix::prelude::AsRawFd;
use std::os::unix::prelude::RawFd;
use std::time::Duration;
use serde::de::DeserializeOwned;
use serde::Deserialize;
use serde::Serialize;
use crate::descriptor::AsRawDescriptor;
use crate::descriptor_reflection::deserialize_with_descriptors;
use crate::descriptor_reflection::SerializeDescriptors;
use crate::handle_eintr;
use crate::tube::Error;
use crate::tube::RecvTube;
use crate::tube::Result;
use crate::tube::SendTube;
use crate::BlockingMode;
use crate::FramingMode;
use crate::RawDescriptor;
use crate::ReadNotifier;
use crate::ScmSocket;
use crate::StreamChannel;
use crate::UnixSeqpacket;
use crate::SCM_SOCKET_MAX_FD_COUNT;
const TUBE_MAX_FDS: usize = 32;
#[derive(Serialize, Deserialize)]
pub struct Tube {
socket: ScmSocket<StreamChannel>,
}
impl Tube {
pub fn pair() -> Result<(Tube, Tube)> {
let (socket1, socket2) = StreamChannel::pair(BlockingMode::Blocking, FramingMode::Message)
.map_err(|errno| Error::Pair(std::io::Error::from(errno)))?;
let tube1 = Tube::new(socket1)?;
let tube2 = Tube::new(socket2)?;
Ok((tube1, tube2))
}
pub fn new(socket: StreamChannel) -> Result<Tube> {
match socket.get_framing_mode() {
FramingMode::Message => Ok(Tube {
socket: socket.try_into().map_err(Error::DupDescriptor)?,
}),
FramingMode::Byte => Err(Error::InvalidFramingMode),
}
}
pub fn new_from_unix_seqpacket(sock: UnixSeqpacket) -> Result<Tube> {
Ok(Tube {
socket: StreamChannel::from_unix_seqpacket(sock)
.try_into()
.map_err(Error::DupDescriptor)?,
})
}
#[deprecated]
pub fn try_clone(&self) -> Result<Self> {
self.socket
.inner()
.try_clone()
.map(Tube::new)
.map_err(Error::Clone)?
}
pub fn send<T: Serialize>(&self, msg: &T) -> Result<()> {
self.send_with_max_fds(msg, TUBE_MAX_FDS)
}
pub fn send_with_max_fds<T: Serialize>(&self, msg: &T, max_fds: usize) -> Result<()> {
if max_fds > SCM_SOCKET_MAX_FD_COUNT {
return Err(Error::SendTooManyFds);
}
let msg_serialize = SerializeDescriptors::new(&msg);
let msg_json = serde_json::to_vec(&msg_serialize).map_err(Error::Json)?;
let msg_descriptors = msg_serialize.into_descriptors();
if msg_descriptors.len() > max_fds {
return Err(Error::SendTooManyFds);
}
handle_eintr!(self.socket.send_with_fds(&msg_json, &msg_descriptors))
.map_err(Error::Send)?;
Ok(())
}
pub fn recv<T: DeserializeOwned>(&self) -> Result<T> {
self.recv_with_max_fds(TUBE_MAX_FDS)
}
pub fn recv_with_max_fds<T: DeserializeOwned>(&self, max_fds: usize) -> Result<T> {
if max_fds > SCM_SOCKET_MAX_FD_COUNT {
return Err(Error::RecvTooManyFds);
}
let msg_size = handle_eintr!(self.socket.inner().peek_size()).map_err(Error::Recv)?;
let mut msg_json = vec![0u8; msg_size];
let (msg_json_size, msg_descriptors) =
handle_eintr!(self.socket.recv_with_fds(&mut msg_json, max_fds))
.map_err(Error::Recv)?;
if msg_json_size == 0 {
return Err(Error::Disconnected);
}
deserialize_with_descriptors(
|| serde_json::from_slice(&msg_json[0..msg_json_size]),
msg_descriptors,
)
.map_err(Error::Json)
}
pub fn set_send_timeout(&self, timeout: Option<Duration>) -> Result<()> {
self.socket
.inner()
.set_write_timeout(timeout)
.map_err(Error::SetSendTimeout)
}
pub fn set_recv_timeout(&self, timeout: Option<Duration>) -> Result<()> {
self.socket
.inner()
.set_read_timeout(timeout)
.map_err(Error::SetRecvTimeout)
}
#[cfg(feature = "proto_tube")]
fn send_proto<M: protobuf::Message>(&self, msg: &M) -> Result<()> {
let bytes = msg.write_to_bytes().map_err(Error::Proto)?;
let no_fds: [RawFd; 0] = [];
handle_eintr!(self.socket.send_with_fds(&bytes, &no_fds)).map_err(Error::Send)?;
Ok(())
}
#[cfg(feature = "proto_tube")]
fn recv_proto<M: protobuf::Message>(&self) -> Result<M> {
let msg_size = handle_eintr!(self.socket.inner().peek_size()).map_err(Error::Recv)?;
let mut msg_bytes = vec![0u8; msg_size];
let (msg_bytes_size, _) =
handle_eintr!(self.socket.recv_with_fds(&mut msg_bytes, TUBE_MAX_FDS))
.map_err(Error::Recv)?;
if msg_bytes_size == 0 {
return Err(Error::Disconnected);
}
protobuf::Message::parse_from_bytes(&msg_bytes).map_err(Error::Proto)
}
}
impl AsRawDescriptor for Tube {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.socket.as_raw_descriptor()
}
}
impl AsRawFd for Tube {
fn as_raw_fd(&self) -> RawFd {
self.socket.inner().as_raw_fd()
}
}
impl ReadNotifier for Tube {
fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
&self.socket
}
}
impl AsRawDescriptor for SendTube {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.0.as_raw_descriptor()
}
}
impl AsRawDescriptor for RecvTube {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.0.as_raw_descriptor()
}
}
#[cfg(feature = "proto_tube")]
pub struct ProtoTube(Tube);
#[cfg(feature = "proto_tube")]
impl ProtoTube {
pub fn pair() -> Result<(ProtoTube, ProtoTube)> {
Tube::pair().map(|(t1, t2)| (ProtoTube(t1), ProtoTube(t2)))
}
pub fn send_proto<M: protobuf::Message>(&self, msg: &M) -> Result<()> {
self.0.send_proto(msg)
}
pub fn recv_proto<M: protobuf::Message>(&self) -> Result<M> {
self.0.recv_proto()
}
pub fn new_from_unix_seqpacket(sock: UnixSeqpacket) -> Result<ProtoTube> {
Ok(ProtoTube(Tube::new_from_unix_seqpacket(sock)?))
}
}
#[cfg(all(feature = "proto_tube", test))]
#[allow(unused_variables)]
mod tests {
use protos::cdisk_spec::ComponentDisk;
use super::*;
#[test]
fn tube_serializes_and_deserializes() {
let (pt1, pt2) = ProtoTube::pair().unwrap();
let proto = ComponentDisk {
file_path: "/some/cool/path".to_string(),
offset: 99,
..ComponentDisk::new()
};
pt1.send_proto(&proto).unwrap();
let recv_proto = pt2.recv_proto().unwrap();
assert!(proto.eq(&recv_proto));
}
}