vmm_vhost/
connection.rs

1// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Common data structures for listener and connection.
5
6use std::fs::File;
7use std::io::IoSliceMut;
8use std::mem;
9
10use base::AsRawDescriptor;
11use base::RawDescriptor;
12use zerocopy::FromBytes;
13use zerocopy::FromZeros;
14use zerocopy::Immutable;
15use zerocopy::IntoBytes;
16
17use crate::message::*;
18use crate::sys::PlatformConnection;
19use crate::Error;
20use crate::Result;
21
22/// Listener for accepting connections.
23pub trait Listener: Sized {
24    /// Accept an incoming connection.
25    fn accept(&mut self) -> Result<Option<Connection>>;
26
27    /// Change blocking status on the listener.
28    fn set_nonblocking(&self, block: bool) -> Result<()>;
29}
30
31// Advance the internal cursor of the slices.
32// This is same with a nightly API `IoSliceMut::advance_slices` but for `&mut [u8]`.
33fn advance_slices_mut(bufs: &mut &mut [&mut [u8]], mut count: usize) {
34    use std::mem::take;
35
36    let mut idx = 0;
37    for b in bufs.iter() {
38        if count < b.len() {
39            break;
40        }
41        count -= b.len();
42        idx += 1;
43    }
44    *bufs = &mut take(bufs)[idx..];
45    if !bufs.is_empty() {
46        let slice = take(&mut bufs[0]);
47        let (_, remaining) = slice.split_at_mut(count);
48        bufs[0] = remaining;
49    }
50}
51
52/// A vhost-user connection at a low abstraction level. Provides methods for sending and receiving
53/// vhost-user message headers and bodies.
54///
55/// Builds on top of `PlatformConnection`, which provides methods for sending and receiving raw
56/// bytes and file descriptors (a thin cross-platform abstraction for unix domain sockets).
57pub struct Connection(
58    pub(crate) PlatformConnection,
59    // Mark `Connection` as `!Sync` because message sends and recvs cannot safely be done
60    // concurrently.
61    pub(crate) std::marker::PhantomData<std::cell::Cell<()>>,
62);
63
64impl Connection {
65    /// Sends a header-only message with optional attached file descriptors.
66    pub fn send_header_only_message(
67        &self,
68        hdr: &VhostUserMsgHeader,
69        fds: Option<&[RawDescriptor]>,
70    ) -> Result<()> {
71        self.0.send_message(hdr.as_bytes(), &[], &[], fds)
72    }
73
74    /// Send a message with header and body. Optional file descriptors may be attached to
75    /// the message.
76    pub fn send_message<T: IntoBytes + Immutable>(
77        &self,
78        hdr: &VhostUserMsgHeader,
79        body: &T,
80        fds: Option<&[RawDescriptor]>,
81    ) -> Result<()> {
82        self.0
83            .send_message(hdr.as_bytes(), body.as_bytes(), &[], fds)
84    }
85
86    /// Send a message with header and body. `payload` is appended to the end of the body. Optional
87    /// file descriptors may also be attached to the message.
88    pub fn send_message_with_payload<T: IntoBytes + Immutable>(
89        &self,
90        hdr: &VhostUserMsgHeader,
91        body: &T,
92        payload: &[u8],
93        fds: Option<&[RawDescriptor]>,
94    ) -> Result<()> {
95        self.0
96            .send_message(hdr.as_bytes(), body.as_bytes(), payload, fds)
97    }
98
99    /// Reads all bytes into the given scatter/gather vectors with optional attached files. Will
100    /// loop until all data has been transfered and errors if EOF is reached before then.
101    ///
102    /// # Return:
103    /// * - received fds on success
104    /// * - `Disconnect` - client is closed
105    ///
106    /// # TODO
107    /// This function takes a slice of `&mut [u8]` instead of `IoSliceMut` because the internal
108    /// cursor needs to be moved by `advance_slices_mut()`.
109    /// Once `IoSliceMut::advance_slices()` becomes stable, this should be updated.
110    /// <https://github.com/rust-lang/rust/issues/62726>.
111    fn recv_into_bufs_all(&self, mut bufs: &mut [&mut [u8]]) -> Result<Vec<File>> {
112        let mut first_read = true;
113        let mut rfds = Vec::new();
114
115        // Guarantee that `bufs` becomes empty if it doesn't contain any data.
116        advance_slices_mut(&mut bufs, 0);
117
118        while !bufs.is_empty() {
119            let mut slices: Vec<IoSliceMut> = bufs.iter_mut().map(|b| IoSliceMut::new(b)).collect();
120            let res = self.0.recv_into_bufs(&mut slices, true);
121            match res {
122                Ok((0, _)) => return Err(Error::PartialMessage),
123                Ok((n, fds)) => {
124                    if first_read {
125                        first_read = false;
126                        if let Some(fds) = fds {
127                            rfds = fds;
128                        }
129                    }
130                    advance_slices_mut(&mut bufs, n);
131                }
132                Err(e) => match e {
133                    Error::SocketRetry(_) => {}
134                    _ => return Err(e),
135                },
136            }
137        }
138        Ok(rfds)
139    }
140
141    /// Receive message header
142    ///
143    /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be accepted and all
144    /// other file descriptor will be discard silently.
145    pub fn recv_header(&self) -> Result<(VhostUserMsgHeader, Vec<File>)> {
146        let mut hdr = VhostUserMsgHeader::new_zeroed();
147        let files = self.recv_into_bufs_all(&mut [hdr.as_mut_bytes()])?;
148        Ok((hdr, files))
149    }
150
151    /// Receive the body following the header `hdr`.
152    pub fn recv_body_bytes(&self, hdr: &VhostUserMsgHeader) -> Result<(Vec<u8>, Vec<File>)> {
153        // NOTE: `recv_into_bufs_all` is a noop when the buffer is empty, so `hdr.get_size() == 0`
154        // works as expected.
155        let mut body = vec![0; hdr.get_size().try_into().unwrap()];
156        let files = self.recv_into_bufs_all(&mut [&mut body[..]])?;
157        Ok((body, files))
158    }
159
160    /// Receive a message header and body.
161    ///
162    /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
163    /// accepted and all other file descriptor will be discard silently.
164    pub fn recv_message<T: IntoBytes + FromBytes>(
165        &self,
166    ) -> Result<(VhostUserMsgHeader, T, Vec<File>)> {
167        let mut hdr = VhostUserMsgHeader::new_zeroed();
168        let mut body = T::new_zeroed();
169        let mut slices = [hdr.as_mut_bytes(), body.as_mut_bytes()];
170        let files = self.recv_into_bufs_all(&mut slices)?;
171        Ok((hdr, body, files))
172    }
173
174    /// Receive a message header and body, where the body includes a variable length payload at the
175    /// end.
176    ///
177    /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be accepted and all
178    /// other file descriptor will be discard silently.
179    pub fn recv_message_with_payload<T: IntoBytes + FromBytes>(
180        &self,
181    ) -> Result<(VhostUserMsgHeader, T, Vec<u8>, Vec<File>, Vec<File>)> {
182        let (hdr, files) = self.recv_header()?;
183
184        let mut body = T::new_zeroed();
185        let payload_size = hdr.get_size() as usize - mem::size_of::<T>();
186        let mut buf: Vec<u8> = vec![0; payload_size];
187        let mut slices = [body.as_mut_bytes(), buf.as_mut_bytes()];
188        let more_files = self.recv_into_bufs_all(&mut slices)?;
189
190        Ok((hdr, body, buf, files, more_files))
191    }
192}
193
194impl AsRawDescriptor for Connection {
195    fn as_raw_descriptor(&self) -> RawDescriptor {
196        self.0.as_raw_descriptor()
197    }
198}
199
200#[cfg(test)]
201pub(crate) mod tests {
202    use std::io::Read;
203    use std::io::Seek;
204    use std::io::SeekFrom;
205    use std::io::Write;
206
207    use tempfile::tempfile;
208
209    use super::*;
210    use crate::message::FrontendReq;
211    use crate::message::VhostUserEmptyMessage;
212    use crate::message::VhostUserU64;
213
214    #[test]
215    fn send_header_only() {
216        let (client_connection, server_connection) = Connection::pair().unwrap();
217        let hdr1 = VhostUserMsgHeader::new_request_header(FrontendReq::GET_FEATURES, 0, false);
218        client_connection
219            .send_header_only_message(&hdr1, None)
220            .unwrap();
221        let (hdr2, _, files) = server_connection
222            .recv_message::<VhostUserEmptyMessage>()
223            .unwrap();
224        assert_eq!(hdr1, hdr2);
225        assert!(files.is_empty());
226    }
227
228    #[test]
229    fn send_data() {
230        let (client_connection, server_connection) = Connection::pair().unwrap();
231        let hdr1 = VhostUserMsgHeader::new_request_header(FrontendReq::SET_FEATURES, 8, false);
232        client_connection
233            .send_message(&hdr1, &VhostUserU64::new(0xf00dbeefdeadf00d), None)
234            .unwrap();
235        let (hdr2, body, files) = server_connection.recv_message::<VhostUserU64>().unwrap();
236        assert_eq!(hdr1, hdr2);
237        let value = body.value;
238        assert_eq!(value, 0xf00dbeefdeadf00d);
239        assert!(files.is_empty());
240    }
241
242    #[test]
243    fn send_fd() {
244        let (client_connection, server_connection) = Connection::pair().unwrap();
245
246        let mut fd = tempfile().unwrap();
247        write!(fd, "test").unwrap();
248
249        // Normal case for sending/receiving file descriptors
250        let hdr1 = VhostUserMsgHeader::new_request_header(FrontendReq::SET_MEM_TABLE, 0, false);
251        client_connection
252            .send_header_only_message(&hdr1, Some(&[fd.as_raw_descriptor()]))
253            .unwrap();
254
255        let (hdr2, _, files) = server_connection
256            .recv_message::<VhostUserEmptyMessage>()
257            .unwrap();
258        assert_eq!(hdr1, hdr2);
259        assert_eq!(files.len(), 1);
260        let mut file = &files[0];
261        let mut content = String::new();
262        file.seek(SeekFrom::Start(0)).unwrap();
263        file.read_to_string(&mut content).unwrap();
264        assert_eq!(content, "test");
265    }
266
267    #[test]
268    fn test_advance_slices_mut() {
269        // Test case from https://doc.rust-lang.org/std/io/struct.IoSliceMut.html#method.advance_slices
270        let mut buf1 = [1; 8];
271        let mut buf2 = [2; 16];
272        let mut buf3 = [3; 8];
273        let mut bufs = &mut [&mut buf1[..], &mut buf2[..], &mut buf3[..]][..];
274        advance_slices_mut(&mut bufs, 10);
275        assert_eq!(bufs[0], [2; 14].as_ref());
276        assert_eq!(bufs[1], [3; 8].as_ref());
277    }
278}