vmm_vhost/
backend_server.rs

1// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fs::File;
5use std::mem;
6
7use base::AsRawDescriptor;
8use base::RawDescriptor;
9use base::SafeDescriptor;
10use zerocopy::FromBytes;
11use zerocopy::Immutable;
12use zerocopy::IntoBytes;
13use zerocopy::Ref;
14
15use crate::into_single_file;
16use crate::message::*;
17use crate::Connection;
18use crate::Error;
19use crate::FrontendReq;
20use crate::Result;
21use crate::SharedMemoryRegion;
22
23/// Trait for vhost-user backends.
24///
25/// Each method corresponds to a vhost-user protocol method. See the specification for details.
26#[allow(missing_docs)]
27pub trait Backend {
28    fn set_owner(&mut self) -> Result<()>;
29    fn reset_owner(&mut self) -> Result<()>;
30    fn get_features(&mut self) -> Result<u64>;
31    fn set_features(&mut self, features: u64) -> Result<()>;
32    fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
33    fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
34    fn set_vring_addr(
35        &mut self,
36        index: u32,
37        flags: VhostUserVringAddrFlags,
38        descriptor: u64,
39        used: u64,
40        available: u64,
41        log: u64,
42    ) -> Result<()>;
43    // TODO: b/331466964 - Argument type is wrong for packed queues.
44    fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
45    // TODO: b/331466964 - Return type is wrong for packed queues.
46    fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
47    fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
48    fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
49    fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
50
51    fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
52    fn set_protocol_features(&mut self, features: u64) -> Result<()>;
53    fn get_queue_num(&mut self) -> Result<u64>;
54    fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
55    fn get_config(
56        &mut self,
57        offset: u32,
58        size: u32,
59        flags: VhostUserConfigFlags,
60    ) -> Result<Vec<u8>>;
61    fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
62    fn set_backend_req_fd(&mut self, _vu_req: Connection) {}
63    fn get_inflight_fd(
64        &mut self,
65        inflight: &VhostUserInflight,
66    ) -> Result<(VhostUserInflight, File)>;
67    fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
68    fn get_max_mem_slots(&mut self) -> Result<u64>;
69    fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
70    fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
71    fn set_device_state_fd(
72        &mut self,
73        transfer_direction: VhostUserTransferDirection,
74        migration_phase: VhostUserMigrationPhase,
75        fd: File,
76    ) -> Result<Option<File>>;
77    fn check_device_state(&mut self) -> Result<()>;
78    fn get_shmem_config(&mut self) -> Result<Vec<SharedMemoryRegion>>;
79}
80
81impl<T> Backend for T
82where
83    T: AsMut<dyn Backend>,
84{
85    fn set_owner(&mut self) -> Result<()> {
86        self.as_mut().set_owner()
87    }
88
89    fn reset_owner(&mut self) -> Result<()> {
90        self.as_mut().reset_owner()
91    }
92
93    fn get_features(&mut self) -> Result<u64> {
94        self.as_mut().get_features()
95    }
96
97    fn set_features(&mut self, features: u64) -> Result<()> {
98        self.as_mut().set_features(features)
99    }
100
101    fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
102        self.as_mut().set_mem_table(ctx, files)
103    }
104
105    fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
106        self.as_mut().set_vring_num(index, num)
107    }
108
109    fn set_vring_addr(
110        &mut self,
111        index: u32,
112        flags: VhostUserVringAddrFlags,
113        descriptor: u64,
114        used: u64,
115        available: u64,
116        log: u64,
117    ) -> Result<()> {
118        self.as_mut()
119            .set_vring_addr(index, flags, descriptor, used, available, log)
120    }
121
122    fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> {
123        self.as_mut().set_vring_base(index, base)
124    }
125
126    fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> {
127        self.as_mut().get_vring_base(index)
128    }
129
130    fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> {
131        self.as_mut().set_vring_kick(index, fd)
132    }
133
134    fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> {
135        self.as_mut().set_vring_call(index, fd)
136    }
137
138    fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> {
139        self.as_mut().set_vring_err(index, fd)
140    }
141
142    fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
143        self.as_mut().get_protocol_features()
144    }
145
146    fn set_protocol_features(&mut self, features: u64) -> Result<()> {
147        self.as_mut().set_protocol_features(features)
148    }
149
150    fn get_queue_num(&mut self) -> Result<u64> {
151        self.as_mut().get_queue_num()
152    }
153
154    fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
155        self.as_mut().set_vring_enable(index, enable)
156    }
157
158    fn get_config(
159        &mut self,
160        offset: u32,
161        size: u32,
162        flags: VhostUserConfigFlags,
163    ) -> Result<Vec<u8>> {
164        self.as_mut().get_config(offset, size, flags)
165    }
166
167    fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
168        self.as_mut().set_config(offset, buf, flags)
169    }
170
171    fn set_backend_req_fd(&mut self, vu_req: Connection) {
172        self.as_mut().set_backend_req_fd(vu_req)
173    }
174
175    fn get_inflight_fd(
176        &mut self,
177        inflight: &VhostUserInflight,
178    ) -> Result<(VhostUserInflight, File)> {
179        self.as_mut().get_inflight_fd(inflight)
180    }
181
182    fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()> {
183        self.as_mut().set_inflight_fd(inflight, file)
184    }
185
186    fn get_max_mem_slots(&mut self) -> Result<u64> {
187        self.as_mut().get_max_mem_slots()
188    }
189
190    fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
191        self.as_mut().add_mem_region(region, fd)
192    }
193
194    fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
195        self.as_mut().remove_mem_region(region)
196    }
197
198    fn set_device_state_fd(
199        &mut self,
200        transfer_direction: VhostUserTransferDirection,
201        migration_phase: VhostUserMigrationPhase,
202        fd: File,
203    ) -> Result<Option<File>> {
204        self.as_mut()
205            .set_device_state_fd(transfer_direction, migration_phase, fd)
206    }
207
208    fn check_device_state(&mut self) -> Result<()> {
209        self.as_mut().check_device_state()
210    }
211
212    fn get_shmem_config(&mut self) -> Result<Vec<SharedMemoryRegion>> {
213        self.as_mut().get_shmem_config()
214    }
215}
216
217/// Handles requests from a vhost-user connection by dispatching them to [[Backend]] methods.
218pub struct BackendServer<S: Backend> {
219    /// Underlying connection for communication.
220    connection: Connection,
221    // the vhost-user backend device object
222    backend: S,
223
224    virtio_features: u64,
225    acked_virtio_features: u64,
226    protocol_features: VhostUserProtocolFeatures,
227    acked_protocol_features: u64,
228
229    /// Sending ack for messages without payload.
230    reply_ack_enabled: bool,
231}
232
233impl<S: Backend> AsRef<S> for BackendServer<S> {
234    fn as_ref(&self) -> &S {
235        &self.backend
236    }
237}
238
239impl<S: Backend> BackendServer<S> {
240    pub fn new(connection: Connection, backend: S) -> Self {
241        BackendServer {
242            connection,
243            backend,
244            virtio_features: 0,
245            acked_virtio_features: 0,
246            protocol_features: VhostUserProtocolFeatures::empty(),
247            acked_protocol_features: 0,
248            reply_ack_enabled: false,
249        }
250    }
251
252    /// Receives and validates a vhost-user message header and optional files.
253    ///
254    /// Since the length of vhost-user messages are different among message types, regular
255    /// vhost-user messages are sent via an underlying communication channel in stream mode.
256    /// (e.g. `SOCK_STREAM` in UNIX)
257    /// So, the logic of receiving and handling a message consists of the following steps:
258    ///
259    /// 1. Receives a message header and optional attached file.
260    /// 2. Validates the message header.
261    /// 3. Check if optional payloads is expected.
262    /// 4. Wait for the optional payloads.
263    /// 5. Receives optional payloads.
264    /// 6. Processes the message.
265    ///
266    /// This method [`BackendServer::recv_header()`] is in charge of the step (1) and (2),
267    /// [`BackendServer::needs_wait_for_payload()`] is (3), and
268    /// [`BackendServer::process_message()`] is (5) and (6). We need to have the three method
269    /// separately for multi-platform supports; [`BackendServer::recv_header()`] and
270    /// [`BackendServer::process_message()`] need to be separated because the way of waiting for
271    /// incoming messages differs between Unix and Windows so it's the caller's responsibility to
272    /// wait before [`BackendServer::process_message()`].
273    ///
274    /// Note that some vhost-user protocol variant such as VVU doesn't assume stream mode. In this
275    /// case, a message header and its body are sent together so the step (4) is skipped. We handle
276    /// this case in [`BackendServer::needs_wait_for_payload()`].
277    ///
278    /// The following pseudo code describes how a caller should process incoming vhost-user
279    /// messages:
280    /// ```ignore
281    /// loop {
282    ///   // block until a message header comes.
283    ///   // The actual code differs, depending on platforms.
284    ///   connection.wait_readable().unwrap();
285    ///
286    ///   // (1) and (2)
287    ///   let (hdr, files) = backend_server.recv_header();
288    ///
289    ///   // (3)
290    ///   if backend_server.needs_wait_for_payload(&hdr) {
291    ///     // (4) block until a payload comes if needed.
292    ///     connection.wait_readable().unwrap();
293    ///   }
294    ///
295    ///   // (5) and (6)
296    ///   backend_server.process_message(&hdr, &files).unwrap();
297    /// }
298    /// ```
299    pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader, Vec<File>)> {
300        // The underlying communication channel is a Unix domain socket in
301        // stream mode, and recvmsg() is a little tricky here. To successfully
302        // receive attached file descriptors, we need to receive messages and
303        // corresponding attached file descriptors in this way:
304        // . recv messsage header and optional attached file
305        // . validate message header
306        // . recv optional message body and payload according size field in
307        //   message header
308        // . validate message body and optional payload
309        let (hdr, files) = match self.connection.recv_header() {
310            Ok((hdr, files)) => (hdr, files),
311            Err(Error::Disconnect) => {
312                // If the client closed the connection before sending a header, this should be
313                // handled as a legal exit.
314                return Err(Error::ClientExit);
315            }
316            Err(e) => {
317                return Err(e);
318            }
319        };
320
321        if !hdr.is_valid() {
322            return Err(Error::InvalidMessage);
323        }
324
325        self.check_attached_files(&hdr, &files)?;
326
327        Ok((hdr, files))
328    }
329
330    /// Returns whether the caller needs to wait for the incoming message before calling
331    /// [`BackendServer::process_message`].
332    ///
333    /// See [`BackendServer::recv_header`]'s doc comment for the usage.
334    pub fn needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader) -> bool {
335        // Since the vhost-user protocol uses stream mode, we need to wait until an additional
336        // payload is available if exists.
337        hdr.get_size() != 0
338    }
339
340    /// Main entrance to request from the communication channel.
341    ///
342    /// Receive and handle one incoming request message from the frontend.
343    /// See [`BackendServer::recv_header`]'s doc comment for the usage.
344    ///
345    /// # Return:
346    /// * `Ok(())`: one request was successfully handled.
347    /// * `Err(ClientExit)`: the frontend closed the connection properly. This isn't an actual
348    ///   failure.
349    /// * `Err(Disconnect)`: the connection was closed unexpectedly.
350    /// * `Err(InvalidMessage)`: the vmm sent a illegal message.
351    /// * other errors: failed to handle a request.
352    pub fn process_message(&mut self, hdr: VhostUserMsgHeader, files: Vec<File>) -> Result<()> {
353        let (buf, extra_files) = self.connection.recv_body_bytes(&hdr)?;
354        let size = buf.len();
355        if !extra_files.is_empty() {
356            return Err(Error::InvalidMessage);
357        }
358
359        // TODO: The error handling here is inconsistent. Sometimes we report the error to the
360        // client and keep going, sometimes we report the error and then close the connection,
361        // sometimes we just close the connection.
362        match hdr.get_code() {
363            Ok(FrontendReq::SET_OWNER) => {
364                self.check_request_size(&hdr, size, 0)?;
365                let res = self.backend.set_owner();
366                self.send_ack_message(&hdr, res.is_ok())?;
367                res?;
368            }
369            Ok(FrontendReq::RESET_OWNER) => {
370                self.check_request_size(&hdr, size, 0)?;
371                let res = self.backend.reset_owner();
372                self.send_ack_message(&hdr, res.is_ok())?;
373                res?;
374            }
375            Ok(FrontendReq::GET_FEATURES) => {
376                self.check_request_size(&hdr, size, 0)?;
377                let mut features = self.backend.get_features()?;
378
379                // Don't advertise packed queues even if the device does. We don't handle them
380                // properly yet at the protocol layer.
381                // TODO: b/331466964 - Remove once support is added.
382                features &= !(1 << VIRTIO_F_RING_PACKED);
383
384                let msg = VhostUserU64::new(features);
385                self.send_reply_message(&hdr, &msg)?;
386                self.virtio_features = features;
387                self.update_reply_ack_flag();
388            }
389            Ok(FrontendReq::SET_FEATURES) => {
390                let mut msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
391
392                // Don't allow packed queues even if the device does. We don't handle them
393                // properly yet at the protocol layer.
394                // TODO: b/331466964 - Remove once support is added.
395                msg.value &= !(1 << VIRTIO_F_RING_PACKED);
396
397                let res = self.backend.set_features(msg.value);
398                self.acked_virtio_features = msg.value;
399                self.update_reply_ack_flag();
400                self.send_ack_message(&hdr, res.is_ok())?;
401                res?;
402            }
403            Ok(FrontendReq::SET_MEM_TABLE) => {
404                let res = self.set_mem_table(&hdr, size, &buf, files);
405                self.send_ack_message(&hdr, res.is_ok())?;
406                res?;
407            }
408            Ok(FrontendReq::SET_VRING_NUM) => {
409                let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
410                let res = self.backend.set_vring_num(msg.index, msg.num);
411                self.send_ack_message(&hdr, res.is_ok())?;
412                res?;
413            }
414            Ok(FrontendReq::SET_VRING_ADDR) => {
415                let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
416                let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
417                    Some(val) => val,
418                    None => return Err(Error::InvalidMessage),
419                };
420                let res = self.backend.set_vring_addr(
421                    msg.index,
422                    flags,
423                    msg.descriptor,
424                    msg.used,
425                    msg.available,
426                    msg.log,
427                );
428                self.send_ack_message(&hdr, res.is_ok())?;
429                res?;
430            }
431            Ok(FrontendReq::SET_VRING_BASE) => {
432                let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
433                let res = self.backend.set_vring_base(msg.index, msg.num);
434                self.send_ack_message(&hdr, res.is_ok())?;
435                res?;
436            }
437            Ok(FrontendReq::GET_VRING_BASE) => {
438                let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
439                let reply = self.backend.get_vring_base(msg.index)?;
440                self.send_reply_message(&hdr, &reply)?;
441            }
442            Ok(FrontendReq::SET_VRING_CALL) => {
443                self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
444                let (index, file) = self.handle_vring_fd_request(&buf, files)?;
445                let res = self.backend.set_vring_call(index, file);
446                self.send_ack_message(&hdr, res.is_ok())?;
447                res?;
448            }
449            Ok(FrontendReq::SET_VRING_KICK) => {
450                self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
451                let (index, file) = self.handle_vring_fd_request(&buf, files)?;
452                let res = self.backend.set_vring_kick(index, file);
453                self.send_ack_message(&hdr, res.is_ok())?;
454                res?;
455            }
456            Ok(FrontendReq::SET_VRING_ERR) => {
457                self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
458                let (index, file) = self.handle_vring_fd_request(&buf, files)?;
459                let res = self.backend.set_vring_err(index, file);
460                self.send_ack_message(&hdr, res.is_ok())?;
461                res?;
462            }
463            Ok(FrontendReq::GET_PROTOCOL_FEATURES) => {
464                self.check_request_size(&hdr, size, 0)?;
465                let features = self.backend.get_protocol_features()?;
466                let msg = VhostUserU64::new(features.bits());
467                self.send_reply_message(&hdr, &msg)?;
468                self.protocol_features = features;
469                self.update_reply_ack_flag();
470            }
471            Ok(FrontendReq::SET_PROTOCOL_FEATURES) => {
472                let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
473                let res = self.backend.set_protocol_features(msg.value);
474                self.acked_protocol_features = msg.value;
475                self.update_reply_ack_flag();
476                self.send_ack_message(&hdr, res.is_ok())?;
477                res?;
478            }
479            Ok(FrontendReq::GET_QUEUE_NUM) => {
480                if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
481                    return Err(Error::InvalidOperation);
482                }
483                self.check_request_size(&hdr, size, 0)?;
484                let num = self.backend.get_queue_num()?;
485                let msg = VhostUserU64::new(num);
486                self.send_reply_message(&hdr, &msg)?;
487            }
488            Ok(FrontendReq::SET_VRING_ENABLE) => {
489                let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
490                if self.acked_virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
491                    return Err(Error::InvalidOperation);
492                }
493                let enable = match msg.num {
494                    1 => true,
495                    0 => false,
496                    _ => {
497                        return Err(Error::InvalidParam(
498                            "SET_VRING_ENABLE: num out of range (must be [0, 1])",
499                        ))
500                    }
501                };
502
503                let res = self.backend.set_vring_enable(msg.index, enable);
504                self.send_ack_message(&hdr, res.is_ok())?;
505                res?;
506            }
507            Ok(FrontendReq::GET_CONFIG) => {
508                if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
509                    return Err(Error::InvalidOperation);
510                }
511                self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
512                self.get_config(&hdr, &buf)?;
513            }
514            Ok(FrontendReq::SET_CONFIG) => {
515                if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
516                    return Err(Error::InvalidOperation);
517                }
518                self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
519                let res = self.set_config(&buf);
520                self.send_ack_message(&hdr, res.is_ok())?;
521                res?;
522            }
523            Ok(FrontendReq::SET_BACKEND_REQ_FD) => {
524                if self.acked_protocol_features & VhostUserProtocolFeatures::BACKEND_REQ.bits() == 0
525                {
526                    return Err(Error::InvalidOperation);
527                }
528                self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
529                let res = self.set_backend_req_fd(files);
530                self.send_ack_message(&hdr, res.is_ok())?;
531                res?;
532            }
533            Ok(FrontendReq::GET_INFLIGHT_FD) => {
534                if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
535                    == 0
536                {
537                    return Err(Error::InvalidOperation);
538                }
539
540                let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
541                let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
542                let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?;
543                self.connection.send_message(
544                    &reply_hdr,
545                    &inflight,
546                    Some(&[file.as_raw_descriptor()]),
547                )?;
548            }
549            Ok(FrontendReq::SET_INFLIGHT_FD) => {
550                if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
551                    == 0
552                {
553                    return Err(Error::InvalidOperation);
554                }
555                let file = into_single_file(files).ok_or(Error::IncorrectFds)?;
556                let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
557                let res = self.backend.set_inflight_fd(&msg, file);
558                self.send_ack_message(&hdr, res.is_ok())?;
559                res?;
560            }
561            Ok(FrontendReq::GET_MAX_MEM_SLOTS) => {
562                if self.acked_protocol_features
563                    & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
564                    == 0
565                {
566                    return Err(Error::InvalidOperation);
567                }
568                self.check_request_size(&hdr, size, 0)?;
569                let num = self.backend.get_max_mem_slots()?;
570                let msg = VhostUserU64::new(num);
571                self.send_reply_message(&hdr, &msg)?;
572            }
573            Ok(FrontendReq::ADD_MEM_REG) => {
574                if self.acked_protocol_features
575                    & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
576                    == 0
577                {
578                    return Err(Error::InvalidOperation);
579                }
580                let file = into_single_file(files).ok_or(Error::InvalidParam(
581                    "ADD_MEM_REG: exactly one file must be provided",
582                ))?;
583                let msg =
584                    self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
585                let res = self.backend.add_mem_region(&msg, file);
586                self.send_ack_message(&hdr, res.is_ok())?;
587                res?;
588            }
589            Ok(FrontendReq::REM_MEM_REG) => {
590                if self.acked_protocol_features
591                    & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
592                    == 0
593                {
594                    return Err(Error::InvalidOperation);
595                }
596
597                let msg =
598                    self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
599                let res = self.backend.remove_mem_region(&msg);
600                self.send_ack_message(&hdr, res.is_ok())?;
601                res?;
602            }
603            Ok(FrontendReq::SET_DEVICE_STATE_FD) => {
604                if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
605                    == 0
606                {
607                    return Err(Error::InvalidOperation);
608                }
609                // Read request.
610                let msg =
611                    self.extract_request_body::<DeviceStateTransferParameters>(&hdr, size, &buf)?;
612                let transfer_direction = match msg.transfer_direction {
613                    0 => VhostUserTransferDirection::Save,
614                    1 => VhostUserTransferDirection::Load,
615                    _ => return Err(Error::InvalidMessage),
616                };
617                let migration_phase = match msg.migration_phase {
618                    0 => VhostUserMigrationPhase::Stopped,
619                    _ => return Err(Error::InvalidMessage),
620                };
621                // Call backend.
622                let res = self.backend.set_device_state_fd(
623                    transfer_direction,
624                    migration_phase,
625                    files.into_iter().next().ok_or(Error::IncorrectFds)?,
626                );
627                // Send response.
628                let (msg, fds) = match &res {
629                    Ok(None) => (VhostUserU64::new(0x100), None),
630                    Ok(Some(file)) => (VhostUserU64::new(0), Some(file.as_raw_descriptor())),
631                    // Just in case, set the "invalid FD" flag on error.
632                    Err(_) => (VhostUserU64::new(0x101), None),
633                };
634                let reply_hdr: VhostUserMsgHeader =
635                    self.new_reply_header::<VhostUserU64>(&hdr, 0)?;
636                self.connection.send_message(
637                    &reply_hdr,
638                    &msg,
639                    fds.as_ref().map(std::slice::from_ref),
640                )?;
641                res?;
642            }
643            Ok(FrontendReq::CHECK_DEVICE_STATE) => {
644                if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
645                    == 0
646                {
647                    return Err(Error::InvalidOperation);
648                }
649                let res = self.backend.check_device_state();
650                let msg = VhostUserU64::new(if res.is_ok() { 0 } else { 1 });
651                self.send_reply_message(&hdr, &msg)?;
652                res?;
653            }
654            Ok(FrontendReq::GET_SHMEM_CONFIG) => {
655                let msg = VhostUserShMemConfig::new(&self.backend.get_shmem_config()?);
656                self.send_reply_message(&hdr, &msg)?;
657            }
658            _ => {
659                return Err(Error::InvalidMessage);
660            }
661        }
662        Ok(())
663    }
664
665    fn new_reply_header<T: Sized>(
666        &self,
667        req: &VhostUserMsgHeader,
668        payload_size: usize,
669    ) -> Result<VhostUserMsgHeader> {
670        Ok(VhostUserMsgHeader::new_reply_header(
671            req.get_code::<FrontendReq>()
672                .map_err(|_| Error::InvalidMessage)?,
673            (mem::size_of::<T>()
674                .checked_add(payload_size)
675                .ok_or(Error::OversizedMsg)?)
676            .try_into()
677            .map_err(Error::InvalidCastToInt)?,
678        ))
679    }
680
681    /// Sends reply back to Vhost frontend in response to a message.
682    fn send_ack_message(&mut self, req: &VhostUserMsgHeader, success: bool) -> Result<()> {
683        if self.reply_ack_enabled && req.is_need_reply() {
684            let hdr: VhostUserMsgHeader = self.new_reply_header::<VhostUserU64>(req, 0)?;
685            let val = if success { 0 } else { 1 };
686            let msg = VhostUserU64::new(val);
687            self.connection.send_message(&hdr, &msg, None)?;
688        }
689        Ok(())
690    }
691
692    fn send_reply_message<T: IntoBytes + Immutable>(
693        &mut self,
694        req: &VhostUserMsgHeader,
695        msg: &T,
696    ) -> Result<()> {
697        let hdr = self.new_reply_header::<T>(req, 0)?;
698        self.connection.send_message(&hdr, msg, None)?;
699        Ok(())
700    }
701
702    fn send_reply_with_payload<T: IntoBytes + Immutable>(
703        &mut self,
704        req: &VhostUserMsgHeader,
705        msg: &T,
706        payload: &[u8],
707    ) -> Result<()> {
708        let hdr = self.new_reply_header::<T>(req, payload.len())?;
709        self.connection
710            .send_message_with_payload(&hdr, msg, payload, None)?;
711        Ok(())
712    }
713
714    fn set_mem_table(
715        &mut self,
716        hdr: &VhostUserMsgHeader,
717        size: usize,
718        buf: &[u8],
719        files: Vec<File>,
720    ) -> Result<()> {
721        self.check_request_size(hdr, size, hdr.get_size() as usize)?;
722
723        let (msg, regions) =
724            Ref::<_, VhostUserMemory>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
725        if !msg.is_valid() {
726            return Err(Error::InvalidMessage);
727        }
728
729        // validate number of fds matching number of memory regions
730        if files.len() != msg.num_regions as usize {
731            return Err(Error::InvalidMessage);
732        }
733
734        let (regions, excess) = Ref::<_, [VhostUserMemoryRegion]>::from_prefix_with_elems(
735            regions,
736            msg.num_regions as usize,
737        )
738        .map_err(|_| Error::InvalidMessage)?;
739        if !excess.is_empty() {
740            return Err(Error::InvalidMessage);
741        }
742
743        // Validate memory regions
744        for region in regions.iter() {
745            if !region.is_valid() {
746                return Err(Error::InvalidMessage);
747            }
748        }
749
750        self.backend.set_mem_table(&regions, files)
751    }
752
753    fn get_config(&mut self, hdr: &VhostUserMsgHeader, buf: &[u8]) -> Result<()> {
754        let (msg, payload) =
755            Ref::<_, VhostUserConfig>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
756        if !msg.is_valid() {
757            return Err(Error::InvalidMessage);
758        }
759        if payload.len() != msg.size as usize {
760            return Err(Error::InvalidMessage);
761        }
762        let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
763            Some(val) => val,
764            None => return Err(Error::InvalidMessage),
765        };
766        let res = self.backend.get_config(msg.offset, msg.size, flags);
767
768        // The response payload size MUST match the request payload size on success. A zero length
769        // response is used to indicate an error.
770        match res {
771            Ok(ref buf) if buf.len() == msg.size as usize => {
772                let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
773                self.send_reply_with_payload(hdr, &reply, buf.as_slice())?;
774            }
775            Ok(_) => {
776                let reply = VhostUserConfig::new(msg.offset, 0, flags);
777                self.send_reply_message(hdr, &reply)?;
778            }
779            Err(_) => {
780                let reply = VhostUserConfig::new(msg.offset, 0, flags);
781                self.send_reply_message(hdr, &reply)?;
782            }
783        }
784        Ok(())
785    }
786
787    fn set_config(&mut self, buf: &[u8]) -> Result<()> {
788        let (msg, payload) =
789            Ref::<_, VhostUserConfig>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
790        if !msg.is_valid() {
791            return Err(Error::InvalidMessage);
792        }
793        if payload.len() != msg.size as usize {
794            return Err(Error::InvalidMessage);
795        }
796        let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) {
797            Some(val) => val,
798            None => return Err(Error::InvalidMessage),
799        };
800
801        self.backend.set_config(msg.offset, payload, flags)
802    }
803
804    fn set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()> {
805        let file = into_single_file(files).ok_or(Error::InvalidMessage)?;
806        let fd: SafeDescriptor = file.into();
807        let connection = Connection::try_from(fd).map_err(|_| Error::InvalidMessage)?;
808        self.backend.set_backend_req_fd(connection);
809        Ok(())
810    }
811
812    /// Parses an incoming |SET_VRING_KICK| or |SET_VRING_CALL| message into a
813    /// Vring number and an fd.
814    fn handle_vring_fd_request(
815        &mut self,
816        buf: &[u8],
817        files: Vec<File>,
818    ) -> Result<(u8, Option<File>)> {
819        let (msg, _) = VhostUserU64::read_from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
820        if !msg.is_valid() {
821            return Err(Error::InvalidMessage);
822        }
823
824        // Bits (0-7) of the payload contain the vring index. Bit 8 is the
825        // invalid FD flag (VHOST_USER_VRING_NOFD_MASK).
826        // This bit is set when there is no file descriptor
827        // in the ancillary data. This signals that polling will be used
828        // instead of waiting for the call.
829        // If Bit 8 is unset, the data must contain a file descriptor.
830        let has_fd = (msg.value & 0x100u64) == 0;
831
832        let file = into_single_file(files);
833
834        if has_fd && file.is_none() || !has_fd && file.is_some() {
835            return Err(Error::InvalidMessage);
836        }
837
838        Ok((msg.value as u8, file))
839    }
840
841    fn check_request_size(
842        &self,
843        hdr: &VhostUserMsgHeader,
844        size: usize,
845        expected: usize,
846    ) -> Result<()> {
847        if hdr.get_size() as usize != expected
848            || hdr.is_reply()
849            || hdr.get_version() != 0x1
850            || size != expected
851        {
852            return Err(Error::InvalidMessage);
853        }
854        Ok(())
855    }
856
857    fn check_attached_files(&self, hdr: &VhostUserMsgHeader, files: &[File]) -> Result<()> {
858        match hdr.get_code() {
859            Ok(FrontendReq::SET_MEM_TABLE)
860            | Ok(FrontendReq::SET_VRING_CALL)
861            | Ok(FrontendReq::SET_VRING_KICK)
862            | Ok(FrontendReq::SET_VRING_ERR)
863            | Ok(FrontendReq::SET_LOG_BASE)
864            | Ok(FrontendReq::SET_LOG_FD)
865            | Ok(FrontendReq::SET_BACKEND_REQ_FD)
866            | Ok(FrontendReq::SET_INFLIGHT_FD)
867            | Ok(FrontendReq::ADD_MEM_REG)
868            | Ok(FrontendReq::SET_DEVICE_STATE_FD) => Ok(()),
869            Err(_) => Err(Error::InvalidMessage),
870            _ if !files.is_empty() => Err(Error::InvalidMessage),
871            _ => Ok(()),
872        }
873    }
874
875    fn extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>(
876        &self,
877        hdr: &VhostUserMsgHeader,
878        size: usize,
879        buf: &[u8],
880    ) -> Result<T> {
881        self.check_request_size(hdr, size, mem::size_of::<T>())?;
882        let (body, _) = T::read_from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
883        if body.is_valid() {
884            Ok(body)
885        } else {
886            Err(Error::InvalidMessage)
887        }
888    }
889
890    fn update_reply_ack_flag(&mut self) {
891        let pflag = VhostUserProtocolFeatures::REPLY_ACK;
892        self.reply_ack_enabled = (self.virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES) != 0
893            && self.protocol_features.contains(pflag)
894            && (self.acked_protocol_features & pflag.bits()) != 0;
895    }
896}
897
898impl<S: Backend> AsRawDescriptor for BackendServer<S> {
899    fn as_raw_descriptor(&self) -> RawDescriptor {
900        // TODO(b/221882601): figure out if this used for polling.
901        self.connection.as_raw_descriptor()
902    }
903}
904
905#[cfg(test)]
906mod tests {
907    use base::INVALID_DESCRIPTOR;
908
909    use super::*;
910    use crate::test_backend::TestBackend;
911    use crate::Connection;
912
913    #[test]
914    fn test_backend_server_new() {
915        let (p1, _p2) = Connection::pair().unwrap();
916        let backend = TestBackend::new();
917        let handler = BackendServer::new(p1, backend);
918
919        assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR);
920    }
921}