devices/virtio/vhost/
worker.rs1use std::collections::BTreeMap;
6
7use anyhow::Context;
8use base::error;
9use base::Error as SysError;
10use base::Event;
11use base::EventToken;
12use base::Tube;
13use base::WaitContext;
14use libc::EIO;
15use serde::Deserialize;
16use serde::Serialize;
17use vhost::Vhost;
18use vm_memory::GuestMemory;
19
20use super::control_socket::VhostDevRequest;
21use super::control_socket::VhostDevResponse;
22use super::Error;
23use super::Result;
24use crate::virtio::Interrupt;
25use crate::virtio::Queue;
26use crate::virtio::VIRTIO_F_ACCESS_PLATFORM;
27
28#[derive(Clone, Serialize, Deserialize)]
29pub struct VringBase {
30 pub index: usize,
31 pub base: u16,
32}
33
34pub struct Worker<T: Vhost> {
36 name: &'static str, interrupt: Interrupt,
38 pub queues: BTreeMap<usize, Queue>,
39 pub vhost_handle: T,
40 vhost_interrupts: BTreeMap<usize, Event>,
42 vhost_error_events: BTreeMap<usize, Event>,
43 acked_features: u64,
44 pub server_tube: Tube,
45}
46
47impl<T: Vhost> Worker<T> {
48 pub fn new(
49 name: &'static str,
50 queues: BTreeMap<usize, Queue>,
51 vhost_handle: T,
52 interrupt: Interrupt,
53 acked_features: u64,
54 server_tube: Tube,
55 mem: GuestMemory,
56 queue_vrings_base: Option<Vec<VringBase>>,
57 ) -> anyhow::Result<Worker<T>> {
58 let vhost_interrupts = queues
59 .keys()
60 .copied()
61 .map(|i| Ok((i, Event::new().context("failed to create Event")?)))
62 .collect::<anyhow::Result<_>>()?;
63 let vhost_error_events = queues
64 .keys()
65 .copied()
66 .map(|i| Ok((i, Event::new().context("failed to create Event")?)))
67 .collect::<anyhow::Result<_>>()?;
68 let worker = Worker {
69 name,
70 interrupt,
71 queues,
72 vhost_handle,
73 vhost_interrupts,
74 vhost_error_events,
75 acked_features,
76 server_tube,
77 };
78
79 let avail_features = worker
80 .vhost_handle
81 .get_features()
82 .map_err(Error::VhostGetFeatures)?;
83
84 let mut features = worker.acked_features & avail_features;
85 if worker.acked_features & (1u64 << VIRTIO_F_ACCESS_PLATFORM) != 0 {
86 features &= !(1u64 << VIRTIO_F_ACCESS_PLATFORM);
91 }
92
93 worker
94 .vhost_handle
95 .set_features(features)
96 .map_err(Error::VhostSetFeatures)?;
97
98 worker
99 .vhost_handle
100 .set_mem_table(&mem)
101 .map_err(Error::VhostSetMemTable)?;
102
103 for (&queue_index, queue) in worker.queues.iter() {
104 worker
105 .vhost_handle
106 .set_vring_num(queue_index, queue.size())
107 .map_err(Error::VhostSetVringNum)?;
108
109 worker
110 .vhost_handle
111 .set_vring_err(queue_index, &worker.vhost_error_events[&queue_index])
112 .map_err(Error::VhostSetVringErr)?;
113
114 worker
115 .vhost_handle
116 .set_vring_addr(
117 &mem,
118 queue.size(),
119 queue_index,
120 0,
121 queue.desc_table(),
122 queue.used_ring(),
123 queue.avail_ring(),
124 None,
125 )
126 .map_err(Error::VhostSetVringAddr)?;
127 if let Some(vrings_base) = &queue_vrings_base {
128 let base = if let Some(vring_base) = vrings_base
129 .iter()
130 .find(|vring_base| vring_base.index == queue_index)
131 {
132 vring_base.base
133 } else {
134 anyhow::bail!(Error::VringBaseMissing);
135 };
136 worker
137 .vhost_handle
138 .set_vring_base(queue_index, base)
139 .map_err(Error::VhostSetVringBase)?;
140 } else {
141 worker
142 .vhost_handle
143 .set_vring_base(queue_index, 0)
144 .map_err(Error::VhostSetVringBase)?;
145 }
146 worker.set_vring_call_for_entry(queue_index, queue.vector() as usize)?;
147 worker
148 .vhost_handle
149 .set_vring_kick(queue_index, queue.event())
150 .map_err(Error::VhostSetVringKick)?;
151 }
152
153 Ok(worker)
154 }
155
156 pub fn run(&mut self, kill_evt: Event) -> Result<()> {
157 #[derive(EventToken)]
158 enum Token {
159 VhostIrqi { index: usize },
160 VhostError { index: usize },
161 Kill,
162 ControlNotify,
163 }
164
165 let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[(&kill_evt, Token::Kill)])
166 .map_err(Error::CreateWaitContext)?;
167
168 for (&index, vhost_int) in self.vhost_interrupts.iter() {
169 wait_ctx
170 .add(vhost_int, Token::VhostIrqi { index })
171 .map_err(Error::CreateWaitContext)?;
172 }
173 for (&index, event) in self.vhost_error_events.iter() {
174 wait_ctx
175 .add(event, Token::VhostError { index })
176 .map_err(Error::CreateWaitContext)?;
177 }
178 wait_ctx
179 .add(&self.server_tube, Token::ControlNotify)
180 .map_err(Error::CreateWaitContext)?;
181
182 'wait: loop {
183 let events = wait_ctx.wait().map_err(Error::WaitError)?;
184
185 for event in events.iter().filter(|e| e.is_readable) {
186 match event.token {
187 Token::VhostIrqi { index } => {
188 self.vhost_interrupts[&index]
189 .wait()
190 .map_err(Error::VhostIrqRead)?;
191 self.interrupt
192 .signal_used_queue(self.queues[&index].vector());
193 }
194 Token::VhostError { index } => {
195 self.vhost_error_events[&index]
196 .wait()
197 .map_err(Error::VhostErrorRead)?;
198 error!("{} reported error for virtqueue {index}", self.name);
199 }
200 Token::Kill => {
201 let _ = kill_evt.wait();
202 break 'wait;
203 }
204 Token::ControlNotify => match self.server_tube.recv() {
205 Ok(VhostDevRequest::MsixEntryChanged(index)) => {
206 let mut qindex = 0;
207 for (&queue_index, queue) in self.queues.iter() {
208 if queue.vector() == index as u16 {
209 qindex = queue_index;
210 break;
211 }
212 }
213 let response = match self.set_vring_call_for_entry(qindex, index) {
214 Ok(()) => VhostDevResponse::Ok,
215 Err(e) => {
216 error!(
217 "Set vring call failed for masked entry {}: {:?}",
218 index, e
219 );
220 VhostDevResponse::Err(SysError::new(EIO))
221 }
222 };
223 if let Err(e) = self.server_tube.send(&response) {
224 error!("Vhost failed to send VhostMsixEntryMasked Response for entry {}: {:?}", index, e);
225 }
226 }
227 Ok(VhostDevRequest::MsixChanged) => {
228 let response = match self.set_vring_calls() {
229 Ok(()) => VhostDevResponse::Ok,
230 Err(e) => {
231 error!("Set vring calls failed: {:?}", e);
232 VhostDevResponse::Err(SysError::new(EIO))
233 }
234 };
235 if let Err(e) = self.server_tube.send(&response) {
236 error!("Vhost failed to send VhostMsixMasked Response: {:?}", e);
237 }
238 }
239 Err(e) => {
240 error!("Vhost failed to receive Control request: {:?}", e);
241 }
242 },
243 }
244 }
245 }
246 Ok(())
247 }
248
249 fn set_vring_call_for_entry(&self, queue_index: usize, vector: usize) -> Result<()> {
250 if let Some(msix_config) = self.interrupt.get_msix_config() {
259 let msix_config = msix_config.lock();
260 if !msix_config.masked() && !msix_config.table_masked(vector) {
261 if let Some(irqfd) = msix_config.get_irqfd(vector) {
262 self.vhost_handle
263 .set_vring_call(queue_index, irqfd)
264 .map_err(Error::VhostSetVringCall)?;
265 return Ok(());
266 }
267 }
268 }
269 self.vhost_handle
270 .set_vring_call(queue_index, &self.vhost_interrupts[&queue_index])
271 .map_err(Error::VhostSetVringCall)?;
272 Ok(())
273 }
274
275 fn set_vring_calls(&self) -> Result<()> {
276 for (&queue_index, queue) in self.queues.iter() {
277 self.set_vring_call_for_entry(queue_index, queue.vector() as usize)?;
278 }
279 Ok(())
280 }
281}