1use std::collections::BTreeMap;
6use std::collections::VecDeque;
7use std::io::Write;
8use std::sync::Arc;
9
10use anyhow::anyhow;
11use anyhow::Context;
12use balloon_control::BalloonStats;
13use balloon_control::BalloonTubeCommand;
14use balloon_control::BalloonTubeResult;
15use balloon_control::BalloonWS;
16use balloon_control::WSBucket;
17use balloon_control::VIRTIO_BALLOON_WS_MAX_NUM_BINS;
18use balloon_control::VIRTIO_BALLOON_WS_MIN_NUM_BINS;
19use base::debug;
20use base::error;
21use base::warn;
22use base::AsRawDescriptor;
23use base::Event;
24use base::RawDescriptor;
25#[cfg(feature = "registered_events")]
26use base::SendTube;
27use base::Tube;
28use base::WorkerThread;
29use cros_async::block_on;
30use cros_async::sync::RwLock as AsyncRwLock;
31use cros_async::AsyncTube;
32use cros_async::EventAsync;
33use cros_async::Executor;
34#[cfg(feature = "registered_events")]
35use cros_async::SendTubeAsync;
36use data_model::Le16;
37use data_model::Le32;
38use data_model::Le64;
39use futures::channel::mpsc;
40use futures::channel::oneshot;
41use futures::pin_mut;
42use futures::select;
43use futures::select_biased;
44use futures::FutureExt;
45use futures::StreamExt;
46use remain::sorted;
47use serde::Deserialize;
48use serde::Serialize;
49use snapshot::AnySnapshot;
50use thiserror::Error as ThisError;
51use vm_control::api::VmMemoryClient;
52#[cfg(feature = "registered_events")]
53use vm_control::RegisteredEventWithData;
54use vm_memory::GuestAddress;
55use vm_memory::GuestMemory;
56use zerocopy::FromBytes;
57use zerocopy::Immutable;
58use zerocopy::IntoBytes;
59use zerocopy::KnownLayout;
60
61use super::async_utils;
62use super::copy_config;
63use super::create_stop_oneshot;
64use super::DescriptorChain;
65use super::DeviceType;
66use super::Interrupt;
67use super::Queue;
68use super::Reader;
69use super::StoppedWorker;
70use super::VirtioDevice;
71use crate::UnpinRequest;
72use crate::UnpinResponse;
73
74#[sorted]
75#[derive(ThisError, Debug)]
76pub enum BalloonError {
77 #[error("failed async await: {0}")]
79 AsyncAwait(cros_async::AsyncError),
80 #[error("failed async await: {0}")]
82 AsyncAwaitAnyhow(anyhow::Error),
83 #[error("failed to create event: {0}")]
85 CreatingEvent(base::Error),
86 #[error("failed to create async message receiver: {0}")]
88 CreatingMessageReceiver(base::TubeError),
89 #[error("failed to receive command message: {0}")]
91 ReceivingCommand(base::TubeError),
92 #[error("failed to send command response: {0}")]
94 SendResponse(base::TubeError),
95 #[error("failed to write to virtqueue: {0}")]
97 WriteQueue(std::io::Error),
98 #[error("failed to write config event: {0}")]
100 WritingConfigEvent(base::Error),
101}
102pub type Result<T> = std::result::Result<T, BalloonError>;
103
104const QUEUE_SIZE: u16 = 128;
105
106const INFLATEQ: usize = 0;
108const DEFLATEQ: usize = 1;
109
110const VIRTIO_BALLOON_PFN_SHIFT: u32 = 12;
111const VIRTIO_BALLOON_PF_SIZE: u64 = 1 << VIRTIO_BALLOON_PFN_SHIFT;
112
113const VIRTIO_BALLOON_F_MUST_TELL_HOST: u32 = 0; const VIRTIO_BALLOON_F_STATS_VQ: u32 = 1; const VIRTIO_BALLOON_F_DEFLATE_ON_OOM: u32 = 2; const VIRTIO_BALLOON_F_PAGE_REPORTING: u32 = 5; const VIRTIO_BALLOON_F_WS_REPORTING: u32 = 8; #[derive(Copy, Clone)]
122#[repr(u32)]
123pub enum BalloonFeatures {
125 PageReporting = VIRTIO_BALLOON_F_PAGE_REPORTING,
127 WSReporting = VIRTIO_BALLOON_F_WS_REPORTING,
129}
130
131#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
133#[repr(C)]
134struct virtio_balloon_config {
135 num_pages: Le32,
136 actual: Le32,
137 free_page_hint_cmd_id: Le32,
138 poison_val: Le32,
139 ws_num_bins: u8,
141 _reserved: [u8; 3],
142}
143
144#[derive(Clone, Default, Serialize, Deserialize)]
146struct BalloonState {
147 num_pages: u32,
148 actual_pages: u32,
149 expecting_ws: bool,
150 failable_update: bool,
154 pending_adjusted_responses: VecDeque<u32>,
155}
156
157const VIRTIO_BALLOON_S_SWAP_IN: u16 = 0;
159const VIRTIO_BALLOON_S_SWAP_OUT: u16 = 1;
160const VIRTIO_BALLOON_S_MAJFLT: u16 = 2;
161const VIRTIO_BALLOON_S_MINFLT: u16 = 3;
162const VIRTIO_BALLOON_S_MEMFREE: u16 = 4;
163const VIRTIO_BALLOON_S_MEMTOT: u16 = 5;
164const VIRTIO_BALLOON_S_AVAIL: u16 = 6;
165const VIRTIO_BALLOON_S_CACHES: u16 = 7;
166const VIRTIO_BALLOON_S_HTLB_PGALLOC: u16 = 8;
167const VIRTIO_BALLOON_S_HTLB_PGFAIL: u16 = 9;
168const VIRTIO_BALLOON_S_NONSTANDARD_SHMEM: u16 = 65534;
169const VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE: u16 = 65535;
170
171#[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
173#[repr(C, packed)]
174struct BalloonStat {
175 tag: Le16,
176 val: Le64,
177}
178
179impl BalloonStat {
180 fn update_stats(&self, stats: &mut BalloonStats) {
181 let val = Some(self.val.to_native());
182 match self.tag.to_native() {
183 VIRTIO_BALLOON_S_SWAP_IN => stats.swap_in = val,
184 VIRTIO_BALLOON_S_SWAP_OUT => stats.swap_out = val,
185 VIRTIO_BALLOON_S_MAJFLT => stats.major_faults = val,
186 VIRTIO_BALLOON_S_MINFLT => stats.minor_faults = val,
187 VIRTIO_BALLOON_S_MEMFREE => stats.free_memory = val,
188 VIRTIO_BALLOON_S_MEMTOT => stats.total_memory = val,
189 VIRTIO_BALLOON_S_AVAIL => stats.available_memory = val,
190 VIRTIO_BALLOON_S_CACHES => stats.disk_caches = val,
191 VIRTIO_BALLOON_S_HTLB_PGALLOC => stats.hugetlb_allocations = val,
192 VIRTIO_BALLOON_S_HTLB_PGFAIL => stats.hugetlb_failures = val,
193 VIRTIO_BALLOON_S_NONSTANDARD_SHMEM => stats.shared_memory = val,
194 VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE => stats.unevictable_memory = val,
195 _ => (),
196 }
197 }
198}
199
200#[repr(C)]
202#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
203struct virtio_balloon_ws {
204 tag: Le16,
205 node_id: Le16,
206 _reserved: [u8; 4],
209 idle_age_ms: Le64,
210 memory_size_bytes: [Le64; 2],
212}
213
214impl virtio_balloon_ws {
215 fn update_ws(&self, ws: &mut BalloonWS) {
216 let bucket = WSBucket {
217 age: self.idle_age_ms.to_native(),
218 bytes: [
219 self.memory_size_bytes[0].to_native(),
220 self.memory_size_bytes[1].to_native(),
221 ],
222 };
223 ws.ws.push(bucket);
224 }
225}
226
227const _VIRTIO_BALLOON_WS_OP_INVALID: u16 = 0;
228const VIRTIO_BALLOON_WS_OP_REQUEST: u16 = 1;
229const VIRTIO_BALLOON_WS_OP_CONFIG: u16 = 2;
230const _VIRTIO_BALLOON_WS_OP_DISCARD: u16 = 3;
231
232#[repr(C, packed)]
234#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
235struct virtio_balloon_op {
236 type_: Le16,
237}
238
239fn invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F)
240where
241 F: FnMut(Vec<(GuestAddress, u64)>),
242{
243 desc_handler(
244 ranges
245 .into_iter()
246 .map(|range| (GuestAddress(range.0), range.1))
247 .collect(),
248 );
249}
250
251fn release_ranges<F>(
255 release_memory_tube: Option<&Tube>,
256 inflate_ranges: Vec<(u64, u64)>,
257 desc_handler: &mut F,
258) -> anyhow::Result<()>
259where
260 F: FnMut(Vec<(GuestAddress, u64)>),
261{
262 if let Some(tube) = release_memory_tube {
263 let unpin_ranges = inflate_ranges
264 .iter()
265 .map(|v| {
266 (
267 v.0 >> VIRTIO_BALLOON_PFN_SHIFT,
268 v.1 / VIRTIO_BALLOON_PF_SIZE,
269 )
270 })
271 .collect();
272 let req = UnpinRequest {
273 ranges: unpin_ranges,
274 };
275 if let Err(e) = tube.send(&req) {
276 error!("failed to send unpin request: {}", e);
277 } else {
278 match tube.recv() {
279 Ok(resp) => match resp {
280 UnpinResponse::Success => invoke_desc_handler(inflate_ranges, desc_handler),
281 UnpinResponse::Failed => error!("failed to handle unpin request"),
282 },
283 Err(e) => error!("failed to handle get unpin response: {}", e),
284 }
285 }
286 } else {
287 invoke_desc_handler(inflate_ranges, desc_handler);
288 }
289
290 Ok(())
291}
292
293fn handle_address_chain<F>(
295 release_memory_tube: Option<&Tube>,
296 avail_desc: &mut DescriptorChain,
297 desc_handler: &mut F,
298) -> anyhow::Result<()>
299where
300 F: FnMut(Vec<(GuestAddress, u64)>),
301{
302 let mut range_start = 0;
307 let mut range_size = 0;
308 let mut inflate_ranges: Vec<(u64, u64)> = Vec::new();
309 for res in avail_desc.reader.iter::<Le32>() {
310 let pfn = match res {
311 Ok(pfn) => pfn,
312 Err(e) => {
313 error!("error while reading unused pages: {}", e);
314 break;
315 }
316 };
317 let guest_address = (u64::from(pfn.to_native())) << VIRTIO_BALLOON_PFN_SHIFT;
318 if range_start + range_size == guest_address {
319 range_size += VIRTIO_BALLOON_PF_SIZE;
320 } else if range_start == guest_address + VIRTIO_BALLOON_PF_SIZE {
321 range_start = guest_address;
322 range_size += VIRTIO_BALLOON_PF_SIZE;
323 } else {
324 if range_size != 0 {
327 inflate_ranges.push((range_start, range_size));
328 }
329 range_start = guest_address;
330 range_size = VIRTIO_BALLOON_PF_SIZE;
331 }
332 }
333 if range_size != 0 {
334 inflate_ranges.push((range_start, range_size));
335 }
336
337 release_ranges(release_memory_tube, inflate_ranges, desc_handler)
338}
339
340async fn handle_queue<F>(
342 mut queue: Queue,
343 mut queue_event: EventAsync,
344 release_memory_tube: Option<&Tube>,
345 mut desc_handler: F,
346 mut stop_rx: oneshot::Receiver<()>,
347) -> Queue
348where
349 F: FnMut(Vec<(GuestAddress, u64)>),
350{
351 loop {
352 let mut avail_desc = match queue
353 .next_async_interruptable(&mut queue_event, &mut stop_rx)
354 .await
355 {
356 Ok(Some(res)) => res,
357 Ok(None) => return queue,
358 Err(e) => {
359 error!("Failed to read descriptor {}", e);
360 return queue;
361 }
362 };
363 if let Err(e) =
364 handle_address_chain(release_memory_tube, &mut avail_desc, &mut desc_handler)
365 {
366 error!("balloon: failed to process inflate addresses: {}", e);
367 }
368 queue.add_used(avail_desc);
369 queue.trigger_interrupt();
370 }
371}
372
373fn handle_reported_buffer<F>(
375 release_memory_tube: Option<&Tube>,
376 avail_desc: &DescriptorChain,
377 desc_handler: &mut F,
378) -> anyhow::Result<()>
379where
380 F: FnMut(Vec<(GuestAddress, u64)>),
381{
382 let reported_ranges: Vec<(u64, u64)> = avail_desc
383 .reader
384 .get_remaining_regions()
385 .chain(avail_desc.writer.get_remaining_regions())
386 .map(|r| (r.offset, r.len as u64))
387 .collect();
388
389 release_ranges(release_memory_tube, reported_ranges, desc_handler)
390}
391
392async fn handle_reporting_queue<F>(
394 mut queue: Queue,
395 mut queue_event: EventAsync,
396 release_memory_tube: Option<&Tube>,
397 mut desc_handler: F,
398 mut stop_rx: oneshot::Receiver<()>,
399) -> Queue
400where
401 F: FnMut(Vec<(GuestAddress, u64)>),
402{
403 loop {
404 let avail_desc = match queue
405 .next_async_interruptable(&mut queue_event, &mut stop_rx)
406 .await
407 {
408 Ok(Some(res)) => res,
409 Ok(None) => return queue,
410 Err(e) => {
411 error!("Failed to read descriptor {}", e);
412 return queue;
413 }
414 };
415 if let Err(e) = handle_reported_buffer(release_memory_tube, &avail_desc, &mut desc_handler)
416 {
417 error!("balloon: failed to process reported buffer: {}", e);
418 }
419 queue.add_used(avail_desc);
420 queue.trigger_interrupt();
421 }
422}
423
424fn parse_balloon_stats(reader: &mut Reader) -> BalloonStats {
425 let mut stats: BalloonStats = Default::default();
426 for res in reader.iter::<BalloonStat>() {
427 match res {
428 Ok(stat) => stat.update_stats(&mut stats),
429 Err(e) => {
430 error!("error while reading stats: {}", e);
431 break;
432 }
433 };
434 }
435 stats
436}
437
438async fn handle_stats_queue(
443 mut queue: Queue,
444 mut queue_event: EventAsync,
445 mut stats_rx: mpsc::Receiver<()>,
446 command_tube: &AsyncTube,
447 #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
448 state: Arc<AsyncRwLock<BalloonState>>,
449 mut stop_rx: oneshot::Receiver<()>,
450) -> Queue {
451 let mut avail_desc = match queue
452 .next_async_interruptable(&mut queue_event, &mut stop_rx)
453 .await
454 {
455 Ok(Some(res)) => res,
458 Ok(None) => return queue,
459 Err(e) => {
460 error!("Failed to read descriptor {}", e);
461 return queue;
462 }
463 };
464
465 loop {
466 select_biased! {
467 msg = stats_rx.next() => {
468 match msg {
470 Some(()) => (),
471 None => {
472 error!("stats signal channel was closed");
473 return queue;
474 }
475 }
476 }
477 _ = stop_rx => return queue,
478 };
479
480 queue.add_used(avail_desc);
482 queue.trigger_interrupt();
483
484 avail_desc = match queue.next_async(&mut queue_event).await {
485 Err(e) => {
486 error!("Failed to read descriptor {}", e);
487 return queue;
488 }
489 Ok(d) => d,
490 };
491 let stats = parse_balloon_stats(&mut avail_desc.reader);
492
493 let actual_pages = state.lock().await.actual_pages as u64;
494 let result = BalloonTubeResult::Stats {
495 balloon_actual: actual_pages << VIRTIO_BALLOON_PFN_SHIFT,
496 stats,
497 };
498 let send_result = command_tube.send(result).await;
499 if let Err(e) = send_result {
500 error!("failed to send stats result: {}", e);
501 }
502
503 #[cfg(feature = "registered_events")]
504 if let Some(registered_evt_q) = registered_evt_q {
505 if let Err(e) = registered_evt_q
506 .send(&RegisteredEventWithData::VirtioBalloonResize)
507 .await
508 {
509 error!("failed to send VirtioBalloonResize event: {}", e);
510 }
511 }
512 }
513}
514
515async fn send_adjusted_response(
516 tube: &AsyncTube,
517 num_pages: u32,
518) -> std::result::Result<(), base::TubeError> {
519 let num_bytes = (num_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
520 let result = BalloonTubeResult::Adjusted { num_bytes };
521 tube.send(result).await
522}
523
524enum WSOp {
525 WSReport,
526 WSConfig {
527 bins: Vec<u32>,
528 refresh_threshold: u32,
529 report_threshold: u32,
530 },
531}
532
533async fn handle_ws_op_queue(
534 mut queue: Queue,
535 mut queue_event: EventAsync,
536 mut ws_op_rx: mpsc::Receiver<WSOp>,
537 state: Arc<AsyncRwLock<BalloonState>>,
538 mut stop_rx: oneshot::Receiver<()>,
539) -> Result<Queue> {
540 loop {
541 let op = select_biased! {
542 next_op = ws_op_rx.next().fuse() => {
543 match next_op {
544 Some(op) => op,
545 None => {
546 error!("ws op tube was closed");
547 break;
548 }
549 }
550 }
551 _ = stop_rx => {
552 break;
553 }
554 };
555 let mut avail_desc = queue
556 .next_async(&mut queue_event)
557 .await
558 .map_err(BalloonError::AsyncAwait)?;
559 let writer = &mut avail_desc.writer;
560
561 match op {
562 WSOp::WSReport => {
563 {
564 let mut state = state.lock().await;
565 state.expecting_ws = true;
566 }
567
568 let ws_r = virtio_balloon_op {
569 type_: VIRTIO_BALLOON_WS_OP_REQUEST.into(),
570 };
571
572 writer.write_obj(ws_r).map_err(BalloonError::WriteQueue)?;
573 }
574 WSOp::WSConfig {
575 bins,
576 refresh_threshold,
577 report_threshold,
578 } => {
579 let cmd = virtio_balloon_op {
580 type_: VIRTIO_BALLOON_WS_OP_CONFIG.into(),
581 };
582
583 writer.write_obj(cmd).map_err(BalloonError::WriteQueue)?;
584 writer
585 .write_all(bins.as_bytes())
586 .map_err(BalloonError::WriteQueue)?;
587 writer
588 .write_obj(refresh_threshold)
589 .map_err(BalloonError::WriteQueue)?;
590 writer
591 .write_obj(report_threshold)
592 .map_err(BalloonError::WriteQueue)?;
593 }
594 }
595
596 queue.add_used(avail_desc);
597 queue.trigger_interrupt();
598 }
599
600 Ok(queue)
601}
602
603fn parse_balloon_ws(reader: &mut Reader) -> BalloonWS {
604 let mut ws = BalloonWS::new();
605 for res in reader.iter::<virtio_balloon_ws>() {
606 match res {
607 Ok(ws_msg) => {
608 ws_msg.update_ws(&mut ws);
609 }
610 Err(e) => {
611 error!("error while reading ws: {}", e);
612 break;
613 }
614 }
615 }
616 if ws.ws.len() < VIRTIO_BALLOON_WS_MIN_NUM_BINS || ws.ws.len() > VIRTIO_BALLOON_WS_MAX_NUM_BINS
617 {
618 error!("unexpected number of WS buckets: {}", ws.ws.len());
619 }
620 ws
621}
622
623async fn handle_ws_data_queue(
629 mut queue: Queue,
630 mut queue_event: EventAsync,
631 command_tube: &AsyncTube,
632 #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
633 state: Arc<AsyncRwLock<BalloonState>>,
634 mut stop_rx: oneshot::Receiver<()>,
635) -> Result<Queue> {
636 loop {
637 let mut avail_desc = match queue
638 .next_async_interruptable(&mut queue_event, &mut stop_rx)
639 .await
640 .map_err(BalloonError::AsyncAwait)?
641 {
642 Some(res) => res,
643 None => return Ok(queue),
644 };
645
646 let ws = parse_balloon_ws(&mut avail_desc.reader);
647
648 let mut state = state.lock().await;
649
650 let balloon_actual = (state.actual_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
652
653 if state.expecting_ws {
654 let result = BalloonTubeResult::WorkingSet { ws, balloon_actual };
655 let send_result = command_tube.send(result).await;
656 if let Err(e) = send_result {
657 error!("failed to send ws result: {}", e);
658 }
659
660 state.expecting_ws = false;
661 } else {
662 #[cfg(feature = "registered_events")]
663 if let Some(registered_evt_q) = registered_evt_q {
664 if let Err(e) = registered_evt_q
665 .send(RegisteredEventWithData::from_ws(&ws, balloon_actual))
666 .await
667 {
668 error!("failed to send VirtioBalloonWSReport event: {}", e);
669 }
670 }
671 }
672
673 queue.add_used(avail_desc);
674 queue.trigger_interrupt();
675 }
676}
677
678async fn handle_command_tube(
681 command_tube: &AsyncTube,
682 interrupt: Interrupt,
683 state: Arc<AsyncRwLock<BalloonState>>,
684 mut stats_tx: mpsc::Sender<()>,
685 mut ws_op_tx: mpsc::Sender<WSOp>,
686 mut stop_rx: oneshot::Receiver<()>,
687) -> Result<()> {
688 loop {
689 let cmd_res = select_biased! {
690 res = command_tube.next().fuse() => res,
691 _ = stop_rx => return Ok(())
692 };
693 match cmd_res {
694 Ok(command) => match command {
695 BalloonTubeCommand::Adjust {
696 num_bytes,
697 allow_failure,
698 } => {
699 let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as u32;
700 let mut state = state.lock().await;
701
702 state.num_pages = num_pages;
703 interrupt.signal_config_changed();
704
705 if allow_failure {
706 if num_pages == state.actual_pages {
707 send_adjusted_response(command_tube, num_pages)
708 .await
709 .map_err(BalloonError::SendResponse)?;
710 } else {
711 state.failable_update = true;
712 }
713 }
714 }
715 BalloonTubeCommand::WorkingSetConfig {
716 bins,
717 refresh_threshold,
718 report_threshold,
719 } => {
720 if let Err(e) = ws_op_tx.try_send(WSOp::WSConfig {
721 bins,
722 refresh_threshold,
723 report_threshold,
724 }) {
725 error!("failed to send config to ws handler: {}", e);
726 }
727 }
728 BalloonTubeCommand::Stats => {
729 if let Err(e) = stats_tx.try_send(()) {
730 error!("failed to signal the stat handler: {}", e);
731 }
732 }
733 BalloonTubeCommand::WorkingSet => {
734 if let Err(e) = ws_op_tx.try_send(WSOp::WSReport) {
735 error!("failed to send report request to ws handler: {}", e);
736 }
737 }
738 },
739 #[cfg(windows)]
740 Err(base::TubeError::Recv(e)) if e.kind() == std::io::ErrorKind::TimedOut => {
741 let _ = stop_rx.await;
748 return Ok(());
749 }
750 Err(e) => {
751 return Err(BalloonError::ReceivingCommand(e));
752 }
753 }
754 }
755}
756
757async fn handle_pending_adjusted_responses(
758 pending_adjusted_response_event: EventAsync,
759 command_tube: &AsyncTube,
760 state: Arc<AsyncRwLock<BalloonState>>,
761) -> Result<()> {
762 loop {
763 pending_adjusted_response_event
764 .next_val()
765 .await
766 .map_err(BalloonError::AsyncAwait)?;
767 while let Some(num_pages) = state.lock().await.pending_adjusted_responses.pop_front() {
768 send_adjusted_response(command_tube, num_pages)
769 .await
770 .map_err(BalloonError::SendResponse)?;
771 }
772 }
773}
774
775struct BalloonQueues {
777 inflate: Queue,
778 deflate: Queue,
779 stats: Option<Queue>,
780 reporting: Option<Queue>,
781 ws_data: Option<Queue>,
782 ws_op: Option<Queue>,
783}
784
785impl BalloonQueues {
786 fn new(inflate: Queue, deflate: Queue) -> Self {
787 BalloonQueues {
788 inflate,
789 deflate,
790 stats: None,
791 reporting: None,
792 ws_data: None,
793 ws_op: None,
794 }
795 }
796}
797
798struct PausedQueues {
800 inflate: Queue,
801 deflate: Queue,
802 stats: Option<Queue>,
803 reporting: Option<Queue>,
804 ws_data: Option<Queue>,
805 ws_op: Option<Queue>,
806}
807
808impl PausedQueues {
809 fn new(inflate: Queue, deflate: Queue) -> Self {
810 PausedQueues {
811 inflate,
812 deflate,
813 stats: None,
814 reporting: None,
815 ws_data: None,
816 ws_op: None,
817 }
818 }
819}
820
821fn apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F)
822where
823 F: FnMut(Queue) -> R,
824{
825 if let Some(queue) = queue_opt {
826 func(queue);
827 }
828}
829
830impl From<Box<PausedQueues>> for BTreeMap<usize, Queue> {
831 fn from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue> {
832 let mut ret = Vec::new();
833 ret.push(queues.inflate);
834 ret.push(queues.deflate);
835 apply_if_some(queues.stats, |stats| ret.push(stats));
836 apply_if_some(queues.reporting, |reporting| ret.push(reporting));
837 apply_if_some(queues.ws_data, |ws_data| ret.push(ws_data));
838 apply_if_some(queues.ws_op, |ws_op| ret.push(ws_op));
839 ret.into_iter().enumerate().collect()
842 }
843}
844
845fn free_memory(
846 vm_memory_client: &VmMemoryClient,
847 mem: &GuestMemory,
848 ranges: Vec<(GuestAddress, u64)>,
849) {
850 #[cfg(any(target_os = "android", target_os = "linux"))]
859 if mem.locked() && !mem.use_punchhole_locked() {
860 for (guest_address, len) in ranges {
861 if let Err(e) = mem.remove_range(guest_address, len) {
862 warn!("Marking pages unused failed: {}, addr={}", e, guest_address);
863 }
864 }
865 return;
866 }
867 if let Err(e) = vm_memory_client.dynamically_free_memory_ranges(ranges) {
868 warn!("Failed to dynamically free memory ranges: {e:#}");
869 }
870}
871
872fn reclaim_memory(vm_memory_client: &VmMemoryClient, ranges: Vec<(GuestAddress, u64)>) {
873 if let Err(e) = vm_memory_client.dynamically_reclaim_memory_ranges(ranges) {
874 warn!("Failed to dynamically reclaim memory range: {e:#}");
875 }
876}
877
878struct WorkerReturn {
881 release_memory_tube: Option<Tube>,
882 command_tube: Tube,
883 #[cfg(feature = "registered_events")]
884 registered_evt_q: Option<SendTube>,
885 paused_queues: Option<PausedQueues>,
886 vm_memory_client: VmMemoryClient,
887}
888
889fn run_worker(
892 inflate_queue: Queue,
893 deflate_queue: Queue,
894 stats_queue: Option<Queue>,
895 reporting_queue: Option<Queue>,
896 ws_data_queue: Option<Queue>,
897 ws_op_queue: Option<Queue>,
898 command_tube: Tube,
899 vm_memory_client: VmMemoryClient,
900 mem: GuestMemory,
901 release_memory_tube: Option<Tube>,
902 interrupt: Interrupt,
903 kill_evt: Event,
904 target_reached_evt: Event,
905 pending_adjusted_response_event: Event,
906 state: Arc<AsyncRwLock<BalloonState>>,
907 #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
908) -> WorkerReturn {
909 let ex = Executor::new().unwrap();
910 let command_tube = AsyncTube::new(&ex, command_tube).unwrap();
911 #[cfg(feature = "registered_events")]
912 let registered_evt_q_async = registered_evt_q
913 .as_ref()
914 .map(|q| SendTubeAsync::new(q.try_clone().unwrap(), &ex).unwrap());
915
916 let mut stop_queue_oneshots = Vec::new();
917
918 let paused_queues = {
920 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
922 let inflate_queue_evt = inflate_queue
923 .event()
924 .try_clone()
925 .expect("failed to clone queue event");
926 let inflate = handle_queue(
927 inflate_queue,
928 EventAsync::new(inflate_queue_evt, &ex).expect("failed to create async event"),
929 release_memory_tube.as_ref(),
930 |ranges| free_memory(&vm_memory_client, &mem, ranges),
931 stop_rx,
932 );
933 let inflate = inflate.fuse();
934 pin_mut!(inflate);
935
936 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
938 let deflate_queue_evt = deflate_queue
939 .event()
940 .try_clone()
941 .expect("failed to clone queue event");
942 let deflate = handle_queue(
943 deflate_queue,
944 EventAsync::new(deflate_queue_evt, &ex).expect("failed to create async event"),
945 None,
946 |ranges| reclaim_memory(&vm_memory_client, ranges),
947 stop_rx,
948 );
949 let deflate = deflate.fuse();
950 pin_mut!(deflate);
951
952 let (stats_tx, stats_rx) = mpsc::channel::<()>(1);
954 let has_stats_queue = stats_queue.is_some();
955 let stats = if let Some(stats_queue) = stats_queue {
956 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
957 let stats_queue_evt = stats_queue
958 .event()
959 .try_clone()
960 .expect("failed to clone queue event");
961 handle_stats_queue(
962 stats_queue,
963 EventAsync::new(stats_queue_evt, &ex).expect("failed to create async event"),
964 stats_rx,
965 &command_tube,
966 #[cfg(feature = "registered_events")]
967 registered_evt_q_async.as_ref(),
968 state.clone(),
969 stop_rx,
970 )
971 .left_future()
972 } else {
973 std::future::pending().right_future()
974 };
975 let stats = stats.fuse();
976 pin_mut!(stats);
977
978 let has_reporting_queue = reporting_queue.is_some();
980 let reporting = if let Some(reporting_queue) = reporting_queue {
981 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
982 let reporting_queue_evt = reporting_queue
983 .event()
984 .try_clone()
985 .expect("failed to clone queue event");
986 handle_reporting_queue(
987 reporting_queue,
988 EventAsync::new(reporting_queue_evt, &ex).expect("failed to create async event"),
989 release_memory_tube.as_ref(),
990 |ranges| free_memory(&vm_memory_client, &mem, ranges),
991 stop_rx,
992 )
993 .left_future()
994 } else {
995 std::future::pending().right_future()
996 };
997 let reporting = reporting.fuse();
998 pin_mut!(reporting);
999
1000 let has_ws_data_queue = ws_data_queue.is_some();
1003 let ws_data = if let Some(ws_data_queue) = ws_data_queue {
1004 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1005 let ws_data_queue_evt = ws_data_queue
1006 .event()
1007 .try_clone()
1008 .expect("failed to clone queue event");
1009 handle_ws_data_queue(
1010 ws_data_queue,
1011 EventAsync::new(ws_data_queue_evt, &ex).expect("failed to create async event"),
1012 &command_tube,
1013 #[cfg(feature = "registered_events")]
1014 registered_evt_q_async.as_ref(),
1015 state.clone(),
1016 stop_rx,
1017 )
1018 .left_future()
1019 } else {
1020 std::future::pending().right_future()
1021 };
1022 let ws_data = ws_data.fuse();
1023 pin_mut!(ws_data);
1024
1025 let (ws_op_tx, ws_op_rx) = mpsc::channel::<WSOp>(1);
1026 let has_ws_op_queue = ws_op_queue.is_some();
1027 let ws_op = if let Some(ws_op_queue) = ws_op_queue {
1028 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1029 let ws_op_queue_evt = ws_op_queue
1030 .event()
1031 .try_clone()
1032 .expect("failed to clone queue event");
1033 handle_ws_op_queue(
1034 ws_op_queue,
1035 EventAsync::new(ws_op_queue_evt, &ex).expect("failed to create async event"),
1036 ws_op_rx,
1037 state.clone(),
1038 stop_rx,
1039 )
1040 .left_future()
1041 } else {
1042 std::future::pending().right_future()
1043 };
1044 let ws_op = ws_op.fuse();
1045 pin_mut!(ws_op);
1046
1047 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1049 let command = handle_command_tube(
1050 &command_tube,
1051 interrupt.clone(),
1052 state.clone(),
1053 stats_tx,
1054 ws_op_tx,
1055 stop_rx,
1056 );
1057 pin_mut!(command);
1058
1059 let target_reached = handle_target_reached(&ex, target_reached_evt, &vm_memory_client);
1061 pin_mut!(target_reached);
1062
1063 let kill = async_utils::await_and_exit(&ex, kill_evt);
1065 pin_mut!(kill);
1066
1067 let pending_adjusted = handle_pending_adjusted_responses(
1068 EventAsync::new(pending_adjusted_response_event, &ex)
1069 .expect("failed to create async event"),
1070 &command_tube,
1071 state,
1072 );
1073 pin_mut!(pending_adjusted);
1074
1075 let res = ex.run_until(async {
1076 select! {
1077 _ = kill.fuse() => (),
1078 _ = inflate => return Err(anyhow!("inflate stopped unexpectedly")),
1079 _ = deflate => return Err(anyhow!("deflate stopped unexpectedly")),
1080 _ = stats => return Err(anyhow!("stats stopped unexpectedly")),
1081 _ = reporting => return Err(anyhow!("reporting stopped unexpectedly")),
1082 _ = command.fuse() => return Err(anyhow!("command stopped unexpectedly")),
1083 _ = ws_op => return Err(anyhow!("ws_op stopped unexpectedly")),
1084 _ = pending_adjusted.fuse() => return Err(anyhow!("pending_adjusted stopped unexpectedly")),
1085 _ = ws_data => return Err(anyhow!("ws_data stopped unexpectedly")),
1086 _ = target_reached.fuse() => return Err(anyhow!("target_reached stopped unexpectedly")),
1087 }
1088
1089 for stop_tx in stop_queue_oneshots {
1092 if stop_tx.send(()).is_err() {
1093 return Err(anyhow!("failed to request stop for queue future"));
1094 }
1095 }
1096
1097 let mut paused_queues = PausedQueues::new(
1100 inflate.await,
1101 deflate.await,
1102 );
1103 if has_reporting_queue {
1104 paused_queues.reporting = Some(reporting.await);
1105 }
1106 if has_stats_queue {
1107 paused_queues.stats = Some(stats.await);
1108 }
1109 if has_ws_data_queue {
1110 paused_queues.ws_data = Some(ws_data.await.context("failed to stop ws_data queue")?);
1111 }
1112 if has_ws_op_queue {
1113 paused_queues.ws_op = Some(ws_op.await.context("failed to stop ws_op queue")?);
1114 }
1115 Ok(paused_queues)
1116 });
1117
1118 match res {
1119 Err(e) => {
1120 error!("error happened in executor: {}", e);
1121 None
1122 }
1123 Ok(main_future_res) => match main_future_res {
1124 Ok(paused_queues) => Some(paused_queues),
1125 Err(e) => {
1126 error!("error happened in main balloon future: {}", e);
1127 None
1128 }
1129 },
1130 }
1131 };
1132
1133 WorkerReturn {
1134 command_tube: command_tube.into(),
1135 paused_queues,
1136 release_memory_tube,
1137 #[cfg(feature = "registered_events")]
1138 registered_evt_q,
1139 vm_memory_client,
1140 }
1141}
1142
1143async fn handle_target_reached(
1144 ex: &Executor,
1145 target_reached_evt: Event,
1146 vm_memory_client: &VmMemoryClient,
1147) -> anyhow::Result<()> {
1148 let event_async =
1149 EventAsync::new(target_reached_evt, ex).context("failed to create EventAsync")?;
1150 loop {
1151 let _ = event_async.next_val().await;
1153 if let Err(e) = vm_memory_client.balloon_target_reached(0) {
1156 warn!("Failed to send or receive allocation complete request: {e:#}");
1157 }
1158 }
1159 #[allow(unreachable_code)]
1163 Ok(())
1164}
1165
1166pub struct Balloon {
1168 command_tube: Option<Tube>,
1169 vm_memory_client: Option<VmMemoryClient>,
1170 release_memory_tube: Option<Tube>,
1171 pending_adjusted_response_event: Event,
1172 state: Arc<AsyncRwLock<BalloonState>>,
1173 features: u64,
1174 acked_features: u64,
1175 worker_thread: Option<WorkerThread<WorkerReturn>>,
1176 #[cfg(feature = "registered_events")]
1177 registered_evt_q: Option<SendTube>,
1178 ws_num_bins: u8,
1179 target_reached_evt: Option<Event>,
1180 queue_sizes: Vec<u16>,
1181}
1182
1183#[derive(Serialize, Deserialize)]
1185struct BalloonSnapshot {
1186 state: BalloonState,
1187 features: u64,
1188 acked_features: u64,
1189 ws_num_bins: u8,
1190}
1191
1192impl Balloon {
1193 pub fn new(
1199 base_features: u64,
1200 command_tube: Tube,
1201 vm_memory_client: VmMemoryClient,
1202 release_memory_tube: Option<Tube>,
1203 init_balloon_size: u64,
1204 enabled_features: u64,
1205 #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
1206 ws_num_bins: u8,
1207 ) -> Result<Balloon> {
1208 let features = base_features
1209 | 1 << VIRTIO_BALLOON_F_MUST_TELL_HOST
1210 | 1 << VIRTIO_BALLOON_F_STATS_VQ
1211 | 1 << VIRTIO_BALLOON_F_DEFLATE_ON_OOM
1212 | enabled_features;
1213
1214 let mut queue_sizes = Vec::new();
1215 queue_sizes.push(QUEUE_SIZE); queue_sizes.push(QUEUE_SIZE); if features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1218 queue_sizes.push(QUEUE_SIZE); }
1220 if features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1221 queue_sizes.push(QUEUE_SIZE); }
1223 if features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1224 queue_sizes.push(QUEUE_SIZE); queue_sizes.push(QUEUE_SIZE); }
1227
1228 Ok(Balloon {
1229 command_tube: Some(command_tube),
1230 vm_memory_client: Some(vm_memory_client),
1231 release_memory_tube,
1232 pending_adjusted_response_event: Event::new().map_err(BalloonError::CreatingEvent)?,
1233 state: Arc::new(AsyncRwLock::new(BalloonState {
1234 num_pages: (init_balloon_size >> VIRTIO_BALLOON_PFN_SHIFT) as u32,
1235 actual_pages: 0,
1236 failable_update: false,
1237 pending_adjusted_responses: VecDeque::new(),
1238 expecting_ws: false,
1239 })),
1240 worker_thread: None,
1241 features,
1242 acked_features: 0,
1243 #[cfg(feature = "registered_events")]
1244 registered_evt_q,
1245 ws_num_bins,
1246 target_reached_evt: None,
1247 queue_sizes,
1248 })
1249 }
1250
1251 fn get_config(&self) -> virtio_balloon_config {
1252 let state = block_on(self.state.lock());
1253 virtio_balloon_config {
1254 num_pages: state.num_pages.into(),
1255 actual: state.actual_pages.into(),
1256 free_page_hint_cmd_id: 0.into(),
1261 poison_val: 0.into(),
1262 ws_num_bins: self.ws_num_bins,
1263 _reserved: [0, 0, 0],
1264 }
1265 }
1266
1267 fn stop_worker(&mut self) -> StoppedWorker<PausedQueues> {
1268 if let Some(worker_thread) = self.worker_thread.take() {
1269 let worker_ret = worker_thread.stop();
1270 self.release_memory_tube = worker_ret.release_memory_tube;
1271 self.command_tube = Some(worker_ret.command_tube);
1272 #[cfg(feature = "registered_events")]
1273 {
1274 self.registered_evt_q = worker_ret.registered_evt_q;
1275 }
1276 self.vm_memory_client = Some(worker_ret.vm_memory_client);
1277
1278 if let Some(queues) = worker_ret.paused_queues {
1279 StoppedWorker::WithQueues(Box::new(queues))
1280 } else {
1281 StoppedWorker::MissingQueues
1282 }
1283 } else {
1284 StoppedWorker::AlreadyStopped
1285 }
1286 }
1287
1288 fn get_queues_from_map(
1292 &self,
1293 mut queues: BTreeMap<usize, Queue>,
1294 ) -> anyhow::Result<BalloonQueues> {
1295 let inflate_queue = queues.remove(&INFLATEQ).context("missing inflateq")?;
1296 let deflate_queue = queues.remove(&DEFLATEQ).context("missing deflateq")?;
1297 let mut queue_struct = BalloonQueues::new(inflate_queue, deflate_queue);
1298
1299 let mut next_queue_index = 2;
1301 let mut next_queue = || {
1302 let idx = next_queue_index;
1303 next_queue_index += 1;
1304 idx
1305 };
1306
1307 if self.features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1308 let statsq = next_queue();
1309 if self.acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1310 queue_struct.stats = Some(queues.remove(&statsq).context("missing statsq")?);
1311 }
1312 }
1313
1314 if self.features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1315 let reporting_vq = next_queue();
1316 if self.acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1317 queue_struct.reporting = Some(
1318 queues
1319 .remove(&reporting_vq)
1320 .context("missing reporting_vq")?,
1321 );
1322 }
1323 }
1324
1325 if self.features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1326 let ws_data_vq = next_queue();
1327 let ws_op_vq = next_queue();
1328 if self.acked_features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1329 queue_struct.ws_data =
1330 Some(queues.remove(&ws_data_vq).context("missing ws_data_vq")?);
1331 queue_struct.ws_op = Some(queues.remove(&ws_op_vq).context("missing ws_op_vq")?);
1332 }
1333 }
1334
1335 if !queues.is_empty() {
1336 return Err(anyhow!("unexpected queues {:?}", queues.into_keys()));
1337 }
1338
1339 Ok(queue_struct)
1340 }
1341
1342 fn start_worker(
1343 &mut self,
1344 mem: GuestMemory,
1345 interrupt: Interrupt,
1346 queues: BalloonQueues,
1347 ) -> anyhow::Result<()> {
1348 let (self_target_reached_evt, target_reached_evt) = Event::new()
1349 .and_then(|e| Ok((e.try_clone()?, e)))
1350 .context("failed to create target_reached Event pair: {}")?;
1351 self.target_reached_evt = Some(self_target_reached_evt);
1352
1353 let state = self.state.clone();
1354
1355 let command_tube = self.command_tube.take().unwrap();
1356
1357 let vm_memory_client = self.vm_memory_client.take().unwrap();
1358 let release_memory_tube = self.release_memory_tube.take();
1359 #[cfg(feature = "registered_events")]
1360 let registered_evt_q = self.registered_evt_q.take();
1361 let pending_adjusted_response_event = self
1362 .pending_adjusted_response_event
1363 .try_clone()
1364 .context("failed to clone Event")?;
1365
1366 self.worker_thread = Some(WorkerThread::start("v_balloon", move |kill_evt| {
1367 run_worker(
1368 queues.inflate,
1369 queues.deflate,
1370 queues.stats,
1371 queues.reporting,
1372 queues.ws_data,
1373 queues.ws_op,
1374 command_tube,
1375 vm_memory_client,
1376 mem,
1377 release_memory_tube,
1378 interrupt,
1379 kill_evt,
1380 target_reached_evt,
1381 pending_adjusted_response_event,
1382 state,
1383 #[cfg(feature = "registered_events")]
1384 registered_evt_q,
1385 )
1386 }));
1387
1388 Ok(())
1389 }
1390}
1391
1392impl VirtioDevice for Balloon {
1393 fn keep_rds(&self) -> Vec<RawDescriptor> {
1394 let mut rds = Vec::new();
1395 if let Some(command_tube) = &self.command_tube {
1396 rds.push(command_tube.as_raw_descriptor());
1397 }
1398 if let Some(vm_memory_client) = &self.vm_memory_client {
1399 rds.push(vm_memory_client.as_raw_descriptor());
1400 }
1401 if let Some(release_memory_tube) = &self.release_memory_tube {
1402 rds.push(release_memory_tube.as_raw_descriptor());
1403 }
1404 #[cfg(feature = "registered_events")]
1405 if let Some(registered_evt_q) = &self.registered_evt_q {
1406 rds.push(registered_evt_q.as_raw_descriptor());
1407 }
1408 rds.push(self.pending_adjusted_response_event.as_raw_descriptor());
1409 rds
1410 }
1411
1412 fn device_type(&self) -> DeviceType {
1413 DeviceType::Balloon
1414 }
1415
1416 fn queue_max_sizes(&self) -> &[u16] {
1417 &self.queue_sizes
1418 }
1419
1420 fn read_config(&self, offset: u64, data: &mut [u8]) {
1421 copy_config(data, 0, self.get_config().as_bytes(), offset);
1422 }
1423
1424 fn write_config(&mut self, offset: u64, data: &[u8]) {
1425 let mut config = self.get_config();
1426 copy_config(config.as_mut_bytes(), offset, data, 0);
1427 let mut state = block_on(self.state.lock());
1428 state.actual_pages = config.actual.to_native();
1429
1430 if config.num_pages == config.actual {
1432 debug!(
1433 "sending target reached event at {}",
1434 u32::from(config.num_pages)
1435 );
1436 self.target_reached_evt.as_ref().map(|e| e.signal());
1437 }
1438 if state.failable_update && state.actual_pages == state.num_pages {
1439 state.failable_update = false;
1440 let num_pages = state.num_pages;
1441 state.pending_adjusted_responses.push_back(num_pages);
1442 let _ = self.pending_adjusted_response_event.signal();
1443 }
1444 }
1445
1446 fn features(&self) -> u64 {
1447 self.features
1448 }
1449
1450 fn ack_features(&mut self, mut value: u64) {
1451 if value & !self.features != 0 {
1452 warn!("virtio_balloon got unknown feature ack {:x}", value);
1453 value &= self.features;
1454 }
1455 self.acked_features |= value;
1456 }
1457
1458 fn activate(
1459 &mut self,
1460 mem: GuestMemory,
1461 interrupt: Interrupt,
1462 queues: BTreeMap<usize, Queue>,
1463 ) -> anyhow::Result<()> {
1464 let queues = self.get_queues_from_map(queues)?;
1465 self.start_worker(mem, interrupt, queues)
1466 }
1467
1468 fn reset(&mut self) -> anyhow::Result<()> {
1469 let _worker = self.stop_worker();
1470 Ok(())
1471 }
1472
1473 fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
1474 match self.stop_worker() {
1475 StoppedWorker::WithQueues(paused_queues) => Ok(Some(paused_queues.into())),
1476 StoppedWorker::MissingQueues => {
1477 anyhow::bail!("balloon queue workers did not stop cleanly.")
1478 }
1479 StoppedWorker::AlreadyStopped => {
1480 Ok(None)
1482 }
1483 }
1484 }
1485
1486 fn virtio_wake(
1487 &mut self,
1488 queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
1489 ) -> anyhow::Result<()> {
1490 if let Some((mem, interrupt, queues)) = queues_state {
1491 if queues.len() < 2 {
1492 anyhow::bail!("{} queues were found, but an activated balloon must have at least 2 active queues.", queues.len());
1493 }
1494
1495 let balloon_queues = self.get_queues_from_map(queues)?;
1496 self.start_worker(mem, interrupt, balloon_queues)?;
1497 }
1498 Ok(())
1499 }
1500
1501 fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1502 let state = self
1503 .state
1504 .lock()
1505 .now_or_never()
1506 .context("failed to acquire balloon lock")?;
1507 AnySnapshot::to_any(BalloonSnapshot {
1508 features: self.features,
1509 acked_features: self.acked_features,
1510 state: state.clone(),
1511 ws_num_bins: self.ws_num_bins,
1512 })
1513 .context("failed to serialize balloon state")
1514 }
1515
1516 fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1517 let snap: BalloonSnapshot = AnySnapshot::from_any(data).context("error deserializing")?;
1518 if snap.features != self.features {
1519 anyhow::bail!(
1520 "balloon: expected features to match, but they did not. Live: {:?}, snapshot {:?}",
1521 self.features,
1522 snap.features,
1523 );
1524 }
1525
1526 let mut state = self
1527 .state
1528 .lock()
1529 .now_or_never()
1530 .context("failed to acquire balloon lock")?;
1531 *state = snap.state;
1532 self.ws_num_bins = snap.ws_num_bins;
1533 self.acked_features = snap.acked_features;
1534 Ok(())
1535 }
1536}
1537
1538#[cfg(test)]
1539mod tests {
1540 use super::*;
1541 use crate::suspendable_virtio_tests;
1542 use crate::virtio::descriptor_utils::create_descriptor_chain;
1543 use crate::virtio::descriptor_utils::DescriptorType;
1544
1545 #[test]
1546 fn desc_parsing_inflate() {
1547 let memory_start_addr = GuestAddress(0x0);
1550 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1551 memory
1552 .write_obj_at_addr(0x10u32, GuestAddress(0x100))
1553 .unwrap();
1554 memory
1555 .write_obj_at_addr(0xaa55aa55u32, GuestAddress(0x104))
1556 .unwrap();
1557
1558 let mut chain = create_descriptor_chain(
1559 &memory,
1560 GuestAddress(0x0),
1561 GuestAddress(0x100),
1562 vec![(DescriptorType::Readable, 8)],
1563 0,
1564 )
1565 .expect("create_descriptor_chain failed");
1566
1567 let mut addrs = Vec::new();
1568 let res = handle_address_chain(None, &mut chain, &mut |mut ranges| {
1569 addrs.append(&mut ranges)
1570 });
1571 assert!(res.is_ok());
1572 assert_eq!(addrs.len(), 2);
1573 assert_eq!(
1574 addrs[0].0,
1575 GuestAddress(0x10u64 << VIRTIO_BALLOON_PFN_SHIFT)
1576 );
1577 assert_eq!(
1578 addrs[1].0,
1579 GuestAddress(0xaa55aa55u64 << VIRTIO_BALLOON_PFN_SHIFT)
1580 );
1581 }
1582
1583 struct BalloonContext {
1584 _ctrl_tube: Tube,
1585 _mem_client_tube: Tube,
1586 }
1587
1588 fn modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon) {
1589 balloon.ws_num_bins = !balloon.ws_num_bins;
1590 }
1591
1592 fn create_device() -> (BalloonContext, Balloon) {
1593 let (_ctrl_tube, ctrl_tube_device) = Tube::pair().unwrap();
1594 let (_mem_client_tube, mem_client_tube_device) = Tube::pair().unwrap();
1595 (
1596 BalloonContext {
1597 _ctrl_tube,
1598 _mem_client_tube,
1599 },
1600 Balloon::new(
1601 0,
1602 ctrl_tube_device,
1603 VmMemoryClient::new(mem_client_tube_device),
1604 None,
1605 1024,
1606 0,
1607 #[cfg(feature = "registered_events")]
1608 None,
1609 0,
1610 )
1611 .unwrap(),
1612 )
1613 }
1614
1615 suspendable_virtio_tests!(balloon, create_device, 2, modify_device);
1616}