devices/virtio/vhost/
worker.rs

1// Copyright 2017 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::collections::BTreeMap;
6
7use anyhow::Context;
8use base::error;
9use base::Error as SysError;
10use base::Event;
11use base::EventToken;
12use base::Tube;
13use base::WaitContext;
14use libc::EIO;
15use serde::Deserialize;
16use serde::Serialize;
17use vhost::Vhost;
18use vm_memory::GuestMemory;
19
20use super::control_socket::VhostDevRequest;
21use super::control_socket::VhostDevResponse;
22use super::Error;
23use super::Result;
24use crate::virtio::Interrupt;
25use crate::virtio::Queue;
26use crate::virtio::VIRTIO_F_ACCESS_PLATFORM;
27
28#[derive(Clone, Serialize, Deserialize)]
29pub struct VringBase {
30    pub index: usize,
31    pub base: u16,
32}
33
34/// Worker that takes care of running the vhost device.
35pub struct Worker<T: Vhost> {
36    name: &'static str, // e.g. "vhost-vsock"
37    interrupt: Interrupt,
38    pub queues: BTreeMap<usize, Queue>,
39    pub vhost_handle: T,
40    // Event signaled by vhost when we should send an interrupt for a queue.
41    vhost_interrupts: BTreeMap<usize, Event>,
42    vhost_error_events: BTreeMap<usize, Event>,
43    acked_features: u64,
44    pub server_tube: Tube,
45}
46
47impl<T: Vhost> Worker<T> {
48    pub fn new(
49        name: &'static str,
50        queues: BTreeMap<usize, Queue>,
51        vhost_handle: T,
52        interrupt: Interrupt,
53        acked_features: u64,
54        server_tube: Tube,
55        mem: GuestMemory,
56        queue_vrings_base: Option<Vec<VringBase>>,
57    ) -> anyhow::Result<Worker<T>> {
58        let vhost_interrupts = queues
59            .keys()
60            .copied()
61            .map(|i| Ok((i, Event::new().context("failed to create Event")?)))
62            .collect::<anyhow::Result<_>>()?;
63        let vhost_error_events = queues
64            .keys()
65            .copied()
66            .map(|i| Ok((i, Event::new().context("failed to create Event")?)))
67            .collect::<anyhow::Result<_>>()?;
68        let worker = Worker {
69            name,
70            interrupt,
71            queues,
72            vhost_handle,
73            vhost_interrupts,
74            vhost_error_events,
75            acked_features,
76            server_tube,
77        };
78
79        let avail_features = worker
80            .vhost_handle
81            .get_features()
82            .map_err(Error::VhostGetFeatures)?;
83
84        let mut features = worker.acked_features & avail_features;
85        if worker.acked_features & (1u64 << VIRTIO_F_ACCESS_PLATFORM) != 0 {
86            // The vhost API is a bit poorly named, this flag in the context of vhost
87            // means that it will do address translation via its IOTLB APIs. If the
88            // underlying virtio device doesn't use viommu, it doesn't need vhost
89            // translation.
90            features &= !(1u64 << VIRTIO_F_ACCESS_PLATFORM);
91        }
92
93        worker
94            .vhost_handle
95            .set_features(features)
96            .map_err(Error::VhostSetFeatures)?;
97
98        worker
99            .vhost_handle
100            .set_mem_table(&mem)
101            .map_err(Error::VhostSetMemTable)?;
102
103        for (&queue_index, queue) in worker.queues.iter() {
104            worker
105                .vhost_handle
106                .set_vring_num(queue_index, queue.size())
107                .map_err(Error::VhostSetVringNum)?;
108
109            worker
110                .vhost_handle
111                .set_vring_err(queue_index, &worker.vhost_error_events[&queue_index])
112                .map_err(Error::VhostSetVringErr)?;
113
114            worker
115                .vhost_handle
116                .set_vring_addr(
117                    &mem,
118                    queue.size(),
119                    queue_index,
120                    0,
121                    queue.desc_table(),
122                    queue.used_ring(),
123                    queue.avail_ring(),
124                    None,
125                )
126                .map_err(Error::VhostSetVringAddr)?;
127            if let Some(vrings_base) = &queue_vrings_base {
128                let base = if let Some(vring_base) = vrings_base
129                    .iter()
130                    .find(|vring_base| vring_base.index == queue_index)
131                {
132                    vring_base.base
133                } else {
134                    anyhow::bail!(Error::VringBaseMissing);
135                };
136                worker
137                    .vhost_handle
138                    .set_vring_base(queue_index, base)
139                    .map_err(Error::VhostSetVringBase)?;
140            } else {
141                worker
142                    .vhost_handle
143                    .set_vring_base(queue_index, 0)
144                    .map_err(Error::VhostSetVringBase)?;
145            }
146            worker.set_vring_call_for_entry(queue_index, queue.vector() as usize)?;
147            worker
148                .vhost_handle
149                .set_vring_kick(queue_index, queue.event())
150                .map_err(Error::VhostSetVringKick)?;
151        }
152
153        Ok(worker)
154    }
155
156    pub fn run(&mut self, kill_evt: Event) -> Result<()> {
157        #[derive(EventToken)]
158        enum Token {
159            VhostIrqi { index: usize },
160            VhostError { index: usize },
161            Kill,
162            ControlNotify,
163        }
164
165        let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[(&kill_evt, Token::Kill)])
166            .map_err(Error::CreateWaitContext)?;
167
168        for (&index, vhost_int) in self.vhost_interrupts.iter() {
169            wait_ctx
170                .add(vhost_int, Token::VhostIrqi { index })
171                .map_err(Error::CreateWaitContext)?;
172        }
173        for (&index, event) in self.vhost_error_events.iter() {
174            wait_ctx
175                .add(event, Token::VhostError { index })
176                .map_err(Error::CreateWaitContext)?;
177        }
178        wait_ctx
179            .add(&self.server_tube, Token::ControlNotify)
180            .map_err(Error::CreateWaitContext)?;
181
182        'wait: loop {
183            let events = wait_ctx.wait().map_err(Error::WaitError)?;
184
185            for event in events.iter().filter(|e| e.is_readable) {
186                match event.token {
187                    Token::VhostIrqi { index } => {
188                        self.vhost_interrupts[&index]
189                            .wait()
190                            .map_err(Error::VhostIrqRead)?;
191                        self.interrupt
192                            .signal_used_queue(self.queues[&index].vector());
193                    }
194                    Token::VhostError { index } => {
195                        self.vhost_error_events[&index]
196                            .wait()
197                            .map_err(Error::VhostErrorRead)?;
198                        error!("{} reported error for virtqueue {index}", self.name);
199                    }
200                    Token::Kill => {
201                        let _ = kill_evt.wait();
202                        break 'wait;
203                    }
204                    Token::ControlNotify => match self.server_tube.recv() {
205                        Ok(VhostDevRequest::MsixEntryChanged(index)) => {
206                            let mut qindex = 0;
207                            for (&queue_index, queue) in self.queues.iter() {
208                                if queue.vector() == index as u16 {
209                                    qindex = queue_index;
210                                    break;
211                                }
212                            }
213                            let response = match self.set_vring_call_for_entry(qindex, index) {
214                                Ok(()) => VhostDevResponse::Ok,
215                                Err(e) => {
216                                    error!(
217                                        "Set vring call failed for masked entry {}: {:?}",
218                                        index, e
219                                    );
220                                    VhostDevResponse::Err(SysError::new(EIO))
221                                }
222                            };
223                            if let Err(e) = self.server_tube.send(&response) {
224                                error!("Vhost failed to send VhostMsixEntryMasked Response for entry {}: {:?}", index, e);
225                            }
226                        }
227                        Ok(VhostDevRequest::MsixChanged) => {
228                            let response = match self.set_vring_calls() {
229                                Ok(()) => VhostDevResponse::Ok,
230                                Err(e) => {
231                                    error!("Set vring calls failed: {:?}", e);
232                                    VhostDevResponse::Err(SysError::new(EIO))
233                                }
234                            };
235                            if let Err(e) = self.server_tube.send(&response) {
236                                error!("Vhost failed to send VhostMsixMasked Response: {:?}", e);
237                            }
238                        }
239                        Err(e) => {
240                            error!("Vhost failed to receive Control request: {:?}", e);
241                        }
242                    },
243                }
244            }
245        }
246        Ok(())
247    }
248
249    fn set_vring_call_for_entry(&self, queue_index: usize, vector: usize) -> Result<()> {
250        // If MSI-X is enabled and not masked for this queue, then give the irqfd directly to the
251        // vhost driver so that it can send interrupts without context switching into a crosvm
252        // userspace thread.
253        //
254        // Note that if the MSI-X config for the queue is masked or disabled, a config change
255        // notification will cause this function to be reinvoked and we will switch to the indirect
256        // interrupt handling if necessary. The config register write will block while this code is
257        // running.
258        if let Some(msix_config) = self.interrupt.get_msix_config() {
259            let msix_config = msix_config.lock();
260            if !msix_config.masked() && !msix_config.table_masked(vector) {
261                if let Some(irqfd) = msix_config.get_irqfd(vector) {
262                    self.vhost_handle
263                        .set_vring_call(queue_index, irqfd)
264                        .map_err(Error::VhostSetVringCall)?;
265                    return Ok(());
266                }
267            }
268        }
269        self.vhost_handle
270            .set_vring_call(queue_index, &self.vhost_interrupts[&queue_index])
271            .map_err(Error::VhostSetVringCall)?;
272        Ok(())
273    }
274
275    fn set_vring_calls(&self) -> Result<()> {
276        for (&queue_index, queue) in self.queues.iter() {
277            self.set_vring_call_for_entry(queue_index, queue.vector() as usize)?;
278        }
279        Ok(())
280    }
281}