vmm_vhost/
connection.rs

1// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Common data structures for listener and connection.
5
6use std::fs::File;
7use std::io::IoSliceMut;
8use std::mem;
9
10use base::AsRawDescriptor;
11use base::RawDescriptor;
12use zerocopy::FromBytes;
13use zerocopy::Immutable;
14use zerocopy::IntoBytes;
15
16use crate::connection::Req;
17use crate::message::FrontendReq;
18use crate::message::*;
19use crate::sys::PlatformConnection;
20use crate::Error;
21use crate::Result;
22
23/// Listener for accepting connections.
24pub trait Listener: Sized {
25    /// Accept an incoming connection.
26    fn accept(&mut self) -> Result<Option<Connection<FrontendReq>>>;
27
28    /// Change blocking status on the listener.
29    fn set_nonblocking(&self, block: bool) -> Result<()>;
30}
31
32// Advance the internal cursor of the slices.
33// This is same with a nightly API `IoSliceMut::advance_slices` but for `&mut [u8]`.
34fn advance_slices_mut(bufs: &mut &mut [&mut [u8]], mut count: usize) {
35    use std::mem::take;
36
37    let mut idx = 0;
38    for b in bufs.iter() {
39        if count < b.len() {
40            break;
41        }
42        count -= b.len();
43        idx += 1;
44    }
45    *bufs = &mut take(bufs)[idx..];
46    if !bufs.is_empty() {
47        let slice = take(&mut bufs[0]);
48        let (_, remaining) = slice.split_at_mut(count);
49        bufs[0] = remaining;
50    }
51}
52
53/// A vhost-user connection at a low abstraction level. Provides methods for sending and receiving
54/// vhost-user message headers and bodies.
55///
56/// Builds on top of `PlatformConnection`, which provides methods for sending and receiving raw
57/// bytes and file descriptors (a thin cross-platform abstraction for unix domain sockets).
58pub struct Connection<R: Req>(
59    pub(crate) PlatformConnection,
60    pub(crate) std::marker::PhantomData<R>,
61    // Mark `Connection` as `!Sync` because message sends and recvs cannot safely be done
62    // concurrently.
63    pub(crate) std::marker::PhantomData<std::cell::Cell<()>>,
64);
65
66impl<R: Req> Connection<R> {
67    /// Sends a header-only message with optional attached file descriptors.
68    pub fn send_header_only_message(
69        &self,
70        hdr: &VhostUserMsgHeader<R>,
71        fds: Option<&[RawDescriptor]>,
72    ) -> Result<()> {
73        self.0
74            .send_message(hdr.into_raw().as_bytes(), &[], &[], fds)
75    }
76
77    /// Send a message with header and body. Optional file descriptors may be attached to
78    /// the message.
79    pub fn send_message<T: IntoBytes + Immutable>(
80        &self,
81        hdr: &VhostUserMsgHeader<R>,
82        body: &T,
83        fds: Option<&[RawDescriptor]>,
84    ) -> Result<()> {
85        self.0
86            .send_message(hdr.into_raw().as_bytes(), body.as_bytes(), &[], fds)
87    }
88
89    /// Send a message with header and body. `payload` is appended to the end of the body. Optional
90    /// file descriptors may also be attached to the message.
91    pub fn send_message_with_payload<T: IntoBytes + Immutable>(
92        &self,
93        hdr: &VhostUserMsgHeader<R>,
94        body: &T,
95        payload: &[u8],
96        fds: Option<&[RawDescriptor]>,
97    ) -> Result<()> {
98        self.0
99            .send_message(hdr.into_raw().as_bytes(), body.as_bytes(), payload, fds)
100    }
101
102    /// Reads all bytes into the given scatter/gather vectors with optional attached files. Will
103    /// loop until all data has been transfered and errors if EOF is reached before then.
104    ///
105    /// # Return:
106    /// * - received fds on success
107    /// * - `Disconnect` - client is closed
108    ///
109    /// # TODO
110    /// This function takes a slice of `&mut [u8]` instead of `IoSliceMut` because the internal
111    /// cursor needs to be moved by `advance_slices_mut()`.
112    /// Once `IoSliceMut::advance_slices()` becomes stable, this should be updated.
113    /// <https://github.com/rust-lang/rust/issues/62726>.
114    fn recv_into_bufs_all(&self, mut bufs: &mut [&mut [u8]]) -> Result<Vec<File>> {
115        let mut first_read = true;
116        let mut rfds = Vec::new();
117
118        // Guarantee that `bufs` becomes empty if it doesn't contain any data.
119        advance_slices_mut(&mut bufs, 0);
120
121        while !bufs.is_empty() {
122            let mut slices: Vec<IoSliceMut> = bufs.iter_mut().map(|b| IoSliceMut::new(b)).collect();
123            let res = self.0.recv_into_bufs(&mut slices, true);
124            match res {
125                Ok((0, _)) => return Err(Error::PartialMessage),
126                Ok((n, fds)) => {
127                    if first_read {
128                        first_read = false;
129                        if let Some(fds) = fds {
130                            rfds = fds;
131                        }
132                    }
133                    advance_slices_mut(&mut bufs, n);
134                }
135                Err(e) => match e {
136                    Error::SocketRetry(_) => {}
137                    _ => return Err(e),
138                },
139            }
140        }
141        Ok(rfds)
142    }
143
144    /// Receive message header
145    ///
146    /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be accepted and all
147    /// other file descriptor will be discard silently.
148    pub fn recv_header(&self) -> Result<(VhostUserMsgHeader<R>, Vec<File>)> {
149        let mut hdr_raw = [0u32; 3];
150        let files = self.recv_into_bufs_all(&mut [hdr_raw.as_mut_bytes()])?;
151        let hdr = VhostUserMsgHeader::from_raw(hdr_raw);
152        Ok((hdr, files))
153    }
154
155    /// Receive the body following the header `hdr`.
156    pub fn recv_body_bytes(&self, hdr: &VhostUserMsgHeader<R>) -> Result<(Vec<u8>, Vec<File>)> {
157        // NOTE: `recv_into_bufs_all` is a noop when the buffer is empty, so `hdr.get_size() == 0`
158        // works as expected.
159        let mut body = vec![0; hdr.get_size().try_into().unwrap()];
160        let files = self.recv_into_bufs_all(&mut [&mut body[..]])?;
161        Ok((body, files))
162    }
163
164    /// Receive a message header and body.
165    ///
166    /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
167    /// accepted and all other file descriptor will be discard silently.
168    pub fn recv_message<T: IntoBytes + FromBytes>(
169        &self,
170    ) -> Result<(VhostUserMsgHeader<R>, T, Vec<File>)> {
171        let mut hdr_raw = [0u32; 3];
172        let mut body = T::new_zeroed();
173        let mut slices = [hdr_raw.as_mut_bytes(), body.as_mut_bytes()];
174        let files = self.recv_into_bufs_all(&mut slices)?;
175
176        let hdr = VhostUserMsgHeader::from_raw(hdr_raw);
177
178        Ok((hdr, body, files))
179    }
180
181    /// Receive a message header and body, where the body includes a variable length payload at the
182    /// end.
183    ///
184    /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be accepted and all
185    /// other file descriptor will be discard silently.
186    pub fn recv_message_with_payload<T: IntoBytes + FromBytes>(
187        &self,
188    ) -> Result<(VhostUserMsgHeader<R>, T, Vec<u8>, Vec<File>, Vec<File>)> {
189        let (hdr, files) = self.recv_header()?;
190
191        let mut body = T::new_zeroed();
192        let payload_size = hdr.get_size() as usize - mem::size_of::<T>();
193        let mut buf: Vec<u8> = vec![0; payload_size];
194        let mut slices = [body.as_mut_bytes(), buf.as_mut_bytes()];
195        let more_files = self.recv_into_bufs_all(&mut slices)?;
196
197        Ok((hdr, body, buf, files, more_files))
198    }
199}
200
201impl<R: Req> AsRawDescriptor for Connection<R> {
202    fn as_raw_descriptor(&self) -> RawDescriptor {
203        self.0.as_raw_descriptor()
204    }
205}
206
207#[cfg(test)]
208pub(crate) mod tests {
209    use std::io::Read;
210    use std::io::Seek;
211    use std::io::SeekFrom;
212    use std::io::Write;
213
214    use tempfile::tempfile;
215
216    use super::*;
217    use crate::message::VhostUserEmptyMessage;
218    use crate::message::VhostUserU64;
219
220    #[test]
221    fn send_header_only() {
222        let (client_connection, server_connection) = Connection::pair().unwrap();
223        let hdr1 = VhostUserMsgHeader::new(FrontendReq::GET_FEATURES, 0, 0);
224        client_connection
225            .send_header_only_message(&hdr1, None)
226            .unwrap();
227        let (hdr2, _, files) = server_connection
228            .recv_message::<VhostUserEmptyMessage>()
229            .unwrap();
230        assert_eq!(hdr1, hdr2);
231        assert!(files.is_empty());
232    }
233
234    #[test]
235    fn send_data() {
236        let (client_connection, server_connection) = Connection::pair().unwrap();
237        let hdr1 = VhostUserMsgHeader::new(FrontendReq::SET_FEATURES, 0, 8);
238        client_connection
239            .send_message(&hdr1, &VhostUserU64::new(0xf00dbeefdeadf00d), None)
240            .unwrap();
241        let (hdr2, body, files) = server_connection.recv_message::<VhostUserU64>().unwrap();
242        assert_eq!(hdr1, hdr2);
243        let value = body.value;
244        assert_eq!(value, 0xf00dbeefdeadf00d);
245        assert!(files.is_empty());
246    }
247
248    #[test]
249    fn send_fd() {
250        let (client_connection, server_connection) = Connection::pair().unwrap();
251
252        let mut fd = tempfile().unwrap();
253        write!(fd, "test").unwrap();
254
255        // Normal case for sending/receiving file descriptors
256        let hdr1 = VhostUserMsgHeader::new(FrontendReq::SET_MEM_TABLE, 0, 0);
257        client_connection
258            .send_header_only_message(&hdr1, Some(&[fd.as_raw_descriptor()]))
259            .unwrap();
260
261        let (hdr2, _, files) = server_connection
262            .recv_message::<VhostUserEmptyMessage>()
263            .unwrap();
264        assert_eq!(hdr1, hdr2);
265        assert_eq!(files.len(), 1);
266        let mut file = &files[0];
267        let mut content = String::new();
268        file.seek(SeekFrom::Start(0)).unwrap();
269        file.read_to_string(&mut content).unwrap();
270        assert_eq!(content, "test");
271    }
272
273    #[test]
274    fn test_advance_slices_mut() {
275        // Test case from https://doc.rust-lang.org/std/io/struct.IoSliceMut.html#method.advance_slices
276        let mut buf1 = [1; 8];
277        let mut buf2 = [2; 16];
278        let mut buf3 = [3; 8];
279        let mut bufs = &mut [&mut buf1[..], &mut buf2[..], &mut buf3[..]][..];
280        advance_slices_mut(&mut bufs, 10);
281        assert_eq!(bufs[0], [2; 14].as_ref());
282        assert_eq!(bufs[1], [3; 8].as_ref());
283    }
284}