1use std::fs::File;
34use std::io::Error as IOError;
35use std::num::TryFromIntError;
36
37use remain::sorted;
38use thiserror::Error as ThisError;
39
40mod backend;
41pub use backend::*;
42
43pub mod message;
44pub use message::VHOST_USER_F_PROTOCOL_FEATURES;
45
46pub mod connection;
47
48mod sys;
49pub use connection::Connection;
50pub use message::BackendReq;
51pub use message::FrontendReq;
52#[cfg(unix)]
53pub use sys::unix;
54
55pub(crate) mod backend_client;
56pub use backend_client::BackendClient;
57mod frontend_server;
58pub use self::frontend_server::Frontend;
59mod backend_server;
60mod frontend_client;
61pub use self::backend_server::Backend;
62pub use self::backend_server::BackendServer;
63pub use self::frontend_client::FrontendClient;
64pub use self::frontend_server::FrontendServer;
65
66#[sorted]
68#[derive(Debug, ThisError)]
69pub enum Error {
70 #[error("backend internal error")]
72 BackendInternalError,
73 #[error("client exited properly")]
75 ClientExit,
76 #[error("client closed the connection")]
79 Disconnect,
80 #[error("Failed to enter suspended state")]
81 EnterSuspendedState(anyhow::Error),
82 #[error("frontend Internal error")]
84 FrontendInternalError,
85 #[error("wrong number of attached fds")]
87 IncorrectFds,
88 #[error("invalid cast to int: {0}")]
90 InvalidCastToInt(TryFromIntError),
91 #[error("invalid message")]
93 InvalidMessage,
94 #[error("invalid operation")]
96 InvalidOperation,
97 #[error("invalid parameters: {0}")]
99 InvalidParam(&'static str),
100 #[error("oversized message")]
102 OversizedMsg,
103 #[error("partial message")]
105 PartialMessage,
106 #[error("buffer for recv was too small, data was dropped: got size {got}, needed {want}")]
108 RecvBufferTooSmall {
109 got: usize,
111 want: usize,
113 },
114 #[error("handler failed to handle request: {0}")]
116 ReqHandlerError(#[source] IOError),
117 #[error("Failed to restore")]
119 RestoreError(anyhow::Error),
120 #[error("Failed to snapshot")]
122 SnapshotError(anyhow::Error),
123 #[error("socket error: {0}")]
125 SocketError(std::io::Error),
126 #[error("temporary socket error: {0}")]
128 SocketRetry(std::io::Error),
129 #[error("failed to read/write on Tube: {0}")]
131 TubeError(base::TubeError),
132}
133
134pub type Result<T> = std::result::Result<T, Error>;
136
137pub type HandlerResult<T> = std::result::Result<T, IOError>;
139
140pub(crate) fn into_single_file(mut files: Vec<File>) -> Option<File> {
143 if files.len() != 1 {
144 return None;
145 }
146 Some(files.swap_remove(0))
147}
148
149#[cfg(test)]
150mod test_backend;
151
152#[cfg(test)]
153mod tests {
154 use std::sync::Arc;
155 use std::sync::Barrier;
156 use std::thread;
157
158 use base::AsRawDescriptor;
159 use tempfile::tempfile;
160
161 use super::*;
162 use crate::message::*;
163 use crate::test_backend::TestBackend;
164 use crate::test_backend::VIRTIO_FEATURES;
165 use crate::VhostUserMemoryRegionInfo;
166 use crate::VringConfigData;
167
168 fn create_client_server_pair<S>(backend: S) -> (BackendClient, BackendServer<S>)
169 where
170 S: Backend,
171 {
172 let (client_connection, server_connection) = Connection::pair().unwrap();
173 let backend_client = BackendClient::new(client_connection);
174 (
175 backend_client,
176 BackendServer::<S>::new(server_connection, backend),
177 )
178 }
179
180 fn handle_request(h: &mut BackendServer<TestBackend>) -> Result<()> {
182 let (hdr, files) = h.recv_header()?;
185 h.process_message(hdr, files)
186 }
187
188 #[test]
189 fn create_test_backend() {
190 let mut backend = TestBackend::new();
191
192 backend.set_owner().unwrap();
193 assert!(backend.set_owner().is_err());
194 }
195
196 #[test]
197 fn test_set_owner() {
198 let test_backend = TestBackend::new();
199 let (backend_client, mut backend_server) = create_client_server_pair(test_backend);
200
201 assert!(!backend_server.as_ref().owned);
202 backend_client.set_owner().unwrap();
203 handle_request(&mut backend_server).unwrap();
204 assert!(backend_server.as_ref().owned);
205 backend_client.set_owner().unwrap();
206 assert!(handle_request(&mut backend_server).is_err());
207 assert!(backend_server.as_ref().owned);
208 }
209
210 #[test]
211 fn test_set_features() {
212 let mbar = Arc::new(Barrier::new(2));
213 let sbar = mbar.clone();
214 let test_backend = TestBackend::new();
215 let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
216
217 thread::spawn(move || {
218 handle_request(&mut backend_server).unwrap();
219 assert!(backend_server.as_ref().owned);
220
221 handle_request(&mut backend_server).unwrap();
222 handle_request(&mut backend_server).unwrap();
223 assert_eq!(
224 backend_server.as_ref().acked_features,
225 VIRTIO_FEATURES & !0x1
226 );
227
228 handle_request(&mut backend_server).unwrap();
229 handle_request(&mut backend_server).unwrap();
230 assert_eq!(
231 backend_server.as_ref().acked_protocol_features,
232 VhostUserProtocolFeatures::all().bits()
233 );
234
235 sbar.wait();
236 });
237
238 backend_client.set_owner().unwrap();
239
240 let features = backend_client.get_features().unwrap();
242 assert_eq!(features, VIRTIO_FEATURES);
243 backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
244
245 let features = backend_client.get_protocol_features().unwrap();
247 assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
248 backend_client.set_protocol_features(features).unwrap();
249
250 mbar.wait();
251 }
252
253 #[test]
254 fn test_client_server_process() {
255 let mbar = Arc::new(Barrier::new(2));
256 let sbar = mbar.clone();
257 let test_backend = TestBackend::new();
258 let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
259
260 thread::spawn(move || {
261 handle_request(&mut backend_server).unwrap();
263 assert!(backend_server.as_ref().owned);
264
265 handle_request(&mut backend_server).unwrap();
267 handle_request(&mut backend_server).unwrap();
268 assert_eq!(
269 backend_server.as_ref().acked_features,
270 VIRTIO_FEATURES & !0x1
271 );
272
273 handle_request(&mut backend_server).unwrap();
274 handle_request(&mut backend_server).unwrap();
275 assert_eq!(
276 backend_server.as_ref().acked_protocol_features,
277 VhostUserProtocolFeatures::all().bits()
278 );
279
280 handle_request(&mut backend_server).unwrap();
282 handle_request(&mut backend_server).unwrap();
284
285 handle_request(&mut backend_server).unwrap();
287
288 handle_request(&mut backend_server).unwrap();
290
291 handle_request(&mut backend_server).unwrap();
293 handle_request(&mut backend_server).unwrap();
294
295 handle_request(&mut backend_server).unwrap();
297
298 handle_request(&mut backend_server).unwrap();
300
301 handle_request(&mut backend_server).unwrap_err();
303 handle_request(&mut backend_server).unwrap_err();
304
305 handle_request(&mut backend_server).unwrap();
307 handle_request(&mut backend_server).unwrap();
308 handle_request(&mut backend_server).unwrap();
309 handle_request(&mut backend_server).unwrap();
310 handle_request(&mut backend_server).unwrap();
311 handle_request(&mut backend_server).unwrap();
312
313 handle_request(&mut backend_server).unwrap();
315
316 handle_request(&mut backend_server).unwrap();
318
319 handle_request(&mut backend_server).unwrap();
321
322 sbar.wait();
323 });
324
325 backend_client.set_owner().unwrap();
326
327 let features = backend_client.get_features().unwrap();
329 assert_eq!(features, VIRTIO_FEATURES);
330 backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
331
332 let features = backend_client.get_protocol_features().unwrap();
334 assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
335 backend_client.set_protocol_features(features).unwrap();
336
337 let (inflight_info, inflight_file) = backend_client
339 .get_inflight_fd(&VhostUserInflight {
340 num_queues: 2,
341 queue_size: 256,
342 ..Default::default()
343 })
344 .unwrap();
345 backend_client
347 .set_inflight_fd(&inflight_info, inflight_file.as_raw_descriptor())
348 .unwrap();
349
350 let num = backend_client.get_queue_num().unwrap();
351 assert_eq!(num, 2);
352
353 let event = base::Event::new().unwrap();
354 let mem = [VhostUserMemoryRegionInfo {
355 guest_phys_addr: 0,
356 memory_size: 0x10_0000,
357 userspace_addr: 0,
358 mmap_offset: 0,
359 mmap_handle: event.as_raw_descriptor(),
360 }];
361 backend_client.set_mem_table(&mem).unwrap();
362
363 backend_client
364 .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8])
365 .unwrap();
366 let buf = [0x0u8; 4];
367 let (reply_body, reply_payload) = backend_client
368 .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
369 .unwrap();
370 let offset = reply_body.offset;
371 assert_eq!(offset, 0x100);
372 assert_eq!(reply_payload[0], 0xa5);
373
374 #[cfg(windows)]
375 let tubes = base::Tube::pair().unwrap();
376 #[cfg(windows)]
377 let descriptor =
378 unsafe { tube_transporter::packed_tube::pack(tubes.0, std::process::id()).unwrap() };
381
382 #[cfg(unix)]
383 let descriptor = base::Event::new().unwrap();
384
385 backend_client.set_backend_req_fd(&descriptor).unwrap();
386 backend_client.set_vring_enable(0, true).unwrap();
387
388 backend_client
390 .set_log_base(0, Some(event.as_raw_descriptor()))
391 .unwrap();
392 backend_client
393 .set_log_fd(event.as_raw_descriptor())
394 .unwrap();
395
396 backend_client.set_vring_num(0, 256).unwrap();
397 backend_client.set_vring_base(0, 0).unwrap();
398 let config = VringConfigData {
399 queue_size: 128,
400 flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
401 desc_table_addr: 0x1000,
402 used_ring_addr: 0x2000,
403 avail_ring_addr: 0x3000,
404 log_addr: Some(0x4000),
405 };
406 backend_client.set_vring_addr(0, &config).unwrap();
407 backend_client.set_vring_call(0, &event).unwrap();
408 backend_client.set_vring_kick(0, &event).unwrap();
409 backend_client.set_vring_err(0, &event).unwrap();
410
411 let max_mem_slots = backend_client.get_max_mem_slots().unwrap();
412 assert_eq!(max_mem_slots, 32);
413
414 let region_file = tempfile().unwrap();
415 let region = VhostUserMemoryRegionInfo {
416 guest_phys_addr: 0x10_0000,
417 memory_size: 0x10_0000,
418 userspace_addr: 0,
419 mmap_offset: 0,
420 mmap_handle: region_file.as_raw_descriptor(),
421 };
422 backend_client.add_mem_region(®ion).unwrap();
423
424 backend_client.remove_mem_region(®ion).unwrap();
425
426 mbar.wait();
427 }
428
429 #[test]
430 fn test_error_display() {
431 assert_eq!(
432 format!("{}", Error::InvalidParam("")),
433 "invalid parameters: "
434 );
435 assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation");
436 }
437}