1use std::fs::File;
5use std::mem;
6
7use base::AsRawDescriptor;
8use base::RawDescriptor;
9use base::SafeDescriptor;
10use zerocopy::FromBytes;
11use zerocopy::Immutable;
12use zerocopy::IntoBytes;
13use zerocopy::Ref;
14
15use crate::into_single_file;
16use crate::message::*;
17use crate::Connection;
18use crate::Error;
19use crate::FrontendReq;
20use crate::Result;
21use crate::SharedMemoryRegion;
22
23#[allow(missing_docs)]
27pub trait Backend {
28 fn set_owner(&mut self) -> Result<()>;
29 fn reset_owner(&mut self) -> Result<()>;
30 fn get_features(&mut self) -> Result<u64>;
31 fn set_features(&mut self, features: u64) -> Result<()>;
32 fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
33 fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
34 fn set_vring_addr(
35 &mut self,
36 index: u32,
37 flags: VhostUserVringAddrFlags,
38 descriptor: u64,
39 used: u64,
40 available: u64,
41 log: u64,
42 ) -> Result<()>;
43 fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
45 fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
47 fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
48 fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
49 fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
50
51 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
52 fn set_protocol_features(&mut self, features: u64) -> Result<()>;
53 fn get_queue_num(&mut self) -> Result<u64>;
54 fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
55 fn get_config(
56 &mut self,
57 offset: u32,
58 size: u32,
59 flags: VhostUserConfigFlags,
60 ) -> Result<Vec<u8>>;
61 fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
62 fn set_backend_req_fd(&mut self, _vu_req: Connection) {}
63 fn get_inflight_fd(
64 &mut self,
65 inflight: &VhostUserInflight,
66 ) -> Result<(VhostUserInflight, File)>;
67 fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
68 fn get_max_mem_slots(&mut self) -> Result<u64>;
69 fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
70 fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
71 fn set_device_state_fd(
72 &mut self,
73 transfer_direction: VhostUserTransferDirection,
74 migration_phase: VhostUserMigrationPhase,
75 fd: File,
76 ) -> Result<Option<File>>;
77 fn check_device_state(&mut self) -> Result<()>;
78 fn get_shmem_config(&mut self) -> Result<Vec<SharedMemoryRegion>>;
79}
80
81impl<T> Backend for T
82where
83 T: AsMut<dyn Backend>,
84{
85 fn set_owner(&mut self) -> Result<()> {
86 self.as_mut().set_owner()
87 }
88
89 fn reset_owner(&mut self) -> Result<()> {
90 self.as_mut().reset_owner()
91 }
92
93 fn get_features(&mut self) -> Result<u64> {
94 self.as_mut().get_features()
95 }
96
97 fn set_features(&mut self, features: u64) -> Result<()> {
98 self.as_mut().set_features(features)
99 }
100
101 fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
102 self.as_mut().set_mem_table(ctx, files)
103 }
104
105 fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
106 self.as_mut().set_vring_num(index, num)
107 }
108
109 fn set_vring_addr(
110 &mut self,
111 index: u32,
112 flags: VhostUserVringAddrFlags,
113 descriptor: u64,
114 used: u64,
115 available: u64,
116 log: u64,
117 ) -> Result<()> {
118 self.as_mut()
119 .set_vring_addr(index, flags, descriptor, used, available, log)
120 }
121
122 fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> {
123 self.as_mut().set_vring_base(index, base)
124 }
125
126 fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> {
127 self.as_mut().get_vring_base(index)
128 }
129
130 fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> {
131 self.as_mut().set_vring_kick(index, fd)
132 }
133
134 fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> {
135 self.as_mut().set_vring_call(index, fd)
136 }
137
138 fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> {
139 self.as_mut().set_vring_err(index, fd)
140 }
141
142 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
143 self.as_mut().get_protocol_features()
144 }
145
146 fn set_protocol_features(&mut self, features: u64) -> Result<()> {
147 self.as_mut().set_protocol_features(features)
148 }
149
150 fn get_queue_num(&mut self) -> Result<u64> {
151 self.as_mut().get_queue_num()
152 }
153
154 fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
155 self.as_mut().set_vring_enable(index, enable)
156 }
157
158 fn get_config(
159 &mut self,
160 offset: u32,
161 size: u32,
162 flags: VhostUserConfigFlags,
163 ) -> Result<Vec<u8>> {
164 self.as_mut().get_config(offset, size, flags)
165 }
166
167 fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
168 self.as_mut().set_config(offset, buf, flags)
169 }
170
171 fn set_backend_req_fd(&mut self, vu_req: Connection) {
172 self.as_mut().set_backend_req_fd(vu_req)
173 }
174
175 fn get_inflight_fd(
176 &mut self,
177 inflight: &VhostUserInflight,
178 ) -> Result<(VhostUserInflight, File)> {
179 self.as_mut().get_inflight_fd(inflight)
180 }
181
182 fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()> {
183 self.as_mut().set_inflight_fd(inflight, file)
184 }
185
186 fn get_max_mem_slots(&mut self) -> Result<u64> {
187 self.as_mut().get_max_mem_slots()
188 }
189
190 fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
191 self.as_mut().add_mem_region(region, fd)
192 }
193
194 fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
195 self.as_mut().remove_mem_region(region)
196 }
197
198 fn set_device_state_fd(
199 &mut self,
200 transfer_direction: VhostUserTransferDirection,
201 migration_phase: VhostUserMigrationPhase,
202 fd: File,
203 ) -> Result<Option<File>> {
204 self.as_mut()
205 .set_device_state_fd(transfer_direction, migration_phase, fd)
206 }
207
208 fn check_device_state(&mut self) -> Result<()> {
209 self.as_mut().check_device_state()
210 }
211
212 fn get_shmem_config(&mut self) -> Result<Vec<SharedMemoryRegion>> {
213 self.as_mut().get_shmem_config()
214 }
215}
216
217pub struct BackendServer<S: Backend> {
219 connection: Connection,
221 backend: S,
223
224 virtio_features: u64,
225 acked_virtio_features: u64,
226 protocol_features: VhostUserProtocolFeatures,
227 acked_protocol_features: u64,
228
229 reply_ack_enabled: bool,
231}
232
233impl<S: Backend> AsRef<S> for BackendServer<S> {
234 fn as_ref(&self) -> &S {
235 &self.backend
236 }
237}
238
239impl<S: Backend> BackendServer<S> {
240 pub fn new(connection: Connection, backend: S) -> Self {
241 BackendServer {
242 connection,
243 backend,
244 virtio_features: 0,
245 acked_virtio_features: 0,
246 protocol_features: VhostUserProtocolFeatures::empty(),
247 acked_protocol_features: 0,
248 reply_ack_enabled: false,
249 }
250 }
251
252 pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader, Vec<File>)> {
300 let (hdr, files) = match self.connection.recv_header() {
310 Ok((hdr, files)) => (hdr, files),
311 Err(Error::Disconnect) => {
312 return Err(Error::ClientExit);
315 }
316 Err(e) => {
317 return Err(e);
318 }
319 };
320
321 if !hdr.is_valid() {
322 return Err(Error::InvalidMessage);
323 }
324
325 self.check_attached_files(&hdr, &files)?;
326
327 Ok((hdr, files))
328 }
329
330 pub fn needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader) -> bool {
335 hdr.get_size() != 0
338 }
339
340 pub fn process_message(&mut self, hdr: VhostUserMsgHeader, files: Vec<File>) -> Result<()> {
353 let (buf, extra_files) = self.connection.recv_body_bytes(&hdr)?;
354 let size = buf.len();
355 if !extra_files.is_empty() {
356 return Err(Error::InvalidMessage);
357 }
358
359 match hdr.get_code() {
363 Ok(FrontendReq::SET_OWNER) => {
364 self.check_request_size(&hdr, size, 0)?;
365 let res = self.backend.set_owner();
366 self.send_ack_message(&hdr, res.is_ok())?;
367 res?;
368 }
369 Ok(FrontendReq::RESET_OWNER) => {
370 self.check_request_size(&hdr, size, 0)?;
371 let res = self.backend.reset_owner();
372 self.send_ack_message(&hdr, res.is_ok())?;
373 res?;
374 }
375 Ok(FrontendReq::GET_FEATURES) => {
376 self.check_request_size(&hdr, size, 0)?;
377 let mut features = self.backend.get_features()?;
378
379 features &= !(1 << VIRTIO_F_RING_PACKED);
383
384 let msg = VhostUserU64::new(features);
385 self.send_reply_message(&hdr, &msg)?;
386 self.virtio_features = features;
387 self.update_reply_ack_flag();
388 }
389 Ok(FrontendReq::SET_FEATURES) => {
390 let mut msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
391
392 msg.value &= !(1 << VIRTIO_F_RING_PACKED);
396
397 let res = self.backend.set_features(msg.value);
398 self.acked_virtio_features = msg.value;
399 self.update_reply_ack_flag();
400 self.send_ack_message(&hdr, res.is_ok())?;
401 res?;
402 }
403 Ok(FrontendReq::SET_MEM_TABLE) => {
404 let res = self.set_mem_table(&hdr, size, &buf, files);
405 self.send_ack_message(&hdr, res.is_ok())?;
406 res?;
407 }
408 Ok(FrontendReq::SET_VRING_NUM) => {
409 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
410 let res = self.backend.set_vring_num(msg.index, msg.num);
411 self.send_ack_message(&hdr, res.is_ok())?;
412 res?;
413 }
414 Ok(FrontendReq::SET_VRING_ADDR) => {
415 let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
416 let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
417 Some(val) => val,
418 None => return Err(Error::InvalidMessage),
419 };
420 let res = self.backend.set_vring_addr(
421 msg.index,
422 flags,
423 msg.descriptor,
424 msg.used,
425 msg.available,
426 msg.log,
427 );
428 self.send_ack_message(&hdr, res.is_ok())?;
429 res?;
430 }
431 Ok(FrontendReq::SET_VRING_BASE) => {
432 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
433 let res = self.backend.set_vring_base(msg.index, msg.num);
434 self.send_ack_message(&hdr, res.is_ok())?;
435 res?;
436 }
437 Ok(FrontendReq::GET_VRING_BASE) => {
438 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
439 let reply = self.backend.get_vring_base(msg.index)?;
440 self.send_reply_message(&hdr, &reply)?;
441 }
442 Ok(FrontendReq::SET_VRING_CALL) => {
443 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
444 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
445 let res = self.backend.set_vring_call(index, file);
446 self.send_ack_message(&hdr, res.is_ok())?;
447 res?;
448 }
449 Ok(FrontendReq::SET_VRING_KICK) => {
450 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
451 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
452 let res = self.backend.set_vring_kick(index, file);
453 self.send_ack_message(&hdr, res.is_ok())?;
454 res?;
455 }
456 Ok(FrontendReq::SET_VRING_ERR) => {
457 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
458 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
459 let res = self.backend.set_vring_err(index, file);
460 self.send_ack_message(&hdr, res.is_ok())?;
461 res?;
462 }
463 Ok(FrontendReq::GET_PROTOCOL_FEATURES) => {
464 self.check_request_size(&hdr, size, 0)?;
465 let features = self.backend.get_protocol_features()?;
466 let msg = VhostUserU64::new(features.bits());
467 self.send_reply_message(&hdr, &msg)?;
468 self.protocol_features = features;
469 self.update_reply_ack_flag();
470 }
471 Ok(FrontendReq::SET_PROTOCOL_FEATURES) => {
472 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
473 let res = self.backend.set_protocol_features(msg.value);
474 self.acked_protocol_features = msg.value;
475 self.update_reply_ack_flag();
476 self.send_ack_message(&hdr, res.is_ok())?;
477 res?;
478 }
479 Ok(FrontendReq::GET_QUEUE_NUM) => {
480 if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
481 return Err(Error::InvalidOperation);
482 }
483 self.check_request_size(&hdr, size, 0)?;
484 let num = self.backend.get_queue_num()?;
485 let msg = VhostUserU64::new(num);
486 self.send_reply_message(&hdr, &msg)?;
487 }
488 Ok(FrontendReq::SET_VRING_ENABLE) => {
489 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
490 if self.acked_virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
491 return Err(Error::InvalidOperation);
492 }
493 let enable = match msg.num {
494 1 => true,
495 0 => false,
496 _ => {
497 return Err(Error::InvalidParam(
498 "SET_VRING_ENABLE: num out of range (must be [0, 1])",
499 ))
500 }
501 };
502
503 let res = self.backend.set_vring_enable(msg.index, enable);
504 self.send_ack_message(&hdr, res.is_ok())?;
505 res?;
506 }
507 Ok(FrontendReq::GET_CONFIG) => {
508 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
509 return Err(Error::InvalidOperation);
510 }
511 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
512 self.get_config(&hdr, &buf)?;
513 }
514 Ok(FrontendReq::SET_CONFIG) => {
515 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
516 return Err(Error::InvalidOperation);
517 }
518 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
519 let res = self.set_config(&buf);
520 self.send_ack_message(&hdr, res.is_ok())?;
521 res?;
522 }
523 Ok(FrontendReq::SET_BACKEND_REQ_FD) => {
524 if self.acked_protocol_features & VhostUserProtocolFeatures::BACKEND_REQ.bits() == 0
525 {
526 return Err(Error::InvalidOperation);
527 }
528 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
529 let res = self.set_backend_req_fd(files);
530 self.send_ack_message(&hdr, res.is_ok())?;
531 res?;
532 }
533 Ok(FrontendReq::GET_INFLIGHT_FD) => {
534 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
535 == 0
536 {
537 return Err(Error::InvalidOperation);
538 }
539
540 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
541 let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
542 let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?;
543 self.connection.send_message(
544 &reply_hdr,
545 &inflight,
546 Some(&[file.as_raw_descriptor()]),
547 )?;
548 }
549 Ok(FrontendReq::SET_INFLIGHT_FD) => {
550 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
551 == 0
552 {
553 return Err(Error::InvalidOperation);
554 }
555 let file = into_single_file(files).ok_or(Error::IncorrectFds)?;
556 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
557 let res = self.backend.set_inflight_fd(&msg, file);
558 self.send_ack_message(&hdr, res.is_ok())?;
559 res?;
560 }
561 Ok(FrontendReq::GET_MAX_MEM_SLOTS) => {
562 if self.acked_protocol_features
563 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
564 == 0
565 {
566 return Err(Error::InvalidOperation);
567 }
568 self.check_request_size(&hdr, size, 0)?;
569 let num = self.backend.get_max_mem_slots()?;
570 let msg = VhostUserU64::new(num);
571 self.send_reply_message(&hdr, &msg)?;
572 }
573 Ok(FrontendReq::ADD_MEM_REG) => {
574 if self.acked_protocol_features
575 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
576 == 0
577 {
578 return Err(Error::InvalidOperation);
579 }
580 let file = into_single_file(files).ok_or(Error::InvalidParam(
581 "ADD_MEM_REG: exactly one file must be provided",
582 ))?;
583 let msg =
584 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
585 let res = self.backend.add_mem_region(&msg, file);
586 self.send_ack_message(&hdr, res.is_ok())?;
587 res?;
588 }
589 Ok(FrontendReq::REM_MEM_REG) => {
590 if self.acked_protocol_features
591 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
592 == 0
593 {
594 return Err(Error::InvalidOperation);
595 }
596
597 let msg =
598 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
599 let res = self.backend.remove_mem_region(&msg);
600 self.send_ack_message(&hdr, res.is_ok())?;
601 res?;
602 }
603 Ok(FrontendReq::SET_DEVICE_STATE_FD) => {
604 if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
605 == 0
606 {
607 return Err(Error::InvalidOperation);
608 }
609 let msg =
611 self.extract_request_body::<DeviceStateTransferParameters>(&hdr, size, &buf)?;
612 let transfer_direction = match msg.transfer_direction {
613 0 => VhostUserTransferDirection::Save,
614 1 => VhostUserTransferDirection::Load,
615 _ => return Err(Error::InvalidMessage),
616 };
617 let migration_phase = match msg.migration_phase {
618 0 => VhostUserMigrationPhase::Stopped,
619 _ => return Err(Error::InvalidMessage),
620 };
621 let res = self.backend.set_device_state_fd(
623 transfer_direction,
624 migration_phase,
625 files.into_iter().next().ok_or(Error::IncorrectFds)?,
626 );
627 let (msg, fds) = match &res {
629 Ok(None) => (VhostUserU64::new(0x100), None),
630 Ok(Some(file)) => (VhostUserU64::new(0), Some(file.as_raw_descriptor())),
631 Err(_) => (VhostUserU64::new(0x101), None),
633 };
634 let reply_hdr: VhostUserMsgHeader =
635 self.new_reply_header::<VhostUserU64>(&hdr, 0)?;
636 self.connection.send_message(
637 &reply_hdr,
638 &msg,
639 fds.as_ref().map(std::slice::from_ref),
640 )?;
641 res?;
642 }
643 Ok(FrontendReq::CHECK_DEVICE_STATE) => {
644 if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
645 == 0
646 {
647 return Err(Error::InvalidOperation);
648 }
649 let res = self.backend.check_device_state();
650 let msg = VhostUserU64::new(if res.is_ok() { 0 } else { 1 });
651 self.send_reply_message(&hdr, &msg)?;
652 res?;
653 }
654 Ok(FrontendReq::GET_SHMEM_CONFIG) => {
655 let msg = VhostUserShMemConfig::new(&self.backend.get_shmem_config()?);
656 self.send_reply_message(&hdr, &msg)?;
657 }
658 _ => {
659 return Err(Error::InvalidMessage);
660 }
661 }
662 Ok(())
663 }
664
665 fn new_reply_header<T: Sized>(
666 &self,
667 req: &VhostUserMsgHeader,
668 payload_size: usize,
669 ) -> Result<VhostUserMsgHeader> {
670 Ok(VhostUserMsgHeader::new_reply_header(
671 req.get_code::<FrontendReq>()
672 .map_err(|_| Error::InvalidMessage)?,
673 (mem::size_of::<T>()
674 .checked_add(payload_size)
675 .ok_or(Error::OversizedMsg)?)
676 .try_into()
677 .map_err(Error::InvalidCastToInt)?,
678 ))
679 }
680
681 fn send_ack_message(&mut self, req: &VhostUserMsgHeader, success: bool) -> Result<()> {
683 if self.reply_ack_enabled && req.is_need_reply() {
684 let hdr: VhostUserMsgHeader = self.new_reply_header::<VhostUserU64>(req, 0)?;
685 let val = if success { 0 } else { 1 };
686 let msg = VhostUserU64::new(val);
687 self.connection.send_message(&hdr, &msg, None)?;
688 }
689 Ok(())
690 }
691
692 fn send_reply_message<T: IntoBytes + Immutable>(
693 &mut self,
694 req: &VhostUserMsgHeader,
695 msg: &T,
696 ) -> Result<()> {
697 let hdr = self.new_reply_header::<T>(req, 0)?;
698 self.connection.send_message(&hdr, msg, None)?;
699 Ok(())
700 }
701
702 fn send_reply_with_payload<T: IntoBytes + Immutable>(
703 &mut self,
704 req: &VhostUserMsgHeader,
705 msg: &T,
706 payload: &[u8],
707 ) -> Result<()> {
708 let hdr = self.new_reply_header::<T>(req, payload.len())?;
709 self.connection
710 .send_message_with_payload(&hdr, msg, payload, None)?;
711 Ok(())
712 }
713
714 fn set_mem_table(
715 &mut self,
716 hdr: &VhostUserMsgHeader,
717 size: usize,
718 buf: &[u8],
719 files: Vec<File>,
720 ) -> Result<()> {
721 self.check_request_size(hdr, size, hdr.get_size() as usize)?;
722
723 let (msg, regions) =
724 Ref::<_, VhostUserMemory>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
725 if !msg.is_valid() {
726 return Err(Error::InvalidMessage);
727 }
728
729 if files.len() != msg.num_regions as usize {
731 return Err(Error::InvalidMessage);
732 }
733
734 let (regions, excess) = Ref::<_, [VhostUserMemoryRegion]>::from_prefix_with_elems(
735 regions,
736 msg.num_regions as usize,
737 )
738 .map_err(|_| Error::InvalidMessage)?;
739 if !excess.is_empty() {
740 return Err(Error::InvalidMessage);
741 }
742
743 for region in regions.iter() {
745 if !region.is_valid() {
746 return Err(Error::InvalidMessage);
747 }
748 }
749
750 self.backend.set_mem_table(®ions, files)
751 }
752
753 fn get_config(&mut self, hdr: &VhostUserMsgHeader, buf: &[u8]) -> Result<()> {
754 let (msg, payload) =
755 Ref::<_, VhostUserConfig>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
756 if !msg.is_valid() {
757 return Err(Error::InvalidMessage);
758 }
759 if payload.len() != msg.size as usize {
760 return Err(Error::InvalidMessage);
761 }
762 let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
763 Some(val) => val,
764 None => return Err(Error::InvalidMessage),
765 };
766 let res = self.backend.get_config(msg.offset, msg.size, flags);
767
768 match res {
771 Ok(ref buf) if buf.len() == msg.size as usize => {
772 let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
773 self.send_reply_with_payload(hdr, &reply, buf.as_slice())?;
774 }
775 Ok(_) => {
776 let reply = VhostUserConfig::new(msg.offset, 0, flags);
777 self.send_reply_message(hdr, &reply)?;
778 }
779 Err(_) => {
780 let reply = VhostUserConfig::new(msg.offset, 0, flags);
781 self.send_reply_message(hdr, &reply)?;
782 }
783 }
784 Ok(())
785 }
786
787 fn set_config(&mut self, buf: &[u8]) -> Result<()> {
788 let (msg, payload) =
789 Ref::<_, VhostUserConfig>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
790 if !msg.is_valid() {
791 return Err(Error::InvalidMessage);
792 }
793 if payload.len() != msg.size as usize {
794 return Err(Error::InvalidMessage);
795 }
796 let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) {
797 Some(val) => val,
798 None => return Err(Error::InvalidMessage),
799 };
800
801 self.backend.set_config(msg.offset, payload, flags)
802 }
803
804 fn set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()> {
805 let file = into_single_file(files).ok_or(Error::InvalidMessage)?;
806 let fd: SafeDescriptor = file.into();
807 let connection = Connection::try_from(fd).map_err(|_| Error::InvalidMessage)?;
808 self.backend.set_backend_req_fd(connection);
809 Ok(())
810 }
811
812 fn handle_vring_fd_request(
815 &mut self,
816 buf: &[u8],
817 files: Vec<File>,
818 ) -> Result<(u8, Option<File>)> {
819 let (msg, _) = VhostUserU64::read_from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
820 if !msg.is_valid() {
821 return Err(Error::InvalidMessage);
822 }
823
824 let has_fd = (msg.value & 0x100u64) == 0;
831
832 let file = into_single_file(files);
833
834 if has_fd && file.is_none() || !has_fd && file.is_some() {
835 return Err(Error::InvalidMessage);
836 }
837
838 Ok((msg.value as u8, file))
839 }
840
841 fn check_request_size(
842 &self,
843 hdr: &VhostUserMsgHeader,
844 size: usize,
845 expected: usize,
846 ) -> Result<()> {
847 if hdr.get_size() as usize != expected
848 || hdr.is_reply()
849 || hdr.get_version() != 0x1
850 || size != expected
851 {
852 return Err(Error::InvalidMessage);
853 }
854 Ok(())
855 }
856
857 fn check_attached_files(&self, hdr: &VhostUserMsgHeader, files: &[File]) -> Result<()> {
858 match hdr.get_code() {
859 Ok(FrontendReq::SET_MEM_TABLE)
860 | Ok(FrontendReq::SET_VRING_CALL)
861 | Ok(FrontendReq::SET_VRING_KICK)
862 | Ok(FrontendReq::SET_VRING_ERR)
863 | Ok(FrontendReq::SET_LOG_BASE)
864 | Ok(FrontendReq::SET_LOG_FD)
865 | Ok(FrontendReq::SET_BACKEND_REQ_FD)
866 | Ok(FrontendReq::SET_INFLIGHT_FD)
867 | Ok(FrontendReq::ADD_MEM_REG)
868 | Ok(FrontendReq::SET_DEVICE_STATE_FD) => Ok(()),
869 Err(_) => Err(Error::InvalidMessage),
870 _ if !files.is_empty() => Err(Error::InvalidMessage),
871 _ => Ok(()),
872 }
873 }
874
875 fn extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>(
876 &self,
877 hdr: &VhostUserMsgHeader,
878 size: usize,
879 buf: &[u8],
880 ) -> Result<T> {
881 self.check_request_size(hdr, size, mem::size_of::<T>())?;
882 let (body, _) = T::read_from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
883 if body.is_valid() {
884 Ok(body)
885 } else {
886 Err(Error::InvalidMessage)
887 }
888 }
889
890 fn update_reply_ack_flag(&mut self) {
891 let pflag = VhostUserProtocolFeatures::REPLY_ACK;
892 self.reply_ack_enabled = (self.virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES) != 0
893 && self.protocol_features.contains(pflag)
894 && (self.acked_protocol_features & pflag.bits()) != 0;
895 }
896}
897
898impl<S: Backend> AsRawDescriptor for BackendServer<S> {
899 fn as_raw_descriptor(&self) -> RawDescriptor {
900 self.connection.as_raw_descriptor()
902 }
903}
904
905#[cfg(test)]
906mod tests {
907 use base::INVALID_DESCRIPTOR;
908
909 use super::*;
910 use crate::test_backend::TestBackend;
911 use crate::Connection;
912
913 #[test]
914 fn test_backend_server_new() {
915 let (p1, _p2) = Connection::pair().unwrap();
916 let backend = TestBackend::new();
917 let handler = BackendServer::new(p1, backend);
918
919 assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR);
920 }
921}