1use std::fs::File;
34use std::io::Error as IOError;
35use std::num::TryFromIntError;
36
37use remain::sorted;
38use thiserror::Error as ThisError;
39
40mod backend;
41pub use backend::*;
42
43pub mod message;
44pub use message::VHOST_USER_F_PROTOCOL_FEATURES;
45
46pub mod connection;
47
48mod sys;
49pub use connection::Connection;
50pub use message::BackendReq;
51pub use message::FrontendReq;
52#[cfg(unix)]
53pub use sys::unix;
54
55pub(crate) mod backend_client;
56pub use backend_client::BackendClient;
57mod frontend_server;
58pub use self::frontend_server::Frontend;
59mod backend_server;
60mod frontend_client;
61pub use self::backend_server::Backend;
62pub use self::backend_server::BackendServer;
63pub use self::frontend_client::FrontendClient;
64pub use self::frontend_server::FrontendServer;
65
66#[sorted]
68#[derive(Debug, ThisError)]
69pub enum Error {
70 #[error("backend internal error")]
72 BackendInternalError,
73 #[error("client exited properly")]
75 ClientExit,
76 #[error("client closed the connection")]
79 Disconnect,
80 #[error("Failed to enter suspended state")]
81 EnterSuspendedState(anyhow::Error),
82 #[error("frontend Internal error")]
84 FrontendInternalError,
85 #[error("wrong number of attached fds")]
87 IncorrectFds,
88 #[error("invalid cast to int: {0}")]
90 InvalidCastToInt(TryFromIntError),
91 #[error("invalid message")]
93 InvalidMessage,
94 #[error("invalid operation")]
96 InvalidOperation,
97 #[error("invalid parameters: {0}")]
99 InvalidParam(&'static str),
100 #[error("oversized message")]
102 OversizedMsg,
103 #[error("partial message")]
105 PartialMessage,
106 #[error("buffer for recv was too small, data was dropped: got size {got}, needed {want}")]
108 RecvBufferTooSmall {
109 got: usize,
111 want: usize,
113 },
114 #[error("handler failed to handle request: {0}")]
116 ReqHandlerError(#[source] IOError),
117 #[error("Failed to restore")]
119 RestoreError(anyhow::Error),
120 #[error("Failed to snapshot")]
122 SnapshotError(anyhow::Error),
123 #[error("socket error: {0}")]
125 SocketError(std::io::Error),
126 #[error("temporary socket error: {0}")]
128 SocketRetry(std::io::Error),
129 #[error("failed to read/write on Tube: {0}")]
131 TubeError(base::TubeError),
132}
133
134pub type Result<T> = std::result::Result<T, Error>;
136
137pub type HandlerResult<T> = std::result::Result<T, IOError>;
139
140#[derive(Copy, Clone)]
141pub struct SharedMemoryRegion {
142 pub id: u8,
145 pub length: u64,
146}
147
148pub(crate) fn into_single_file(mut files: Vec<File>) -> Option<File> {
151 if files.len() != 1 {
152 return None;
153 }
154 Some(files.swap_remove(0))
155}
156
157#[cfg(test)]
158mod test_backend;
159
160#[cfg(test)]
161mod tests {
162 use std::io::ErrorKind;
163 use std::sync::Arc;
164 use std::sync::Barrier;
165 use std::thread;
166
167 use base::AsRawDescriptor;
168 use tempfile::tempfile;
169
170 use super::*;
171 use crate::message::*;
172 use crate::test_backend::TestBackend;
173 use crate::test_backend::VIRTIO_FEATURES;
174 use crate::VhostUserMemoryRegionInfo;
175 use crate::VringConfigData;
176
177 fn create_client_server_pair<S>(backend: S) -> (BackendClient, BackendServer<S>)
178 where
179 S: Backend,
180 {
181 let (client_connection, server_connection) = Connection::pair().unwrap();
182 let backend_client = BackendClient::new(client_connection);
183 (
184 backend_client,
185 BackendServer::<S>::new(server_connection, backend),
186 )
187 }
188
189 fn handle_request(h: &mut BackendServer<TestBackend>) -> Result<()> {
191 let (hdr, files) = h.recv_header()?;
194 h.process_message(hdr, files)
195 }
196
197 #[test]
198 fn create_test_backend() {
199 let mut backend = TestBackend::new();
200
201 backend.set_owner().unwrap();
202 assert!(backend.set_owner().is_err());
203 }
204
205 #[test]
206 fn test_set_owner() {
207 let test_backend = TestBackend::new();
208 let (backend_client, mut backend_server) = create_client_server_pair(test_backend);
209
210 assert!(!backend_server.as_ref().owned);
211 backend_client.set_owner().unwrap();
212 handle_request(&mut backend_server).unwrap();
213 assert!(backend_server.as_ref().owned);
214 backend_client.set_owner().unwrap();
215 assert!(handle_request(&mut backend_server).is_err());
216 assert!(backend_server.as_ref().owned);
217 }
218
219 #[test]
220 fn test_set_features() {
221 let mbar = Arc::new(Barrier::new(2));
222 let sbar = mbar.clone();
223 let test_backend = TestBackend::new();
224 let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
225
226 thread::spawn(move || {
227 handle_request(&mut backend_server).unwrap();
228 assert!(backend_server.as_ref().owned);
229
230 handle_request(&mut backend_server).unwrap();
231 handle_request(&mut backend_server).unwrap();
232 assert_eq!(
233 backend_server.as_ref().acked_features,
234 VIRTIO_FEATURES & !0x1
235 );
236
237 handle_request(&mut backend_server).unwrap();
238 handle_request(&mut backend_server).unwrap();
239 assert_eq!(
240 backend_server.as_ref().acked_protocol_features,
241 VhostUserProtocolFeatures::all().bits()
242 );
243
244 sbar.wait();
245 });
246
247 backend_client.set_owner().unwrap();
248
249 let features = backend_client.get_features().unwrap();
251 assert_eq!(features, VIRTIO_FEATURES);
252 backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
253
254 let features = backend_client.get_protocol_features().unwrap();
256 assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
257 backend_client.set_protocol_features(features).unwrap();
258
259 mbar.wait();
260 }
261
262 #[test]
263 fn test_client_server_process_no_need_reply() {
264 test_client_server_process(false);
265 }
266
267 #[test]
268 fn test_client_server_process_need_reply() {
269 test_client_server_process(true);
270 }
271
272 fn test_client_server_process(set_need_reply: bool) {
273 let mbar = Arc::new(Barrier::new(2));
274 let sbar = mbar.clone();
275 let test_backend = TestBackend::new();
276 let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
277
278 thread::spawn(move || {
279 handle_request(&mut backend_server).unwrap();
281 assert!(backend_server.as_ref().owned);
282
283 handle_request(&mut backend_server).unwrap();
285 handle_request(&mut backend_server).unwrap();
286 assert_eq!(
287 backend_server.as_ref().acked_features,
288 VIRTIO_FEATURES & !0x1
289 );
290
291 handle_request(&mut backend_server).unwrap();
292 handle_request(&mut backend_server).unwrap();
293 assert_eq!(
294 backend_server.as_ref().acked_protocol_features,
295 VhostUserProtocolFeatures::all().bits()
296 );
297
298 handle_request(&mut backend_server).unwrap();
300 handle_request(&mut backend_server).unwrap();
302
303 handle_request(&mut backend_server).unwrap();
305
306 handle_request(&mut backend_server).unwrap();
308
309 handle_request(&mut backend_server).unwrap();
311 handle_request(&mut backend_server).unwrap();
312
313 handle_request(&mut backend_server).unwrap();
315
316 handle_request(&mut backend_server).unwrap();
318
319 handle_request(&mut backend_server).unwrap();
321 handle_request(&mut backend_server).unwrap();
322 handle_request(&mut backend_server).unwrap();
323 handle_request(&mut backend_server).unwrap();
324 handle_request(&mut backend_server).unwrap();
325 handle_request(&mut backend_server).unwrap();
326
327 handle_request(&mut backend_server).unwrap();
329
330 handle_request(&mut backend_server).unwrap();
332
333 handle_request(&mut backend_server).unwrap();
335
336 handle_request(&mut backend_server).unwrap_err();
342
343 std::mem::drop(backend_server);
344 sbar.wait();
345 });
346
347 backend_client.set_owner().unwrap();
348
349 let features = backend_client.get_features().unwrap();
351 assert_eq!(features, VIRTIO_FEATURES);
352 backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
353
354 let features = backend_client.get_protocol_features().unwrap();
356 assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
357 backend_client.set_protocol_features(features).unwrap();
358
359 backend_client.set_need_reply(set_need_reply);
360
361 let (inflight_info, inflight_file) = backend_client
363 .get_inflight_fd(&VhostUserInflight {
364 num_queues: 2,
365 queue_size: 256,
366 ..Default::default()
367 })
368 .unwrap();
369 backend_client
371 .set_inflight_fd(&inflight_info, inflight_file.as_raw_descriptor())
372 .unwrap();
373
374 let num = backend_client.get_queue_num().unwrap();
375 assert_eq!(num, 2);
376
377 let event = base::Event::new().unwrap();
378 let mem = [VhostUserMemoryRegionInfo {
379 guest_phys_addr: 0,
380 memory_size: 0x10_0000,
381 userspace_addr: 0,
382 mmap_offset: 0,
383 mmap_handle: event.as_raw_descriptor(),
384 }];
385 backend_client.set_mem_table(&mem).unwrap();
386
387 backend_client
388 .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8])
389 .unwrap();
390 let buf = [0x0u8; 4];
391 let (reply_body, reply_payload) = backend_client
392 .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
393 .unwrap();
394 let offset = reply_body.offset;
395 assert_eq!(offset, 0x100);
396 assert_eq!(reply_payload[0], 0xa5);
397
398 #[cfg(windows)]
399 let tubes = base::Tube::pair().unwrap();
400 #[cfg(windows)]
401 let descriptor =
402 unsafe { tube_transporter::packed_tube::pack(tubes.0, std::process::id()).unwrap() };
405
406 #[cfg(unix)]
407 let descriptor = base::Event::new().unwrap();
408
409 backend_client.set_backend_req_fd(&descriptor).unwrap();
410 backend_client.set_vring_enable(0, true).unwrap();
411
412 backend_client.set_vring_num(0, 256).unwrap();
413 backend_client.set_vring_base(0, 0).unwrap();
414 let config = VringConfigData {
415 queue_size: 128,
416 flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
417 desc_table_addr: 0x1000,
418 used_ring_addr: 0x2000,
419 avail_ring_addr: 0x3000,
420 log_addr: Some(0x4000),
421 };
422 backend_client.set_vring_addr(0, &config).unwrap();
423 backend_client.set_vring_call(0, &event).unwrap();
424 backend_client.set_vring_kick(0, &event).unwrap();
425 backend_client.set_vring_err(0, &event).unwrap();
426
427 let max_mem_slots = backend_client.get_max_mem_slots().unwrap();
428 assert_eq!(max_mem_slots, 32);
429
430 let region_file = tempfile().unwrap();
431 let region = VhostUserMemoryRegionInfo {
432 guest_phys_addr: 0x10_0000,
433 memory_size: 0x10_0000,
434 userspace_addr: 0,
435 mmap_offset: 0,
436 mmap_handle: region_file.as_raw_descriptor(),
437 };
438 backend_client.add_mem_region(®ion).unwrap();
439
440 backend_client.remove_mem_region(®ion).unwrap();
441
442 let result = backend_client.set_log_base(0, Some(event.as_raw_descriptor()));
444 if set_need_reply {
445 assert!(
447 matches!(result, Err(Error::Disconnect)),
448 "unexpected result: {result:?}"
449 );
450 } else {
451 result.unwrap();
454 let result = backend_client.get_features();
455 match &result {
456 Err(Error::Disconnect) => {}
458 Err(Error::SocketError(e))
459 if e.kind() == ErrorKind::ConnectionReset
460 || e.kind() == ErrorKind::BrokenPipe => {}
461 _ => panic!("unexpected result: {result:?}"),
462 }
463 }
464
465 mbar.wait();
466 }
467
468 #[test]
469 fn test_error_display() {
470 assert_eq!(
471 format!("{}", Error::InvalidParam("")),
472 "invalid parameters: "
473 );
474 assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation");
475 }
476}