vmm_vhost/
backend_server.rs

1// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fs::File;
5use std::mem;
6
7use base::AsRawDescriptor;
8use base::RawDescriptor;
9use base::SafeDescriptor;
10use zerocopy::FromBytes;
11use zerocopy::Immutable;
12use zerocopy::IntoBytes;
13use zerocopy::Ref;
14
15use crate::into_single_file;
16use crate::message::*;
17use crate::BackendReq;
18use crate::Connection;
19use crate::Error;
20use crate::FrontendReq;
21use crate::Result;
22
23/// Trait for vhost-user backends.
24///
25/// Each method corresponds to a vhost-user protocol method. See the specification for details.
26#[allow(missing_docs)]
27pub trait Backend {
28    fn set_owner(&mut self) -> Result<()>;
29    fn reset_owner(&mut self) -> Result<()>;
30    fn get_features(&mut self) -> Result<u64>;
31    fn set_features(&mut self, features: u64) -> Result<()>;
32    fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
33    fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
34    fn set_vring_addr(
35        &mut self,
36        index: u32,
37        flags: VhostUserVringAddrFlags,
38        descriptor: u64,
39        used: u64,
40        available: u64,
41        log: u64,
42    ) -> Result<()>;
43    // TODO: b/331466964 - Argument type is wrong for packed queues.
44    fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
45    // TODO: b/331466964 - Return type is wrong for packed queues.
46    fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
47    fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
48    fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
49    fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
50
51    fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
52    fn set_protocol_features(&mut self, features: u64) -> Result<()>;
53    fn get_queue_num(&mut self) -> Result<u64>;
54    fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
55    fn get_config(
56        &mut self,
57        offset: u32,
58        size: u32,
59        flags: VhostUserConfigFlags,
60    ) -> Result<Vec<u8>>;
61    fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
62    fn set_backend_req_fd(&mut self, _vu_req: Connection<BackendReq>) {}
63    fn get_inflight_fd(
64        &mut self,
65        inflight: &VhostUserInflight,
66    ) -> Result<(VhostUserInflight, File)>;
67    fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
68    fn get_max_mem_slots(&mut self) -> Result<u64>;
69    fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
70    fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
71    fn set_device_state_fd(
72        &mut self,
73        transfer_direction: VhostUserTransferDirection,
74        migration_phase: VhostUserMigrationPhase,
75        fd: File,
76    ) -> Result<Option<File>>;
77    fn check_device_state(&mut self) -> Result<()>;
78    fn get_shmem_config(&mut self) -> Result<(VhostUserShMemConfigHeader, Vec<u64>)>;
79}
80
81impl<T> Backend for T
82where
83    T: AsMut<dyn Backend>,
84{
85    fn set_owner(&mut self) -> Result<()> {
86        self.as_mut().set_owner()
87    }
88
89    fn reset_owner(&mut self) -> Result<()> {
90        self.as_mut().reset_owner()
91    }
92
93    fn get_features(&mut self) -> Result<u64> {
94        self.as_mut().get_features()
95    }
96
97    fn set_features(&mut self, features: u64) -> Result<()> {
98        self.as_mut().set_features(features)
99    }
100
101    fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
102        self.as_mut().set_mem_table(ctx, files)
103    }
104
105    fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
106        self.as_mut().set_vring_num(index, num)
107    }
108
109    fn set_vring_addr(
110        &mut self,
111        index: u32,
112        flags: VhostUserVringAddrFlags,
113        descriptor: u64,
114        used: u64,
115        available: u64,
116        log: u64,
117    ) -> Result<()> {
118        self.as_mut()
119            .set_vring_addr(index, flags, descriptor, used, available, log)
120    }
121
122    fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> {
123        self.as_mut().set_vring_base(index, base)
124    }
125
126    fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> {
127        self.as_mut().get_vring_base(index)
128    }
129
130    fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> {
131        self.as_mut().set_vring_kick(index, fd)
132    }
133
134    fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> {
135        self.as_mut().set_vring_call(index, fd)
136    }
137
138    fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> {
139        self.as_mut().set_vring_err(index, fd)
140    }
141
142    fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
143        self.as_mut().get_protocol_features()
144    }
145
146    fn set_protocol_features(&mut self, features: u64) -> Result<()> {
147        self.as_mut().set_protocol_features(features)
148    }
149
150    fn get_queue_num(&mut self) -> Result<u64> {
151        self.as_mut().get_queue_num()
152    }
153
154    fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
155        self.as_mut().set_vring_enable(index, enable)
156    }
157
158    fn get_config(
159        &mut self,
160        offset: u32,
161        size: u32,
162        flags: VhostUserConfigFlags,
163    ) -> Result<Vec<u8>> {
164        self.as_mut().get_config(offset, size, flags)
165    }
166
167    fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
168        self.as_mut().set_config(offset, buf, flags)
169    }
170
171    fn set_backend_req_fd(&mut self, vu_req: Connection<BackendReq>) {
172        self.as_mut().set_backend_req_fd(vu_req)
173    }
174
175    fn get_inflight_fd(
176        &mut self,
177        inflight: &VhostUserInflight,
178    ) -> Result<(VhostUserInflight, File)> {
179        self.as_mut().get_inflight_fd(inflight)
180    }
181
182    fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()> {
183        self.as_mut().set_inflight_fd(inflight, file)
184    }
185
186    fn get_max_mem_slots(&mut self) -> Result<u64> {
187        self.as_mut().get_max_mem_slots()
188    }
189
190    fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
191        self.as_mut().add_mem_region(region, fd)
192    }
193
194    fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
195        self.as_mut().remove_mem_region(region)
196    }
197
198    fn set_device_state_fd(
199        &mut self,
200        transfer_direction: VhostUserTransferDirection,
201        migration_phase: VhostUserMigrationPhase,
202        fd: File,
203    ) -> Result<Option<File>> {
204        self.as_mut()
205            .set_device_state_fd(transfer_direction, migration_phase, fd)
206    }
207
208    fn check_device_state(&mut self) -> Result<()> {
209        self.as_mut().check_device_state()
210    }
211
212    fn get_shmem_config(&mut self) -> Result<(VhostUserShMemConfigHeader, Vec<u64>)> {
213        self.as_mut().get_shmem_config()
214    }
215}
216
217/// Handles requests from a vhost-user connection by dispatching them to [[Backend]] methods.
218pub struct BackendServer<S: Backend> {
219    /// Underlying connection for communication.
220    connection: Connection<FrontendReq>,
221    // the vhost-user backend device object
222    backend: S,
223
224    virtio_features: u64,
225    acked_virtio_features: u64,
226    protocol_features: VhostUserProtocolFeatures,
227    acked_protocol_features: u64,
228
229    /// Sending ack for messages without payload.
230    reply_ack_enabled: bool,
231}
232
233impl<S: Backend> AsRef<S> for BackendServer<S> {
234    fn as_ref(&self) -> &S {
235        &self.backend
236    }
237}
238
239impl<S: Backend> BackendServer<S> {
240    pub fn new(connection: Connection<FrontendReq>, backend: S) -> Self {
241        BackendServer {
242            connection,
243            backend,
244            virtio_features: 0,
245            acked_virtio_features: 0,
246            protocol_features: VhostUserProtocolFeatures::empty(),
247            acked_protocol_features: 0,
248            reply_ack_enabled: false,
249        }
250    }
251
252    /// Receives and validates a vhost-user message header and optional files.
253    ///
254    /// Since the length of vhost-user messages are different among message types, regular
255    /// vhost-user messages are sent via an underlying communication channel in stream mode.
256    /// (e.g. `SOCK_STREAM` in UNIX)
257    /// So, the logic of receiving and handling a message consists of the following steps:
258    ///
259    /// 1. Receives a message header and optional attached file.
260    /// 2. Validates the message header.
261    /// 3. Check if optional payloads is expected.
262    /// 4. Wait for the optional payloads.
263    /// 5. Receives optional payloads.
264    /// 6. Processes the message.
265    ///
266    /// This method [`BackendServer::recv_header()`] is in charge of the step (1) and (2),
267    /// [`BackendServer::needs_wait_for_payload()`] is (3), and
268    /// [`BackendServer::process_message()`] is (5) and (6). We need to have the three method
269    /// separately for multi-platform supports; [`BackendServer::recv_header()`] and
270    /// [`BackendServer::process_message()`] need to be separated because the way of waiting for
271    /// incoming messages differs between Unix and Windows so it's the caller's responsibility to
272    /// wait before [`BackendServer::process_message()`].
273    ///
274    /// Note that some vhost-user protocol variant such as VVU doesn't assume stream mode. In this
275    /// case, a message header and its body are sent together so the step (4) is skipped. We handle
276    /// this case in [`BackendServer::needs_wait_for_payload()`].
277    ///
278    /// The following pseudo code describes how a caller should process incoming vhost-user
279    /// messages:
280    /// ```ignore
281    /// loop {
282    ///   // block until a message header comes.
283    ///   // The actual code differs, depending on platforms.
284    ///   connection.wait_readable().unwrap();
285    ///
286    ///   // (1) and (2)
287    ///   let (hdr, files) = backend_server.recv_header();
288    ///
289    ///   // (3)
290    ///   if backend_server.needs_wait_for_payload(&hdr) {
291    ///     // (4) block until a payload comes if needed.
292    ///     connection.wait_readable().unwrap();
293    ///   }
294    ///
295    ///   // (5) and (6)
296    ///   backend_server.process_message(&hdr, &files).unwrap();
297    /// }
298    /// ```
299    pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<FrontendReq>, Vec<File>)> {
300        // The underlying communication channel is a Unix domain socket in
301        // stream mode, and recvmsg() is a little tricky here. To successfully
302        // receive attached file descriptors, we need to receive messages and
303        // corresponding attached file descriptors in this way:
304        // . recv messsage header and optional attached file
305        // . validate message header
306        // . recv optional message body and payload according size field in
307        //   message header
308        // . validate message body and optional payload
309        let (hdr, files) = match self.connection.recv_header() {
310            Ok((hdr, files)) => (hdr, files),
311            Err(Error::Disconnect) => {
312                // If the client closed the connection before sending a header, this should be
313                // handled as a legal exit.
314                return Err(Error::ClientExit);
315            }
316            Err(e) => {
317                return Err(e);
318            }
319        };
320
321        if !hdr.is_valid() {
322            return Err(Error::InvalidMessage);
323        }
324
325        self.check_attached_files(&hdr, &files)?;
326
327        Ok((hdr, files))
328    }
329
330    /// Returns whether the caller needs to wait for the incoming message before calling
331    /// [`BackendServer::process_message`].
332    ///
333    /// See [`BackendServer::recv_header`]'s doc comment for the usage.
334    pub fn needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> bool {
335        // Since the vhost-user protocol uses stream mode, we need to wait until an additional
336        // payload is available if exists.
337        hdr.get_size() != 0
338    }
339
340    /// Main entrance to request from the communication channel.
341    ///
342    /// Receive and handle one incoming request message from the frontend.
343    /// See [`BackendServer::recv_header`]'s doc comment for the usage.
344    ///
345    /// # Return:
346    /// * `Ok(())`: one request was successfully handled.
347    /// * `Err(ClientExit)`: the frontend closed the connection properly. This isn't an actual
348    ///   failure.
349    /// * `Err(Disconnect)`: the connection was closed unexpectedly.
350    /// * `Err(InvalidMessage)`: the vmm sent a illegal message.
351    /// * other errors: failed to handle a request.
352    pub fn process_message(
353        &mut self,
354        hdr: VhostUserMsgHeader<FrontendReq>,
355        files: Vec<File>,
356    ) -> Result<()> {
357        let (buf, extra_files) = self.connection.recv_body_bytes(&hdr)?;
358        let size = buf.len();
359        if !extra_files.is_empty() {
360            return Err(Error::InvalidMessage);
361        }
362
363        // TODO: The error handling here is inconsistent. Sometimes we report the error to the
364        // client and keep going, sometimes we report the error and then close the connection,
365        // sometimes we just close the connection.
366        match hdr.get_code() {
367            Ok(FrontendReq::SET_OWNER) => {
368                self.check_request_size(&hdr, size, 0)?;
369                let res = self.backend.set_owner();
370                self.send_ack_message(&hdr, res.is_ok())?;
371                res?;
372            }
373            Ok(FrontendReq::RESET_OWNER) => {
374                self.check_request_size(&hdr, size, 0)?;
375                let res = self.backend.reset_owner();
376                self.send_ack_message(&hdr, res.is_ok())?;
377                res?;
378            }
379            Ok(FrontendReq::GET_FEATURES) => {
380                self.check_request_size(&hdr, size, 0)?;
381                let mut features = self.backend.get_features()?;
382
383                // Don't advertise packed queues even if the device does. We don't handle them
384                // properly yet at the protocol layer.
385                // TODO: b/331466964 - Remove once support is added.
386                features &= !(1 << VIRTIO_F_RING_PACKED);
387
388                let msg = VhostUserU64::new(features);
389                self.send_reply_message(&hdr, &msg)?;
390                self.virtio_features = features;
391                self.update_reply_ack_flag();
392            }
393            Ok(FrontendReq::SET_FEATURES) => {
394                let mut msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
395
396                // Don't allow packed queues even if the device does. We don't handle them
397                // properly yet at the protocol layer.
398                // TODO: b/331466964 - Remove once support is added.
399                msg.value &= !(1 << VIRTIO_F_RING_PACKED);
400
401                let res = self.backend.set_features(msg.value);
402                self.acked_virtio_features = msg.value;
403                self.update_reply_ack_flag();
404                self.send_ack_message(&hdr, res.is_ok())?;
405                res?;
406            }
407            Ok(FrontendReq::SET_MEM_TABLE) => {
408                let res = self.set_mem_table(&hdr, size, &buf, files);
409                self.send_ack_message(&hdr, res.is_ok())?;
410                res?;
411            }
412            Ok(FrontendReq::SET_VRING_NUM) => {
413                let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
414                let res = self.backend.set_vring_num(msg.index, msg.num);
415                self.send_ack_message(&hdr, res.is_ok())?;
416                res?;
417            }
418            Ok(FrontendReq::SET_VRING_ADDR) => {
419                let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
420                let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
421                    Some(val) => val,
422                    None => return Err(Error::InvalidMessage),
423                };
424                let res = self.backend.set_vring_addr(
425                    msg.index,
426                    flags,
427                    msg.descriptor,
428                    msg.used,
429                    msg.available,
430                    msg.log,
431                );
432                self.send_ack_message(&hdr, res.is_ok())?;
433                res?;
434            }
435            Ok(FrontendReq::SET_VRING_BASE) => {
436                let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
437                let res = self.backend.set_vring_base(msg.index, msg.num);
438                self.send_ack_message(&hdr, res.is_ok())?;
439                res?;
440            }
441            Ok(FrontendReq::GET_VRING_BASE) => {
442                let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
443                let reply = self.backend.get_vring_base(msg.index)?;
444                self.send_reply_message(&hdr, &reply)?;
445            }
446            Ok(FrontendReq::SET_VRING_CALL) => {
447                self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
448                let (index, file) = self.handle_vring_fd_request(&buf, files)?;
449                let res = self.backend.set_vring_call(index, file);
450                self.send_ack_message(&hdr, res.is_ok())?;
451                res?;
452            }
453            Ok(FrontendReq::SET_VRING_KICK) => {
454                self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
455                let (index, file) = self.handle_vring_fd_request(&buf, files)?;
456                let res = self.backend.set_vring_kick(index, file);
457                self.send_ack_message(&hdr, res.is_ok())?;
458                res?;
459            }
460            Ok(FrontendReq::SET_VRING_ERR) => {
461                self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
462                let (index, file) = self.handle_vring_fd_request(&buf, files)?;
463                let res = self.backend.set_vring_err(index, file);
464                self.send_ack_message(&hdr, res.is_ok())?;
465                res?;
466            }
467            Ok(FrontendReq::GET_PROTOCOL_FEATURES) => {
468                self.check_request_size(&hdr, size, 0)?;
469                let features = self.backend.get_protocol_features()?;
470                let msg = VhostUserU64::new(features.bits());
471                self.send_reply_message(&hdr, &msg)?;
472                self.protocol_features = features;
473                self.update_reply_ack_flag();
474            }
475            Ok(FrontendReq::SET_PROTOCOL_FEATURES) => {
476                let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
477                let res = self.backend.set_protocol_features(msg.value);
478                self.acked_protocol_features = msg.value;
479                self.update_reply_ack_flag();
480                self.send_ack_message(&hdr, res.is_ok())?;
481                res?;
482            }
483            Ok(FrontendReq::GET_QUEUE_NUM) => {
484                if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
485                    return Err(Error::InvalidOperation);
486                }
487                self.check_request_size(&hdr, size, 0)?;
488                let num = self.backend.get_queue_num()?;
489                let msg = VhostUserU64::new(num);
490                self.send_reply_message(&hdr, &msg)?;
491            }
492            Ok(FrontendReq::SET_VRING_ENABLE) => {
493                let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
494                if self.acked_virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
495                    return Err(Error::InvalidOperation);
496                }
497                let enable = match msg.num {
498                    1 => true,
499                    0 => false,
500                    _ => {
501                        return Err(Error::InvalidParam(
502                            "SET_VRING_ENABLE: num out of range (must be [0, 1])",
503                        ))
504                    }
505                };
506
507                let res = self.backend.set_vring_enable(msg.index, enable);
508                self.send_ack_message(&hdr, res.is_ok())?;
509                res?;
510            }
511            Ok(FrontendReq::GET_CONFIG) => {
512                if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
513                    return Err(Error::InvalidOperation);
514                }
515                self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
516                self.get_config(&hdr, &buf)?;
517            }
518            Ok(FrontendReq::SET_CONFIG) => {
519                if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
520                    return Err(Error::InvalidOperation);
521                }
522                self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
523                let res = self.set_config(&buf);
524                self.send_ack_message(&hdr, res.is_ok())?;
525                res?;
526            }
527            Ok(FrontendReq::SET_BACKEND_REQ_FD) => {
528                if self.acked_protocol_features & VhostUserProtocolFeatures::BACKEND_REQ.bits() == 0
529                {
530                    return Err(Error::InvalidOperation);
531                }
532                self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
533                let res = self.set_backend_req_fd(files);
534                self.send_ack_message(&hdr, res.is_ok())?;
535                res?;
536            }
537            Ok(FrontendReq::GET_INFLIGHT_FD) => {
538                if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
539                    == 0
540                {
541                    return Err(Error::InvalidOperation);
542                }
543
544                let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
545                let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
546                let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?;
547                self.connection.send_message(
548                    &reply_hdr,
549                    &inflight,
550                    Some(&[file.as_raw_descriptor()]),
551                )?;
552            }
553            Ok(FrontendReq::SET_INFLIGHT_FD) => {
554                if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
555                    == 0
556                {
557                    return Err(Error::InvalidOperation);
558                }
559                let file = into_single_file(files).ok_or(Error::IncorrectFds)?;
560                let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
561                let res = self.backend.set_inflight_fd(&msg, file);
562                self.send_ack_message(&hdr, res.is_ok())?;
563                res?;
564            }
565            Ok(FrontendReq::GET_MAX_MEM_SLOTS) => {
566                if self.acked_protocol_features
567                    & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
568                    == 0
569                {
570                    return Err(Error::InvalidOperation);
571                }
572                self.check_request_size(&hdr, size, 0)?;
573                let num = self.backend.get_max_mem_slots()?;
574                let msg = VhostUserU64::new(num);
575                self.send_reply_message(&hdr, &msg)?;
576            }
577            Ok(FrontendReq::ADD_MEM_REG) => {
578                if self.acked_protocol_features
579                    & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
580                    == 0
581                {
582                    return Err(Error::InvalidOperation);
583                }
584                let file = into_single_file(files).ok_or(Error::InvalidParam(
585                    "ADD_MEM_REG: exactly one file must be provided",
586                ))?;
587                let msg =
588                    self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
589                let res = self.backend.add_mem_region(&msg, file);
590                self.send_ack_message(&hdr, res.is_ok())?;
591                res?;
592            }
593            Ok(FrontendReq::REM_MEM_REG) => {
594                if self.acked_protocol_features
595                    & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
596                    == 0
597                {
598                    return Err(Error::InvalidOperation);
599                }
600
601                let msg =
602                    self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
603                let res = self.backend.remove_mem_region(&msg);
604                self.send_ack_message(&hdr, res.is_ok())?;
605                res?;
606            }
607            Ok(FrontendReq::SET_DEVICE_STATE_FD) => {
608                if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
609                    == 0
610                {
611                    return Err(Error::InvalidOperation);
612                }
613                // Read request.
614                let msg =
615                    self.extract_request_body::<DeviceStateTransferParameters>(&hdr, size, &buf)?;
616                let transfer_direction = match msg.transfer_direction {
617                    0 => VhostUserTransferDirection::Save,
618                    1 => VhostUserTransferDirection::Load,
619                    _ => return Err(Error::InvalidMessage),
620                };
621                let migration_phase = match msg.migration_phase {
622                    0 => VhostUserMigrationPhase::Stopped,
623                    _ => return Err(Error::InvalidMessage),
624                };
625                // Call backend.
626                let res = self.backend.set_device_state_fd(
627                    transfer_direction,
628                    migration_phase,
629                    files.into_iter().next().ok_or(Error::IncorrectFds)?,
630                );
631                // Send response.
632                let (msg, fds) = match &res {
633                    Ok(None) => (VhostUserU64::new(0x100), None),
634                    Ok(Some(file)) => (VhostUserU64::new(0), Some(file.as_raw_descriptor())),
635                    // Just in case, set the "invalid FD" flag on error.
636                    Err(_) => (VhostUserU64::new(0x101), None),
637                };
638                let reply_hdr: VhostUserMsgHeader<FrontendReq> =
639                    self.new_reply_header::<VhostUserU64>(&hdr, 0)?;
640                self.connection.send_message(
641                    &reply_hdr,
642                    &msg,
643                    fds.as_ref().map(std::slice::from_ref),
644                )?;
645                res?;
646            }
647            Ok(FrontendReq::CHECK_DEVICE_STATE) => {
648                if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
649                    == 0
650                {
651                    return Err(Error::InvalidOperation);
652                }
653                let res = self.backend.check_device_state();
654                let msg = VhostUserU64::new(if res.is_ok() { 0 } else { 1 });
655                self.send_reply_message(&hdr, &msg)?;
656                res?;
657            }
658            Ok(FrontendReq::GET_SHMEM_CONFIG) => {
659                let (msg, sizes) = self.backend.get_shmem_config()?;
660                let mut buf = Vec::new();
661                for e in sizes {
662                    buf.extend_from_slice(e.as_bytes())
663                }
664                self.send_reply_with_payload(&hdr, &msg, buf.as_slice())?;
665            }
666            _ => {
667                return Err(Error::InvalidMessage);
668            }
669        }
670        Ok(())
671    }
672
673    fn new_reply_header<T: Sized>(
674        &self,
675        req: &VhostUserMsgHeader<FrontendReq>,
676        payload_size: usize,
677    ) -> Result<VhostUserMsgHeader<FrontendReq>> {
678        Ok(VhostUserMsgHeader::new(
679            req.get_code().map_err(|_| Error::InvalidMessage)?,
680            VhostUserHeaderFlag::REPLY.bits(),
681            (mem::size_of::<T>()
682                .checked_add(payload_size)
683                .ok_or(Error::OversizedMsg)?)
684            .try_into()
685            .map_err(Error::InvalidCastToInt)?,
686        ))
687    }
688
689    /// Sends reply back to Vhost frontend in response to a message.
690    fn send_ack_message(
691        &mut self,
692        req: &VhostUserMsgHeader<FrontendReq>,
693        success: bool,
694    ) -> Result<()> {
695        if self.reply_ack_enabled && req.is_need_reply() {
696            let hdr: VhostUserMsgHeader<FrontendReq> =
697                self.new_reply_header::<VhostUserU64>(req, 0)?;
698            let val = if success { 0 } else { 1 };
699            let msg = VhostUserU64::new(val);
700            self.connection.send_message(&hdr, &msg, None)?;
701        }
702        Ok(())
703    }
704
705    fn send_reply_message<T: IntoBytes + Immutable>(
706        &mut self,
707        req: &VhostUserMsgHeader<FrontendReq>,
708        msg: &T,
709    ) -> Result<()> {
710        let hdr = self.new_reply_header::<T>(req, 0)?;
711        self.connection.send_message(&hdr, msg, None)?;
712        Ok(())
713    }
714
715    fn send_reply_with_payload<T: IntoBytes + Immutable>(
716        &mut self,
717        req: &VhostUserMsgHeader<FrontendReq>,
718        msg: &T,
719        payload: &[u8],
720    ) -> Result<()> {
721        let hdr = self.new_reply_header::<T>(req, payload.len())?;
722        self.connection
723            .send_message_with_payload(&hdr, msg, payload, None)?;
724        Ok(())
725    }
726
727    fn set_mem_table(
728        &mut self,
729        hdr: &VhostUserMsgHeader<FrontendReq>,
730        size: usize,
731        buf: &[u8],
732        files: Vec<File>,
733    ) -> Result<()> {
734        self.check_request_size(hdr, size, hdr.get_size() as usize)?;
735
736        let (msg, regions) =
737            Ref::<_, VhostUserMemory>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
738        if !msg.is_valid() {
739            return Err(Error::InvalidMessage);
740        }
741
742        // validate number of fds matching number of memory regions
743        if files.len() != msg.num_regions as usize {
744            return Err(Error::InvalidMessage);
745        }
746
747        let (regions, excess) = Ref::<_, [VhostUserMemoryRegion]>::from_prefix_with_elems(
748            regions,
749            msg.num_regions as usize,
750        )
751        .map_err(|_| Error::InvalidMessage)?;
752        if !excess.is_empty() {
753            return Err(Error::InvalidMessage);
754        }
755
756        // Validate memory regions
757        for region in regions.iter() {
758            if !region.is_valid() {
759                return Err(Error::InvalidMessage);
760            }
761        }
762
763        self.backend.set_mem_table(&regions, files)
764    }
765
766    fn get_config(&mut self, hdr: &VhostUserMsgHeader<FrontendReq>, buf: &[u8]) -> Result<()> {
767        let (msg, payload) =
768            Ref::<_, VhostUserConfig>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
769        if !msg.is_valid() {
770            return Err(Error::InvalidMessage);
771        }
772        if payload.len() != msg.size as usize {
773            return Err(Error::InvalidMessage);
774        }
775        let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
776            Some(val) => val,
777            None => return Err(Error::InvalidMessage),
778        };
779        let res = self.backend.get_config(msg.offset, msg.size, flags);
780
781        // The response payload size MUST match the request payload size on success. A zero length
782        // response is used to indicate an error.
783        match res {
784            Ok(ref buf) if buf.len() == msg.size as usize => {
785                let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
786                self.send_reply_with_payload(hdr, &reply, buf.as_slice())?;
787            }
788            Ok(_) => {
789                let reply = VhostUserConfig::new(msg.offset, 0, flags);
790                self.send_reply_message(hdr, &reply)?;
791            }
792            Err(_) => {
793                let reply = VhostUserConfig::new(msg.offset, 0, flags);
794                self.send_reply_message(hdr, &reply)?;
795            }
796        }
797        Ok(())
798    }
799
800    fn set_config(&mut self, buf: &[u8]) -> Result<()> {
801        let (msg, payload) =
802            Ref::<_, VhostUserConfig>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
803        if !msg.is_valid() {
804            return Err(Error::InvalidMessage);
805        }
806        if payload.len() != msg.size as usize {
807            return Err(Error::InvalidMessage);
808        }
809        let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) {
810            Some(val) => val,
811            None => return Err(Error::InvalidMessage),
812        };
813
814        self.backend.set_config(msg.offset, payload, flags)
815    }
816
817    fn set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()> {
818        let file = into_single_file(files).ok_or(Error::InvalidMessage)?;
819        let fd: SafeDescriptor = file.into();
820        let connection = Connection::try_from(fd).map_err(|_| Error::InvalidMessage)?;
821        self.backend.set_backend_req_fd(connection);
822        Ok(())
823    }
824
825    /// Parses an incoming |SET_VRING_KICK| or |SET_VRING_CALL| message into a
826    /// Vring number and an fd.
827    fn handle_vring_fd_request(
828        &mut self,
829        buf: &[u8],
830        files: Vec<File>,
831    ) -> Result<(u8, Option<File>)> {
832        let (msg, _) = VhostUserU64::read_from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
833        if !msg.is_valid() {
834            return Err(Error::InvalidMessage);
835        }
836
837        // Bits (0-7) of the payload contain the vring index. Bit 8 is the
838        // invalid FD flag (VHOST_USER_VRING_NOFD_MASK).
839        // This bit is set when there is no file descriptor
840        // in the ancillary data. This signals that polling will be used
841        // instead of waiting for the call.
842        // If Bit 8 is unset, the data must contain a file descriptor.
843        let has_fd = (msg.value & 0x100u64) == 0;
844
845        let file = into_single_file(files);
846
847        if has_fd && file.is_none() || !has_fd && file.is_some() {
848            return Err(Error::InvalidMessage);
849        }
850
851        Ok((msg.value as u8, file))
852    }
853
854    fn check_request_size(
855        &self,
856        hdr: &VhostUserMsgHeader<FrontendReq>,
857        size: usize,
858        expected: usize,
859    ) -> Result<()> {
860        if hdr.get_size() as usize != expected
861            || hdr.is_reply()
862            || hdr.get_version() != 0x1
863            || size != expected
864        {
865            return Err(Error::InvalidMessage);
866        }
867        Ok(())
868    }
869
870    fn check_attached_files(
871        &self,
872        hdr: &VhostUserMsgHeader<FrontendReq>,
873        files: &[File],
874    ) -> Result<()> {
875        match hdr.get_code() {
876            Ok(FrontendReq::SET_MEM_TABLE)
877            | Ok(FrontendReq::SET_VRING_CALL)
878            | Ok(FrontendReq::SET_VRING_KICK)
879            | Ok(FrontendReq::SET_VRING_ERR)
880            | Ok(FrontendReq::SET_LOG_BASE)
881            | Ok(FrontendReq::SET_LOG_FD)
882            | Ok(FrontendReq::SET_BACKEND_REQ_FD)
883            | Ok(FrontendReq::SET_INFLIGHT_FD)
884            | Ok(FrontendReq::ADD_MEM_REG)
885            | Ok(FrontendReq::SET_DEVICE_STATE_FD) => Ok(()),
886            Err(_) => Err(Error::InvalidMessage),
887            _ if !files.is_empty() => Err(Error::InvalidMessage),
888            _ => Ok(()),
889        }
890    }
891
892    fn extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>(
893        &self,
894        hdr: &VhostUserMsgHeader<FrontendReq>,
895        size: usize,
896        buf: &[u8],
897    ) -> Result<T> {
898        self.check_request_size(hdr, size, mem::size_of::<T>())?;
899        let (body, _) = T::read_from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
900        if body.is_valid() {
901            Ok(body)
902        } else {
903            Err(Error::InvalidMessage)
904        }
905    }
906
907    fn update_reply_ack_flag(&mut self) {
908        let pflag = VhostUserProtocolFeatures::REPLY_ACK;
909        self.reply_ack_enabled = (self.virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES) != 0
910            && self.protocol_features.contains(pflag)
911            && (self.acked_protocol_features & pflag.bits()) != 0;
912    }
913}
914
915impl<S: Backend> AsRawDescriptor for BackendServer<S> {
916    fn as_raw_descriptor(&self) -> RawDescriptor {
917        // TODO(b/221882601): figure out if this used for polling.
918        self.connection.as_raw_descriptor()
919    }
920}
921
922#[cfg(test)]
923mod tests {
924    use base::INVALID_DESCRIPTOR;
925
926    use super::*;
927    use crate::test_backend::TestBackend;
928    use crate::Connection;
929
930    #[test]
931    fn test_backend_server_new() {
932        let (p1, _p2) = Connection::pair().unwrap();
933        let backend = TestBackend::new();
934        let handler = BackendServer::new(p1, backend);
935
936        assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR);
937    }
938}