1use std::fs::File;
7use std::io::IoSliceMut;
8use std::mem;
9
10use base::AsRawDescriptor;
11use base::RawDescriptor;
12use zerocopy::FromBytes;
13use zerocopy::Immutable;
14use zerocopy::IntoBytes;
15
16use crate::connection::Req;
17use crate::message::FrontendReq;
18use crate::message::*;
19use crate::sys::PlatformConnection;
20use crate::Error;
21use crate::Result;
22
23pub trait Listener: Sized {
25 fn accept(&mut self) -> Result<Option<Connection<FrontendReq>>>;
27
28 fn set_nonblocking(&self, block: bool) -> Result<()>;
30}
31
32fn advance_slices_mut(bufs: &mut &mut [&mut [u8]], mut count: usize) {
35 use std::mem::take;
36
37 let mut idx = 0;
38 for b in bufs.iter() {
39 if count < b.len() {
40 break;
41 }
42 count -= b.len();
43 idx += 1;
44 }
45 *bufs = &mut take(bufs)[idx..];
46 if !bufs.is_empty() {
47 let slice = take(&mut bufs[0]);
48 let (_, remaining) = slice.split_at_mut(count);
49 bufs[0] = remaining;
50 }
51}
52
53pub struct Connection<R: Req>(
59 pub(crate) PlatformConnection,
60 pub(crate) std::marker::PhantomData<R>,
61 pub(crate) std::marker::PhantomData<std::cell::Cell<()>>,
64);
65
66impl<R: Req> Connection<R> {
67 pub fn send_header_only_message(
69 &self,
70 hdr: &VhostUserMsgHeader<R>,
71 fds: Option<&[RawDescriptor]>,
72 ) -> Result<()> {
73 self.0
74 .send_message(hdr.into_raw().as_bytes(), &[], &[], fds)
75 }
76
77 pub fn send_message<T: IntoBytes + Immutable>(
80 &self,
81 hdr: &VhostUserMsgHeader<R>,
82 body: &T,
83 fds: Option<&[RawDescriptor]>,
84 ) -> Result<()> {
85 self.0
86 .send_message(hdr.into_raw().as_bytes(), body.as_bytes(), &[], fds)
87 }
88
89 pub fn send_message_with_payload<T: IntoBytes + Immutable>(
92 &self,
93 hdr: &VhostUserMsgHeader<R>,
94 body: &T,
95 payload: &[u8],
96 fds: Option<&[RawDescriptor]>,
97 ) -> Result<()> {
98 self.0
99 .send_message(hdr.into_raw().as_bytes(), body.as_bytes(), payload, fds)
100 }
101
102 fn recv_into_bufs_all(&self, mut bufs: &mut [&mut [u8]]) -> Result<Vec<File>> {
115 let mut first_read = true;
116 let mut rfds = Vec::new();
117
118 advance_slices_mut(&mut bufs, 0);
120
121 while !bufs.is_empty() {
122 let mut slices: Vec<IoSliceMut> = bufs.iter_mut().map(|b| IoSliceMut::new(b)).collect();
123 let res = self.0.recv_into_bufs(&mut slices, true);
124 match res {
125 Ok((0, _)) => return Err(Error::PartialMessage),
126 Ok((n, fds)) => {
127 if first_read {
128 first_read = false;
129 if let Some(fds) = fds {
130 rfds = fds;
131 }
132 }
133 advance_slices_mut(&mut bufs, n);
134 }
135 Err(e) => match e {
136 Error::SocketRetry(_) => {}
137 _ => return Err(e),
138 },
139 }
140 }
141 Ok(rfds)
142 }
143
144 pub fn recv_header(&self) -> Result<(VhostUserMsgHeader<R>, Vec<File>)> {
149 let mut hdr_raw = [0u32; 3];
150 let files = self.recv_into_bufs_all(&mut [hdr_raw.as_mut_bytes()])?;
151 let hdr = VhostUserMsgHeader::from_raw(hdr_raw);
152 Ok((hdr, files))
153 }
154
155 pub fn recv_body_bytes(&self, hdr: &VhostUserMsgHeader<R>) -> Result<(Vec<u8>, Vec<File>)> {
157 let mut body = vec![0; hdr.get_size().try_into().unwrap()];
160 let files = self.recv_into_bufs_all(&mut [&mut body[..]])?;
161 Ok((body, files))
162 }
163
164 pub fn recv_message<T: IntoBytes + FromBytes>(
169 &self,
170 ) -> Result<(VhostUserMsgHeader<R>, T, Vec<File>)> {
171 let mut hdr_raw = [0u32; 3];
172 let mut body = T::new_zeroed();
173 let mut slices = [hdr_raw.as_mut_bytes(), body.as_mut_bytes()];
174 let files = self.recv_into_bufs_all(&mut slices)?;
175
176 let hdr = VhostUserMsgHeader::from_raw(hdr_raw);
177
178 Ok((hdr, body, files))
179 }
180
181 pub fn recv_message_with_payload<T: IntoBytes + FromBytes>(
187 &self,
188 ) -> Result<(VhostUserMsgHeader<R>, T, Vec<u8>, Vec<File>, Vec<File>)> {
189 let (hdr, files) = self.recv_header()?;
190
191 let mut body = T::new_zeroed();
192 let payload_size = hdr.get_size() as usize - mem::size_of::<T>();
193 let mut buf: Vec<u8> = vec![0; payload_size];
194 let mut slices = [body.as_mut_bytes(), buf.as_mut_bytes()];
195 let more_files = self.recv_into_bufs_all(&mut slices)?;
196
197 Ok((hdr, body, buf, files, more_files))
198 }
199}
200
201impl<R: Req> AsRawDescriptor for Connection<R> {
202 fn as_raw_descriptor(&self) -> RawDescriptor {
203 self.0.as_raw_descriptor()
204 }
205}
206
207#[cfg(test)]
208pub(crate) mod tests {
209 use std::io::Read;
210 use std::io::Seek;
211 use std::io::SeekFrom;
212 use std::io::Write;
213
214 use tempfile::tempfile;
215
216 use super::*;
217 use crate::message::VhostUserEmptyMessage;
218 use crate::message::VhostUserU64;
219
220 #[test]
221 fn send_header_only() {
222 let (client_connection, server_connection) = Connection::pair().unwrap();
223 let hdr1 = VhostUserMsgHeader::new(FrontendReq::GET_FEATURES, 0, 0);
224 client_connection
225 .send_header_only_message(&hdr1, None)
226 .unwrap();
227 let (hdr2, _, files) = server_connection
228 .recv_message::<VhostUserEmptyMessage>()
229 .unwrap();
230 assert_eq!(hdr1, hdr2);
231 assert!(files.is_empty());
232 }
233
234 #[test]
235 fn send_data() {
236 let (client_connection, server_connection) = Connection::pair().unwrap();
237 let hdr1 = VhostUserMsgHeader::new(FrontendReq::SET_FEATURES, 0, 8);
238 client_connection
239 .send_message(&hdr1, &VhostUserU64::new(0xf00dbeefdeadf00d), None)
240 .unwrap();
241 let (hdr2, body, files) = server_connection.recv_message::<VhostUserU64>().unwrap();
242 assert_eq!(hdr1, hdr2);
243 let value = body.value;
244 assert_eq!(value, 0xf00dbeefdeadf00d);
245 assert!(files.is_empty());
246 }
247
248 #[test]
249 fn send_fd() {
250 let (client_connection, server_connection) = Connection::pair().unwrap();
251
252 let mut fd = tempfile().unwrap();
253 write!(fd, "test").unwrap();
254
255 let hdr1 = VhostUserMsgHeader::new(FrontendReq::SET_MEM_TABLE, 0, 0);
257 client_connection
258 .send_header_only_message(&hdr1, Some(&[fd.as_raw_descriptor()]))
259 .unwrap();
260
261 let (hdr2, _, files) = server_connection
262 .recv_message::<VhostUserEmptyMessage>()
263 .unwrap();
264 assert_eq!(hdr1, hdr2);
265 assert_eq!(files.len(), 1);
266 let mut file = &files[0];
267 let mut content = String::new();
268 file.seek(SeekFrom::Start(0)).unwrap();
269 file.read_to_string(&mut content).unwrap();
270 assert_eq!(content, "test");
271 }
272
273 #[test]
274 fn test_advance_slices_mut() {
275 let mut buf1 = [1; 8];
277 let mut buf2 = [2; 16];
278 let mut buf3 = [3; 8];
279 let mut bufs = &mut [&mut buf1[..], &mut buf2[..], &mut buf3[..]][..];
280 advance_slices_mut(&mut bufs, 10);
281 assert_eq!(bufs[0], [2; 14].as_ref());
282 assert_eq!(bufs[1], [3; 8].as_ref());
283 }
284}