1use std::fs::File;
7use std::io::IoSliceMut;
8use std::mem;
9
10use base::AsRawDescriptor;
11use base::RawDescriptor;
12use zerocopy::FromBytes;
13use zerocopy::FromZeros;
14use zerocopy::Immutable;
15use zerocopy::IntoBytes;
16
17use crate::message::*;
18use crate::sys::PlatformConnection;
19use crate::Error;
20use crate::Result;
21
22pub trait Listener: Sized {
24 fn accept(&mut self) -> Result<Option<Connection>>;
26
27 fn set_nonblocking(&self, block: bool) -> Result<()>;
29}
30
31fn advance_slices_mut(bufs: &mut &mut [&mut [u8]], mut count: usize) {
34 use std::mem::take;
35
36 let mut idx = 0;
37 for b in bufs.iter() {
38 if count < b.len() {
39 break;
40 }
41 count -= b.len();
42 idx += 1;
43 }
44 *bufs = &mut take(bufs)[idx..];
45 if !bufs.is_empty() {
46 let slice = take(&mut bufs[0]);
47 let (_, remaining) = slice.split_at_mut(count);
48 bufs[0] = remaining;
49 }
50}
51
52pub struct Connection(
58 pub(crate) PlatformConnection,
59 pub(crate) std::marker::PhantomData<std::cell::Cell<()>>,
62);
63
64impl Connection {
65 pub fn send_header_only_message(
67 &self,
68 hdr: &VhostUserMsgHeader,
69 fds: Option<&[RawDescriptor]>,
70 ) -> Result<()> {
71 self.0.send_message(hdr.as_bytes(), &[], &[], fds)
72 }
73
74 pub fn send_message<T: IntoBytes + Immutable>(
77 &self,
78 hdr: &VhostUserMsgHeader,
79 body: &T,
80 fds: Option<&[RawDescriptor]>,
81 ) -> Result<()> {
82 self.0
83 .send_message(hdr.as_bytes(), body.as_bytes(), &[], fds)
84 }
85
86 pub fn send_message_with_payload<T: IntoBytes + Immutable>(
89 &self,
90 hdr: &VhostUserMsgHeader,
91 body: &T,
92 payload: &[u8],
93 fds: Option<&[RawDescriptor]>,
94 ) -> Result<()> {
95 self.0
96 .send_message(hdr.as_bytes(), body.as_bytes(), payload, fds)
97 }
98
99 fn recv_into_bufs_all(&self, mut bufs: &mut [&mut [u8]]) -> Result<Vec<File>> {
112 let mut first_read = true;
113 let mut rfds = Vec::new();
114
115 advance_slices_mut(&mut bufs, 0);
117
118 while !bufs.is_empty() {
119 let mut slices: Vec<IoSliceMut> = bufs.iter_mut().map(|b| IoSliceMut::new(b)).collect();
120 let res = self.0.recv_into_bufs(&mut slices, true);
121 match res {
122 Ok((0, _)) => return Err(Error::PartialMessage),
123 Ok((n, fds)) => {
124 if first_read {
125 first_read = false;
126 if let Some(fds) = fds {
127 rfds = fds;
128 }
129 }
130 advance_slices_mut(&mut bufs, n);
131 }
132 Err(e) => match e {
133 Error::SocketRetry(_) => {}
134 _ => return Err(e),
135 },
136 }
137 }
138 Ok(rfds)
139 }
140
141 pub fn recv_header(&self) -> Result<(VhostUserMsgHeader, Vec<File>)> {
146 let mut hdr = VhostUserMsgHeader::new_zeroed();
147 let files = self.recv_into_bufs_all(&mut [hdr.as_mut_bytes()])?;
148 Ok((hdr, files))
149 }
150
151 pub fn recv_body_bytes(&self, hdr: &VhostUserMsgHeader) -> Result<(Vec<u8>, Vec<File>)> {
153 let mut body = vec![0; hdr.get_size().try_into().unwrap()];
156 let files = self.recv_into_bufs_all(&mut [&mut body[..]])?;
157 Ok((body, files))
158 }
159
160 pub fn recv_message<T: IntoBytes + FromBytes>(
165 &self,
166 ) -> Result<(VhostUserMsgHeader, T, Vec<File>)> {
167 let mut hdr = VhostUserMsgHeader::new_zeroed();
168 let mut body = T::new_zeroed();
169 let mut slices = [hdr.as_mut_bytes(), body.as_mut_bytes()];
170 let files = self.recv_into_bufs_all(&mut slices)?;
171 Ok((hdr, body, files))
172 }
173
174 pub fn recv_message_with_payload<T: IntoBytes + FromBytes>(
180 &self,
181 ) -> Result<(VhostUserMsgHeader, T, Vec<u8>, Vec<File>, Vec<File>)> {
182 let (hdr, files) = self.recv_header()?;
183
184 let mut body = T::new_zeroed();
185 let payload_size = hdr.get_size() as usize - mem::size_of::<T>();
186 let mut buf: Vec<u8> = vec![0; payload_size];
187 let mut slices = [body.as_mut_bytes(), buf.as_mut_bytes()];
188 let more_files = self.recv_into_bufs_all(&mut slices)?;
189
190 Ok((hdr, body, buf, files, more_files))
191 }
192}
193
194impl AsRawDescriptor for Connection {
195 fn as_raw_descriptor(&self) -> RawDescriptor {
196 self.0.as_raw_descriptor()
197 }
198}
199
200#[cfg(test)]
201pub(crate) mod tests {
202 use std::io::Read;
203 use std::io::Seek;
204 use std::io::SeekFrom;
205 use std::io::Write;
206
207 use tempfile::tempfile;
208
209 use super::*;
210 use crate::message::FrontendReq;
211 use crate::message::VhostUserEmptyMessage;
212 use crate::message::VhostUserU64;
213
214 #[test]
215 fn send_header_only() {
216 let (client_connection, server_connection) = Connection::pair().unwrap();
217 let hdr1 = VhostUserMsgHeader::new_request_header(FrontendReq::GET_FEATURES, 0, false);
218 client_connection
219 .send_header_only_message(&hdr1, None)
220 .unwrap();
221 let (hdr2, _, files) = server_connection
222 .recv_message::<VhostUserEmptyMessage>()
223 .unwrap();
224 assert_eq!(hdr1, hdr2);
225 assert!(files.is_empty());
226 }
227
228 #[test]
229 fn send_data() {
230 let (client_connection, server_connection) = Connection::pair().unwrap();
231 let hdr1 = VhostUserMsgHeader::new_request_header(FrontendReq::SET_FEATURES, 8, false);
232 client_connection
233 .send_message(&hdr1, &VhostUserU64::new(0xf00dbeefdeadf00d), None)
234 .unwrap();
235 let (hdr2, body, files) = server_connection.recv_message::<VhostUserU64>().unwrap();
236 assert_eq!(hdr1, hdr2);
237 let value = body.value;
238 assert_eq!(value, 0xf00dbeefdeadf00d);
239 assert!(files.is_empty());
240 }
241
242 #[test]
243 fn send_fd() {
244 let (client_connection, server_connection) = Connection::pair().unwrap();
245
246 let mut fd = tempfile().unwrap();
247 write!(fd, "test").unwrap();
248
249 let hdr1 = VhostUserMsgHeader::new_request_header(FrontendReq::SET_MEM_TABLE, 0, false);
251 client_connection
252 .send_header_only_message(&hdr1, Some(&[fd.as_raw_descriptor()]))
253 .unwrap();
254
255 let (hdr2, _, files) = server_connection
256 .recv_message::<VhostUserEmptyMessage>()
257 .unwrap();
258 assert_eq!(hdr1, hdr2);
259 assert_eq!(files.len(), 1);
260 let mut file = &files[0];
261 let mut content = String::new();
262 file.seek(SeekFrom::Start(0)).unwrap();
263 file.read_to_string(&mut content).unwrap();
264 assert_eq!(content, "test");
265 }
266
267 #[test]
268 fn test_advance_slices_mut() {
269 let mut buf1 = [1; 8];
271 let mut buf2 = [2; 16];
272 let mut buf3 = [3; 8];
273 let mut bufs = &mut [&mut buf1[..], &mut buf2[..], &mut buf3[..]][..];
274 advance_slices_mut(&mut bufs, 10);
275 assert_eq!(bufs[0], [2; 14].as_ref());
276 assert_eq!(bufs[1], [3; 8].as_ref());
277 }
278}