1use std::fs::File;
5use std::mem;
6
7use base::AsRawDescriptor;
8use base::RawDescriptor;
9use base::SafeDescriptor;
10use zerocopy::FromBytes;
11use zerocopy::Immutable;
12use zerocopy::IntoBytes;
13use zerocopy::Ref;
14
15use crate::into_single_file;
16use crate::message::*;
17use crate::BackendReq;
18use crate::Connection;
19use crate::Error;
20use crate::FrontendReq;
21use crate::Result;
22
23#[allow(missing_docs)]
27pub trait Backend {
28 fn set_owner(&mut self) -> Result<()>;
29 fn reset_owner(&mut self) -> Result<()>;
30 fn get_features(&mut self) -> Result<u64>;
31 fn set_features(&mut self, features: u64) -> Result<()>;
32 fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
33 fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
34 fn set_vring_addr(
35 &mut self,
36 index: u32,
37 flags: VhostUserVringAddrFlags,
38 descriptor: u64,
39 used: u64,
40 available: u64,
41 log: u64,
42 ) -> Result<()>;
43 fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
45 fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
47 fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
48 fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
49 fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
50
51 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
52 fn set_protocol_features(&mut self, features: u64) -> Result<()>;
53 fn get_queue_num(&mut self) -> Result<u64>;
54 fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
55 fn get_config(
56 &mut self,
57 offset: u32,
58 size: u32,
59 flags: VhostUserConfigFlags,
60 ) -> Result<Vec<u8>>;
61 fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
62 fn set_backend_req_fd(&mut self, _vu_req: Connection<BackendReq>) {}
63 fn get_inflight_fd(
64 &mut self,
65 inflight: &VhostUserInflight,
66 ) -> Result<(VhostUserInflight, File)>;
67 fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
68 fn get_max_mem_slots(&mut self) -> Result<u64>;
69 fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
70 fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
71 fn set_device_state_fd(
72 &mut self,
73 transfer_direction: VhostUserTransferDirection,
74 migration_phase: VhostUserMigrationPhase,
75 fd: File,
76 ) -> Result<Option<File>>;
77 fn check_device_state(&mut self) -> Result<()>;
78 fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>;
79}
80
81impl<T> Backend for T
82where
83 T: AsMut<dyn Backend>,
84{
85 fn set_owner(&mut self) -> Result<()> {
86 self.as_mut().set_owner()
87 }
88
89 fn reset_owner(&mut self) -> Result<()> {
90 self.as_mut().reset_owner()
91 }
92
93 fn get_features(&mut self) -> Result<u64> {
94 self.as_mut().get_features()
95 }
96
97 fn set_features(&mut self, features: u64) -> Result<()> {
98 self.as_mut().set_features(features)
99 }
100
101 fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
102 self.as_mut().set_mem_table(ctx, files)
103 }
104
105 fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
106 self.as_mut().set_vring_num(index, num)
107 }
108
109 fn set_vring_addr(
110 &mut self,
111 index: u32,
112 flags: VhostUserVringAddrFlags,
113 descriptor: u64,
114 used: u64,
115 available: u64,
116 log: u64,
117 ) -> Result<()> {
118 self.as_mut()
119 .set_vring_addr(index, flags, descriptor, used, available, log)
120 }
121
122 fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> {
123 self.as_mut().set_vring_base(index, base)
124 }
125
126 fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> {
127 self.as_mut().get_vring_base(index)
128 }
129
130 fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> {
131 self.as_mut().set_vring_kick(index, fd)
132 }
133
134 fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> {
135 self.as_mut().set_vring_call(index, fd)
136 }
137
138 fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> {
139 self.as_mut().set_vring_err(index, fd)
140 }
141
142 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
143 self.as_mut().get_protocol_features()
144 }
145
146 fn set_protocol_features(&mut self, features: u64) -> Result<()> {
147 self.as_mut().set_protocol_features(features)
148 }
149
150 fn get_queue_num(&mut self) -> Result<u64> {
151 self.as_mut().get_queue_num()
152 }
153
154 fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
155 self.as_mut().set_vring_enable(index, enable)
156 }
157
158 fn get_config(
159 &mut self,
160 offset: u32,
161 size: u32,
162 flags: VhostUserConfigFlags,
163 ) -> Result<Vec<u8>> {
164 self.as_mut().get_config(offset, size, flags)
165 }
166
167 fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
168 self.as_mut().set_config(offset, buf, flags)
169 }
170
171 fn set_backend_req_fd(&mut self, vu_req: Connection<BackendReq>) {
172 self.as_mut().set_backend_req_fd(vu_req)
173 }
174
175 fn get_inflight_fd(
176 &mut self,
177 inflight: &VhostUserInflight,
178 ) -> Result<(VhostUserInflight, File)> {
179 self.as_mut().get_inflight_fd(inflight)
180 }
181
182 fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()> {
183 self.as_mut().set_inflight_fd(inflight, file)
184 }
185
186 fn get_max_mem_slots(&mut self) -> Result<u64> {
187 self.as_mut().get_max_mem_slots()
188 }
189
190 fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
191 self.as_mut().add_mem_region(region, fd)
192 }
193
194 fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
195 self.as_mut().remove_mem_region(region)
196 }
197
198 fn set_device_state_fd(
199 &mut self,
200 transfer_direction: VhostUserTransferDirection,
201 migration_phase: VhostUserMigrationPhase,
202 fd: File,
203 ) -> Result<Option<File>> {
204 self.as_mut()
205 .set_device_state_fd(transfer_direction, migration_phase, fd)
206 }
207
208 fn check_device_state(&mut self) -> Result<()> {
209 self.as_mut().check_device_state()
210 }
211
212 fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>> {
213 self.as_mut().get_shared_memory_regions()
214 }
215}
216
217pub struct BackendServer<S: Backend> {
219 connection: Connection<FrontendReq>,
221 backend: S,
223
224 virtio_features: u64,
225 acked_virtio_features: u64,
226 protocol_features: VhostUserProtocolFeatures,
227 acked_protocol_features: u64,
228
229 reply_ack_enabled: bool,
231}
232
233impl<S: Backend> AsRef<S> for BackendServer<S> {
234 fn as_ref(&self) -> &S {
235 &self.backend
236 }
237}
238
239impl<S: Backend> BackendServer<S> {
240 pub fn new(connection: Connection<FrontendReq>, backend: S) -> Self {
241 BackendServer {
242 connection,
243 backend,
244 virtio_features: 0,
245 acked_virtio_features: 0,
246 protocol_features: VhostUserProtocolFeatures::empty(),
247 acked_protocol_features: 0,
248 reply_ack_enabled: false,
249 }
250 }
251
252 pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<FrontendReq>, Vec<File>)> {
300 let (hdr, files) = match self.connection.recv_header() {
310 Ok((hdr, files)) => (hdr, files),
311 Err(Error::Disconnect) => {
312 return Err(Error::ClientExit);
315 }
316 Err(e) => {
317 return Err(e);
318 }
319 };
320
321 if !hdr.is_valid() {
322 return Err(Error::InvalidMessage);
323 }
324
325 self.check_attached_files(&hdr, &files)?;
326
327 Ok((hdr, files))
328 }
329
330 pub fn needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> bool {
335 hdr.get_size() != 0
338 }
339
340 pub fn process_message(
353 &mut self,
354 hdr: VhostUserMsgHeader<FrontendReq>,
355 files: Vec<File>,
356 ) -> Result<()> {
357 let (buf, extra_files) = self.connection.recv_body_bytes(&hdr)?;
358 let size = buf.len();
359 if !extra_files.is_empty() {
360 return Err(Error::InvalidMessage);
361 }
362
363 match hdr.get_code() {
367 Ok(FrontendReq::SET_OWNER) => {
368 self.check_request_size(&hdr, size, 0)?;
369 let res = self.backend.set_owner();
370 self.send_ack_message(&hdr, res.is_ok())?;
371 res?;
372 }
373 Ok(FrontendReq::RESET_OWNER) => {
374 self.check_request_size(&hdr, size, 0)?;
375 let res = self.backend.reset_owner();
376 self.send_ack_message(&hdr, res.is_ok())?;
377 res?;
378 }
379 Ok(FrontendReq::GET_FEATURES) => {
380 self.check_request_size(&hdr, size, 0)?;
381 let mut features = self.backend.get_features()?;
382
383 features &= !(1 << VIRTIO_F_RING_PACKED);
387
388 let msg = VhostUserU64::new(features);
389 self.send_reply_message(&hdr, &msg)?;
390 self.virtio_features = features;
391 self.update_reply_ack_flag();
392 }
393 Ok(FrontendReq::SET_FEATURES) => {
394 let mut msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
395
396 msg.value &= !(1 << VIRTIO_F_RING_PACKED);
400
401 let res = self.backend.set_features(msg.value);
402 self.acked_virtio_features = msg.value;
403 self.update_reply_ack_flag();
404 self.send_ack_message(&hdr, res.is_ok())?;
405 res?;
406 }
407 Ok(FrontendReq::SET_MEM_TABLE) => {
408 let res = self.set_mem_table(&hdr, size, &buf, files);
409 self.send_ack_message(&hdr, res.is_ok())?;
410 res?;
411 }
412 Ok(FrontendReq::SET_VRING_NUM) => {
413 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
414 let res = self.backend.set_vring_num(msg.index, msg.num);
415 self.send_ack_message(&hdr, res.is_ok())?;
416 res?;
417 }
418 Ok(FrontendReq::SET_VRING_ADDR) => {
419 let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
420 let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
421 Some(val) => val,
422 None => return Err(Error::InvalidMessage),
423 };
424 let res = self.backend.set_vring_addr(
425 msg.index,
426 flags,
427 msg.descriptor,
428 msg.used,
429 msg.available,
430 msg.log,
431 );
432 self.send_ack_message(&hdr, res.is_ok())?;
433 res?;
434 }
435 Ok(FrontendReq::SET_VRING_BASE) => {
436 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
437 let res = self.backend.set_vring_base(msg.index, msg.num);
438 self.send_ack_message(&hdr, res.is_ok())?;
439 res?;
440 }
441 Ok(FrontendReq::GET_VRING_BASE) => {
442 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
443 let reply = self.backend.get_vring_base(msg.index)?;
444 self.send_reply_message(&hdr, &reply)?;
445 }
446 Ok(FrontendReq::SET_VRING_CALL) => {
447 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
448 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
449 let res = self.backend.set_vring_call(index, file);
450 self.send_ack_message(&hdr, res.is_ok())?;
451 res?;
452 }
453 Ok(FrontendReq::SET_VRING_KICK) => {
454 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
455 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
456 let res = self.backend.set_vring_kick(index, file);
457 self.send_ack_message(&hdr, res.is_ok())?;
458 res?;
459 }
460 Ok(FrontendReq::SET_VRING_ERR) => {
461 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
462 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
463 let res = self.backend.set_vring_err(index, file);
464 self.send_ack_message(&hdr, res.is_ok())?;
465 res?;
466 }
467 Ok(FrontendReq::GET_PROTOCOL_FEATURES) => {
468 self.check_request_size(&hdr, size, 0)?;
469 let features = self.backend.get_protocol_features()?;
470 let msg = VhostUserU64::new(features.bits());
471 self.send_reply_message(&hdr, &msg)?;
472 self.protocol_features = features;
473 self.update_reply_ack_flag();
474 }
475 Ok(FrontendReq::SET_PROTOCOL_FEATURES) => {
476 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
477 let res = self.backend.set_protocol_features(msg.value);
478 self.acked_protocol_features = msg.value;
479 self.update_reply_ack_flag();
480 self.send_ack_message(&hdr, res.is_ok())?;
481 res?;
482 }
483 Ok(FrontendReq::GET_QUEUE_NUM) => {
484 if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
485 return Err(Error::InvalidOperation);
486 }
487 self.check_request_size(&hdr, size, 0)?;
488 let num = self.backend.get_queue_num()?;
489 let msg = VhostUserU64::new(num);
490 self.send_reply_message(&hdr, &msg)?;
491 }
492 Ok(FrontendReq::SET_VRING_ENABLE) => {
493 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
494 if self.acked_virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
495 return Err(Error::InvalidOperation);
496 }
497 let enable = match msg.num {
498 1 => true,
499 0 => false,
500 _ => {
501 return Err(Error::InvalidParam(
502 "SET_VRING_ENABLE: num out of range (must be [0, 1])",
503 ))
504 }
505 };
506
507 let res = self.backend.set_vring_enable(msg.index, enable);
508 self.send_ack_message(&hdr, res.is_ok())?;
509 res?;
510 }
511 Ok(FrontendReq::GET_CONFIG) => {
512 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
513 return Err(Error::InvalidOperation);
514 }
515 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
516 self.get_config(&hdr, &buf)?;
517 }
518 Ok(FrontendReq::SET_CONFIG) => {
519 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
520 return Err(Error::InvalidOperation);
521 }
522 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
523 let res = self.set_config(&buf);
524 self.send_ack_message(&hdr, res.is_ok())?;
525 res?;
526 }
527 Ok(FrontendReq::SET_BACKEND_REQ_FD) => {
528 if self.acked_protocol_features & VhostUserProtocolFeatures::BACKEND_REQ.bits() == 0
529 {
530 return Err(Error::InvalidOperation);
531 }
532 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
533 let res = self.set_backend_req_fd(files);
534 self.send_ack_message(&hdr, res.is_ok())?;
535 res?;
536 }
537 Ok(FrontendReq::GET_INFLIGHT_FD) => {
538 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
539 == 0
540 {
541 return Err(Error::InvalidOperation);
542 }
543
544 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
545 let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
546 let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?;
547 self.connection.send_message(
548 &reply_hdr,
549 &inflight,
550 Some(&[file.as_raw_descriptor()]),
551 )?;
552 }
553 Ok(FrontendReq::SET_INFLIGHT_FD) => {
554 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
555 == 0
556 {
557 return Err(Error::InvalidOperation);
558 }
559 let file = into_single_file(files).ok_or(Error::IncorrectFds)?;
560 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
561 let res = self.backend.set_inflight_fd(&msg, file);
562 self.send_ack_message(&hdr, res.is_ok())?;
563 res?;
564 }
565 Ok(FrontendReq::GET_MAX_MEM_SLOTS) => {
566 if self.acked_protocol_features
567 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
568 == 0
569 {
570 return Err(Error::InvalidOperation);
571 }
572 self.check_request_size(&hdr, size, 0)?;
573 let num = self.backend.get_max_mem_slots()?;
574 let msg = VhostUserU64::new(num);
575 self.send_reply_message(&hdr, &msg)?;
576 }
577 Ok(FrontendReq::ADD_MEM_REG) => {
578 if self.acked_protocol_features
579 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
580 == 0
581 {
582 return Err(Error::InvalidOperation);
583 }
584 let file = into_single_file(files).ok_or(Error::InvalidParam(
585 "ADD_MEM_REG: exactly one file must be provided",
586 ))?;
587 let msg =
588 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
589 let res = self.backend.add_mem_region(&msg, file);
590 self.send_ack_message(&hdr, res.is_ok())?;
591 res?;
592 }
593 Ok(FrontendReq::REM_MEM_REG) => {
594 if self.acked_protocol_features
595 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
596 == 0
597 {
598 return Err(Error::InvalidOperation);
599 }
600
601 let msg =
602 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
603 let res = self.backend.remove_mem_region(&msg);
604 self.send_ack_message(&hdr, res.is_ok())?;
605 res?;
606 }
607 Ok(FrontendReq::SET_DEVICE_STATE_FD) => {
608 if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
609 == 0
610 {
611 return Err(Error::InvalidOperation);
612 }
613 let msg =
615 self.extract_request_body::<DeviceStateTransferParameters>(&hdr, size, &buf)?;
616 let transfer_direction = match msg.transfer_direction {
617 0 => VhostUserTransferDirection::Save,
618 1 => VhostUserTransferDirection::Load,
619 _ => return Err(Error::InvalidMessage),
620 };
621 let migration_phase = match msg.migration_phase {
622 0 => VhostUserMigrationPhase::Stopped,
623 _ => return Err(Error::InvalidMessage),
624 };
625 let res = self.backend.set_device_state_fd(
627 transfer_direction,
628 migration_phase,
629 files.into_iter().next().ok_or(Error::IncorrectFds)?,
630 );
631 let (msg, fds) = match &res {
633 Ok(None) => (VhostUserU64::new(0x100), None),
634 Ok(Some(file)) => (VhostUserU64::new(0), Some(file.as_raw_descriptor())),
635 Err(_) => (VhostUserU64::new(0x101), None),
637 };
638 let reply_hdr: VhostUserMsgHeader<FrontendReq> =
639 self.new_reply_header::<VhostUserU64>(&hdr, 0)?;
640 self.connection.send_message(
641 &reply_hdr,
642 &msg,
643 fds.as_ref().map(std::slice::from_ref),
644 )?;
645 res?;
646 }
647 Ok(FrontendReq::CHECK_DEVICE_STATE) => {
648 if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
649 == 0
650 {
651 return Err(Error::InvalidOperation);
652 }
653 let res = self.backend.check_device_state();
654 let msg = VhostUserU64::new(if res.is_ok() { 0 } else { 1 });
655 self.send_reply_message(&hdr, &msg)?;
656 res?;
657 }
658 Ok(FrontendReq::GET_SHARED_MEMORY_REGIONS) => {
659 let regions = self.backend.get_shared_memory_regions()?;
660 let mut buf = Vec::new();
661 let msg = VhostUserU64::new(regions.len() as u64);
662 for r in regions {
663 buf.extend_from_slice(r.as_bytes())
664 }
665 self.send_reply_with_payload(&hdr, &msg, buf.as_slice())?;
666 }
667 _ => {
668 return Err(Error::InvalidMessage);
669 }
670 }
671 Ok(())
672 }
673
674 fn new_reply_header<T: Sized>(
675 &self,
676 req: &VhostUserMsgHeader<FrontendReq>,
677 payload_size: usize,
678 ) -> Result<VhostUserMsgHeader<FrontendReq>> {
679 Ok(VhostUserMsgHeader::new(
680 req.get_code().map_err(|_| Error::InvalidMessage)?,
681 VhostUserHeaderFlag::REPLY.bits(),
682 (mem::size_of::<T>()
683 .checked_add(payload_size)
684 .ok_or(Error::OversizedMsg)?)
685 .try_into()
686 .map_err(Error::InvalidCastToInt)?,
687 ))
688 }
689
690 fn send_ack_message(
692 &mut self,
693 req: &VhostUserMsgHeader<FrontendReq>,
694 success: bool,
695 ) -> Result<()> {
696 if self.reply_ack_enabled && req.is_need_reply() {
697 let hdr: VhostUserMsgHeader<FrontendReq> =
698 self.new_reply_header::<VhostUserU64>(req, 0)?;
699 let val = if success { 0 } else { 1 };
700 let msg = VhostUserU64::new(val);
701 self.connection.send_message(&hdr, &msg, None)?;
702 }
703 Ok(())
704 }
705
706 fn send_reply_message<T: IntoBytes + Immutable>(
707 &mut self,
708 req: &VhostUserMsgHeader<FrontendReq>,
709 msg: &T,
710 ) -> Result<()> {
711 let hdr = self.new_reply_header::<T>(req, 0)?;
712 self.connection.send_message(&hdr, msg, None)?;
713 Ok(())
714 }
715
716 fn send_reply_with_payload<T: IntoBytes + Immutable>(
717 &mut self,
718 req: &VhostUserMsgHeader<FrontendReq>,
719 msg: &T,
720 payload: &[u8],
721 ) -> Result<()> {
722 let hdr = self.new_reply_header::<T>(req, payload.len())?;
723 self.connection
724 .send_message_with_payload(&hdr, msg, payload, None)?;
725 Ok(())
726 }
727
728 fn set_mem_table(
729 &mut self,
730 hdr: &VhostUserMsgHeader<FrontendReq>,
731 size: usize,
732 buf: &[u8],
733 files: Vec<File>,
734 ) -> Result<()> {
735 self.check_request_size(hdr, size, hdr.get_size() as usize)?;
736
737 let (msg, regions) =
738 Ref::<_, VhostUserMemory>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
739 if !msg.is_valid() {
740 return Err(Error::InvalidMessage);
741 }
742
743 if files.len() != msg.num_regions as usize {
745 return Err(Error::InvalidMessage);
746 }
747
748 let (regions, excess) = Ref::<_, [VhostUserMemoryRegion]>::from_prefix_with_elems(
749 regions,
750 msg.num_regions as usize,
751 )
752 .map_err(|_| Error::InvalidMessage)?;
753 if !excess.is_empty() {
754 return Err(Error::InvalidMessage);
755 }
756
757 for region in regions.iter() {
759 if !region.is_valid() {
760 return Err(Error::InvalidMessage);
761 }
762 }
763
764 self.backend.set_mem_table(®ions, files)
765 }
766
767 fn get_config(&mut self, hdr: &VhostUserMsgHeader<FrontendReq>, buf: &[u8]) -> Result<()> {
768 let (msg, payload) =
769 Ref::<_, VhostUserConfig>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
770 if !msg.is_valid() {
771 return Err(Error::InvalidMessage);
772 }
773 if payload.len() != msg.size as usize {
774 return Err(Error::InvalidMessage);
775 }
776 let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
777 Some(val) => val,
778 None => return Err(Error::InvalidMessage),
779 };
780 let res = self.backend.get_config(msg.offset, msg.size, flags);
781
782 match res {
785 Ok(ref buf) if buf.len() == msg.size as usize => {
786 let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
787 self.send_reply_with_payload(hdr, &reply, buf.as_slice())?;
788 }
789 Ok(_) => {
790 let reply = VhostUserConfig::new(msg.offset, 0, flags);
791 self.send_reply_message(hdr, &reply)?;
792 }
793 Err(_) => {
794 let reply = VhostUserConfig::new(msg.offset, 0, flags);
795 self.send_reply_message(hdr, &reply)?;
796 }
797 }
798 Ok(())
799 }
800
801 fn set_config(&mut self, buf: &[u8]) -> Result<()> {
802 let (msg, payload) =
803 Ref::<_, VhostUserConfig>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
804 if !msg.is_valid() {
805 return Err(Error::InvalidMessage);
806 }
807 if payload.len() != msg.size as usize {
808 return Err(Error::InvalidMessage);
809 }
810 let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) {
811 Some(val) => val,
812 None => return Err(Error::InvalidMessage),
813 };
814
815 self.backend.set_config(msg.offset, payload, flags)
816 }
817
818 fn set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()> {
819 let file = into_single_file(files).ok_or(Error::InvalidMessage)?;
820 let fd: SafeDescriptor = file.into();
821 let connection = Connection::try_from(fd).map_err(|_| Error::InvalidMessage)?;
822 self.backend.set_backend_req_fd(connection);
823 Ok(())
824 }
825
826 fn handle_vring_fd_request(
829 &mut self,
830 buf: &[u8],
831 files: Vec<File>,
832 ) -> Result<(u8, Option<File>)> {
833 let (msg, _) = VhostUserU64::read_from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
834 if !msg.is_valid() {
835 return Err(Error::InvalidMessage);
836 }
837
838 let has_fd = (msg.value & 0x100u64) == 0;
845
846 let file = into_single_file(files);
847
848 if has_fd && file.is_none() || !has_fd && file.is_some() {
849 return Err(Error::InvalidMessage);
850 }
851
852 Ok((msg.value as u8, file))
853 }
854
855 fn check_request_size(
856 &self,
857 hdr: &VhostUserMsgHeader<FrontendReq>,
858 size: usize,
859 expected: usize,
860 ) -> Result<()> {
861 if hdr.get_size() as usize != expected
862 || hdr.is_reply()
863 || hdr.get_version() != 0x1
864 || size != expected
865 {
866 return Err(Error::InvalidMessage);
867 }
868 Ok(())
869 }
870
871 fn check_attached_files(
872 &self,
873 hdr: &VhostUserMsgHeader<FrontendReq>,
874 files: &[File],
875 ) -> Result<()> {
876 match hdr.get_code() {
877 Ok(FrontendReq::SET_MEM_TABLE)
878 | Ok(FrontendReq::SET_VRING_CALL)
879 | Ok(FrontendReq::SET_VRING_KICK)
880 | Ok(FrontendReq::SET_VRING_ERR)
881 | Ok(FrontendReq::SET_LOG_BASE)
882 | Ok(FrontendReq::SET_LOG_FD)
883 | Ok(FrontendReq::SET_BACKEND_REQ_FD)
884 | Ok(FrontendReq::SET_INFLIGHT_FD)
885 | Ok(FrontendReq::ADD_MEM_REG)
886 | Ok(FrontendReq::SET_DEVICE_STATE_FD) => Ok(()),
887 Err(_) => Err(Error::InvalidMessage),
888 _ if !files.is_empty() => Err(Error::InvalidMessage),
889 _ => Ok(()),
890 }
891 }
892
893 fn extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>(
894 &self,
895 hdr: &VhostUserMsgHeader<FrontendReq>,
896 size: usize,
897 buf: &[u8],
898 ) -> Result<T> {
899 self.check_request_size(hdr, size, mem::size_of::<T>())?;
900 let (body, _) = T::read_from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
901 if body.is_valid() {
902 Ok(body)
903 } else {
904 Err(Error::InvalidMessage)
905 }
906 }
907
908 fn update_reply_ack_flag(&mut self) {
909 let pflag = VhostUserProtocolFeatures::REPLY_ACK;
910 self.reply_ack_enabled = (self.virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES) != 0
911 && self.protocol_features.contains(pflag)
912 && (self.acked_protocol_features & pflag.bits()) != 0;
913 }
914}
915
916impl<S: Backend> AsRawDescriptor for BackendServer<S> {
917 fn as_raw_descriptor(&self) -> RawDescriptor {
918 self.connection.as_raw_descriptor()
920 }
921}
922
923#[cfg(test)]
924mod tests {
925 use base::INVALID_DESCRIPTOR;
926
927 use super::*;
928 use crate::test_backend::TestBackend;
929 use crate::Connection;
930
931 #[test]
932 fn test_backend_server_new() {
933 let (p1, _p2) = Connection::pair().unwrap();
934 let backend = TestBackend::new();
935 let handler = BackendServer::new(p1, backend);
936
937 assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR);
938 }
939}