1use std::collections::BTreeMap;
6use std::fs::OpenOptions;
7use std::os::unix::prelude::OpenOptionsExt;
8
9use anyhow::anyhow;
10use anyhow::Context;
11use base::error;
12use base::open_file_or_duplicate;
13use base::warn;
14use base::AsRawDescriptor;
15use base::RawDescriptor;
16use base::Tube;
17use base::WorkerThread;
18use data_model::Le64;
19use serde::Deserialize;
20use serde::Serialize;
21use snapshot::AnySnapshot;
22use vhost::Vhost;
23use vhost::Vsock as VhostVsockHandle;
24use vm_memory::GuestMemory;
25use zerocopy::IntoBytes;
26
27use super::control_socket::VhostDevRequest;
28use super::control_socket::VhostDevResponse;
29use super::worker::VringBase;
30use super::worker::Worker;
31use super::Error;
32use crate::pci::MsixStatus;
33use crate::virtio::copy_config;
34use crate::virtio::device_constants::vsock::NUM_QUEUES;
35use crate::virtio::vsock::VsockConfig;
36use crate::virtio::DeviceType;
37use crate::virtio::Interrupt;
38use crate::virtio::Queue;
39use crate::virtio::VirtioDevice;
40
41const DEFAULT_MAX_QUEUE_SIZE: u16 = 256;
42
43pub struct Vsock {
44 worker_thread: Option<WorkerThread<Worker<VhostVsockHandle>>>,
45 worker_client_tube: Tube,
46 worker_server_tube: Option<Tube>,
47 vhost_handle: Option<VhostVsockHandle>,
48 cid: u64,
49 avail_features: u64,
50 acked_features: u64,
51 vrings_base: Option<Vec<VringBase>>,
55 event_queue: Option<Queue>,
57 needs_transport_reset: bool,
59 max_queue_sizes: [u16; NUM_QUEUES],
60}
61
62#[derive(Serialize, Deserialize)]
63struct VsockSnapshot {
64 cid: u64,
65 avail_features: u64,
66 acked_features: u64,
67 vrings_base: Vec<VringBase>,
68}
69
70impl Vsock {
71 pub fn new(base_features: u64, vsock_config: &VsockConfig) -> anyhow::Result<Vsock> {
73 let device_file = open_file_or_duplicate(
74 &vsock_config.vhost_device,
75 OpenOptions::new()
76 .read(true)
77 .write(true)
78 .custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK),
79 )
80 .with_context(|| {
81 format!(
82 "failed to open virtual socket device {}",
83 vsock_config.vhost_device.display(),
84 )
85 })?;
86
87 let handle = VhostVsockHandle::new(device_file);
88
89 let avail_features = base_features;
90
91 let (worker_client_tube, worker_server_tube) = Tube::pair().map_err(Error::CreateTube)?;
92
93 Ok(Vsock {
94 worker_thread: None,
95 worker_client_tube,
96 worker_server_tube: Some(worker_server_tube),
97 vhost_handle: Some(handle),
98 cid: vsock_config.cid,
99 avail_features,
100 acked_features: 0,
101 vrings_base: None,
102 event_queue: None,
103 needs_transport_reset: false,
104 max_queue_sizes: vsock_config
105 .max_queue_sizes
106 .unwrap_or([DEFAULT_MAX_QUEUE_SIZE; NUM_QUEUES]),
107 })
108 }
109
110 pub fn new_for_testing(cid: u64, features: u64) -> Vsock {
111 let (worker_client_tube, worker_server_tube) = Tube::pair().unwrap();
112 Vsock {
113 worker_thread: None,
114 worker_client_tube,
115 worker_server_tube: Some(worker_server_tube),
116 vhost_handle: None,
117 cid,
118 avail_features: features,
119 acked_features: 0,
120 vrings_base: None,
121 event_queue: None,
122 needs_transport_reset: false,
123 max_queue_sizes: [DEFAULT_MAX_QUEUE_SIZE; NUM_QUEUES],
124 }
125 }
126
127 pub fn acked_features(&self) -> u64 {
128 self.acked_features
129 }
130}
131
132impl VirtioDevice for Vsock {
133 fn keep_rds(&self) -> Vec<RawDescriptor> {
134 let mut keep_rds = Vec::new();
135
136 if let Some(handle) = &self.vhost_handle {
137 keep_rds.push(handle.as_raw_descriptor());
138 }
139 keep_rds.push(self.worker_client_tube.as_raw_descriptor());
140 if let Some(worker_server_tube) = &self.worker_server_tube {
141 keep_rds.push(worker_server_tube.as_raw_descriptor());
142 }
143
144 keep_rds
145 }
146
147 fn device_type(&self) -> DeviceType {
148 DeviceType::Vsock
149 }
150
151 fn queue_max_sizes(&self) -> &[u16] {
152 &self.max_queue_sizes[..]
153 }
154
155 fn features(&self) -> u64 {
156 self.avail_features
157 }
158
159 fn read_config(&self, offset: u64, data: &mut [u8]) {
160 let cid = Le64::from(self.cid);
161 copy_config(data, 0, cid.as_bytes(), offset);
162 }
163
164 fn ack_features(&mut self, value: u64) {
165 let mut v = value;
166
167 let unrequested_features = v & !self.avail_features;
169 if unrequested_features != 0 {
170 warn!("vsock: virtio-vsock got unknown feature ack: {:x}", v);
171
172 v &= !unrequested_features;
174 }
175 self.acked_features |= v;
176 }
177
178 fn activate(
179 &mut self,
180 mem: GuestMemory,
181 interrupt: Interrupt,
182 mut queues: BTreeMap<usize, Queue>,
183 ) -> anyhow::Result<()> {
184 if queues.len() != NUM_QUEUES {
185 return Err(anyhow!(
186 "vsock: expected {} queues, got {}",
187 NUM_QUEUES,
188 queues.len()
189 ));
190 }
191
192 let vhost_handle = self.vhost_handle.take().context("missing vhost_handle")?;
193 let acked_features = self.acked_features;
194 let cid = self.cid;
195
196 let mut event_queue = queues.remove(&2).unwrap();
199 if self.needs_transport_reset {
201 self.needs_transport_reset = false;
202
203 let mut avail_desc = event_queue
210 .pop()
211 .expect("event queue is empty, can't send transport reset event");
212 let transport_reset = virtio_sys::virtio_vsock::virtio_vsock_event{
213 id: virtio_sys::virtio_vsock::virtio_vsock_event_id_VIRTIO_VSOCK_EVENT_TRANSPORT_RESET.into(),
214 };
215 avail_desc
216 .writer
217 .write_obj(transport_reset)
218 .expect("failed to write transport reset event");
219 event_queue.add_used(avail_desc);
220 event_queue.trigger_interrupt();
221 }
222 self.event_queue = Some(event_queue);
223
224 let mut worker = Worker::new(
225 "vhost-vsock",
226 queues,
227 vhost_handle,
228 interrupt,
229 acked_features,
230 self.worker_server_tube
231 .take()
232 .expect("worker control tube missing"),
233 mem,
234 self.vrings_base.take(),
235 )
236 .context("vsock worker init exited with error")?;
237 worker
238 .vhost_handle
239 .set_cid(cid)
240 .map_err(Error::VhostVsockSetCid)?;
241 worker
242 .vhost_handle
243 .start()
244 .map_err(Error::VhostVsockStart)?;
245
246 self.worker_thread = Some(WorkerThread::start("vhost_vsock", move |kill_evt| {
247 let result = worker.run(kill_evt);
248 if let Err(e) = result {
249 error!("vsock worker thread exited with error: {:?}", e);
250 }
251 worker
252 }));
253
254 Ok(())
255 }
256
257 fn reset(&mut self) -> anyhow::Result<()> {
258 if let Some(worker_thread) = self.worker_thread.take() {
259 let worker = worker_thread.stop();
260 worker
261 .vhost_handle
262 .stop()
263 .context("failed to stop vrings")?;
264 for (pos, _) in worker.queues.iter() {
266 worker
267 .vhost_handle
268 .get_vring_base(*pos)
269 .context("get_vring_base failed")?;
270 }
271
272 self.vhost_handle = Some(worker.vhost_handle);
273 self.worker_server_tube = Some(worker.server_tube);
274 }
275 self.acked_features = 0;
276 self.vrings_base = None;
277 self.event_queue = None;
278 self.needs_transport_reset = false;
279 Ok(())
280 }
281
282 fn on_device_sandboxed(&mut self) {
283 if let Some(vhost_handle) = &self.vhost_handle {
287 match vhost_handle.set_owner() {
288 Ok(_) => {}
289 Err(e) => error!("{}: failed to set owner: {:?}", self.debug_label(), e),
290 }
291 }
292 }
293
294 fn control_notify(&self, behavior: MsixStatus) {
295 if self.worker_thread.is_none() {
296 return;
297 }
298 match behavior {
299 MsixStatus::EntryChanged(index) => {
300 if let Err(e) = self
301 .worker_client_tube
302 .send(&VhostDevRequest::MsixEntryChanged(index))
303 {
304 error!(
305 "{} failed to send VhostMsixEntryChanged request for entry {}: {:?}",
306 self.debug_label(),
307 index,
308 e
309 );
310 return;
311 }
312 if let Err(e) = self.worker_client_tube.recv::<VhostDevResponse>() {
313 error!(
314 "{} failed to receive VhostMsixEntryChanged response for entry {}: {:?}",
315 self.debug_label(),
316 index,
317 e
318 );
319 }
320 }
321 MsixStatus::Changed => {
322 if let Err(e) = self.worker_client_tube.send(&VhostDevRequest::MsixChanged) {
323 error!(
324 "{} failed to send VhostMsixChanged request: {:?}",
325 self.debug_label(),
326 e
327 );
328 return;
329 }
330 if let Err(e) = self.worker_client_tube.recv::<VhostDevResponse>() {
331 error!(
332 "{} failed to receive VhostMsixChanged response {:?}",
333 self.debug_label(),
334 e
335 );
336 }
337 }
338 _ => {}
339 }
340 }
341
342 fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
343 if let Some(worker_thread) = self.worker_thread.take() {
344 let worker = worker_thread.stop();
345 worker
346 .vhost_handle
347 .stop()
348 .context("failed to stop vrings")?;
349 let mut queues: BTreeMap<usize, Queue> = worker.queues;
350 let mut vrings_base = Vec::new();
351 for (pos, _) in queues.iter() {
352 let vring_base = VringBase {
353 index: *pos,
354 base: worker.vhost_handle.get_vring_base(*pos)?,
355 };
356 vrings_base.push(vring_base);
357 }
358 self.vrings_base = Some(vrings_base);
359 self.vhost_handle = Some(worker.vhost_handle);
360 self.worker_server_tube = Some(worker.server_tube);
361 queues.insert(
362 2,
363 self.event_queue.take().expect("Vsock event queue missing"),
364 );
365 return Ok(Some(BTreeMap::from_iter(queues)));
366 }
367 Ok(None)
368 }
369
370 fn virtio_wake(
371 &mut self,
372 device_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
373 ) -> anyhow::Result<()> {
374 match device_state {
375 None => Ok(()),
376 Some((mem, interrupt, queues)) => {
377 self.activate(mem, interrupt, queues)?;
381 Ok(())
382 }
383 }
384 }
385
386 fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
387 let vrings_base = self.vrings_base.clone().unwrap_or_default();
388 AnySnapshot::to_any(VsockSnapshot {
389 cid: self.cid,
392 avail_features: self.avail_features,
393 acked_features: self.acked_features,
394 vrings_base,
395 })
396 .context("failed to snapshot virtio console")
397 }
398
399 fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
400 let deser: VsockSnapshot =
401 AnySnapshot::from_any(data).context("failed to deserialize virtio vsock")?;
402 anyhow::ensure!(
403 self.cid == deser.cid,
404 "Virtio vsock incorrect cid for restore:\n Expected: {}, Actual: {}",
405 self.cid,
406 deser.cid,
407 );
408 anyhow::ensure!(
409 self.avail_features == deser.avail_features,
410 "Virtio vsock incorrect avail features for restore:\n Expected: {}, Actual: {}",
411 self.avail_features,
412 deser.avail_features,
413 );
414 self.acked_features = deser.acked_features;
415 self.vrings_base = Some(deser.vrings_base);
416 self.needs_transport_reset = true;
419 Ok(())
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use std::convert::TryInto;
426
427 use super::*;
428
429 #[test]
430 fn ack_features() {
431 let cid = 5;
432 let features: u64 = (1 << 20) | (1 << 49) | (1 << 2) | (1 << 19);
433 let mut acked_features: u64 = 0;
434 let mut unavailable_features: u64 = 0;
435
436 let mut vsock = Vsock::new_for_testing(cid, features);
437 assert_eq!(acked_features, vsock.acked_features());
438
439 acked_features |= 1 << 2;
440 vsock.ack_features(acked_features);
441 assert_eq!(acked_features, vsock.acked_features());
442
443 acked_features |= 1 << 49;
444 vsock.ack_features(acked_features);
445 assert_eq!(acked_features, vsock.acked_features());
446
447 acked_features |= 1 << 60;
448 unavailable_features |= 1 << 60;
449 vsock.ack_features(acked_features);
450 assert_eq!(
451 acked_features & !unavailable_features,
452 vsock.acked_features()
453 );
454
455 acked_features |= 1 << 1;
456 unavailable_features |= 1 << 1;
457 vsock.ack_features(acked_features);
458 assert_eq!(
459 acked_features & !unavailable_features,
460 vsock.acked_features()
461 );
462 }
463
464 #[test]
465 fn read_config() {
466 let cid = 0xfca9a559fdcb9756;
467 let vsock = Vsock::new_for_testing(cid, 0);
468
469 let mut buf = [0u8; 8];
470 vsock.read_config(0, &mut buf);
471 assert_eq!(cid, u64::from_le_bytes(buf));
472
473 vsock.read_config(0, &mut buf[..4]);
474 assert_eq!(
475 (cid & 0xffffffff) as u32,
476 u32::from_le_bytes(buf[..4].try_into().unwrap())
477 );
478
479 vsock.read_config(4, &mut buf[..4]);
480 assert_eq!(
481 (cid >> 32) as u32,
482 u32::from_le_bytes(buf[..4].try_into().unwrap())
483 );
484
485 let data: [u8; 8] = [8, 226, 5, 46, 159, 59, 89, 77];
486 buf.copy_from_slice(&data);
487
488 vsock.read_config(12, &mut buf);
489 assert_eq!(&buf, &data);
490 }
491
492 #[test]
493 fn features() {
494 let cid = 5;
495 let features: u64 = 0xfc195ae8db88cff9;
496
497 let vsock = Vsock::new_for_testing(cid, features);
498 assert_eq!(features, vsock.features());
499 }
500}