1use std::fs::File;
5use std::mem;
6
7use base::AsRawDescriptor;
8use base::RawDescriptor;
9use base::SafeDescriptor;
10use zerocopy::FromBytes;
11use zerocopy::Immutable;
12use zerocopy::IntoBytes;
13use zerocopy::Ref;
14
15use crate::into_single_file;
16use crate::message::*;
17use crate::BackendReq;
18use crate::Connection;
19use crate::Error;
20use crate::FrontendReq;
21use crate::Result;
22
23#[allow(missing_docs)]
27pub trait Backend {
28 fn set_owner(&mut self) -> Result<()>;
29 fn reset_owner(&mut self) -> Result<()>;
30 fn get_features(&mut self) -> Result<u64>;
31 fn set_features(&mut self, features: u64) -> Result<()>;
32 fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
33 fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
34 fn set_vring_addr(
35 &mut self,
36 index: u32,
37 flags: VhostUserVringAddrFlags,
38 descriptor: u64,
39 used: u64,
40 available: u64,
41 log: u64,
42 ) -> Result<()>;
43 fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
45 fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
47 fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
48 fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
49 fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
50
51 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
52 fn set_protocol_features(&mut self, features: u64) -> Result<()>;
53 fn get_queue_num(&mut self) -> Result<u64>;
54 fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
55 fn get_config(
56 &mut self,
57 offset: u32,
58 size: u32,
59 flags: VhostUserConfigFlags,
60 ) -> Result<Vec<u8>>;
61 fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
62 fn set_backend_req_fd(&mut self, _vu_req: Connection<BackendReq>) {}
63 fn get_inflight_fd(
64 &mut self,
65 inflight: &VhostUserInflight,
66 ) -> Result<(VhostUserInflight, File)>;
67 fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
68 fn get_max_mem_slots(&mut self) -> Result<u64>;
69 fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
70 fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
71 fn set_device_state_fd(
72 &mut self,
73 transfer_direction: VhostUserTransferDirection,
74 migration_phase: VhostUserMigrationPhase,
75 fd: File,
76 ) -> Result<Option<File>>;
77 fn check_device_state(&mut self) -> Result<()>;
78 fn get_shmem_config(&mut self) -> Result<(VhostUserShMemConfigHeader, Vec<u64>)>;
79}
80
81impl<T> Backend for T
82where
83 T: AsMut<dyn Backend>,
84{
85 fn set_owner(&mut self) -> Result<()> {
86 self.as_mut().set_owner()
87 }
88
89 fn reset_owner(&mut self) -> Result<()> {
90 self.as_mut().reset_owner()
91 }
92
93 fn get_features(&mut self) -> Result<u64> {
94 self.as_mut().get_features()
95 }
96
97 fn set_features(&mut self, features: u64) -> Result<()> {
98 self.as_mut().set_features(features)
99 }
100
101 fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
102 self.as_mut().set_mem_table(ctx, files)
103 }
104
105 fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
106 self.as_mut().set_vring_num(index, num)
107 }
108
109 fn set_vring_addr(
110 &mut self,
111 index: u32,
112 flags: VhostUserVringAddrFlags,
113 descriptor: u64,
114 used: u64,
115 available: u64,
116 log: u64,
117 ) -> Result<()> {
118 self.as_mut()
119 .set_vring_addr(index, flags, descriptor, used, available, log)
120 }
121
122 fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> {
123 self.as_mut().set_vring_base(index, base)
124 }
125
126 fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> {
127 self.as_mut().get_vring_base(index)
128 }
129
130 fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> {
131 self.as_mut().set_vring_kick(index, fd)
132 }
133
134 fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> {
135 self.as_mut().set_vring_call(index, fd)
136 }
137
138 fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> {
139 self.as_mut().set_vring_err(index, fd)
140 }
141
142 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
143 self.as_mut().get_protocol_features()
144 }
145
146 fn set_protocol_features(&mut self, features: u64) -> Result<()> {
147 self.as_mut().set_protocol_features(features)
148 }
149
150 fn get_queue_num(&mut self) -> Result<u64> {
151 self.as_mut().get_queue_num()
152 }
153
154 fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
155 self.as_mut().set_vring_enable(index, enable)
156 }
157
158 fn get_config(
159 &mut self,
160 offset: u32,
161 size: u32,
162 flags: VhostUserConfigFlags,
163 ) -> Result<Vec<u8>> {
164 self.as_mut().get_config(offset, size, flags)
165 }
166
167 fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
168 self.as_mut().set_config(offset, buf, flags)
169 }
170
171 fn set_backend_req_fd(&mut self, vu_req: Connection<BackendReq>) {
172 self.as_mut().set_backend_req_fd(vu_req)
173 }
174
175 fn get_inflight_fd(
176 &mut self,
177 inflight: &VhostUserInflight,
178 ) -> Result<(VhostUserInflight, File)> {
179 self.as_mut().get_inflight_fd(inflight)
180 }
181
182 fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()> {
183 self.as_mut().set_inflight_fd(inflight, file)
184 }
185
186 fn get_max_mem_slots(&mut self) -> Result<u64> {
187 self.as_mut().get_max_mem_slots()
188 }
189
190 fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
191 self.as_mut().add_mem_region(region, fd)
192 }
193
194 fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
195 self.as_mut().remove_mem_region(region)
196 }
197
198 fn set_device_state_fd(
199 &mut self,
200 transfer_direction: VhostUserTransferDirection,
201 migration_phase: VhostUserMigrationPhase,
202 fd: File,
203 ) -> Result<Option<File>> {
204 self.as_mut()
205 .set_device_state_fd(transfer_direction, migration_phase, fd)
206 }
207
208 fn check_device_state(&mut self) -> Result<()> {
209 self.as_mut().check_device_state()
210 }
211
212 fn get_shmem_config(&mut self) -> Result<(VhostUserShMemConfigHeader, Vec<u64>)> {
213 self.as_mut().get_shmem_config()
214 }
215}
216
217pub struct BackendServer<S: Backend> {
219 connection: Connection<FrontendReq>,
221 backend: S,
223
224 virtio_features: u64,
225 acked_virtio_features: u64,
226 protocol_features: VhostUserProtocolFeatures,
227 acked_protocol_features: u64,
228
229 reply_ack_enabled: bool,
231}
232
233impl<S: Backend> AsRef<S> for BackendServer<S> {
234 fn as_ref(&self) -> &S {
235 &self.backend
236 }
237}
238
239impl<S: Backend> BackendServer<S> {
240 pub fn new(connection: Connection<FrontendReq>, backend: S) -> Self {
241 BackendServer {
242 connection,
243 backend,
244 virtio_features: 0,
245 acked_virtio_features: 0,
246 protocol_features: VhostUserProtocolFeatures::empty(),
247 acked_protocol_features: 0,
248 reply_ack_enabled: false,
249 }
250 }
251
252 pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<FrontendReq>, Vec<File>)> {
300 let (hdr, files) = match self.connection.recv_header() {
310 Ok((hdr, files)) => (hdr, files),
311 Err(Error::Disconnect) => {
312 return Err(Error::ClientExit);
315 }
316 Err(e) => {
317 return Err(e);
318 }
319 };
320
321 if !hdr.is_valid() {
322 return Err(Error::InvalidMessage);
323 }
324
325 self.check_attached_files(&hdr, &files)?;
326
327 Ok((hdr, files))
328 }
329
330 pub fn needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> bool {
335 hdr.get_size() != 0
338 }
339
340 pub fn process_message(
353 &mut self,
354 hdr: VhostUserMsgHeader<FrontendReq>,
355 files: Vec<File>,
356 ) -> Result<()> {
357 let (buf, extra_files) = self.connection.recv_body_bytes(&hdr)?;
358 let size = buf.len();
359 if !extra_files.is_empty() {
360 return Err(Error::InvalidMessage);
361 }
362
363 match hdr.get_code() {
367 Ok(FrontendReq::SET_OWNER) => {
368 self.check_request_size(&hdr, size, 0)?;
369 let res = self.backend.set_owner();
370 self.send_ack_message(&hdr, res.is_ok())?;
371 res?;
372 }
373 Ok(FrontendReq::RESET_OWNER) => {
374 self.check_request_size(&hdr, size, 0)?;
375 let res = self.backend.reset_owner();
376 self.send_ack_message(&hdr, res.is_ok())?;
377 res?;
378 }
379 Ok(FrontendReq::GET_FEATURES) => {
380 self.check_request_size(&hdr, size, 0)?;
381 let mut features = self.backend.get_features()?;
382
383 features &= !(1 << VIRTIO_F_RING_PACKED);
387
388 let msg = VhostUserU64::new(features);
389 self.send_reply_message(&hdr, &msg)?;
390 self.virtio_features = features;
391 self.update_reply_ack_flag();
392 }
393 Ok(FrontendReq::SET_FEATURES) => {
394 let mut msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
395
396 msg.value &= !(1 << VIRTIO_F_RING_PACKED);
400
401 let res = self.backend.set_features(msg.value);
402 self.acked_virtio_features = msg.value;
403 self.update_reply_ack_flag();
404 self.send_ack_message(&hdr, res.is_ok())?;
405 res?;
406 }
407 Ok(FrontendReq::SET_MEM_TABLE) => {
408 let res = self.set_mem_table(&hdr, size, &buf, files);
409 self.send_ack_message(&hdr, res.is_ok())?;
410 res?;
411 }
412 Ok(FrontendReq::SET_VRING_NUM) => {
413 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
414 let res = self.backend.set_vring_num(msg.index, msg.num);
415 self.send_ack_message(&hdr, res.is_ok())?;
416 res?;
417 }
418 Ok(FrontendReq::SET_VRING_ADDR) => {
419 let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
420 let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
421 Some(val) => val,
422 None => return Err(Error::InvalidMessage),
423 };
424 let res = self.backend.set_vring_addr(
425 msg.index,
426 flags,
427 msg.descriptor,
428 msg.used,
429 msg.available,
430 msg.log,
431 );
432 self.send_ack_message(&hdr, res.is_ok())?;
433 res?;
434 }
435 Ok(FrontendReq::SET_VRING_BASE) => {
436 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
437 let res = self.backend.set_vring_base(msg.index, msg.num);
438 self.send_ack_message(&hdr, res.is_ok())?;
439 res?;
440 }
441 Ok(FrontendReq::GET_VRING_BASE) => {
442 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
443 let reply = self.backend.get_vring_base(msg.index)?;
444 self.send_reply_message(&hdr, &reply)?;
445 }
446 Ok(FrontendReq::SET_VRING_CALL) => {
447 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
448 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
449 let res = self.backend.set_vring_call(index, file);
450 self.send_ack_message(&hdr, res.is_ok())?;
451 res?;
452 }
453 Ok(FrontendReq::SET_VRING_KICK) => {
454 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
455 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
456 let res = self.backend.set_vring_kick(index, file);
457 self.send_ack_message(&hdr, res.is_ok())?;
458 res?;
459 }
460 Ok(FrontendReq::SET_VRING_ERR) => {
461 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
462 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
463 let res = self.backend.set_vring_err(index, file);
464 self.send_ack_message(&hdr, res.is_ok())?;
465 res?;
466 }
467 Ok(FrontendReq::GET_PROTOCOL_FEATURES) => {
468 self.check_request_size(&hdr, size, 0)?;
469 let features = self.backend.get_protocol_features()?;
470 let msg = VhostUserU64::new(features.bits());
471 self.send_reply_message(&hdr, &msg)?;
472 self.protocol_features = features;
473 self.update_reply_ack_flag();
474 }
475 Ok(FrontendReq::SET_PROTOCOL_FEATURES) => {
476 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
477 let res = self.backend.set_protocol_features(msg.value);
478 self.acked_protocol_features = msg.value;
479 self.update_reply_ack_flag();
480 self.send_ack_message(&hdr, res.is_ok())?;
481 res?;
482 }
483 Ok(FrontendReq::GET_QUEUE_NUM) => {
484 if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
485 return Err(Error::InvalidOperation);
486 }
487 self.check_request_size(&hdr, size, 0)?;
488 let num = self.backend.get_queue_num()?;
489 let msg = VhostUserU64::new(num);
490 self.send_reply_message(&hdr, &msg)?;
491 }
492 Ok(FrontendReq::SET_VRING_ENABLE) => {
493 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
494 if self.acked_virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
495 return Err(Error::InvalidOperation);
496 }
497 let enable = match msg.num {
498 1 => true,
499 0 => false,
500 _ => {
501 return Err(Error::InvalidParam(
502 "SET_VRING_ENABLE: num out of range (must be [0, 1])",
503 ))
504 }
505 };
506
507 let res = self.backend.set_vring_enable(msg.index, enable);
508 self.send_ack_message(&hdr, res.is_ok())?;
509 res?;
510 }
511 Ok(FrontendReq::GET_CONFIG) => {
512 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
513 return Err(Error::InvalidOperation);
514 }
515 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
516 self.get_config(&hdr, &buf)?;
517 }
518 Ok(FrontendReq::SET_CONFIG) => {
519 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
520 return Err(Error::InvalidOperation);
521 }
522 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
523 let res = self.set_config(&buf);
524 self.send_ack_message(&hdr, res.is_ok())?;
525 res?;
526 }
527 Ok(FrontendReq::SET_BACKEND_REQ_FD) => {
528 if self.acked_protocol_features & VhostUserProtocolFeatures::BACKEND_REQ.bits() == 0
529 {
530 return Err(Error::InvalidOperation);
531 }
532 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
533 let res = self.set_backend_req_fd(files);
534 self.send_ack_message(&hdr, res.is_ok())?;
535 res?;
536 }
537 Ok(FrontendReq::GET_INFLIGHT_FD) => {
538 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
539 == 0
540 {
541 return Err(Error::InvalidOperation);
542 }
543
544 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
545 let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
546 let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?;
547 self.connection.send_message(
548 &reply_hdr,
549 &inflight,
550 Some(&[file.as_raw_descriptor()]),
551 )?;
552 }
553 Ok(FrontendReq::SET_INFLIGHT_FD) => {
554 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
555 == 0
556 {
557 return Err(Error::InvalidOperation);
558 }
559 let file = into_single_file(files).ok_or(Error::IncorrectFds)?;
560 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
561 let res = self.backend.set_inflight_fd(&msg, file);
562 self.send_ack_message(&hdr, res.is_ok())?;
563 res?;
564 }
565 Ok(FrontendReq::GET_MAX_MEM_SLOTS) => {
566 if self.acked_protocol_features
567 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
568 == 0
569 {
570 return Err(Error::InvalidOperation);
571 }
572 self.check_request_size(&hdr, size, 0)?;
573 let num = self.backend.get_max_mem_slots()?;
574 let msg = VhostUserU64::new(num);
575 self.send_reply_message(&hdr, &msg)?;
576 }
577 Ok(FrontendReq::ADD_MEM_REG) => {
578 if self.acked_protocol_features
579 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
580 == 0
581 {
582 return Err(Error::InvalidOperation);
583 }
584 let file = into_single_file(files).ok_or(Error::InvalidParam(
585 "ADD_MEM_REG: exactly one file must be provided",
586 ))?;
587 let msg =
588 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
589 let res = self.backend.add_mem_region(&msg, file);
590 self.send_ack_message(&hdr, res.is_ok())?;
591 res?;
592 }
593 Ok(FrontendReq::REM_MEM_REG) => {
594 if self.acked_protocol_features
595 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
596 == 0
597 {
598 return Err(Error::InvalidOperation);
599 }
600
601 let msg =
602 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
603 let res = self.backend.remove_mem_region(&msg);
604 self.send_ack_message(&hdr, res.is_ok())?;
605 res?;
606 }
607 Ok(FrontendReq::SET_DEVICE_STATE_FD) => {
608 if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
609 == 0
610 {
611 return Err(Error::InvalidOperation);
612 }
613 let msg =
615 self.extract_request_body::<DeviceStateTransferParameters>(&hdr, size, &buf)?;
616 let transfer_direction = match msg.transfer_direction {
617 0 => VhostUserTransferDirection::Save,
618 1 => VhostUserTransferDirection::Load,
619 _ => return Err(Error::InvalidMessage),
620 };
621 let migration_phase = match msg.migration_phase {
622 0 => VhostUserMigrationPhase::Stopped,
623 _ => return Err(Error::InvalidMessage),
624 };
625 let res = self.backend.set_device_state_fd(
627 transfer_direction,
628 migration_phase,
629 files.into_iter().next().ok_or(Error::IncorrectFds)?,
630 );
631 let (msg, fds) = match &res {
633 Ok(None) => (VhostUserU64::new(0x100), None),
634 Ok(Some(file)) => (VhostUserU64::new(0), Some(file.as_raw_descriptor())),
635 Err(_) => (VhostUserU64::new(0x101), None),
637 };
638 let reply_hdr: VhostUserMsgHeader<FrontendReq> =
639 self.new_reply_header::<VhostUserU64>(&hdr, 0)?;
640 self.connection.send_message(
641 &reply_hdr,
642 &msg,
643 fds.as_ref().map(std::slice::from_ref),
644 )?;
645 res?;
646 }
647 Ok(FrontendReq::CHECK_DEVICE_STATE) => {
648 if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
649 == 0
650 {
651 return Err(Error::InvalidOperation);
652 }
653 let res = self.backend.check_device_state();
654 let msg = VhostUserU64::new(if res.is_ok() { 0 } else { 1 });
655 self.send_reply_message(&hdr, &msg)?;
656 res?;
657 }
658 Ok(FrontendReq::GET_SHMEM_CONFIG) => {
659 let (msg, sizes) = self.backend.get_shmem_config()?;
660 let mut buf = Vec::new();
661 for e in sizes {
662 buf.extend_from_slice(e.as_bytes())
663 }
664 self.send_reply_with_payload(&hdr, &msg, buf.as_slice())?;
665 }
666 _ => {
667 return Err(Error::InvalidMessage);
668 }
669 }
670 Ok(())
671 }
672
673 fn new_reply_header<T: Sized>(
674 &self,
675 req: &VhostUserMsgHeader<FrontendReq>,
676 payload_size: usize,
677 ) -> Result<VhostUserMsgHeader<FrontendReq>> {
678 Ok(VhostUserMsgHeader::new(
679 req.get_code().map_err(|_| Error::InvalidMessage)?,
680 VhostUserHeaderFlag::REPLY.bits(),
681 (mem::size_of::<T>()
682 .checked_add(payload_size)
683 .ok_or(Error::OversizedMsg)?)
684 .try_into()
685 .map_err(Error::InvalidCastToInt)?,
686 ))
687 }
688
689 fn send_ack_message(
691 &mut self,
692 req: &VhostUserMsgHeader<FrontendReq>,
693 success: bool,
694 ) -> Result<()> {
695 if self.reply_ack_enabled && req.is_need_reply() {
696 let hdr: VhostUserMsgHeader<FrontendReq> =
697 self.new_reply_header::<VhostUserU64>(req, 0)?;
698 let val = if success { 0 } else { 1 };
699 let msg = VhostUserU64::new(val);
700 self.connection.send_message(&hdr, &msg, None)?;
701 }
702 Ok(())
703 }
704
705 fn send_reply_message<T: IntoBytes + Immutable>(
706 &mut self,
707 req: &VhostUserMsgHeader<FrontendReq>,
708 msg: &T,
709 ) -> Result<()> {
710 let hdr = self.new_reply_header::<T>(req, 0)?;
711 self.connection.send_message(&hdr, msg, None)?;
712 Ok(())
713 }
714
715 fn send_reply_with_payload<T: IntoBytes + Immutable>(
716 &mut self,
717 req: &VhostUserMsgHeader<FrontendReq>,
718 msg: &T,
719 payload: &[u8],
720 ) -> Result<()> {
721 let hdr = self.new_reply_header::<T>(req, payload.len())?;
722 self.connection
723 .send_message_with_payload(&hdr, msg, payload, None)?;
724 Ok(())
725 }
726
727 fn set_mem_table(
728 &mut self,
729 hdr: &VhostUserMsgHeader<FrontendReq>,
730 size: usize,
731 buf: &[u8],
732 files: Vec<File>,
733 ) -> Result<()> {
734 self.check_request_size(hdr, size, hdr.get_size() as usize)?;
735
736 let (msg, regions) =
737 Ref::<_, VhostUserMemory>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
738 if !msg.is_valid() {
739 return Err(Error::InvalidMessage);
740 }
741
742 if files.len() != msg.num_regions as usize {
744 return Err(Error::InvalidMessage);
745 }
746
747 let (regions, excess) = Ref::<_, [VhostUserMemoryRegion]>::from_prefix_with_elems(
748 regions,
749 msg.num_regions as usize,
750 )
751 .map_err(|_| Error::InvalidMessage)?;
752 if !excess.is_empty() {
753 return Err(Error::InvalidMessage);
754 }
755
756 for region in regions.iter() {
758 if !region.is_valid() {
759 return Err(Error::InvalidMessage);
760 }
761 }
762
763 self.backend.set_mem_table(®ions, files)
764 }
765
766 fn get_config(&mut self, hdr: &VhostUserMsgHeader<FrontendReq>, buf: &[u8]) -> Result<()> {
767 let (msg, payload) =
768 Ref::<_, VhostUserConfig>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
769 if !msg.is_valid() {
770 return Err(Error::InvalidMessage);
771 }
772 if payload.len() != msg.size as usize {
773 return Err(Error::InvalidMessage);
774 }
775 let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
776 Some(val) => val,
777 None => return Err(Error::InvalidMessage),
778 };
779 let res = self.backend.get_config(msg.offset, msg.size, flags);
780
781 match res {
784 Ok(ref buf) if buf.len() == msg.size as usize => {
785 let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
786 self.send_reply_with_payload(hdr, &reply, buf.as_slice())?;
787 }
788 Ok(_) => {
789 let reply = VhostUserConfig::new(msg.offset, 0, flags);
790 self.send_reply_message(hdr, &reply)?;
791 }
792 Err(_) => {
793 let reply = VhostUserConfig::new(msg.offset, 0, flags);
794 self.send_reply_message(hdr, &reply)?;
795 }
796 }
797 Ok(())
798 }
799
800 fn set_config(&mut self, buf: &[u8]) -> Result<()> {
801 let (msg, payload) =
802 Ref::<_, VhostUserConfig>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
803 if !msg.is_valid() {
804 return Err(Error::InvalidMessage);
805 }
806 if payload.len() != msg.size as usize {
807 return Err(Error::InvalidMessage);
808 }
809 let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) {
810 Some(val) => val,
811 None => return Err(Error::InvalidMessage),
812 };
813
814 self.backend.set_config(msg.offset, payload, flags)
815 }
816
817 fn set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()> {
818 let file = into_single_file(files).ok_or(Error::InvalidMessage)?;
819 let fd: SafeDescriptor = file.into();
820 let connection = Connection::try_from(fd).map_err(|_| Error::InvalidMessage)?;
821 self.backend.set_backend_req_fd(connection);
822 Ok(())
823 }
824
825 fn handle_vring_fd_request(
828 &mut self,
829 buf: &[u8],
830 files: Vec<File>,
831 ) -> Result<(u8, Option<File>)> {
832 let (msg, _) = VhostUserU64::read_from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
833 if !msg.is_valid() {
834 return Err(Error::InvalidMessage);
835 }
836
837 let has_fd = (msg.value & 0x100u64) == 0;
844
845 let file = into_single_file(files);
846
847 if has_fd && file.is_none() || !has_fd && file.is_some() {
848 return Err(Error::InvalidMessage);
849 }
850
851 Ok((msg.value as u8, file))
852 }
853
854 fn check_request_size(
855 &self,
856 hdr: &VhostUserMsgHeader<FrontendReq>,
857 size: usize,
858 expected: usize,
859 ) -> Result<()> {
860 if hdr.get_size() as usize != expected
861 || hdr.is_reply()
862 || hdr.get_version() != 0x1
863 || size != expected
864 {
865 return Err(Error::InvalidMessage);
866 }
867 Ok(())
868 }
869
870 fn check_attached_files(
871 &self,
872 hdr: &VhostUserMsgHeader<FrontendReq>,
873 files: &[File],
874 ) -> Result<()> {
875 match hdr.get_code() {
876 Ok(FrontendReq::SET_MEM_TABLE)
877 | Ok(FrontendReq::SET_VRING_CALL)
878 | Ok(FrontendReq::SET_VRING_KICK)
879 | Ok(FrontendReq::SET_VRING_ERR)
880 | Ok(FrontendReq::SET_LOG_BASE)
881 | Ok(FrontendReq::SET_LOG_FD)
882 | Ok(FrontendReq::SET_BACKEND_REQ_FD)
883 | Ok(FrontendReq::SET_INFLIGHT_FD)
884 | Ok(FrontendReq::ADD_MEM_REG)
885 | Ok(FrontendReq::SET_DEVICE_STATE_FD) => Ok(()),
886 Err(_) => Err(Error::InvalidMessage),
887 _ if !files.is_empty() => Err(Error::InvalidMessage),
888 _ => Ok(()),
889 }
890 }
891
892 fn extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>(
893 &self,
894 hdr: &VhostUserMsgHeader<FrontendReq>,
895 size: usize,
896 buf: &[u8],
897 ) -> Result<T> {
898 self.check_request_size(hdr, size, mem::size_of::<T>())?;
899 let (body, _) = T::read_from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
900 if body.is_valid() {
901 Ok(body)
902 } else {
903 Err(Error::InvalidMessage)
904 }
905 }
906
907 fn update_reply_ack_flag(&mut self) {
908 let pflag = VhostUserProtocolFeatures::REPLY_ACK;
909 self.reply_ack_enabled = (self.virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES) != 0
910 && self.protocol_features.contains(pflag)
911 && (self.acked_protocol_features & pflag.bits()) != 0;
912 }
913}
914
915impl<S: Backend> AsRawDescriptor for BackendServer<S> {
916 fn as_raw_descriptor(&self) -> RawDescriptor {
917 self.connection.as_raw_descriptor()
919 }
920}
921
922#[cfg(test)]
923mod tests {
924 use base::INVALID_DESCRIPTOR;
925
926 use super::*;
927 use crate::test_backend::TestBackend;
928 use crate::Connection;
929
930 #[test]
931 fn test_backend_server_new() {
932 let (p1, _p2) = Connection::pair().unwrap();
933 let backend = TestBackend::new();
934 let handler = BackendServer::new(p1, backend);
935
936 assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR);
937 }
938}