1use std::collections::BTreeMap;
6use std::collections::VecDeque;
7use std::io::Write;
8use std::sync::Arc;
9
10use anyhow::anyhow;
11use anyhow::Context;
12use balloon_control::BalloonStats;
13use balloon_control::BalloonTubeCommand;
14use balloon_control::BalloonTubeResult;
15use balloon_control::BalloonWS;
16use balloon_control::WSBucket;
17use balloon_control::VIRTIO_BALLOON_WS_MAX_NUM_BINS;
18use balloon_control::VIRTIO_BALLOON_WS_MIN_NUM_BINS;
19use base::debug;
20use base::error;
21use base::warn;
22use base::AsRawDescriptor;
23use base::Event;
24use base::RawDescriptor;
25#[cfg(feature = "registered_events")]
26use base::SendTube;
27use base::Tube;
28use base::WorkerThread;
29use cros_async::block_on;
30use cros_async::sync::RwLock as AsyncRwLock;
31use cros_async::AsyncTube;
32use cros_async::EventAsync;
33use cros_async::Executor;
34#[cfg(feature = "registered_events")]
35use cros_async::SendTubeAsync;
36use data_model::Le16;
37use data_model::Le32;
38use data_model::Le64;
39use futures::channel::mpsc;
40use futures::channel::oneshot;
41use futures::pin_mut;
42use futures::select;
43use futures::select_biased;
44use futures::FutureExt;
45use futures::StreamExt;
46use remain::sorted;
47use serde::Deserialize;
48use serde::Serialize;
49use snapshot::AnySnapshot;
50use thiserror::Error as ThisError;
51use vm_control::api::VmMemoryClient;
52#[cfg(feature = "registered_events")]
53use vm_control::RegisteredEventWithData;
54use vm_memory::GuestAddress;
55use vm_memory::GuestMemory;
56use zerocopy::FromBytes;
57use zerocopy::Immutable;
58use zerocopy::IntoBytes;
59use zerocopy::KnownLayout;
60
61use super::async_utils;
62use super::copy_config;
63use super::create_stop_oneshot;
64use super::DescriptorChain;
65use super::DeviceType;
66use super::Interrupt;
67use super::Queue;
68use super::Reader;
69use super::StoppedWorker;
70use super::VirtioDevice;
71use crate::UnpinRequest;
72use crate::UnpinResponse;
73
74#[sorted]
75#[derive(ThisError, Debug)]
76pub enum BalloonError {
77 #[error("failed async await: {0}")]
79 AsyncAwait(cros_async::AsyncError),
80 #[error("failed async await: {0}")]
82 AsyncAwaitAnyhow(anyhow::Error),
83 #[error("failed to create event: {0}")]
85 CreatingEvent(base::Error),
86 #[error("failed to create async message receiver: {0}")]
88 CreatingMessageReceiver(base::TubeError),
89 #[error("failed to receive command message: {0}")]
91 ReceivingCommand(base::TubeError),
92 #[error("failed to send command response: {0}")]
94 SendResponse(base::TubeError),
95 #[error("failed to write to virtqueue: {0}")]
97 WriteQueue(std::io::Error),
98 #[error("failed to write config event: {0}")]
100 WritingConfigEvent(base::Error),
101}
102pub type Result<T> = std::result::Result<T, BalloonError>;
103
104const QUEUE_SIZE: u16 = 128;
105
106const INFLATEQ: usize = 0;
108const DEFLATEQ: usize = 1;
109
110const VIRTIO_BALLOON_PFN_SHIFT: u32 = 12;
111const VIRTIO_BALLOON_PF_SIZE: u64 = 1 << VIRTIO_BALLOON_PFN_SHIFT;
112
113const VIRTIO_BALLOON_F_MUST_TELL_HOST: u32 = 0; const VIRTIO_BALLOON_F_STATS_VQ: u32 = 1; const VIRTIO_BALLOON_F_DEFLATE_ON_OOM: u32 = 2; const VIRTIO_BALLOON_F_PAGE_REPORTING: u32 = 5; const VIRTIO_BALLOON_F_WS_REPORTING: u32 = 8; #[derive(Copy, Clone)]
122#[repr(u32)]
123pub enum BalloonFeatures {
125 PageReporting = VIRTIO_BALLOON_F_PAGE_REPORTING,
127 WSReporting = VIRTIO_BALLOON_F_WS_REPORTING,
129}
130
131#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
133#[repr(C)]
134struct virtio_balloon_config {
135 num_pages: Le32,
136 actual: Le32,
137 free_page_hint_cmd_id: Le32,
138 poison_val: Le32,
139 ws_num_bins: u8,
141 _reserved: [u8; 3],
142}
143
144#[derive(Clone, Default, Serialize, Deserialize)]
146struct BalloonState {
147 num_pages: u32,
148 actual_pages: u32,
149 expecting_ws: bool,
150 failable_update: bool,
154 pending_adjusted_responses: VecDeque<u32>,
155}
156
157const VIRTIO_BALLOON_S_SWAP_IN: u16 = 0;
159const VIRTIO_BALLOON_S_SWAP_OUT: u16 = 1;
160const VIRTIO_BALLOON_S_MAJFLT: u16 = 2;
161const VIRTIO_BALLOON_S_MINFLT: u16 = 3;
162const VIRTIO_BALLOON_S_MEMFREE: u16 = 4;
163const VIRTIO_BALLOON_S_MEMTOT: u16 = 5;
164const VIRTIO_BALLOON_S_AVAIL: u16 = 6;
165const VIRTIO_BALLOON_S_CACHES: u16 = 7;
166const VIRTIO_BALLOON_S_HTLB_PGALLOC: u16 = 8;
167const VIRTIO_BALLOON_S_HTLB_PGFAIL: u16 = 9;
168const VIRTIO_BALLOON_S_NONSTANDARD_SHMEM: u16 = 65534;
169const VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE: u16 = 65535;
170
171#[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
173#[repr(C, packed)]
174struct BalloonStat {
175 tag: Le16,
176 val: Le64,
177}
178
179impl BalloonStat {
180 fn update_stats(&self, stats: &mut BalloonStats) {
181 let val = Some(self.val.to_native());
182 match self.tag.to_native() {
183 VIRTIO_BALLOON_S_SWAP_IN => stats.swap_in = val,
184 VIRTIO_BALLOON_S_SWAP_OUT => stats.swap_out = val,
185 VIRTIO_BALLOON_S_MAJFLT => stats.major_faults = val,
186 VIRTIO_BALLOON_S_MINFLT => stats.minor_faults = val,
187 VIRTIO_BALLOON_S_MEMFREE => stats.free_memory = val,
188 VIRTIO_BALLOON_S_MEMTOT => stats.total_memory = val,
189 VIRTIO_BALLOON_S_AVAIL => stats.available_memory = val,
190 VIRTIO_BALLOON_S_CACHES => stats.disk_caches = val,
191 VIRTIO_BALLOON_S_HTLB_PGALLOC => stats.hugetlb_allocations = val,
192 VIRTIO_BALLOON_S_HTLB_PGFAIL => stats.hugetlb_failures = val,
193 VIRTIO_BALLOON_S_NONSTANDARD_SHMEM => stats.shared_memory = val,
194 VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE => stats.unevictable_memory = val,
195 _ => (),
196 }
197 }
198}
199
200#[repr(C)]
202#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
203struct virtio_balloon_ws {
204 tag: Le16,
205 node_id: Le16,
206 _reserved: [u8; 4],
209 idle_age_ms: Le64,
210 memory_size_bytes: [Le64; 2],
212}
213
214impl virtio_balloon_ws {
215 fn update_ws(&self, ws: &mut BalloonWS) {
216 let bucket = WSBucket {
217 age: self.idle_age_ms.to_native(),
218 bytes: [
219 self.memory_size_bytes[0].to_native(),
220 self.memory_size_bytes[1].to_native(),
221 ],
222 };
223 ws.ws.push(bucket);
224 }
225}
226
227const _VIRTIO_BALLOON_WS_OP_INVALID: u16 = 0;
228const VIRTIO_BALLOON_WS_OP_REQUEST: u16 = 1;
229const VIRTIO_BALLOON_WS_OP_CONFIG: u16 = 2;
230const _VIRTIO_BALLOON_WS_OP_DISCARD: u16 = 3;
231
232#[repr(C, packed)]
234#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
235struct virtio_balloon_op {
236 type_: Le16,
237}
238
239fn invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F)
240where
241 F: FnMut(Vec<(GuestAddress, u64)>),
242{
243 desc_handler(
244 ranges
245 .into_iter()
246 .map(|range| (GuestAddress(range.0), range.1))
247 .collect(),
248 );
249}
250
251fn release_ranges<F>(
255 release_memory_tube: Option<&Tube>,
256 inflate_ranges: Vec<(u64, u64)>,
257 desc_handler: &mut F,
258) -> anyhow::Result<()>
259where
260 F: FnMut(Vec<(GuestAddress, u64)>),
261{
262 if let Some(tube) = release_memory_tube {
263 let unpin_ranges = inflate_ranges
264 .iter()
265 .map(|v| {
266 (
267 v.0 >> VIRTIO_BALLOON_PFN_SHIFT,
268 v.1 / VIRTIO_BALLOON_PF_SIZE,
269 )
270 })
271 .collect();
272 let req = UnpinRequest {
273 ranges: unpin_ranges,
274 };
275 if let Err(e) = tube.send(&req) {
276 error!("failed to send unpin request: {}", e);
277 } else {
278 match tube.recv() {
279 Ok(resp) => match resp {
280 UnpinResponse::Success => invoke_desc_handler(inflate_ranges, desc_handler),
281 UnpinResponse::Failed => error!("failed to handle unpin request"),
282 },
283 Err(e) => error!("failed to handle get unpin response: {}", e),
284 }
285 }
286 } else {
287 invoke_desc_handler(inflate_ranges, desc_handler);
288 }
289
290 Ok(())
291}
292
293fn handle_address_chain<F>(
295 release_memory_tube: Option<&Tube>,
296 avail_desc: &mut DescriptorChain,
297 desc_handler: &mut F,
298) -> anyhow::Result<()>
299where
300 F: FnMut(Vec<(GuestAddress, u64)>),
301{
302 let mut range_start = 0;
307 let mut range_size = 0;
308 let mut inflate_ranges: Vec<(u64, u64)> = Vec::new();
309 for res in avail_desc.reader.iter::<Le32>() {
310 let pfn = match res {
311 Ok(pfn) => pfn,
312 Err(e) => {
313 error!("error while reading unused pages: {}", e);
314 break;
315 }
316 };
317 let guest_address = (u64::from(pfn.to_native())) << VIRTIO_BALLOON_PFN_SHIFT;
318 if range_start + range_size == guest_address {
319 range_size += VIRTIO_BALLOON_PF_SIZE;
320 } else if range_start == guest_address + VIRTIO_BALLOON_PF_SIZE {
321 range_start = guest_address;
322 range_size += VIRTIO_BALLOON_PF_SIZE;
323 } else {
324 if range_size != 0 {
327 inflate_ranges.push((range_start, range_size));
328 }
329 range_start = guest_address;
330 range_size = VIRTIO_BALLOON_PF_SIZE;
331 }
332 }
333 if range_size != 0 {
334 inflate_ranges.push((range_start, range_size));
335 }
336
337 release_ranges(release_memory_tube, inflate_ranges, desc_handler)
338}
339
340async fn handle_queue<F>(
342 mut queue: Queue,
343 mut queue_event: EventAsync,
344 release_memory_tube: Option<&Tube>,
345 mut desc_handler: F,
346 mut stop_rx: oneshot::Receiver<()>,
347) -> Queue
348where
349 F: FnMut(Vec<(GuestAddress, u64)>),
350{
351 loop {
352 let mut avail_desc = match queue
353 .next_async_interruptable(&mut queue_event, &mut stop_rx)
354 .await
355 {
356 Ok(Some(res)) => res,
357 Ok(None) => return queue,
358 Err(e) => {
359 error!("Failed to read descriptor {}", e);
360 return queue;
361 }
362 };
363 if let Err(e) =
364 handle_address_chain(release_memory_tube, &mut avail_desc, &mut desc_handler)
365 {
366 error!("balloon: failed to process inflate addresses: {}", e);
367 }
368 queue.add_used(avail_desc);
369 queue.trigger_interrupt();
370 }
371}
372
373fn handle_reported_buffer<F>(
375 release_memory_tube: Option<&Tube>,
376 avail_desc: &DescriptorChain,
377 desc_handler: &mut F,
378) -> anyhow::Result<()>
379where
380 F: FnMut(Vec<(GuestAddress, u64)>),
381{
382 let reported_ranges: Vec<(u64, u64)> = avail_desc
383 .reader
384 .get_remaining_regions()
385 .chain(avail_desc.writer.get_remaining_regions())
386 .map(|r| (r.offset, r.len as u64))
387 .collect();
388
389 release_ranges(release_memory_tube, reported_ranges, desc_handler)
390}
391
392async fn handle_reporting_queue<F>(
394 mut queue: Queue,
395 mut queue_event: EventAsync,
396 release_memory_tube: Option<&Tube>,
397 mut desc_handler: F,
398 mut stop_rx: oneshot::Receiver<()>,
399) -> Queue
400where
401 F: FnMut(Vec<(GuestAddress, u64)>),
402{
403 loop {
404 let avail_desc = match queue
405 .next_async_interruptable(&mut queue_event, &mut stop_rx)
406 .await
407 {
408 Ok(Some(res)) => res,
409 Ok(None) => return queue,
410 Err(e) => {
411 error!("Failed to read descriptor {}", e);
412 return queue;
413 }
414 };
415 if let Err(e) = handle_reported_buffer(release_memory_tube, &avail_desc, &mut desc_handler)
416 {
417 error!("balloon: failed to process reported buffer: {}", e);
418 }
419 queue.add_used(avail_desc);
420 queue.trigger_interrupt();
421 }
422}
423
424fn parse_balloon_stats(reader: &mut Reader) -> BalloonStats {
425 let mut stats: BalloonStats = Default::default();
426 for res in reader.iter::<BalloonStat>() {
427 match res {
428 Ok(stat) => stat.update_stats(&mut stats),
429 Err(e) => {
430 error!("error while reading stats: {}", e);
431 break;
432 }
433 };
434 }
435 stats
436}
437
438async fn handle_stats_queue(
443 mut queue: Queue,
444 mut queue_event: EventAsync,
445 mut stats_rx: mpsc::Receiver<()>,
446 command_tube: &AsyncTube,
447 #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
448 state: Arc<AsyncRwLock<BalloonState>>,
449 mut stop_rx: oneshot::Receiver<()>,
450) -> Queue {
451 let mut avail_desc = match queue
452 .next_async_interruptable(&mut queue_event, &mut stop_rx)
453 .await
454 {
455 Ok(Some(res)) => res,
458 Ok(None) => return queue,
459 Err(e) => {
460 error!("Failed to read descriptor {}", e);
461 return queue;
462 }
463 };
464
465 loop {
466 select_biased! {
467 msg = stats_rx.next() => {
468 match msg {
470 Some(()) => (),
471 None => {
472 error!("stats signal channel was closed");
473 return queue;
474 }
475 }
476 }
477 _ = stop_rx => return queue,
478 };
479
480 queue.add_used(avail_desc);
482 queue.trigger_interrupt();
483
484 avail_desc = match queue.next_async(&mut queue_event).await {
485 Err(e) => {
486 error!("Failed to read descriptor {}", e);
487 return queue;
488 }
489 Ok(d) => d,
490 };
491 let stats = parse_balloon_stats(&mut avail_desc.reader);
492
493 let actual_pages = state.lock().await.actual_pages as u64;
494 let result = BalloonTubeResult::Stats {
495 balloon_actual: actual_pages << VIRTIO_BALLOON_PFN_SHIFT,
496 stats,
497 };
498 let send_result = command_tube.send(result).await;
499 if let Err(e) = send_result {
500 error!("failed to send stats result: {}", e);
501 }
502
503 #[cfg(feature = "registered_events")]
504 if let Some(registered_evt_q) = registered_evt_q {
505 if let Err(e) = registered_evt_q
506 .send(&RegisteredEventWithData::VirtioBalloonResize)
507 .await
508 {
509 error!("failed to send VirtioBalloonResize event: {}", e);
510 }
511 }
512 }
513}
514
515async fn send_adjusted_response(
516 tube: &AsyncTube,
517 num_pages: u32,
518) -> std::result::Result<(), base::TubeError> {
519 let num_bytes = (num_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
520 let result = BalloonTubeResult::Adjusted { num_bytes };
521 tube.send(result).await
522}
523
524enum WSOp {
525 WSReport,
526 WSConfig {
527 bins: Vec<u32>,
528 refresh_threshold: u32,
529 report_threshold: u32,
530 },
531}
532
533async fn handle_ws_op_queue(
534 mut queue: Queue,
535 mut queue_event: EventAsync,
536 mut ws_op_rx: mpsc::Receiver<WSOp>,
537 state: Arc<AsyncRwLock<BalloonState>>,
538 mut stop_rx: oneshot::Receiver<()>,
539) -> Result<Queue> {
540 loop {
541 let op = select_biased! {
542 next_op = ws_op_rx.next().fuse() => {
543 match next_op {
544 Some(op) => op,
545 None => {
546 error!("ws op tube was closed");
547 break;
548 }
549 }
550 }
551 _ = stop_rx => {
552 break;
553 }
554 };
555 let mut avail_desc = queue
556 .next_async(&mut queue_event)
557 .await
558 .map_err(BalloonError::AsyncAwait)?;
559 let writer = &mut avail_desc.writer;
560
561 match op {
562 WSOp::WSReport => {
563 {
564 let mut state = state.lock().await;
565 state.expecting_ws = true;
566 }
567
568 let ws_r = virtio_balloon_op {
569 type_: VIRTIO_BALLOON_WS_OP_REQUEST.into(),
570 };
571
572 writer.write_obj(ws_r).map_err(BalloonError::WriteQueue)?;
573 }
574 WSOp::WSConfig {
575 bins,
576 refresh_threshold,
577 report_threshold,
578 } => {
579 let cmd = virtio_balloon_op {
580 type_: VIRTIO_BALLOON_WS_OP_CONFIG.into(),
581 };
582
583 writer.write_obj(cmd).map_err(BalloonError::WriteQueue)?;
584 writer
585 .write_all(bins.as_bytes())
586 .map_err(BalloonError::WriteQueue)?;
587 writer
588 .write_obj(refresh_threshold)
589 .map_err(BalloonError::WriteQueue)?;
590 writer
591 .write_obj(report_threshold)
592 .map_err(BalloonError::WriteQueue)?;
593 }
594 }
595
596 queue.add_used(avail_desc);
597 queue.trigger_interrupt();
598 }
599
600 Ok(queue)
601}
602
603fn parse_balloon_ws(reader: &mut Reader) -> BalloonWS {
604 let mut ws = BalloonWS::new();
605 for res in reader.iter::<virtio_balloon_ws>() {
606 match res {
607 Ok(ws_msg) => {
608 ws_msg.update_ws(&mut ws);
609 }
610 Err(e) => {
611 error!("error while reading ws: {}", e);
612 break;
613 }
614 }
615 }
616 if ws.ws.len() < VIRTIO_BALLOON_WS_MIN_NUM_BINS || ws.ws.len() > VIRTIO_BALLOON_WS_MAX_NUM_BINS
617 {
618 error!("unexpected number of WS buckets: {}", ws.ws.len());
619 }
620 ws
621}
622
623async fn handle_ws_data_queue(
629 mut queue: Queue,
630 mut queue_event: EventAsync,
631 command_tube: &AsyncTube,
632 #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
633 state: Arc<AsyncRwLock<BalloonState>>,
634 mut stop_rx: oneshot::Receiver<()>,
635) -> Result<Queue> {
636 loop {
637 let mut avail_desc = match queue
638 .next_async_interruptable(&mut queue_event, &mut stop_rx)
639 .await
640 .map_err(BalloonError::AsyncAwait)?
641 {
642 Some(res) => res,
643 None => return Ok(queue),
644 };
645
646 let ws = parse_balloon_ws(&mut avail_desc.reader);
647
648 let mut state = state.lock().await;
649
650 let balloon_actual = (state.actual_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
652
653 if state.expecting_ws {
654 let result = BalloonTubeResult::WorkingSet { ws, balloon_actual };
655 let send_result = command_tube.send(result).await;
656 if let Err(e) = send_result {
657 error!("failed to send ws result: {}", e);
658 }
659
660 state.expecting_ws = false;
661 } else {
662 #[cfg(feature = "registered_events")]
663 if let Some(registered_evt_q) = registered_evt_q {
664 if let Err(e) = registered_evt_q
665 .send(RegisteredEventWithData::from_ws(&ws, balloon_actual))
666 .await
667 {
668 error!("failed to send VirtioBalloonWSReport event: {}", e);
669 }
670 }
671 }
672
673 queue.add_used(avail_desc);
674 queue.trigger_interrupt();
675 }
676}
677
678async fn handle_command_tube(
681 command_tube: &AsyncTube,
682 interrupt: Interrupt,
683 state: Arc<AsyncRwLock<BalloonState>>,
684 mut stats_tx: mpsc::Sender<()>,
685 mut ws_op_tx: mpsc::Sender<WSOp>,
686 mut stop_rx: oneshot::Receiver<()>,
687) -> Result<()> {
688 loop {
689 let cmd_res = select_biased! {
690 res = command_tube.next().fuse() => res,
691 _ = stop_rx => return Ok(())
692 };
693 match cmd_res {
694 Ok(command) => match command {
695 BalloonTubeCommand::Adjust {
696 num_bytes,
697 allow_failure,
698 } => {
699 let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as u32;
700 let mut state = state.lock().await;
701
702 state.num_pages = num_pages;
703 interrupt.signal_config_changed();
704
705 if allow_failure {
706 if num_pages == state.actual_pages {
707 send_adjusted_response(command_tube, num_pages)
708 .await
709 .map_err(BalloonError::SendResponse)?;
710 } else {
711 state.failable_update = true;
712 }
713 }
714 }
715 BalloonTubeCommand::WorkingSetConfig {
716 bins,
717 refresh_threshold,
718 report_threshold,
719 } => {
720 if let Err(e) = ws_op_tx.try_send(WSOp::WSConfig {
721 bins,
722 refresh_threshold,
723 report_threshold,
724 }) {
725 error!("failed to send config to ws handler: {}", e);
726 }
727 }
728 BalloonTubeCommand::Stats => {
729 if let Err(e) = stats_tx.try_send(()) {
730 error!("failed to signal the stat handler: {}", e);
731 }
732 }
733 BalloonTubeCommand::WorkingSet => {
734 if let Err(e) = ws_op_tx.try_send(WSOp::WSReport) {
735 error!("failed to send report request to ws handler: {}", e);
736 }
737 }
738 },
739 #[cfg(windows)]
740 Err(base::TubeError::Recv(e)) if e.kind() == std::io::ErrorKind::TimedOut => {
741 let _ = stop_rx.await;
748 return Ok(());
749 }
750 Err(e) => {
751 return Err(BalloonError::ReceivingCommand(e));
752 }
753 }
754 }
755}
756
757async fn handle_pending_adjusted_responses(
758 pending_adjusted_response_event: EventAsync,
759 command_tube: &AsyncTube,
760 state: Arc<AsyncRwLock<BalloonState>>,
761) -> Result<()> {
762 loop {
763 pending_adjusted_response_event
764 .next_val()
765 .await
766 .map_err(BalloonError::AsyncAwait)?;
767 while let Some(num_pages) = state.lock().await.pending_adjusted_responses.pop_front() {
768 send_adjusted_response(command_tube, num_pages)
769 .await
770 .map_err(BalloonError::SendResponse)?;
771 }
772 }
773}
774
775struct BalloonQueues {
777 inflate: Queue,
778 deflate: Queue,
779 stats: Option<Queue>,
780 reporting: Option<Queue>,
781 ws_data: Option<Queue>,
782 ws_op: Option<Queue>,
783}
784
785impl BalloonQueues {
786 fn new(inflate: Queue, deflate: Queue) -> Self {
787 BalloonQueues {
788 inflate,
789 deflate,
790 stats: None,
791 reporting: None,
792 ws_data: None,
793 ws_op: None,
794 }
795 }
796}
797
798struct PausedQueues {
800 inflate: Queue,
801 deflate: Queue,
802 stats: Option<Queue>,
803 reporting: Option<Queue>,
804 ws_data: Option<Queue>,
805 ws_op: Option<Queue>,
806}
807
808impl PausedQueues {
809 fn new(inflate: Queue, deflate: Queue) -> Self {
810 PausedQueues {
811 inflate,
812 deflate,
813 stats: None,
814 reporting: None,
815 ws_data: None,
816 ws_op: None,
817 }
818 }
819}
820
821fn apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F)
822where
823 F: FnMut(Queue) -> R,
824{
825 if let Some(queue) = queue_opt {
826 func(queue);
827 }
828}
829
830impl From<Box<PausedQueues>> for BTreeMap<usize, Queue> {
831 fn from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue> {
832 let mut ret = Vec::new();
833 ret.push(queues.inflate);
834 ret.push(queues.deflate);
835 apply_if_some(queues.stats, |stats| ret.push(stats));
836 apply_if_some(queues.reporting, |reporting| ret.push(reporting));
837 apply_if_some(queues.ws_data, |ws_data| ret.push(ws_data));
838 apply_if_some(queues.ws_op, |ws_op| ret.push(ws_op));
839 ret.into_iter().enumerate().collect()
842 }
843}
844
845fn free_memory(
846 vm_memory_client: &VmMemoryClient,
847 mem: &GuestMemory,
848 ranges: Vec<(GuestAddress, u64)>,
849) {
850 #[cfg(any(target_os = "android", target_os = "linux"))]
853 if mem.locked() && !mem.use_punchhole_locked() {
854 for (guest_address, len) in ranges {
855 if let Err(e) = mem.remove_range(guest_address, len) {
856 warn!("Marking pages unused failed: {}, addr={}", e, guest_address);
857 }
858 }
859 return;
860 }
861 if let Err(e) = vm_memory_client.dynamically_free_memory_ranges(ranges) {
862 warn!("Failed to dynamically free memory ranges: {e:#}");
863 }
864}
865
866fn reclaim_memory(vm_memory_client: &VmMemoryClient, ranges: Vec<(GuestAddress, u64)>) {
867 if let Err(e) = vm_memory_client.dynamically_reclaim_memory_ranges(ranges) {
868 warn!("Failed to dynamically reclaim memory range: {e:#}");
869 }
870}
871
872struct WorkerReturn {
875 release_memory_tube: Option<Tube>,
876 command_tube: Tube,
877 #[cfg(feature = "registered_events")]
878 registered_evt_q: Option<SendTube>,
879 paused_queues: Option<PausedQueues>,
880 vm_memory_client: VmMemoryClient,
881}
882
883fn run_worker(
886 inflate_queue: Queue,
887 deflate_queue: Queue,
888 stats_queue: Option<Queue>,
889 reporting_queue: Option<Queue>,
890 ws_data_queue: Option<Queue>,
891 ws_op_queue: Option<Queue>,
892 command_tube: Tube,
893 vm_memory_client: VmMemoryClient,
894 mem: GuestMemory,
895 release_memory_tube: Option<Tube>,
896 interrupt: Interrupt,
897 kill_evt: Event,
898 target_reached_evt: Event,
899 pending_adjusted_response_event: Event,
900 state: Arc<AsyncRwLock<BalloonState>>,
901 #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
902) -> WorkerReturn {
903 let ex = Executor::new().unwrap();
904 let command_tube = AsyncTube::new(&ex, command_tube).unwrap();
905 #[cfg(feature = "registered_events")]
906 let registered_evt_q_async = registered_evt_q
907 .as_ref()
908 .map(|q| SendTubeAsync::new(q.try_clone().unwrap(), &ex).unwrap());
909
910 let mut stop_queue_oneshots = Vec::new();
911
912 let paused_queues = {
914 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
916 let inflate_queue_evt = inflate_queue
917 .event()
918 .try_clone()
919 .expect("failed to clone queue event");
920 let inflate = handle_queue(
921 inflate_queue,
922 EventAsync::new(inflate_queue_evt, &ex).expect("failed to create async event"),
923 release_memory_tube.as_ref(),
924 |ranges| free_memory(&vm_memory_client, &mem, ranges),
925 stop_rx,
926 );
927 let inflate = inflate.fuse();
928 pin_mut!(inflate);
929
930 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
932 let deflate_queue_evt = deflate_queue
933 .event()
934 .try_clone()
935 .expect("failed to clone queue event");
936 let deflate = handle_queue(
937 deflate_queue,
938 EventAsync::new(deflate_queue_evt, &ex).expect("failed to create async event"),
939 None,
940 |ranges| reclaim_memory(&vm_memory_client, ranges),
941 stop_rx,
942 );
943 let deflate = deflate.fuse();
944 pin_mut!(deflate);
945
946 let (stats_tx, stats_rx) = mpsc::channel::<()>(1);
948 let has_stats_queue = stats_queue.is_some();
949 let stats = if let Some(stats_queue) = stats_queue {
950 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
951 let stats_queue_evt = stats_queue
952 .event()
953 .try_clone()
954 .expect("failed to clone queue event");
955 handle_stats_queue(
956 stats_queue,
957 EventAsync::new(stats_queue_evt, &ex).expect("failed to create async event"),
958 stats_rx,
959 &command_tube,
960 #[cfg(feature = "registered_events")]
961 registered_evt_q_async.as_ref(),
962 state.clone(),
963 stop_rx,
964 )
965 .left_future()
966 } else {
967 std::future::pending().right_future()
968 };
969 let stats = stats.fuse();
970 pin_mut!(stats);
971
972 let has_reporting_queue = reporting_queue.is_some();
974 let reporting = if let Some(reporting_queue) = reporting_queue {
975 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
976 let reporting_queue_evt = reporting_queue
977 .event()
978 .try_clone()
979 .expect("failed to clone queue event");
980 handle_reporting_queue(
981 reporting_queue,
982 EventAsync::new(reporting_queue_evt, &ex).expect("failed to create async event"),
983 release_memory_tube.as_ref(),
984 |ranges| free_memory(&vm_memory_client, &mem, ranges),
985 stop_rx,
986 )
987 .left_future()
988 } else {
989 std::future::pending().right_future()
990 };
991 let reporting = reporting.fuse();
992 pin_mut!(reporting);
993
994 let has_ws_data_queue = ws_data_queue.is_some();
997 let ws_data = if let Some(ws_data_queue) = ws_data_queue {
998 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
999 let ws_data_queue_evt = ws_data_queue
1000 .event()
1001 .try_clone()
1002 .expect("failed to clone queue event");
1003 handle_ws_data_queue(
1004 ws_data_queue,
1005 EventAsync::new(ws_data_queue_evt, &ex).expect("failed to create async event"),
1006 &command_tube,
1007 #[cfg(feature = "registered_events")]
1008 registered_evt_q_async.as_ref(),
1009 state.clone(),
1010 stop_rx,
1011 )
1012 .left_future()
1013 } else {
1014 std::future::pending().right_future()
1015 };
1016 let ws_data = ws_data.fuse();
1017 pin_mut!(ws_data);
1018
1019 let (ws_op_tx, ws_op_rx) = mpsc::channel::<WSOp>(1);
1020 let has_ws_op_queue = ws_op_queue.is_some();
1021 let ws_op = if let Some(ws_op_queue) = ws_op_queue {
1022 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1023 let ws_op_queue_evt = ws_op_queue
1024 .event()
1025 .try_clone()
1026 .expect("failed to clone queue event");
1027 handle_ws_op_queue(
1028 ws_op_queue,
1029 EventAsync::new(ws_op_queue_evt, &ex).expect("failed to create async event"),
1030 ws_op_rx,
1031 state.clone(),
1032 stop_rx,
1033 )
1034 .left_future()
1035 } else {
1036 std::future::pending().right_future()
1037 };
1038 let ws_op = ws_op.fuse();
1039 pin_mut!(ws_op);
1040
1041 let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1043 let command = handle_command_tube(
1044 &command_tube,
1045 interrupt.clone(),
1046 state.clone(),
1047 stats_tx,
1048 ws_op_tx,
1049 stop_rx,
1050 );
1051 pin_mut!(command);
1052
1053 let target_reached = handle_target_reached(&ex, target_reached_evt, &vm_memory_client);
1055 pin_mut!(target_reached);
1056
1057 let kill = async_utils::await_and_exit(&ex, kill_evt);
1059 pin_mut!(kill);
1060
1061 let pending_adjusted = handle_pending_adjusted_responses(
1062 EventAsync::new(pending_adjusted_response_event, &ex)
1063 .expect("failed to create async event"),
1064 &command_tube,
1065 state,
1066 );
1067 pin_mut!(pending_adjusted);
1068
1069 let res = ex.run_until(async {
1070 select! {
1071 _ = kill.fuse() => (),
1072 _ = inflate => return Err(anyhow!("inflate stopped unexpectedly")),
1073 _ = deflate => return Err(anyhow!("deflate stopped unexpectedly")),
1074 _ = stats => return Err(anyhow!("stats stopped unexpectedly")),
1075 _ = reporting => return Err(anyhow!("reporting stopped unexpectedly")),
1076 _ = command.fuse() => return Err(anyhow!("command stopped unexpectedly")),
1077 _ = ws_op => return Err(anyhow!("ws_op stopped unexpectedly")),
1078 _ = pending_adjusted.fuse() => return Err(anyhow!("pending_adjusted stopped unexpectedly")),
1079 _ = ws_data => return Err(anyhow!("ws_data stopped unexpectedly")),
1080 _ = target_reached.fuse() => return Err(anyhow!("target_reached stopped unexpectedly")),
1081 }
1082
1083 for stop_tx in stop_queue_oneshots {
1086 if stop_tx.send(()).is_err() {
1087 return Err(anyhow!("failed to request stop for queue future"));
1088 }
1089 }
1090
1091 let mut paused_queues = PausedQueues::new(
1094 inflate.await,
1095 deflate.await,
1096 );
1097 if has_reporting_queue {
1098 paused_queues.reporting = Some(reporting.await);
1099 }
1100 if has_stats_queue {
1101 paused_queues.stats = Some(stats.await);
1102 }
1103 if has_ws_data_queue {
1104 paused_queues.ws_data = Some(ws_data.await.context("failed to stop ws_data queue")?);
1105 }
1106 if has_ws_op_queue {
1107 paused_queues.ws_op = Some(ws_op.await.context("failed to stop ws_op queue")?);
1108 }
1109 Ok(paused_queues)
1110 });
1111
1112 match res {
1113 Err(e) => {
1114 error!("error happened in executor: {}", e);
1115 None
1116 }
1117 Ok(main_future_res) => match main_future_res {
1118 Ok(paused_queues) => Some(paused_queues),
1119 Err(e) => {
1120 error!("error happened in main balloon future: {}", e);
1121 None
1122 }
1123 },
1124 }
1125 };
1126
1127 WorkerReturn {
1128 command_tube: command_tube.into(),
1129 paused_queues,
1130 release_memory_tube,
1131 #[cfg(feature = "registered_events")]
1132 registered_evt_q,
1133 vm_memory_client,
1134 }
1135}
1136
1137async fn handle_target_reached(
1138 ex: &Executor,
1139 target_reached_evt: Event,
1140 vm_memory_client: &VmMemoryClient,
1141) -> anyhow::Result<()> {
1142 let event_async =
1143 EventAsync::new(target_reached_evt, ex).context("failed to create EventAsync")?;
1144 loop {
1145 let _ = event_async.next_val().await;
1147 if let Err(e) = vm_memory_client.balloon_target_reached(0) {
1150 warn!("Failed to send or receive allocation complete request: {e:#}");
1151 }
1152 }
1153 #[allow(unreachable_code)]
1157 Ok(())
1158}
1159
1160pub struct Balloon {
1162 command_tube: Option<Tube>,
1163 vm_memory_client: Option<VmMemoryClient>,
1164 release_memory_tube: Option<Tube>,
1165 pending_adjusted_response_event: Event,
1166 state: Arc<AsyncRwLock<BalloonState>>,
1167 features: u64,
1168 acked_features: u64,
1169 worker_thread: Option<WorkerThread<WorkerReturn>>,
1170 #[cfg(feature = "registered_events")]
1171 registered_evt_q: Option<SendTube>,
1172 ws_num_bins: u8,
1173 target_reached_evt: Option<Event>,
1174 queue_sizes: Vec<u16>,
1175}
1176
1177#[derive(Serialize, Deserialize)]
1179struct BalloonSnapshot {
1180 state: BalloonState,
1181 features: u64,
1182 acked_features: u64,
1183 ws_num_bins: u8,
1184}
1185
1186impl Balloon {
1187 pub fn new(
1193 base_features: u64,
1194 command_tube: Tube,
1195 vm_memory_client: VmMemoryClient,
1196 release_memory_tube: Option<Tube>,
1197 init_balloon_size: u64,
1198 enabled_features: u64,
1199 #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
1200 ws_num_bins: u8,
1201 ) -> Result<Balloon> {
1202 let features = base_features
1203 | 1 << VIRTIO_BALLOON_F_MUST_TELL_HOST
1204 | 1 << VIRTIO_BALLOON_F_STATS_VQ
1205 | 1 << VIRTIO_BALLOON_F_DEFLATE_ON_OOM
1206 | enabled_features;
1207
1208 let mut queue_sizes = Vec::new();
1209 queue_sizes.push(QUEUE_SIZE); queue_sizes.push(QUEUE_SIZE); if features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1212 queue_sizes.push(QUEUE_SIZE); }
1214 if features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1215 queue_sizes.push(QUEUE_SIZE); }
1217 if features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1218 queue_sizes.push(QUEUE_SIZE); queue_sizes.push(QUEUE_SIZE); }
1221
1222 Ok(Balloon {
1223 command_tube: Some(command_tube),
1224 vm_memory_client: Some(vm_memory_client),
1225 release_memory_tube,
1226 pending_adjusted_response_event: Event::new().map_err(BalloonError::CreatingEvent)?,
1227 state: Arc::new(AsyncRwLock::new(BalloonState {
1228 num_pages: (init_balloon_size >> VIRTIO_BALLOON_PFN_SHIFT) as u32,
1229 actual_pages: 0,
1230 failable_update: false,
1231 pending_adjusted_responses: VecDeque::new(),
1232 expecting_ws: false,
1233 })),
1234 worker_thread: None,
1235 features,
1236 acked_features: 0,
1237 #[cfg(feature = "registered_events")]
1238 registered_evt_q,
1239 ws_num_bins,
1240 target_reached_evt: None,
1241 queue_sizes,
1242 })
1243 }
1244
1245 fn get_config(&self) -> virtio_balloon_config {
1246 let state = block_on(self.state.lock());
1247 virtio_balloon_config {
1248 num_pages: state.num_pages.into(),
1249 actual: state.actual_pages.into(),
1250 free_page_hint_cmd_id: 0.into(),
1255 poison_val: 0.into(),
1256 ws_num_bins: self.ws_num_bins,
1257 _reserved: [0, 0, 0],
1258 }
1259 }
1260
1261 fn stop_worker(&mut self) -> StoppedWorker<PausedQueues> {
1262 if let Some(worker_thread) = self.worker_thread.take() {
1263 let worker_ret = worker_thread.stop();
1264 self.release_memory_tube = worker_ret.release_memory_tube;
1265 self.command_tube = Some(worker_ret.command_tube);
1266 #[cfg(feature = "registered_events")]
1267 {
1268 self.registered_evt_q = worker_ret.registered_evt_q;
1269 }
1270 self.vm_memory_client = Some(worker_ret.vm_memory_client);
1271
1272 if let Some(queues) = worker_ret.paused_queues {
1273 StoppedWorker::WithQueues(Box::new(queues))
1274 } else {
1275 StoppedWorker::MissingQueues
1276 }
1277 } else {
1278 StoppedWorker::AlreadyStopped
1279 }
1280 }
1281
1282 fn get_queues_from_map(
1286 &self,
1287 mut queues: BTreeMap<usize, Queue>,
1288 ) -> anyhow::Result<BalloonQueues> {
1289 let inflate_queue = queues.remove(&INFLATEQ).context("missing inflateq")?;
1290 let deflate_queue = queues.remove(&DEFLATEQ).context("missing deflateq")?;
1291 let mut queue_struct = BalloonQueues::new(inflate_queue, deflate_queue);
1292
1293 let mut next_queue_index = 2;
1295 let mut next_queue = || {
1296 let idx = next_queue_index;
1297 next_queue_index += 1;
1298 idx
1299 };
1300
1301 if self.features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1302 let statsq = next_queue();
1303 if self.acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1304 queue_struct.stats = Some(queues.remove(&statsq).context("missing statsq")?);
1305 }
1306 }
1307
1308 if self.features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1309 let reporting_vq = next_queue();
1310 if self.acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1311 queue_struct.reporting = Some(
1312 queues
1313 .remove(&reporting_vq)
1314 .context("missing reporting_vq")?,
1315 );
1316 }
1317 }
1318
1319 if self.features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1320 let ws_data_vq = next_queue();
1321 let ws_op_vq = next_queue();
1322 if self.acked_features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1323 queue_struct.ws_data =
1324 Some(queues.remove(&ws_data_vq).context("missing ws_data_vq")?);
1325 queue_struct.ws_op = Some(queues.remove(&ws_op_vq).context("missing ws_op_vq")?);
1326 }
1327 }
1328
1329 if !queues.is_empty() {
1330 return Err(anyhow!("unexpected queues {:?}", queues.into_keys()));
1331 }
1332
1333 Ok(queue_struct)
1334 }
1335
1336 fn start_worker(
1337 &mut self,
1338 mem: GuestMemory,
1339 interrupt: Interrupt,
1340 queues: BalloonQueues,
1341 ) -> anyhow::Result<()> {
1342 let (self_target_reached_evt, target_reached_evt) = Event::new()
1343 .and_then(|e| Ok((e.try_clone()?, e)))
1344 .context("failed to create target_reached Event pair: {}")?;
1345 self.target_reached_evt = Some(self_target_reached_evt);
1346
1347 let state = self.state.clone();
1348
1349 let command_tube = self.command_tube.take().unwrap();
1350
1351 let vm_memory_client = self.vm_memory_client.take().unwrap();
1352 let release_memory_tube = self.release_memory_tube.take();
1353 #[cfg(feature = "registered_events")]
1354 let registered_evt_q = self.registered_evt_q.take();
1355 let pending_adjusted_response_event = self
1356 .pending_adjusted_response_event
1357 .try_clone()
1358 .context("failed to clone Event")?;
1359
1360 self.worker_thread = Some(WorkerThread::start("v_balloon", move |kill_evt| {
1361 run_worker(
1362 queues.inflate,
1363 queues.deflate,
1364 queues.stats,
1365 queues.reporting,
1366 queues.ws_data,
1367 queues.ws_op,
1368 command_tube,
1369 vm_memory_client,
1370 mem,
1371 release_memory_tube,
1372 interrupt,
1373 kill_evt,
1374 target_reached_evt,
1375 pending_adjusted_response_event,
1376 state,
1377 #[cfg(feature = "registered_events")]
1378 registered_evt_q,
1379 )
1380 }));
1381
1382 Ok(())
1383 }
1384}
1385
1386impl VirtioDevice for Balloon {
1387 fn keep_rds(&self) -> Vec<RawDescriptor> {
1388 let mut rds = Vec::new();
1389 if let Some(command_tube) = &self.command_tube {
1390 rds.push(command_tube.as_raw_descriptor());
1391 }
1392 if let Some(vm_memory_client) = &self.vm_memory_client {
1393 rds.push(vm_memory_client.as_raw_descriptor());
1394 }
1395 if let Some(release_memory_tube) = &self.release_memory_tube {
1396 rds.push(release_memory_tube.as_raw_descriptor());
1397 }
1398 #[cfg(feature = "registered_events")]
1399 if let Some(registered_evt_q) = &self.registered_evt_q {
1400 rds.push(registered_evt_q.as_raw_descriptor());
1401 }
1402 rds.push(self.pending_adjusted_response_event.as_raw_descriptor());
1403 rds
1404 }
1405
1406 fn device_type(&self) -> DeviceType {
1407 DeviceType::Balloon
1408 }
1409
1410 fn queue_max_sizes(&self) -> &[u16] {
1411 &self.queue_sizes
1412 }
1413
1414 fn read_config(&self, offset: u64, data: &mut [u8]) {
1415 copy_config(data, 0, self.get_config().as_bytes(), offset);
1416 }
1417
1418 fn write_config(&mut self, offset: u64, data: &[u8]) {
1419 let mut config = self.get_config();
1420 copy_config(config.as_mut_bytes(), offset, data, 0);
1421 let mut state = block_on(self.state.lock());
1422 state.actual_pages = config.actual.to_native();
1423
1424 if config.num_pages == config.actual {
1426 debug!(
1427 "sending target reached event at {}",
1428 u32::from(config.num_pages)
1429 );
1430 self.target_reached_evt.as_ref().map(|e| e.signal());
1431 }
1432 if state.failable_update && state.actual_pages == state.num_pages {
1433 state.failable_update = false;
1434 let num_pages = state.num_pages;
1435 state.pending_adjusted_responses.push_back(num_pages);
1436 let _ = self.pending_adjusted_response_event.signal();
1437 }
1438 }
1439
1440 fn features(&self) -> u64 {
1441 self.features
1442 }
1443
1444 fn ack_features(&mut self, mut value: u64) {
1445 if value & !self.features != 0 {
1446 warn!("virtio_balloon got unknown feature ack {:x}", value);
1447 value &= self.features;
1448 }
1449 self.acked_features |= value;
1450 }
1451
1452 fn activate(
1453 &mut self,
1454 mem: GuestMemory,
1455 interrupt: Interrupt,
1456 queues: BTreeMap<usize, Queue>,
1457 ) -> anyhow::Result<()> {
1458 let queues = self.get_queues_from_map(queues)?;
1459 self.start_worker(mem, interrupt, queues)
1460 }
1461
1462 fn reset(&mut self) -> anyhow::Result<()> {
1463 let _worker = self.stop_worker();
1464 Ok(())
1465 }
1466
1467 fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
1468 match self.stop_worker() {
1469 StoppedWorker::WithQueues(paused_queues) => Ok(Some(paused_queues.into())),
1470 StoppedWorker::MissingQueues => {
1471 anyhow::bail!("balloon queue workers did not stop cleanly.")
1472 }
1473 StoppedWorker::AlreadyStopped => {
1474 Ok(None)
1476 }
1477 }
1478 }
1479
1480 fn virtio_wake(
1481 &mut self,
1482 queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
1483 ) -> anyhow::Result<()> {
1484 if let Some((mem, interrupt, queues)) = queues_state {
1485 if queues.len() < 2 {
1486 anyhow::bail!("{} queues were found, but an activated balloon must have at least 2 active queues.", queues.len());
1487 }
1488
1489 let balloon_queues = self.get_queues_from_map(queues)?;
1490 self.start_worker(mem, interrupt, balloon_queues)?;
1491 }
1492 Ok(())
1493 }
1494
1495 fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1496 let state = self
1497 .state
1498 .lock()
1499 .now_or_never()
1500 .context("failed to acquire balloon lock")?;
1501 AnySnapshot::to_any(BalloonSnapshot {
1502 features: self.features,
1503 acked_features: self.acked_features,
1504 state: state.clone(),
1505 ws_num_bins: self.ws_num_bins,
1506 })
1507 .context("failed to serialize balloon state")
1508 }
1509
1510 fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1511 let snap: BalloonSnapshot = AnySnapshot::from_any(data).context("error deserializing")?;
1512 if snap.features != self.features {
1513 anyhow::bail!(
1514 "balloon: expected features to match, but they did not. Live: {:?}, snapshot {:?}",
1515 self.features,
1516 snap.features,
1517 );
1518 }
1519
1520 let mut state = self
1521 .state
1522 .lock()
1523 .now_or_never()
1524 .context("failed to acquire balloon lock")?;
1525 *state = snap.state;
1526 self.ws_num_bins = snap.ws_num_bins;
1527 self.acked_features = snap.acked_features;
1528 Ok(())
1529 }
1530}
1531
1532#[cfg(test)]
1533mod tests {
1534 use super::*;
1535 use crate::suspendable_virtio_tests;
1536 use crate::virtio::descriptor_utils::create_descriptor_chain;
1537 use crate::virtio::descriptor_utils::DescriptorType;
1538
1539 #[test]
1540 fn desc_parsing_inflate() {
1541 let memory_start_addr = GuestAddress(0x0);
1544 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1545 memory
1546 .write_obj_at_addr(0x10u32, GuestAddress(0x100))
1547 .unwrap();
1548 memory
1549 .write_obj_at_addr(0xaa55aa55u32, GuestAddress(0x104))
1550 .unwrap();
1551
1552 let mut chain = create_descriptor_chain(
1553 &memory,
1554 GuestAddress(0x0),
1555 GuestAddress(0x100),
1556 vec![(DescriptorType::Readable, 8)],
1557 0,
1558 )
1559 .expect("create_descriptor_chain failed");
1560
1561 let mut addrs = Vec::new();
1562 let res = handle_address_chain(None, &mut chain, &mut |mut ranges| {
1563 addrs.append(&mut ranges)
1564 });
1565 assert!(res.is_ok());
1566 assert_eq!(addrs.len(), 2);
1567 assert_eq!(
1568 addrs[0].0,
1569 GuestAddress(0x10u64 << VIRTIO_BALLOON_PFN_SHIFT)
1570 );
1571 assert_eq!(
1572 addrs[1].0,
1573 GuestAddress(0xaa55aa55u64 << VIRTIO_BALLOON_PFN_SHIFT)
1574 );
1575 }
1576
1577 struct BalloonContext {
1578 _ctrl_tube: Tube,
1579 _mem_client_tube: Tube,
1580 }
1581
1582 fn modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon) {
1583 balloon.ws_num_bins = !balloon.ws_num_bins;
1584 }
1585
1586 fn create_device() -> (BalloonContext, Balloon) {
1587 let (_ctrl_tube, ctrl_tube_device) = Tube::pair().unwrap();
1588 let (_mem_client_tube, mem_client_tube_device) = Tube::pair().unwrap();
1589 (
1590 BalloonContext {
1591 _ctrl_tube,
1592 _mem_client_tube,
1593 },
1594 Balloon::new(
1595 0,
1596 ctrl_tube_device,
1597 VmMemoryClient::new(mem_client_tube_device),
1598 None,
1599 1024,
1600 0,
1601 #[cfg(feature = "registered_events")]
1602 None,
1603 0,
1604 )
1605 .unwrap(),
1606 )
1607 }
1608
1609 suspendable_virtio_tests!(balloon, create_device, 2, modify_device);
1610}