vmm_vhost/
backend_server.rs

1// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fs::File;
5use std::mem;
6
7use base::AsRawDescriptor;
8use base::RawDescriptor;
9use base::SafeDescriptor;
10use zerocopy::FromBytes;
11use zerocopy::Immutable;
12use zerocopy::IntoBytes;
13use zerocopy::Ref;
14
15use crate::into_single_file;
16use crate::message::*;
17use crate::BackendReq;
18use crate::Connection;
19use crate::Error;
20use crate::FrontendReq;
21use crate::Result;
22
23/// Trait for vhost-user backends.
24///
25/// Each method corresponds to a vhost-user protocol method. See the specification for details.
26#[allow(missing_docs)]
27pub trait Backend {
28    fn set_owner(&mut self) -> Result<()>;
29    fn reset_owner(&mut self) -> Result<()>;
30    fn get_features(&mut self) -> Result<u64>;
31    fn set_features(&mut self, features: u64) -> Result<()>;
32    fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
33    fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
34    fn set_vring_addr(
35        &mut self,
36        index: u32,
37        flags: VhostUserVringAddrFlags,
38        descriptor: u64,
39        used: u64,
40        available: u64,
41        log: u64,
42    ) -> Result<()>;
43    // TODO: b/331466964 - Argument type is wrong for packed queues.
44    fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
45    // TODO: b/331466964 - Return type is wrong for packed queues.
46    fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
47    fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
48    fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
49    fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
50
51    fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
52    fn set_protocol_features(&mut self, features: u64) -> Result<()>;
53    fn get_queue_num(&mut self) -> Result<u64>;
54    fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
55    fn get_config(
56        &mut self,
57        offset: u32,
58        size: u32,
59        flags: VhostUserConfigFlags,
60    ) -> Result<Vec<u8>>;
61    fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
62    fn set_backend_req_fd(&mut self, _vu_req: Connection<BackendReq>) {}
63    fn get_inflight_fd(
64        &mut self,
65        inflight: &VhostUserInflight,
66    ) -> Result<(VhostUserInflight, File)>;
67    fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
68    fn get_max_mem_slots(&mut self) -> Result<u64>;
69    fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
70    fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
71    fn set_device_state_fd(
72        &mut self,
73        transfer_direction: VhostUserTransferDirection,
74        migration_phase: VhostUserMigrationPhase,
75        fd: File,
76    ) -> Result<Option<File>>;
77    fn check_device_state(&mut self) -> Result<()>;
78    fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>;
79}
80
81impl<T> Backend for T
82where
83    T: AsMut<dyn Backend>,
84{
85    fn set_owner(&mut self) -> Result<()> {
86        self.as_mut().set_owner()
87    }
88
89    fn reset_owner(&mut self) -> Result<()> {
90        self.as_mut().reset_owner()
91    }
92
93    fn get_features(&mut self) -> Result<u64> {
94        self.as_mut().get_features()
95    }
96
97    fn set_features(&mut self, features: u64) -> Result<()> {
98        self.as_mut().set_features(features)
99    }
100
101    fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
102        self.as_mut().set_mem_table(ctx, files)
103    }
104
105    fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
106        self.as_mut().set_vring_num(index, num)
107    }
108
109    fn set_vring_addr(
110        &mut self,
111        index: u32,
112        flags: VhostUserVringAddrFlags,
113        descriptor: u64,
114        used: u64,
115        available: u64,
116        log: u64,
117    ) -> Result<()> {
118        self.as_mut()
119            .set_vring_addr(index, flags, descriptor, used, available, log)
120    }
121
122    fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> {
123        self.as_mut().set_vring_base(index, base)
124    }
125
126    fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> {
127        self.as_mut().get_vring_base(index)
128    }
129
130    fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> {
131        self.as_mut().set_vring_kick(index, fd)
132    }
133
134    fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> {
135        self.as_mut().set_vring_call(index, fd)
136    }
137
138    fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> {
139        self.as_mut().set_vring_err(index, fd)
140    }
141
142    fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
143        self.as_mut().get_protocol_features()
144    }
145
146    fn set_protocol_features(&mut self, features: u64) -> Result<()> {
147        self.as_mut().set_protocol_features(features)
148    }
149
150    fn get_queue_num(&mut self) -> Result<u64> {
151        self.as_mut().get_queue_num()
152    }
153
154    fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
155        self.as_mut().set_vring_enable(index, enable)
156    }
157
158    fn get_config(
159        &mut self,
160        offset: u32,
161        size: u32,
162        flags: VhostUserConfigFlags,
163    ) -> Result<Vec<u8>> {
164        self.as_mut().get_config(offset, size, flags)
165    }
166
167    fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
168        self.as_mut().set_config(offset, buf, flags)
169    }
170
171    fn set_backend_req_fd(&mut self, vu_req: Connection<BackendReq>) {
172        self.as_mut().set_backend_req_fd(vu_req)
173    }
174
175    fn get_inflight_fd(
176        &mut self,
177        inflight: &VhostUserInflight,
178    ) -> Result<(VhostUserInflight, File)> {
179        self.as_mut().get_inflight_fd(inflight)
180    }
181
182    fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()> {
183        self.as_mut().set_inflight_fd(inflight, file)
184    }
185
186    fn get_max_mem_slots(&mut self) -> Result<u64> {
187        self.as_mut().get_max_mem_slots()
188    }
189
190    fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
191        self.as_mut().add_mem_region(region, fd)
192    }
193
194    fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
195        self.as_mut().remove_mem_region(region)
196    }
197
198    fn set_device_state_fd(
199        &mut self,
200        transfer_direction: VhostUserTransferDirection,
201        migration_phase: VhostUserMigrationPhase,
202        fd: File,
203    ) -> Result<Option<File>> {
204        self.as_mut()
205            .set_device_state_fd(transfer_direction, migration_phase, fd)
206    }
207
208    fn check_device_state(&mut self) -> Result<()> {
209        self.as_mut().check_device_state()
210    }
211
212    fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>> {
213        self.as_mut().get_shared_memory_regions()
214    }
215}
216
217/// Handles requests from a vhost-user connection by dispatching them to [[Backend]] methods.
218pub struct BackendServer<S: Backend> {
219    /// Underlying connection for communication.
220    connection: Connection<FrontendReq>,
221    // the vhost-user backend device object
222    backend: S,
223
224    virtio_features: u64,
225    acked_virtio_features: u64,
226    protocol_features: VhostUserProtocolFeatures,
227    acked_protocol_features: u64,
228
229    /// Sending ack for messages without payload.
230    reply_ack_enabled: bool,
231}
232
233impl<S: Backend> AsRef<S> for BackendServer<S> {
234    fn as_ref(&self) -> &S {
235        &self.backend
236    }
237}
238
239impl<S: Backend> BackendServer<S> {
240    pub fn new(connection: Connection<FrontendReq>, backend: S) -> Self {
241        BackendServer {
242            connection,
243            backend,
244            virtio_features: 0,
245            acked_virtio_features: 0,
246            protocol_features: VhostUserProtocolFeatures::empty(),
247            acked_protocol_features: 0,
248            reply_ack_enabled: false,
249        }
250    }
251
252    /// Receives and validates a vhost-user message header and optional files.
253    ///
254    /// Since the length of vhost-user messages are different among message types, regular
255    /// vhost-user messages are sent via an underlying communication channel in stream mode.
256    /// (e.g. `SOCK_STREAM` in UNIX)
257    /// So, the logic of receiving and handling a message consists of the following steps:
258    ///
259    /// 1. Receives a message header and optional attached file.
260    /// 2. Validates the message header.
261    /// 3. Check if optional payloads is expected.
262    /// 4. Wait for the optional payloads.
263    /// 5. Receives optional payloads.
264    /// 6. Processes the message.
265    ///
266    /// This method [`BackendServer::recv_header()`] is in charge of the step (1) and (2),
267    /// [`BackendServer::needs_wait_for_payload()`] is (3), and
268    /// [`BackendServer::process_message()`] is (5) and (6). We need to have the three method
269    /// separately for multi-platform supports; [`BackendServer::recv_header()`] and
270    /// [`BackendServer::process_message()`] need to be separated because the way of waiting for
271    /// incoming messages differs between Unix and Windows so it's the caller's responsibility to
272    /// wait before [`BackendServer::process_message()`].
273    ///
274    /// Note that some vhost-user protocol variant such as VVU doesn't assume stream mode. In this
275    /// case, a message header and its body are sent together so the step (4) is skipped. We handle
276    /// this case in [`BackendServer::needs_wait_for_payload()`].
277    ///
278    /// The following pseudo code describes how a caller should process incoming vhost-user
279    /// messages:
280    /// ```ignore
281    /// loop {
282    ///   // block until a message header comes.
283    ///   // The actual code differs, depending on platforms.
284    ///   connection.wait_readable().unwrap();
285    ///
286    ///   // (1) and (2)
287    ///   let (hdr, files) = backend_server.recv_header();
288    ///
289    ///   // (3)
290    ///   if backend_server.needs_wait_for_payload(&hdr) {
291    ///     // (4) block until a payload comes if needed.
292    ///     connection.wait_readable().unwrap();
293    ///   }
294    ///
295    ///   // (5) and (6)
296    ///   backend_server.process_message(&hdr, &files).unwrap();
297    /// }
298    /// ```
299    pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<FrontendReq>, Vec<File>)> {
300        // The underlying communication channel is a Unix domain socket in
301        // stream mode, and recvmsg() is a little tricky here. To successfully
302        // receive attached file descriptors, we need to receive messages and
303        // corresponding attached file descriptors in this way:
304        // . recv messsage header and optional attached file
305        // . validate message header
306        // . recv optional message body and payload according size field in
307        //   message header
308        // . validate message body and optional payload
309        let (hdr, files) = match self.connection.recv_header() {
310            Ok((hdr, files)) => (hdr, files),
311            Err(Error::Disconnect) => {
312                // If the client closed the connection before sending a header, this should be
313                // handled as a legal exit.
314                return Err(Error::ClientExit);
315            }
316            Err(e) => {
317                return Err(e);
318            }
319        };
320
321        if !hdr.is_valid() {
322            return Err(Error::InvalidMessage);
323        }
324
325        self.check_attached_files(&hdr, &files)?;
326
327        Ok((hdr, files))
328    }
329
330    /// Returns whether the caller needs to wait for the incoming message before calling
331    /// [`BackendServer::process_message`].
332    ///
333    /// See [`BackendServer::recv_header`]'s doc comment for the usage.
334    pub fn needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> bool {
335        // Since the vhost-user protocol uses stream mode, we need to wait until an additional
336        // payload is available if exists.
337        hdr.get_size() != 0
338    }
339
340    /// Main entrance to request from the communication channel.
341    ///
342    /// Receive and handle one incoming request message from the frontend.
343    /// See [`BackendServer::recv_header`]'s doc comment for the usage.
344    ///
345    /// # Return:
346    /// * `Ok(())`: one request was successfully handled.
347    /// * `Err(ClientExit)`: the frontend closed the connection properly. This isn't an actual
348    ///   failure.
349    /// * `Err(Disconnect)`: the connection was closed unexpectedly.
350    /// * `Err(InvalidMessage)`: the vmm sent a illegal message.
351    /// * other errors: failed to handle a request.
352    pub fn process_message(
353        &mut self,
354        hdr: VhostUserMsgHeader<FrontendReq>,
355        files: Vec<File>,
356    ) -> Result<()> {
357        let (buf, extra_files) = self.connection.recv_body_bytes(&hdr)?;
358        let size = buf.len();
359        if !extra_files.is_empty() {
360            return Err(Error::InvalidMessage);
361        }
362
363        // TODO: The error handling here is inconsistent. Sometimes we report the error to the
364        // client and keep going, sometimes we report the error and then close the connection,
365        // sometimes we just close the connection.
366        match hdr.get_code() {
367            Ok(FrontendReq::SET_OWNER) => {
368                self.check_request_size(&hdr, size, 0)?;
369                let res = self.backend.set_owner();
370                self.send_ack_message(&hdr, res.is_ok())?;
371                res?;
372            }
373            Ok(FrontendReq::RESET_OWNER) => {
374                self.check_request_size(&hdr, size, 0)?;
375                let res = self.backend.reset_owner();
376                self.send_ack_message(&hdr, res.is_ok())?;
377                res?;
378            }
379            Ok(FrontendReq::GET_FEATURES) => {
380                self.check_request_size(&hdr, size, 0)?;
381                let mut features = self.backend.get_features()?;
382
383                // Don't advertise packed queues even if the device does. We don't handle them
384                // properly yet at the protocol layer.
385                // TODO: b/331466964 - Remove once support is added.
386                features &= !(1 << VIRTIO_F_RING_PACKED);
387
388                let msg = VhostUserU64::new(features);
389                self.send_reply_message(&hdr, &msg)?;
390                self.virtio_features = features;
391                self.update_reply_ack_flag();
392            }
393            Ok(FrontendReq::SET_FEATURES) => {
394                let mut msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
395
396                // Don't allow packed queues even if the device does. We don't handle them
397                // properly yet at the protocol layer.
398                // TODO: b/331466964 - Remove once support is added.
399                msg.value &= !(1 << VIRTIO_F_RING_PACKED);
400
401                let res = self.backend.set_features(msg.value);
402                self.acked_virtio_features = msg.value;
403                self.update_reply_ack_flag();
404                self.send_ack_message(&hdr, res.is_ok())?;
405                res?;
406            }
407            Ok(FrontendReq::SET_MEM_TABLE) => {
408                let res = self.set_mem_table(&hdr, size, &buf, files);
409                self.send_ack_message(&hdr, res.is_ok())?;
410                res?;
411            }
412            Ok(FrontendReq::SET_VRING_NUM) => {
413                let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
414                let res = self.backend.set_vring_num(msg.index, msg.num);
415                self.send_ack_message(&hdr, res.is_ok())?;
416                res?;
417            }
418            Ok(FrontendReq::SET_VRING_ADDR) => {
419                let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
420                let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
421                    Some(val) => val,
422                    None => return Err(Error::InvalidMessage),
423                };
424                let res = self.backend.set_vring_addr(
425                    msg.index,
426                    flags,
427                    msg.descriptor,
428                    msg.used,
429                    msg.available,
430                    msg.log,
431                );
432                self.send_ack_message(&hdr, res.is_ok())?;
433                res?;
434            }
435            Ok(FrontendReq::SET_VRING_BASE) => {
436                let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
437                let res = self.backend.set_vring_base(msg.index, msg.num);
438                self.send_ack_message(&hdr, res.is_ok())?;
439                res?;
440            }
441            Ok(FrontendReq::GET_VRING_BASE) => {
442                let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
443                let reply = self.backend.get_vring_base(msg.index)?;
444                self.send_reply_message(&hdr, &reply)?;
445            }
446            Ok(FrontendReq::SET_VRING_CALL) => {
447                self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
448                let (index, file) = self.handle_vring_fd_request(&buf, files)?;
449                let res = self.backend.set_vring_call(index, file);
450                self.send_ack_message(&hdr, res.is_ok())?;
451                res?;
452            }
453            Ok(FrontendReq::SET_VRING_KICK) => {
454                self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
455                let (index, file) = self.handle_vring_fd_request(&buf, files)?;
456                let res = self.backend.set_vring_kick(index, file);
457                self.send_ack_message(&hdr, res.is_ok())?;
458                res?;
459            }
460            Ok(FrontendReq::SET_VRING_ERR) => {
461                self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
462                let (index, file) = self.handle_vring_fd_request(&buf, files)?;
463                let res = self.backend.set_vring_err(index, file);
464                self.send_ack_message(&hdr, res.is_ok())?;
465                res?;
466            }
467            Ok(FrontendReq::GET_PROTOCOL_FEATURES) => {
468                self.check_request_size(&hdr, size, 0)?;
469                let features = self.backend.get_protocol_features()?;
470                let msg = VhostUserU64::new(features.bits());
471                self.send_reply_message(&hdr, &msg)?;
472                self.protocol_features = features;
473                self.update_reply_ack_flag();
474            }
475            Ok(FrontendReq::SET_PROTOCOL_FEATURES) => {
476                let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
477                let res = self.backend.set_protocol_features(msg.value);
478                self.acked_protocol_features = msg.value;
479                self.update_reply_ack_flag();
480                self.send_ack_message(&hdr, res.is_ok())?;
481                res?;
482            }
483            Ok(FrontendReq::GET_QUEUE_NUM) => {
484                if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
485                    return Err(Error::InvalidOperation);
486                }
487                self.check_request_size(&hdr, size, 0)?;
488                let num = self.backend.get_queue_num()?;
489                let msg = VhostUserU64::new(num);
490                self.send_reply_message(&hdr, &msg)?;
491            }
492            Ok(FrontendReq::SET_VRING_ENABLE) => {
493                let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
494                if self.acked_virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
495                    return Err(Error::InvalidOperation);
496                }
497                let enable = match msg.num {
498                    1 => true,
499                    0 => false,
500                    _ => {
501                        return Err(Error::InvalidParam(
502                            "SET_VRING_ENABLE: num out of range (must be [0, 1])",
503                        ))
504                    }
505                };
506
507                let res = self.backend.set_vring_enable(msg.index, enable);
508                self.send_ack_message(&hdr, res.is_ok())?;
509                res?;
510            }
511            Ok(FrontendReq::GET_CONFIG) => {
512                if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
513                    return Err(Error::InvalidOperation);
514                }
515                self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
516                self.get_config(&hdr, &buf)?;
517            }
518            Ok(FrontendReq::SET_CONFIG) => {
519                if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
520                    return Err(Error::InvalidOperation);
521                }
522                self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
523                let res = self.set_config(&buf);
524                self.send_ack_message(&hdr, res.is_ok())?;
525                res?;
526            }
527            Ok(FrontendReq::SET_BACKEND_REQ_FD) => {
528                if self.acked_protocol_features & VhostUserProtocolFeatures::BACKEND_REQ.bits() == 0
529                {
530                    return Err(Error::InvalidOperation);
531                }
532                self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
533                let res = self.set_backend_req_fd(files);
534                self.send_ack_message(&hdr, res.is_ok())?;
535                res?;
536            }
537            Ok(FrontendReq::GET_INFLIGHT_FD) => {
538                if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
539                    == 0
540                {
541                    return Err(Error::InvalidOperation);
542                }
543
544                let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
545                let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
546                let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?;
547                self.connection.send_message(
548                    &reply_hdr,
549                    &inflight,
550                    Some(&[file.as_raw_descriptor()]),
551                )?;
552            }
553            Ok(FrontendReq::SET_INFLIGHT_FD) => {
554                if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
555                    == 0
556                {
557                    return Err(Error::InvalidOperation);
558                }
559                let file = into_single_file(files).ok_or(Error::IncorrectFds)?;
560                let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
561                let res = self.backend.set_inflight_fd(&msg, file);
562                self.send_ack_message(&hdr, res.is_ok())?;
563                res?;
564            }
565            Ok(FrontendReq::GET_MAX_MEM_SLOTS) => {
566                if self.acked_protocol_features
567                    & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
568                    == 0
569                {
570                    return Err(Error::InvalidOperation);
571                }
572                self.check_request_size(&hdr, size, 0)?;
573                let num = self.backend.get_max_mem_slots()?;
574                let msg = VhostUserU64::new(num);
575                self.send_reply_message(&hdr, &msg)?;
576            }
577            Ok(FrontendReq::ADD_MEM_REG) => {
578                if self.acked_protocol_features
579                    & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
580                    == 0
581                {
582                    return Err(Error::InvalidOperation);
583                }
584                let file = into_single_file(files).ok_or(Error::InvalidParam(
585                    "ADD_MEM_REG: exactly one file must be provided",
586                ))?;
587                let msg =
588                    self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
589                let res = self.backend.add_mem_region(&msg, file);
590                self.send_ack_message(&hdr, res.is_ok())?;
591                res?;
592            }
593            Ok(FrontendReq::REM_MEM_REG) => {
594                if self.acked_protocol_features
595                    & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
596                    == 0
597                {
598                    return Err(Error::InvalidOperation);
599                }
600
601                let msg =
602                    self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
603                let res = self.backend.remove_mem_region(&msg);
604                self.send_ack_message(&hdr, res.is_ok())?;
605                res?;
606            }
607            Ok(FrontendReq::SET_DEVICE_STATE_FD) => {
608                if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
609                    == 0
610                {
611                    return Err(Error::InvalidOperation);
612                }
613                // Read request.
614                let msg =
615                    self.extract_request_body::<DeviceStateTransferParameters>(&hdr, size, &buf)?;
616                let transfer_direction = match msg.transfer_direction {
617                    0 => VhostUserTransferDirection::Save,
618                    1 => VhostUserTransferDirection::Load,
619                    _ => return Err(Error::InvalidMessage),
620                };
621                let migration_phase = match msg.migration_phase {
622                    0 => VhostUserMigrationPhase::Stopped,
623                    _ => return Err(Error::InvalidMessage),
624                };
625                // Call backend.
626                let res = self.backend.set_device_state_fd(
627                    transfer_direction,
628                    migration_phase,
629                    files.into_iter().next().ok_or(Error::IncorrectFds)?,
630                );
631                // Send response.
632                let (msg, fds) = match &res {
633                    Ok(None) => (VhostUserU64::new(0x100), None),
634                    Ok(Some(file)) => (VhostUserU64::new(0), Some(file.as_raw_descriptor())),
635                    // Just in case, set the "invalid FD" flag on error.
636                    Err(_) => (VhostUserU64::new(0x101), None),
637                };
638                let reply_hdr: VhostUserMsgHeader<FrontendReq> =
639                    self.new_reply_header::<VhostUserU64>(&hdr, 0)?;
640                self.connection.send_message(
641                    &reply_hdr,
642                    &msg,
643                    fds.as_ref().map(std::slice::from_ref),
644                )?;
645                res?;
646            }
647            Ok(FrontendReq::CHECK_DEVICE_STATE) => {
648                if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
649                    == 0
650                {
651                    return Err(Error::InvalidOperation);
652                }
653                let res = self.backend.check_device_state();
654                let msg = VhostUserU64::new(if res.is_ok() { 0 } else { 1 });
655                self.send_reply_message(&hdr, &msg)?;
656                res?;
657            }
658            Ok(FrontendReq::GET_SHARED_MEMORY_REGIONS) => {
659                let regions = self.backend.get_shared_memory_regions()?;
660                let mut buf = Vec::new();
661                let msg = VhostUserU64::new(regions.len() as u64);
662                for r in regions {
663                    buf.extend_from_slice(r.as_bytes())
664                }
665                self.send_reply_with_payload(&hdr, &msg, buf.as_slice())?;
666            }
667            _ => {
668                return Err(Error::InvalidMessage);
669            }
670        }
671        Ok(())
672    }
673
674    fn new_reply_header<T: Sized>(
675        &self,
676        req: &VhostUserMsgHeader<FrontendReq>,
677        payload_size: usize,
678    ) -> Result<VhostUserMsgHeader<FrontendReq>> {
679        Ok(VhostUserMsgHeader::new(
680            req.get_code().map_err(|_| Error::InvalidMessage)?,
681            VhostUserHeaderFlag::REPLY.bits(),
682            (mem::size_of::<T>()
683                .checked_add(payload_size)
684                .ok_or(Error::OversizedMsg)?)
685            .try_into()
686            .map_err(Error::InvalidCastToInt)?,
687        ))
688    }
689
690    /// Sends reply back to Vhost frontend in response to a message.
691    fn send_ack_message(
692        &mut self,
693        req: &VhostUserMsgHeader<FrontendReq>,
694        success: bool,
695    ) -> Result<()> {
696        if self.reply_ack_enabled && req.is_need_reply() {
697            let hdr: VhostUserMsgHeader<FrontendReq> =
698                self.new_reply_header::<VhostUserU64>(req, 0)?;
699            let val = if success { 0 } else { 1 };
700            let msg = VhostUserU64::new(val);
701            self.connection.send_message(&hdr, &msg, None)?;
702        }
703        Ok(())
704    }
705
706    fn send_reply_message<T: IntoBytes + Immutable>(
707        &mut self,
708        req: &VhostUserMsgHeader<FrontendReq>,
709        msg: &T,
710    ) -> Result<()> {
711        let hdr = self.new_reply_header::<T>(req, 0)?;
712        self.connection.send_message(&hdr, msg, None)?;
713        Ok(())
714    }
715
716    fn send_reply_with_payload<T: IntoBytes + Immutable>(
717        &mut self,
718        req: &VhostUserMsgHeader<FrontendReq>,
719        msg: &T,
720        payload: &[u8],
721    ) -> Result<()> {
722        let hdr = self.new_reply_header::<T>(req, payload.len())?;
723        self.connection
724            .send_message_with_payload(&hdr, msg, payload, None)?;
725        Ok(())
726    }
727
728    fn set_mem_table(
729        &mut self,
730        hdr: &VhostUserMsgHeader<FrontendReq>,
731        size: usize,
732        buf: &[u8],
733        files: Vec<File>,
734    ) -> Result<()> {
735        self.check_request_size(hdr, size, hdr.get_size() as usize)?;
736
737        let (msg, regions) =
738            Ref::<_, VhostUserMemory>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
739        if !msg.is_valid() {
740            return Err(Error::InvalidMessage);
741        }
742
743        // validate number of fds matching number of memory regions
744        if files.len() != msg.num_regions as usize {
745            return Err(Error::InvalidMessage);
746        }
747
748        let (regions, excess) = Ref::<_, [VhostUserMemoryRegion]>::from_prefix_with_elems(
749            regions,
750            msg.num_regions as usize,
751        )
752        .map_err(|_| Error::InvalidMessage)?;
753        if !excess.is_empty() {
754            return Err(Error::InvalidMessage);
755        }
756
757        // Validate memory regions
758        for region in regions.iter() {
759            if !region.is_valid() {
760                return Err(Error::InvalidMessage);
761            }
762        }
763
764        self.backend.set_mem_table(&regions, files)
765    }
766
767    fn get_config(&mut self, hdr: &VhostUserMsgHeader<FrontendReq>, buf: &[u8]) -> Result<()> {
768        let (msg, payload) =
769            Ref::<_, VhostUserConfig>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
770        if !msg.is_valid() {
771            return Err(Error::InvalidMessage);
772        }
773        if payload.len() != msg.size as usize {
774            return Err(Error::InvalidMessage);
775        }
776        let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
777            Some(val) => val,
778            None => return Err(Error::InvalidMessage),
779        };
780        let res = self.backend.get_config(msg.offset, msg.size, flags);
781
782        // The response payload size MUST match the request payload size on success. A zero length
783        // response is used to indicate an error.
784        match res {
785            Ok(ref buf) if buf.len() == msg.size as usize => {
786                let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
787                self.send_reply_with_payload(hdr, &reply, buf.as_slice())?;
788            }
789            Ok(_) => {
790                let reply = VhostUserConfig::new(msg.offset, 0, flags);
791                self.send_reply_message(hdr, &reply)?;
792            }
793            Err(_) => {
794                let reply = VhostUserConfig::new(msg.offset, 0, flags);
795                self.send_reply_message(hdr, &reply)?;
796            }
797        }
798        Ok(())
799    }
800
801    fn set_config(&mut self, buf: &[u8]) -> Result<()> {
802        let (msg, payload) =
803            Ref::<_, VhostUserConfig>::from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
804        if !msg.is_valid() {
805            return Err(Error::InvalidMessage);
806        }
807        if payload.len() != msg.size as usize {
808            return Err(Error::InvalidMessage);
809        }
810        let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) {
811            Some(val) => val,
812            None => return Err(Error::InvalidMessage),
813        };
814
815        self.backend.set_config(msg.offset, payload, flags)
816    }
817
818    fn set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()> {
819        let file = into_single_file(files).ok_or(Error::InvalidMessage)?;
820        let fd: SafeDescriptor = file.into();
821        let connection = Connection::try_from(fd).map_err(|_| Error::InvalidMessage)?;
822        self.backend.set_backend_req_fd(connection);
823        Ok(())
824    }
825
826    /// Parses an incoming |SET_VRING_KICK| or |SET_VRING_CALL| message into a
827    /// Vring number and an fd.
828    fn handle_vring_fd_request(
829        &mut self,
830        buf: &[u8],
831        files: Vec<File>,
832    ) -> Result<(u8, Option<File>)> {
833        let (msg, _) = VhostUserU64::read_from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
834        if !msg.is_valid() {
835            return Err(Error::InvalidMessage);
836        }
837
838        // Bits (0-7) of the payload contain the vring index. Bit 8 is the
839        // invalid FD flag (VHOST_USER_VRING_NOFD_MASK).
840        // This bit is set when there is no file descriptor
841        // in the ancillary data. This signals that polling will be used
842        // instead of waiting for the call.
843        // If Bit 8 is unset, the data must contain a file descriptor.
844        let has_fd = (msg.value & 0x100u64) == 0;
845
846        let file = into_single_file(files);
847
848        if has_fd && file.is_none() || !has_fd && file.is_some() {
849            return Err(Error::InvalidMessage);
850        }
851
852        Ok((msg.value as u8, file))
853    }
854
855    fn check_request_size(
856        &self,
857        hdr: &VhostUserMsgHeader<FrontendReq>,
858        size: usize,
859        expected: usize,
860    ) -> Result<()> {
861        if hdr.get_size() as usize != expected
862            || hdr.is_reply()
863            || hdr.get_version() != 0x1
864            || size != expected
865        {
866            return Err(Error::InvalidMessage);
867        }
868        Ok(())
869    }
870
871    fn check_attached_files(
872        &self,
873        hdr: &VhostUserMsgHeader<FrontendReq>,
874        files: &[File],
875    ) -> Result<()> {
876        match hdr.get_code() {
877            Ok(FrontendReq::SET_MEM_TABLE)
878            | Ok(FrontendReq::SET_VRING_CALL)
879            | Ok(FrontendReq::SET_VRING_KICK)
880            | Ok(FrontendReq::SET_VRING_ERR)
881            | Ok(FrontendReq::SET_LOG_BASE)
882            | Ok(FrontendReq::SET_LOG_FD)
883            | Ok(FrontendReq::SET_BACKEND_REQ_FD)
884            | Ok(FrontendReq::SET_INFLIGHT_FD)
885            | Ok(FrontendReq::ADD_MEM_REG)
886            | Ok(FrontendReq::SET_DEVICE_STATE_FD) => Ok(()),
887            Err(_) => Err(Error::InvalidMessage),
888            _ if !files.is_empty() => Err(Error::InvalidMessage),
889            _ => Ok(()),
890        }
891    }
892
893    fn extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>(
894        &self,
895        hdr: &VhostUserMsgHeader<FrontendReq>,
896        size: usize,
897        buf: &[u8],
898    ) -> Result<T> {
899        self.check_request_size(hdr, size, mem::size_of::<T>())?;
900        let (body, _) = T::read_from_prefix(buf).map_err(|_| Error::InvalidMessage)?;
901        if body.is_valid() {
902            Ok(body)
903        } else {
904            Err(Error::InvalidMessage)
905        }
906    }
907
908    fn update_reply_ack_flag(&mut self) {
909        let pflag = VhostUserProtocolFeatures::REPLY_ACK;
910        self.reply_ack_enabled = (self.virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES) != 0
911            && self.protocol_features.contains(pflag)
912            && (self.acked_protocol_features & pflag.bits()) != 0;
913    }
914}
915
916impl<S: Backend> AsRawDescriptor for BackendServer<S> {
917    fn as_raw_descriptor(&self) -> RawDescriptor {
918        // TODO(b/221882601): figure out if this used for polling.
919        self.connection.as_raw_descriptor()
920    }
921}
922
923#[cfg(test)]
924mod tests {
925    use base::INVALID_DESCRIPTOR;
926
927    use super::*;
928    use crate::test_backend::TestBackend;
929    use crate::Connection;
930
931    #[test]
932    fn test_backend_server_new() {
933        let (p1, _p2) = Connection::pair().unwrap();
934        let backend = TestBackend::new();
935        let handler = BackendServer::new(p1, backend);
936
937        assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR);
938    }
939}