devices/virtio/
balloon.rs

1// Copyright 2017 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::collections::BTreeMap;
6use std::collections::VecDeque;
7use std::io::Write;
8use std::sync::Arc;
9
10use anyhow::anyhow;
11use anyhow::Context;
12use balloon_control::BalloonStats;
13use balloon_control::BalloonTubeCommand;
14use balloon_control::BalloonTubeResult;
15use balloon_control::BalloonWS;
16use balloon_control::WSBucket;
17use balloon_control::VIRTIO_BALLOON_WS_MAX_NUM_BINS;
18use balloon_control::VIRTIO_BALLOON_WS_MIN_NUM_BINS;
19use base::debug;
20use base::error;
21use base::warn;
22use base::AsRawDescriptor;
23use base::Event;
24use base::RawDescriptor;
25#[cfg(feature = "registered_events")]
26use base::SendTube;
27use base::Tube;
28use base::WorkerThread;
29use cros_async::block_on;
30use cros_async::sync::RwLock as AsyncRwLock;
31use cros_async::AsyncTube;
32use cros_async::EventAsync;
33use cros_async::Executor;
34#[cfg(feature = "registered_events")]
35use cros_async::SendTubeAsync;
36use data_model::Le16;
37use data_model::Le32;
38use data_model::Le64;
39use futures::channel::mpsc;
40use futures::channel::oneshot;
41use futures::pin_mut;
42use futures::select;
43use futures::select_biased;
44use futures::FutureExt;
45use futures::StreamExt;
46use remain::sorted;
47use serde::Deserialize;
48use serde::Serialize;
49use snapshot::AnySnapshot;
50use thiserror::Error as ThisError;
51use vm_control::api::VmMemoryClient;
52#[cfg(feature = "registered_events")]
53use vm_control::RegisteredEventWithData;
54use vm_memory::GuestAddress;
55use vm_memory::GuestMemory;
56use zerocopy::FromBytes;
57use zerocopy::Immutable;
58use zerocopy::IntoBytes;
59use zerocopy::KnownLayout;
60
61use super::async_utils;
62use super::copy_config;
63use super::create_stop_oneshot;
64use super::DescriptorChain;
65use super::DeviceType;
66use super::Interrupt;
67use super::Queue;
68use super::Reader;
69use super::StoppedWorker;
70use super::VirtioDevice;
71use crate::UnpinRequest;
72use crate::UnpinResponse;
73
74#[sorted]
75#[derive(ThisError, Debug)]
76pub enum BalloonError {
77    /// Failed an async await
78    #[error("failed async await: {0}")]
79    AsyncAwait(cros_async::AsyncError),
80    /// Failed an async await
81    #[error("failed async await: {0}")]
82    AsyncAwaitAnyhow(anyhow::Error),
83    /// Failed to create event.
84    #[error("failed to create event: {0}")]
85    CreatingEvent(base::Error),
86    /// Failed to create async message receiver.
87    #[error("failed to create async message receiver: {0}")]
88    CreatingMessageReceiver(base::TubeError),
89    /// Failed to receive command message.
90    #[error("failed to receive command message: {0}")]
91    ReceivingCommand(base::TubeError),
92    /// Failed to send command response.
93    #[error("failed to send command response: {0}")]
94    SendResponse(base::TubeError),
95    /// Error while writing to virtqueue
96    #[error("failed to write to virtqueue: {0}")]
97    WriteQueue(std::io::Error),
98    /// Failed to write config event.
99    #[error("failed to write config event: {0}")]
100    WritingConfigEvent(base::Error),
101}
102pub type Result<T> = std::result::Result<T, BalloonError>;
103
104const QUEUE_SIZE: u16 = 128;
105
106// Virtqueue indexes that do not depend on advertised features
107const INFLATEQ: usize = 0;
108const DEFLATEQ: usize = 1;
109
110const VIRTIO_BALLOON_PFN_SHIFT: u32 = 12;
111const VIRTIO_BALLOON_PF_SIZE: u64 = 1 << VIRTIO_BALLOON_PFN_SHIFT;
112
113// The feature bitmap for virtio balloon
114const VIRTIO_BALLOON_F_MUST_TELL_HOST: u32 = 0; // Tell before reclaiming pages
115const VIRTIO_BALLOON_F_STATS_VQ: u32 = 1; // Stats reporting enabled
116const VIRTIO_BALLOON_F_DEFLATE_ON_OOM: u32 = 2; // Deflate balloon on OOM
117const VIRTIO_BALLOON_F_PAGE_REPORTING: u32 = 5; // Page reporting virtqueue
118                                                // TODO(b/273973298): this should maybe be bit 6? to be changed later
119const VIRTIO_BALLOON_F_WS_REPORTING: u32 = 8; // Working Set Reporting virtqueues
120
121#[derive(Copy, Clone)]
122#[repr(u32)]
123// Balloon virtqueues
124pub enum BalloonFeatures {
125    // Page Reporting enabled
126    PageReporting = VIRTIO_BALLOON_F_PAGE_REPORTING,
127    // WS Reporting enabled
128    WSReporting = VIRTIO_BALLOON_F_WS_REPORTING,
129}
130
131// virtio_balloon_config is the balloon device configuration space defined by the virtio spec.
132#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
133#[repr(C)]
134struct virtio_balloon_config {
135    num_pages: Le32,
136    actual: Le32,
137    free_page_hint_cmd_id: Le32,
138    poison_val: Le32,
139    // WS field is part of proposed spec extension (b/273973298).
140    ws_num_bins: u8,
141    _reserved: [u8; 3],
142}
143
144// BalloonState is shared by the worker and device thread.
145#[derive(Clone, Default, Serialize, Deserialize)]
146struct BalloonState {
147    num_pages: u32,
148    actual_pages: u32,
149    expecting_ws: bool,
150    // Flag indicating that the balloon is in the process of a failable update. This
151    // is set by an Adjust command that has allow_failure set, and is cleared when the
152    // Adjusted success/failure response is sent.
153    failable_update: bool,
154    pending_adjusted_responses: VecDeque<u32>,
155}
156
157// The constants defining stats types in virtio_baloon_stat
158const VIRTIO_BALLOON_S_SWAP_IN: u16 = 0;
159const VIRTIO_BALLOON_S_SWAP_OUT: u16 = 1;
160const VIRTIO_BALLOON_S_MAJFLT: u16 = 2;
161const VIRTIO_BALLOON_S_MINFLT: u16 = 3;
162const VIRTIO_BALLOON_S_MEMFREE: u16 = 4;
163const VIRTIO_BALLOON_S_MEMTOT: u16 = 5;
164const VIRTIO_BALLOON_S_AVAIL: u16 = 6;
165const VIRTIO_BALLOON_S_CACHES: u16 = 7;
166const VIRTIO_BALLOON_S_HTLB_PGALLOC: u16 = 8;
167const VIRTIO_BALLOON_S_HTLB_PGFAIL: u16 = 9;
168const VIRTIO_BALLOON_S_NONSTANDARD_SHMEM: u16 = 65534;
169const VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE: u16 = 65535;
170
171// BalloonStat is used to deserialize stats from the stats_queue.
172#[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
173#[repr(C, packed)]
174struct BalloonStat {
175    tag: Le16,
176    val: Le64,
177}
178
179impl BalloonStat {
180    fn update_stats(&self, stats: &mut BalloonStats) {
181        let val = Some(self.val.to_native());
182        match self.tag.to_native() {
183            VIRTIO_BALLOON_S_SWAP_IN => stats.swap_in = val,
184            VIRTIO_BALLOON_S_SWAP_OUT => stats.swap_out = val,
185            VIRTIO_BALLOON_S_MAJFLT => stats.major_faults = val,
186            VIRTIO_BALLOON_S_MINFLT => stats.minor_faults = val,
187            VIRTIO_BALLOON_S_MEMFREE => stats.free_memory = val,
188            VIRTIO_BALLOON_S_MEMTOT => stats.total_memory = val,
189            VIRTIO_BALLOON_S_AVAIL => stats.available_memory = val,
190            VIRTIO_BALLOON_S_CACHES => stats.disk_caches = val,
191            VIRTIO_BALLOON_S_HTLB_PGALLOC => stats.hugetlb_allocations = val,
192            VIRTIO_BALLOON_S_HTLB_PGFAIL => stats.hugetlb_failures = val,
193            VIRTIO_BALLOON_S_NONSTANDARD_SHMEM => stats.shared_memory = val,
194            VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE => stats.unevictable_memory = val,
195            _ => (),
196        }
197    }
198}
199
200// virtio_balloon_ws is used to deserialize from the ws data vq.
201#[repr(C)]
202#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
203struct virtio_balloon_ws {
204    tag: Le16,
205    node_id: Le16,
206    // virtio prefers field members to align on a word boundary so we must pad. see:
207    // https://crsrc.org/o/src/third_party/kernel/v5.15/include/uapi/linux/virtio_balloon.h;l=105
208    _reserved: [u8; 4],
209    idle_age_ms: Le64,
210    // TODO(b/273973298): these should become separate fields - bytes for ANON and FILE
211    memory_size_bytes: [Le64; 2],
212}
213
214impl virtio_balloon_ws {
215    fn update_ws(&self, ws: &mut BalloonWS) {
216        let bucket = WSBucket {
217            age: self.idle_age_ms.to_native(),
218            bytes: [
219                self.memory_size_bytes[0].to_native(),
220                self.memory_size_bytes[1].to_native(),
221            ],
222        };
223        ws.ws.push(bucket);
224    }
225}
226
227const _VIRTIO_BALLOON_WS_OP_INVALID: u16 = 0;
228const VIRTIO_BALLOON_WS_OP_REQUEST: u16 = 1;
229const VIRTIO_BALLOON_WS_OP_CONFIG: u16 = 2;
230const _VIRTIO_BALLOON_WS_OP_DISCARD: u16 = 3;
231
232// virtio_balloon_op is used to serialize to the ws cmd vq.
233#[repr(C, packed)]
234#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
235struct virtio_balloon_op {
236    type_: Le16,
237}
238
239fn invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F)
240where
241    F: FnMut(Vec<(GuestAddress, u64)>),
242{
243    desc_handler(
244        ranges
245            .into_iter()
246            .map(|range| (GuestAddress(range.0), range.1))
247            .collect(),
248    );
249}
250
251// Release a list of guest memory ranges back to the host system.
252// Unpin requests for each inflate range will be sent via `release_memory_tube`
253// if provided, and then `desc_handler` will be called for each inflate range.
254fn release_ranges<F>(
255    release_memory_tube: Option<&Tube>,
256    inflate_ranges: Vec<(u64, u64)>,
257    desc_handler: &mut F,
258) -> anyhow::Result<()>
259where
260    F: FnMut(Vec<(GuestAddress, u64)>),
261{
262    if let Some(tube) = release_memory_tube {
263        let unpin_ranges = inflate_ranges
264            .iter()
265            .map(|v| {
266                (
267                    v.0 >> VIRTIO_BALLOON_PFN_SHIFT,
268                    v.1 / VIRTIO_BALLOON_PF_SIZE,
269                )
270            })
271            .collect();
272        let req = UnpinRequest {
273            ranges: unpin_ranges,
274        };
275        if let Err(e) = tube.send(&req) {
276            error!("failed to send unpin request: {}", e);
277        } else {
278            match tube.recv() {
279                Ok(resp) => match resp {
280                    UnpinResponse::Success => invoke_desc_handler(inflate_ranges, desc_handler),
281                    UnpinResponse::Failed => error!("failed to handle unpin request"),
282                },
283                Err(e) => error!("failed to handle get unpin response: {}", e),
284            }
285        }
286    } else {
287        invoke_desc_handler(inflate_ranges, desc_handler);
288    }
289
290    Ok(())
291}
292
293// Processes one message's list of addresses.
294fn handle_address_chain<F>(
295    release_memory_tube: Option<&Tube>,
296    avail_desc: &mut DescriptorChain,
297    desc_handler: &mut F,
298) -> anyhow::Result<()>
299where
300    F: FnMut(Vec<(GuestAddress, u64)>),
301{
302    // In a long-running system, there is no reason to expect that
303    // a significant number of freed pages are consecutive. However,
304    // batching is relatively simple and can result in significant
305    // gains in a newly booted system, so it's worth attempting.
306    let mut range_start = 0;
307    let mut range_size = 0;
308    let mut inflate_ranges: Vec<(u64, u64)> = Vec::new();
309    for res in avail_desc.reader.iter::<Le32>() {
310        let pfn = match res {
311            Ok(pfn) => pfn,
312            Err(e) => {
313                error!("error while reading unused pages: {}", e);
314                break;
315            }
316        };
317        let guest_address = (u64::from(pfn.to_native())) << VIRTIO_BALLOON_PFN_SHIFT;
318        if range_start + range_size == guest_address {
319            range_size += VIRTIO_BALLOON_PF_SIZE;
320        } else if range_start == guest_address + VIRTIO_BALLOON_PF_SIZE {
321            range_start = guest_address;
322            range_size += VIRTIO_BALLOON_PF_SIZE;
323        } else {
324            // Discontinuity, so flush the previous range. Note range_size
325            // will be 0 on the first iteration, so skip that.
326            if range_size != 0 {
327                inflate_ranges.push((range_start, range_size));
328            }
329            range_start = guest_address;
330            range_size = VIRTIO_BALLOON_PF_SIZE;
331        }
332    }
333    if range_size != 0 {
334        inflate_ranges.push((range_start, range_size));
335    }
336
337    release_ranges(release_memory_tube, inflate_ranges, desc_handler)
338}
339
340// Async task that handles the main balloon inflate and deflate queues.
341async fn handle_queue<F>(
342    mut queue: Queue,
343    mut queue_event: EventAsync,
344    release_memory_tube: Option<&Tube>,
345    mut desc_handler: F,
346    mut stop_rx: oneshot::Receiver<()>,
347) -> Queue
348where
349    F: FnMut(Vec<(GuestAddress, u64)>),
350{
351    loop {
352        let mut avail_desc = match queue
353            .next_async_interruptable(&mut queue_event, &mut stop_rx)
354            .await
355        {
356            Ok(Some(res)) => res,
357            Ok(None) => return queue,
358            Err(e) => {
359                error!("Failed to read descriptor {}", e);
360                return queue;
361            }
362        };
363        if let Err(e) =
364            handle_address_chain(release_memory_tube, &mut avail_desc, &mut desc_handler)
365        {
366            error!("balloon: failed to process inflate addresses: {}", e);
367        }
368        queue.add_used(avail_desc);
369        queue.trigger_interrupt();
370    }
371}
372
373// Processes one page-reporting descriptor.
374fn handle_reported_buffer<F>(
375    release_memory_tube: Option<&Tube>,
376    avail_desc: &DescriptorChain,
377    desc_handler: &mut F,
378) -> anyhow::Result<()>
379where
380    F: FnMut(Vec<(GuestAddress, u64)>),
381{
382    let reported_ranges: Vec<(u64, u64)> = avail_desc
383        .reader
384        .get_remaining_regions()
385        .chain(avail_desc.writer.get_remaining_regions())
386        .map(|r| (r.offset, r.len as u64))
387        .collect();
388
389    release_ranges(release_memory_tube, reported_ranges, desc_handler)
390}
391
392// Async task that handles the page reporting queue.
393async fn handle_reporting_queue<F>(
394    mut queue: Queue,
395    mut queue_event: EventAsync,
396    release_memory_tube: Option<&Tube>,
397    mut desc_handler: F,
398    mut stop_rx: oneshot::Receiver<()>,
399) -> Queue
400where
401    F: FnMut(Vec<(GuestAddress, u64)>),
402{
403    loop {
404        let avail_desc = match queue
405            .next_async_interruptable(&mut queue_event, &mut stop_rx)
406            .await
407        {
408            Ok(Some(res)) => res,
409            Ok(None) => return queue,
410            Err(e) => {
411                error!("Failed to read descriptor {}", e);
412                return queue;
413            }
414        };
415        if let Err(e) = handle_reported_buffer(release_memory_tube, &avail_desc, &mut desc_handler)
416        {
417            error!("balloon: failed to process reported buffer: {}", e);
418        }
419        queue.add_used(avail_desc);
420        queue.trigger_interrupt();
421    }
422}
423
424fn parse_balloon_stats(reader: &mut Reader) -> BalloonStats {
425    let mut stats: BalloonStats = Default::default();
426    for res in reader.iter::<BalloonStat>() {
427        match res {
428            Ok(stat) => stat.update_stats(&mut stats),
429            Err(e) => {
430                error!("error while reading stats: {}", e);
431                break;
432            }
433        };
434    }
435    stats
436}
437
438// Async task that handles the stats queue. Note that the cadence of this is driven by requests for
439// balloon stats from the control pipe.
440// The guests queues an initial buffer on boot, which is read and then this future will block until
441// signaled from the command socket that stats should be collected again.
442async fn handle_stats_queue(
443    mut queue: Queue,
444    mut queue_event: EventAsync,
445    mut stats_rx: mpsc::Receiver<()>,
446    command_tube: &AsyncTube,
447    #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
448    state: Arc<AsyncRwLock<BalloonState>>,
449    mut stop_rx: oneshot::Receiver<()>,
450) -> Queue {
451    let mut avail_desc = match queue
452        .next_async_interruptable(&mut queue_event, &mut stop_rx)
453        .await
454    {
455        // Consume the first stats buffer sent from the guest at startup. It was not
456        // requested by anyone, and the stats are stale.
457        Ok(Some(res)) => res,
458        Ok(None) => return queue,
459        Err(e) => {
460            error!("Failed to read descriptor {}", e);
461            return queue;
462        }
463    };
464
465    loop {
466        select_biased! {
467            msg = stats_rx.next() => {
468                // Wait for a request to read the stats.
469                match msg {
470                    Some(()) => (),
471                    None => {
472                        error!("stats signal channel was closed");
473                        return queue;
474                    }
475                }
476            }
477            _ = stop_rx => return queue,
478        };
479
480        // Request a new stats_desc to the guest.
481        queue.add_used(avail_desc);
482        queue.trigger_interrupt();
483
484        avail_desc = match queue.next_async(&mut queue_event).await {
485            Err(e) => {
486                error!("Failed to read descriptor {}", e);
487                return queue;
488            }
489            Ok(d) => d,
490        };
491        let stats = parse_balloon_stats(&mut avail_desc.reader);
492
493        let actual_pages = state.lock().await.actual_pages as u64;
494        let result = BalloonTubeResult::Stats {
495            balloon_actual: actual_pages << VIRTIO_BALLOON_PFN_SHIFT,
496            stats,
497        };
498        let send_result = command_tube.send(result).await;
499        if let Err(e) = send_result {
500            error!("failed to send stats result: {}", e);
501        }
502
503        #[cfg(feature = "registered_events")]
504        if let Some(registered_evt_q) = registered_evt_q {
505            if let Err(e) = registered_evt_q
506                .send(&RegisteredEventWithData::VirtioBalloonResize)
507                .await
508            {
509                error!("failed to send VirtioBalloonResize event: {}", e);
510            }
511        }
512    }
513}
514
515async fn send_adjusted_response(
516    tube: &AsyncTube,
517    num_pages: u32,
518) -> std::result::Result<(), base::TubeError> {
519    let num_bytes = (num_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
520    let result = BalloonTubeResult::Adjusted { num_bytes };
521    tube.send(result).await
522}
523
524enum WSOp {
525    WSReport,
526    WSConfig {
527        bins: Vec<u32>,
528        refresh_threshold: u32,
529        report_threshold: u32,
530    },
531}
532
533async fn handle_ws_op_queue(
534    mut queue: Queue,
535    mut queue_event: EventAsync,
536    mut ws_op_rx: mpsc::Receiver<WSOp>,
537    state: Arc<AsyncRwLock<BalloonState>>,
538    mut stop_rx: oneshot::Receiver<()>,
539) -> Result<Queue> {
540    loop {
541        let op = select_biased! {
542            next_op = ws_op_rx.next().fuse() => {
543                match next_op {
544                    Some(op) => op,
545                    None => {
546                        error!("ws op tube was closed");
547                        break;
548                    }
549                }
550            }
551            _ = stop_rx => {
552                break;
553            }
554        };
555        let mut avail_desc = queue
556            .next_async(&mut queue_event)
557            .await
558            .map_err(BalloonError::AsyncAwait)?;
559        let writer = &mut avail_desc.writer;
560
561        match op {
562            WSOp::WSReport => {
563                {
564                    let mut state = state.lock().await;
565                    state.expecting_ws = true;
566                }
567
568                let ws_r = virtio_balloon_op {
569                    type_: VIRTIO_BALLOON_WS_OP_REQUEST.into(),
570                };
571
572                writer.write_obj(ws_r).map_err(BalloonError::WriteQueue)?;
573            }
574            WSOp::WSConfig {
575                bins,
576                refresh_threshold,
577                report_threshold,
578            } => {
579                let cmd = virtio_balloon_op {
580                    type_: VIRTIO_BALLOON_WS_OP_CONFIG.into(),
581                };
582
583                writer.write_obj(cmd).map_err(BalloonError::WriteQueue)?;
584                writer
585                    .write_all(bins.as_bytes())
586                    .map_err(BalloonError::WriteQueue)?;
587                writer
588                    .write_obj(refresh_threshold)
589                    .map_err(BalloonError::WriteQueue)?;
590                writer
591                    .write_obj(report_threshold)
592                    .map_err(BalloonError::WriteQueue)?;
593            }
594        }
595
596        queue.add_used(avail_desc);
597        queue.trigger_interrupt();
598    }
599
600    Ok(queue)
601}
602
603fn parse_balloon_ws(reader: &mut Reader) -> BalloonWS {
604    let mut ws = BalloonWS::new();
605    for res in reader.iter::<virtio_balloon_ws>() {
606        match res {
607            Ok(ws_msg) => {
608                ws_msg.update_ws(&mut ws);
609            }
610            Err(e) => {
611                error!("error while reading ws: {}", e);
612                break;
613            }
614        }
615    }
616    if ws.ws.len() < VIRTIO_BALLOON_WS_MIN_NUM_BINS || ws.ws.len() > VIRTIO_BALLOON_WS_MAX_NUM_BINS
617    {
618        error!("unexpected number of WS buckets: {}", ws.ws.len());
619    }
620    ws
621}
622
623// Async task that handles the stats queue. Note that the arrival of events on
624// the WS vq may be the result of either a WS request (WS-R) command having
625// been sent to the guest, or an unprompted send due to memory pressue in the
626// guest. If the data was requested, we should also send that back on the
627// command tube.
628async fn handle_ws_data_queue(
629    mut queue: Queue,
630    mut queue_event: EventAsync,
631    command_tube: &AsyncTube,
632    #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
633    state: Arc<AsyncRwLock<BalloonState>>,
634    mut stop_rx: oneshot::Receiver<()>,
635) -> Result<Queue> {
636    loop {
637        let mut avail_desc = match queue
638            .next_async_interruptable(&mut queue_event, &mut stop_rx)
639            .await
640            .map_err(BalloonError::AsyncAwait)?
641        {
642            Some(res) => res,
643            None => return Ok(queue),
644        };
645
646        let ws = parse_balloon_ws(&mut avail_desc.reader);
647
648        let mut state = state.lock().await;
649
650        // update ws report with balloon pages now that we have a lock on state
651        let balloon_actual = (state.actual_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
652
653        if state.expecting_ws {
654            let result = BalloonTubeResult::WorkingSet { ws, balloon_actual };
655            let send_result = command_tube.send(result).await;
656            if let Err(e) = send_result {
657                error!("failed to send ws result: {}", e);
658            }
659
660            state.expecting_ws = false;
661        } else {
662            #[cfg(feature = "registered_events")]
663            if let Some(registered_evt_q) = registered_evt_q {
664                if let Err(e) = registered_evt_q
665                    .send(RegisteredEventWithData::from_ws(&ws, balloon_actual))
666                    .await
667                {
668                    error!("failed to send VirtioBalloonWSReport event: {}", e);
669                }
670            }
671        }
672
673        queue.add_used(avail_desc);
674        queue.trigger_interrupt();
675    }
676}
677
678// Async task that handles the command socket. The command socket handles messages from the host
679// requesting that the guest balloon be adjusted or to report guest memory statistics.
680async fn handle_command_tube(
681    command_tube: &AsyncTube,
682    interrupt: Interrupt,
683    state: Arc<AsyncRwLock<BalloonState>>,
684    mut stats_tx: mpsc::Sender<()>,
685    mut ws_op_tx: mpsc::Sender<WSOp>,
686    mut stop_rx: oneshot::Receiver<()>,
687) -> Result<()> {
688    loop {
689        let cmd_res = select_biased! {
690            res = command_tube.next().fuse() => res,
691            _ = stop_rx => return Ok(())
692        };
693        match cmd_res {
694            Ok(command) => match command {
695                BalloonTubeCommand::Adjust {
696                    num_bytes,
697                    allow_failure,
698                } => {
699                    let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as u32;
700                    let mut state = state.lock().await;
701
702                    state.num_pages = num_pages;
703                    interrupt.signal_config_changed();
704
705                    if allow_failure {
706                        if num_pages == state.actual_pages {
707                            send_adjusted_response(command_tube, num_pages)
708                                .await
709                                .map_err(BalloonError::SendResponse)?;
710                        } else {
711                            state.failable_update = true;
712                        }
713                    }
714                }
715                BalloonTubeCommand::WorkingSetConfig {
716                    bins,
717                    refresh_threshold,
718                    report_threshold,
719                } => {
720                    if let Err(e) = ws_op_tx.try_send(WSOp::WSConfig {
721                        bins,
722                        refresh_threshold,
723                        report_threshold,
724                    }) {
725                        error!("failed to send config to ws handler: {}", e);
726                    }
727                }
728                BalloonTubeCommand::Stats => {
729                    if let Err(e) = stats_tx.try_send(()) {
730                        error!("failed to signal the stat handler: {}", e);
731                    }
732                }
733                BalloonTubeCommand::WorkingSet => {
734                    if let Err(e) = ws_op_tx.try_send(WSOp::WSReport) {
735                        error!("failed to send report request to ws handler: {}", e);
736                    }
737                }
738            },
739            #[cfg(windows)]
740            Err(base::TubeError::Recv(e)) if e.kind() == std::io::ErrorKind::TimedOut => {
741                // On Windows, async IO tasks like the next/recv above are cancelled as the VM is
742                // shutting down. For the sake of consistency with unix, we can't *just* return
743                // here; instead, we wait for the stop request to arrive, *and then* return.
744                //
745                // The real fix is to get rid of the global unblock pool, since then we won't
746                // cancel the tasks early (b/196911556).
747                let _ = stop_rx.await;
748                return Ok(());
749            }
750            Err(e) => {
751                return Err(BalloonError::ReceivingCommand(e));
752            }
753        }
754    }
755}
756
757async fn handle_pending_adjusted_responses(
758    pending_adjusted_response_event: EventAsync,
759    command_tube: &AsyncTube,
760    state: Arc<AsyncRwLock<BalloonState>>,
761) -> Result<()> {
762    loop {
763        pending_adjusted_response_event
764            .next_val()
765            .await
766            .map_err(BalloonError::AsyncAwait)?;
767        while let Some(num_pages) = state.lock().await.pending_adjusted_responses.pop_front() {
768            send_adjusted_response(command_tube, num_pages)
769                .await
770                .map_err(BalloonError::SendResponse)?;
771        }
772    }
773}
774
775/// Represents queues & events for the balloon device.
776struct BalloonQueues {
777    inflate: Queue,
778    deflate: Queue,
779    stats: Option<Queue>,
780    reporting: Option<Queue>,
781    ws_data: Option<Queue>,
782    ws_op: Option<Queue>,
783}
784
785impl BalloonQueues {
786    fn new(inflate: Queue, deflate: Queue) -> Self {
787        BalloonQueues {
788            inflate,
789            deflate,
790            stats: None,
791            reporting: None,
792            ws_data: None,
793            ws_op: None,
794        }
795    }
796}
797
798/// When the worker is stopped, the queues are preserved here.
799struct PausedQueues {
800    inflate: Queue,
801    deflate: Queue,
802    stats: Option<Queue>,
803    reporting: Option<Queue>,
804    ws_data: Option<Queue>,
805    ws_op: Option<Queue>,
806}
807
808impl PausedQueues {
809    fn new(inflate: Queue, deflate: Queue) -> Self {
810        PausedQueues {
811            inflate,
812            deflate,
813            stats: None,
814            reporting: None,
815            ws_data: None,
816            ws_op: None,
817        }
818    }
819}
820
821fn apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F)
822where
823    F: FnMut(Queue) -> R,
824{
825    if let Some(queue) = queue_opt {
826        func(queue);
827    }
828}
829
830impl From<Box<PausedQueues>> for BTreeMap<usize, Queue> {
831    fn from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue> {
832        let mut ret = Vec::new();
833        ret.push(queues.inflate);
834        ret.push(queues.deflate);
835        apply_if_some(queues.stats, |stats| ret.push(stats));
836        apply_if_some(queues.reporting, |reporting| ret.push(reporting));
837        apply_if_some(queues.ws_data, |ws_data| ret.push(ws_data));
838        apply_if_some(queues.ws_op, |ws_op| ret.push(ws_op));
839        // WARNING: We don't use the indices from the virito spec on purpose, see comment in
840        // get_queues_from_map for the rationale.
841        ret.into_iter().enumerate().collect()
842    }
843}
844
845fn free_memory(
846    vm_memory_client: &VmMemoryClient,
847    mem: &GuestMemory,
848    ranges: Vec<(GuestAddress, u64)>,
849) {
850    // When `--lock-guest-memory` is used, it is not possible to free the memory from the main
851    // process, so we free it from the sandboxed balloon process directly.
852    #[cfg(any(target_os = "android", target_os = "linux"))]
853    if mem.locked() && !mem.use_punchhole_locked() {
854        for (guest_address, len) in ranges {
855            if let Err(e) = mem.remove_range(guest_address, len) {
856                warn!("Marking pages unused failed: {}, addr={}", e, guest_address);
857            }
858        }
859        return;
860    }
861    if let Err(e) = vm_memory_client.dynamically_free_memory_ranges(ranges) {
862        warn!("Failed to dynamically free memory ranges: {e:#}");
863    }
864}
865
866fn reclaim_memory(vm_memory_client: &VmMemoryClient, ranges: Vec<(GuestAddress, u64)>) {
867    if let Err(e) = vm_memory_client.dynamically_reclaim_memory_ranges(ranges) {
868        warn!("Failed to dynamically reclaim memory range: {e:#}");
869    }
870}
871
872/// Stores data from the worker when it stops so that data can be re-used when
873/// the worker is restarted.
874struct WorkerReturn {
875    release_memory_tube: Option<Tube>,
876    command_tube: Tube,
877    #[cfg(feature = "registered_events")]
878    registered_evt_q: Option<SendTube>,
879    paused_queues: Option<PausedQueues>,
880    vm_memory_client: VmMemoryClient,
881}
882
883// The main worker thread. Initialized the asynchronous worker tasks and passes them to the executor
884// to be processed.
885fn run_worker(
886    inflate_queue: Queue,
887    deflate_queue: Queue,
888    stats_queue: Option<Queue>,
889    reporting_queue: Option<Queue>,
890    ws_data_queue: Option<Queue>,
891    ws_op_queue: Option<Queue>,
892    command_tube: Tube,
893    vm_memory_client: VmMemoryClient,
894    mem: GuestMemory,
895    release_memory_tube: Option<Tube>,
896    interrupt: Interrupt,
897    kill_evt: Event,
898    target_reached_evt: Event,
899    pending_adjusted_response_event: Event,
900    state: Arc<AsyncRwLock<BalloonState>>,
901    #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
902) -> WorkerReturn {
903    let ex = Executor::new().unwrap();
904    let command_tube = AsyncTube::new(&ex, command_tube).unwrap();
905    #[cfg(feature = "registered_events")]
906    let registered_evt_q_async = registered_evt_q
907        .as_ref()
908        .map(|q| SendTubeAsync::new(q.try_clone().unwrap(), &ex).unwrap());
909
910    let mut stop_queue_oneshots = Vec::new();
911
912    // We need a block to release all references to command_tube at the end before returning it.
913    let paused_queues = {
914        // The first queue is used for inflate messages
915        let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
916        let inflate_queue_evt = inflate_queue
917            .event()
918            .try_clone()
919            .expect("failed to clone queue event");
920        let inflate = handle_queue(
921            inflate_queue,
922            EventAsync::new(inflate_queue_evt, &ex).expect("failed to create async event"),
923            release_memory_tube.as_ref(),
924            |ranges| free_memory(&vm_memory_client, &mem, ranges),
925            stop_rx,
926        );
927        let inflate = inflate.fuse();
928        pin_mut!(inflate);
929
930        // The second queue is used for deflate messages
931        let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
932        let deflate_queue_evt = deflate_queue
933            .event()
934            .try_clone()
935            .expect("failed to clone queue event");
936        let deflate = handle_queue(
937            deflate_queue,
938            EventAsync::new(deflate_queue_evt, &ex).expect("failed to create async event"),
939            None,
940            |ranges| reclaim_memory(&vm_memory_client, ranges),
941            stop_rx,
942        );
943        let deflate = deflate.fuse();
944        pin_mut!(deflate);
945
946        // The next queue is used for stats messages if VIRTIO_BALLOON_F_STATS_VQ is negotiated.
947        let (stats_tx, stats_rx) = mpsc::channel::<()>(1);
948        let has_stats_queue = stats_queue.is_some();
949        let stats = if let Some(stats_queue) = stats_queue {
950            let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
951            let stats_queue_evt = stats_queue
952                .event()
953                .try_clone()
954                .expect("failed to clone queue event");
955            handle_stats_queue(
956                stats_queue,
957                EventAsync::new(stats_queue_evt, &ex).expect("failed to create async event"),
958                stats_rx,
959                &command_tube,
960                #[cfg(feature = "registered_events")]
961                registered_evt_q_async.as_ref(),
962                state.clone(),
963                stop_rx,
964            )
965            .left_future()
966        } else {
967            std::future::pending().right_future()
968        };
969        let stats = stats.fuse();
970        pin_mut!(stats);
971
972        // The next queue is used for reporting messages
973        let has_reporting_queue = reporting_queue.is_some();
974        let reporting = if let Some(reporting_queue) = reporting_queue {
975            let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
976            let reporting_queue_evt = reporting_queue
977                .event()
978                .try_clone()
979                .expect("failed to clone queue event");
980            handle_reporting_queue(
981                reporting_queue,
982                EventAsync::new(reporting_queue_evt, &ex).expect("failed to create async event"),
983                release_memory_tube.as_ref(),
984                |ranges| free_memory(&vm_memory_client, &mem, ranges),
985                stop_rx,
986            )
987            .left_future()
988        } else {
989            std::future::pending().right_future()
990        };
991        let reporting = reporting.fuse();
992        pin_mut!(reporting);
993
994        // If VIRTIO_BALLOON_F_WS_REPORTING is set 2 queues must handled - one for WS data and one
995        // for WS notifications.
996        let has_ws_data_queue = ws_data_queue.is_some();
997        let ws_data = if let Some(ws_data_queue) = ws_data_queue {
998            let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
999            let ws_data_queue_evt = ws_data_queue
1000                .event()
1001                .try_clone()
1002                .expect("failed to clone queue event");
1003            handle_ws_data_queue(
1004                ws_data_queue,
1005                EventAsync::new(ws_data_queue_evt, &ex).expect("failed to create async event"),
1006                &command_tube,
1007                #[cfg(feature = "registered_events")]
1008                registered_evt_q_async.as_ref(),
1009                state.clone(),
1010                stop_rx,
1011            )
1012            .left_future()
1013        } else {
1014            std::future::pending().right_future()
1015        };
1016        let ws_data = ws_data.fuse();
1017        pin_mut!(ws_data);
1018
1019        let (ws_op_tx, ws_op_rx) = mpsc::channel::<WSOp>(1);
1020        let has_ws_op_queue = ws_op_queue.is_some();
1021        let ws_op = if let Some(ws_op_queue) = ws_op_queue {
1022            let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1023            let ws_op_queue_evt = ws_op_queue
1024                .event()
1025                .try_clone()
1026                .expect("failed to clone queue event");
1027            handle_ws_op_queue(
1028                ws_op_queue,
1029                EventAsync::new(ws_op_queue_evt, &ex).expect("failed to create async event"),
1030                ws_op_rx,
1031                state.clone(),
1032                stop_rx,
1033            )
1034            .left_future()
1035        } else {
1036            std::future::pending().right_future()
1037        };
1038        let ws_op = ws_op.fuse();
1039        pin_mut!(ws_op);
1040
1041        // Future to handle command messages that resize the balloon.
1042        let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1043        let command = handle_command_tube(
1044            &command_tube,
1045            interrupt.clone(),
1046            state.clone(),
1047            stats_tx,
1048            ws_op_tx,
1049            stop_rx,
1050        );
1051        pin_mut!(command);
1052
1053        // Send a message if balloon target reached event is triggered.
1054        let target_reached = handle_target_reached(&ex, target_reached_evt, &vm_memory_client);
1055        pin_mut!(target_reached);
1056
1057        // Exit if the kill event is triggered.
1058        let kill = async_utils::await_and_exit(&ex, kill_evt);
1059        pin_mut!(kill);
1060
1061        let pending_adjusted = handle_pending_adjusted_responses(
1062            EventAsync::new(pending_adjusted_response_event, &ex)
1063                .expect("failed to create async event"),
1064            &command_tube,
1065            state,
1066        );
1067        pin_mut!(pending_adjusted);
1068
1069        let res = ex.run_until(async {
1070            select! {
1071                _ = kill.fuse() => (),
1072                _ = inflate => return Err(anyhow!("inflate stopped unexpectedly")),
1073                _ = deflate => return Err(anyhow!("deflate stopped unexpectedly")),
1074                _ = stats => return Err(anyhow!("stats stopped unexpectedly")),
1075                _ = reporting => return Err(anyhow!("reporting stopped unexpectedly")),
1076                _ = command.fuse() => return Err(anyhow!("command stopped unexpectedly")),
1077                _ = ws_op => return Err(anyhow!("ws_op stopped unexpectedly")),
1078                _ = pending_adjusted.fuse() => return Err(anyhow!("pending_adjusted stopped unexpectedly")),
1079                _ = ws_data => return Err(anyhow!("ws_data stopped unexpectedly")),
1080                _ = target_reached.fuse() => return Err(anyhow!("target_reached stopped unexpectedly")),
1081            }
1082
1083            // Worker is shutting down. To recover the queues, we have to signal
1084            // all the queue futures to exit.
1085            for stop_tx in stop_queue_oneshots {
1086                if stop_tx.send(()).is_err() {
1087                    return Err(anyhow!("failed to request stop for queue future"));
1088                }
1089            }
1090
1091            // Collect all the queues (awaiting any queue future should now
1092            // return its Queue immediately).
1093            let mut paused_queues = PausedQueues::new(
1094                inflate.await,
1095                deflate.await,
1096            );
1097            if has_reporting_queue {
1098                paused_queues.reporting = Some(reporting.await);
1099            }
1100            if has_stats_queue {
1101                paused_queues.stats = Some(stats.await);
1102            }
1103            if has_ws_data_queue {
1104                paused_queues.ws_data = Some(ws_data.await.context("failed to stop ws_data queue")?);
1105            }
1106            if has_ws_op_queue {
1107                paused_queues.ws_op = Some(ws_op.await.context("failed to stop ws_op queue")?);
1108            }
1109            Ok(paused_queues)
1110        });
1111
1112        match res {
1113            Err(e) => {
1114                error!("error happened in executor: {}", e);
1115                None
1116            }
1117            Ok(main_future_res) => match main_future_res {
1118                Ok(paused_queues) => Some(paused_queues),
1119                Err(e) => {
1120                    error!("error happened in main balloon future: {}", e);
1121                    None
1122                }
1123            },
1124        }
1125    };
1126
1127    WorkerReturn {
1128        command_tube: command_tube.into(),
1129        paused_queues,
1130        release_memory_tube,
1131        #[cfg(feature = "registered_events")]
1132        registered_evt_q,
1133        vm_memory_client,
1134    }
1135}
1136
1137async fn handle_target_reached(
1138    ex: &Executor,
1139    target_reached_evt: Event,
1140    vm_memory_client: &VmMemoryClient,
1141) -> anyhow::Result<()> {
1142    let event_async =
1143        EventAsync::new(target_reached_evt, ex).context("failed to create EventAsync")?;
1144    loop {
1145        // Wait for target reached trigger.
1146        let _ = event_async.next_val().await;
1147        // Send the message to vm_control on the event. We don't have to read the current
1148        // size yet.
1149        if let Err(e) = vm_memory_client.balloon_target_reached(0) {
1150            warn!("Failed to send or receive allocation complete request: {e:#}");
1151        }
1152    }
1153    // The above loop will never terminate and there is no reason to terminate it either. However,
1154    // the function is used in an executor that expects a Result<> return. Make sure that clippy
1155    // doesn't enforce the unreachable_code condition.
1156    #[allow(unreachable_code)]
1157    Ok(())
1158}
1159
1160/// Virtio device for memory balloon inflation/deflation.
1161pub struct Balloon {
1162    command_tube: Option<Tube>,
1163    vm_memory_client: Option<VmMemoryClient>,
1164    release_memory_tube: Option<Tube>,
1165    pending_adjusted_response_event: Event,
1166    state: Arc<AsyncRwLock<BalloonState>>,
1167    features: u64,
1168    acked_features: u64,
1169    worker_thread: Option<WorkerThread<WorkerReturn>>,
1170    #[cfg(feature = "registered_events")]
1171    registered_evt_q: Option<SendTube>,
1172    ws_num_bins: u8,
1173    target_reached_evt: Option<Event>,
1174    queue_sizes: Vec<u16>,
1175}
1176
1177/// Snapshot of the [Balloon] state.
1178#[derive(Serialize, Deserialize)]
1179struct BalloonSnapshot {
1180    state: BalloonState,
1181    features: u64,
1182    acked_features: u64,
1183    ws_num_bins: u8,
1184}
1185
1186impl Balloon {
1187    /// Creates a new virtio balloon device.
1188    /// To let Balloon able to successfully release the memory which are pinned
1189    /// by CoIOMMU to host, the release_memory_tube will be used to send the inflate
1190    /// ranges to CoIOMMU with UnpinRequest/UnpinResponse messages, so that The
1191    /// memory in the inflate range can be unpinned first.
1192    pub fn new(
1193        base_features: u64,
1194        command_tube: Tube,
1195        vm_memory_client: VmMemoryClient,
1196        release_memory_tube: Option<Tube>,
1197        init_balloon_size: u64,
1198        enabled_features: u64,
1199        #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
1200        ws_num_bins: u8,
1201    ) -> Result<Balloon> {
1202        let features = base_features
1203            | 1 << VIRTIO_BALLOON_F_MUST_TELL_HOST
1204            | 1 << VIRTIO_BALLOON_F_STATS_VQ
1205            | 1 << VIRTIO_BALLOON_F_DEFLATE_ON_OOM
1206            | enabled_features;
1207
1208        let mut queue_sizes = Vec::new();
1209        queue_sizes.push(QUEUE_SIZE); // inflateq
1210        queue_sizes.push(QUEUE_SIZE); // deflateq
1211        if features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1212            queue_sizes.push(QUEUE_SIZE); // statsq
1213        }
1214        if features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1215            queue_sizes.push(QUEUE_SIZE); // reporting_vq
1216        }
1217        if features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1218            queue_sizes.push(QUEUE_SIZE); // ws_data
1219            queue_sizes.push(QUEUE_SIZE); // ws_cmd
1220        }
1221
1222        Ok(Balloon {
1223            command_tube: Some(command_tube),
1224            vm_memory_client: Some(vm_memory_client),
1225            release_memory_tube,
1226            pending_adjusted_response_event: Event::new().map_err(BalloonError::CreatingEvent)?,
1227            state: Arc::new(AsyncRwLock::new(BalloonState {
1228                num_pages: (init_balloon_size >> VIRTIO_BALLOON_PFN_SHIFT) as u32,
1229                actual_pages: 0,
1230                failable_update: false,
1231                pending_adjusted_responses: VecDeque::new(),
1232                expecting_ws: false,
1233            })),
1234            worker_thread: None,
1235            features,
1236            acked_features: 0,
1237            #[cfg(feature = "registered_events")]
1238            registered_evt_q,
1239            ws_num_bins,
1240            target_reached_evt: None,
1241            queue_sizes,
1242        })
1243    }
1244
1245    fn get_config(&self) -> virtio_balloon_config {
1246        let state = block_on(self.state.lock());
1247        virtio_balloon_config {
1248            num_pages: state.num_pages.into(),
1249            actual: state.actual_pages.into(),
1250            // crosvm does not (currently) use free_page_hint_cmd_id or
1251            // poison_val, but they must be present in the right order and size
1252            // for the virtio-balloon driver in the guest to deserialize the
1253            // config correctly.
1254            free_page_hint_cmd_id: 0.into(),
1255            poison_val: 0.into(),
1256            ws_num_bins: self.ws_num_bins,
1257            _reserved: [0, 0, 0],
1258        }
1259    }
1260
1261    fn stop_worker(&mut self) -> StoppedWorker<PausedQueues> {
1262        if let Some(worker_thread) = self.worker_thread.take() {
1263            let worker_ret = worker_thread.stop();
1264            self.release_memory_tube = worker_ret.release_memory_tube;
1265            self.command_tube = Some(worker_ret.command_tube);
1266            #[cfg(feature = "registered_events")]
1267            {
1268                self.registered_evt_q = worker_ret.registered_evt_q;
1269            }
1270            self.vm_memory_client = Some(worker_ret.vm_memory_client);
1271
1272            if let Some(queues) = worker_ret.paused_queues {
1273                StoppedWorker::WithQueues(Box::new(queues))
1274            } else {
1275                StoppedWorker::MissingQueues
1276            }
1277        } else {
1278            StoppedWorker::AlreadyStopped
1279        }
1280    }
1281
1282    /// Given a filtered queue vector from [VirtioDevice::activate], extract
1283    /// the queues (accounting for queues that are missing because the features
1284    /// are not negotiated) into a structure that is easier to work with.
1285    fn get_queues_from_map(
1286        &self,
1287        mut queues: BTreeMap<usize, Queue>,
1288    ) -> anyhow::Result<BalloonQueues> {
1289        let inflate_queue = queues.remove(&INFLATEQ).context("missing inflateq")?;
1290        let deflate_queue = queues.remove(&DEFLATEQ).context("missing deflateq")?;
1291        let mut queue_struct = BalloonQueues::new(inflate_queue, deflate_queue);
1292
1293        // Queues whose existence depends on advertised features start at queue index 2.
1294        let mut next_queue_index = 2;
1295        let mut next_queue = || {
1296            let idx = next_queue_index;
1297            next_queue_index += 1;
1298            idx
1299        };
1300
1301        if self.features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1302            let statsq = next_queue();
1303            if self.acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1304                queue_struct.stats = Some(queues.remove(&statsq).context("missing statsq")?);
1305            }
1306        }
1307
1308        if self.features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1309            let reporting_vq = next_queue();
1310            if self.acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1311                queue_struct.reporting = Some(
1312                    queues
1313                        .remove(&reporting_vq)
1314                        .context("missing reporting_vq")?,
1315                );
1316            }
1317        }
1318
1319        if self.features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1320            let ws_data_vq = next_queue();
1321            let ws_op_vq = next_queue();
1322            if self.acked_features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1323                queue_struct.ws_data =
1324                    Some(queues.remove(&ws_data_vq).context("missing ws_data_vq")?);
1325                queue_struct.ws_op = Some(queues.remove(&ws_op_vq).context("missing ws_op_vq")?);
1326            }
1327        }
1328
1329        if !queues.is_empty() {
1330            return Err(anyhow!("unexpected queues {:?}", queues.into_keys()));
1331        }
1332
1333        Ok(queue_struct)
1334    }
1335
1336    fn start_worker(
1337        &mut self,
1338        mem: GuestMemory,
1339        interrupt: Interrupt,
1340        queues: BalloonQueues,
1341    ) -> anyhow::Result<()> {
1342        let (self_target_reached_evt, target_reached_evt) = Event::new()
1343            .and_then(|e| Ok((e.try_clone()?, e)))
1344            .context("failed to create target_reached Event pair: {}")?;
1345        self.target_reached_evt = Some(self_target_reached_evt);
1346
1347        let state = self.state.clone();
1348
1349        let command_tube = self.command_tube.take().unwrap();
1350
1351        let vm_memory_client = self.vm_memory_client.take().unwrap();
1352        let release_memory_tube = self.release_memory_tube.take();
1353        #[cfg(feature = "registered_events")]
1354        let registered_evt_q = self.registered_evt_q.take();
1355        let pending_adjusted_response_event = self
1356            .pending_adjusted_response_event
1357            .try_clone()
1358            .context("failed to clone Event")?;
1359
1360        self.worker_thread = Some(WorkerThread::start("v_balloon", move |kill_evt| {
1361            run_worker(
1362                queues.inflate,
1363                queues.deflate,
1364                queues.stats,
1365                queues.reporting,
1366                queues.ws_data,
1367                queues.ws_op,
1368                command_tube,
1369                vm_memory_client,
1370                mem,
1371                release_memory_tube,
1372                interrupt,
1373                kill_evt,
1374                target_reached_evt,
1375                pending_adjusted_response_event,
1376                state,
1377                #[cfg(feature = "registered_events")]
1378                registered_evt_q,
1379            )
1380        }));
1381
1382        Ok(())
1383    }
1384}
1385
1386impl VirtioDevice for Balloon {
1387    fn keep_rds(&self) -> Vec<RawDescriptor> {
1388        let mut rds = Vec::new();
1389        if let Some(command_tube) = &self.command_tube {
1390            rds.push(command_tube.as_raw_descriptor());
1391        }
1392        if let Some(vm_memory_client) = &self.vm_memory_client {
1393            rds.push(vm_memory_client.as_raw_descriptor());
1394        }
1395        if let Some(release_memory_tube) = &self.release_memory_tube {
1396            rds.push(release_memory_tube.as_raw_descriptor());
1397        }
1398        #[cfg(feature = "registered_events")]
1399        if let Some(registered_evt_q) = &self.registered_evt_q {
1400            rds.push(registered_evt_q.as_raw_descriptor());
1401        }
1402        rds.push(self.pending_adjusted_response_event.as_raw_descriptor());
1403        rds
1404    }
1405
1406    fn device_type(&self) -> DeviceType {
1407        DeviceType::Balloon
1408    }
1409
1410    fn queue_max_sizes(&self) -> &[u16] {
1411        &self.queue_sizes
1412    }
1413
1414    fn read_config(&self, offset: u64, data: &mut [u8]) {
1415        copy_config(data, 0, self.get_config().as_bytes(), offset);
1416    }
1417
1418    fn write_config(&mut self, offset: u64, data: &[u8]) {
1419        let mut config = self.get_config();
1420        copy_config(config.as_mut_bytes(), offset, data, 0);
1421        let mut state = block_on(self.state.lock());
1422        state.actual_pages = config.actual.to_native();
1423
1424        // If balloon has updated to the requested memory, let the hypervisor know.
1425        if config.num_pages == config.actual {
1426            debug!(
1427                "sending target reached event at {}",
1428                u32::from(config.num_pages)
1429            );
1430            self.target_reached_evt.as_ref().map(|e| e.signal());
1431        }
1432        if state.failable_update && state.actual_pages == state.num_pages {
1433            state.failable_update = false;
1434            let num_pages = state.num_pages;
1435            state.pending_adjusted_responses.push_back(num_pages);
1436            let _ = self.pending_adjusted_response_event.signal();
1437        }
1438    }
1439
1440    fn features(&self) -> u64 {
1441        self.features
1442    }
1443
1444    fn ack_features(&mut self, mut value: u64) {
1445        if value & !self.features != 0 {
1446            warn!("virtio_balloon got unknown feature ack {:x}", value);
1447            value &= self.features;
1448        }
1449        self.acked_features |= value;
1450    }
1451
1452    fn activate(
1453        &mut self,
1454        mem: GuestMemory,
1455        interrupt: Interrupt,
1456        queues: BTreeMap<usize, Queue>,
1457    ) -> anyhow::Result<()> {
1458        let queues = self.get_queues_from_map(queues)?;
1459        self.start_worker(mem, interrupt, queues)
1460    }
1461
1462    fn reset(&mut self) -> anyhow::Result<()> {
1463        let _worker = self.stop_worker();
1464        Ok(())
1465    }
1466
1467    fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
1468        match self.stop_worker() {
1469            StoppedWorker::WithQueues(paused_queues) => Ok(Some(paused_queues.into())),
1470            StoppedWorker::MissingQueues => {
1471                anyhow::bail!("balloon queue workers did not stop cleanly.")
1472            }
1473            StoppedWorker::AlreadyStopped => {
1474                // Device hasn't been activated.
1475                Ok(None)
1476            }
1477        }
1478    }
1479
1480    fn virtio_wake(
1481        &mut self,
1482        queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
1483    ) -> anyhow::Result<()> {
1484        if let Some((mem, interrupt, queues)) = queues_state {
1485            if queues.len() < 2 {
1486                anyhow::bail!("{} queues were found, but an activated balloon must have at least 2 active queues.", queues.len());
1487            }
1488
1489            let balloon_queues = self.get_queues_from_map(queues)?;
1490            self.start_worker(mem, interrupt, balloon_queues)?;
1491        }
1492        Ok(())
1493    }
1494
1495    fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1496        let state = self
1497            .state
1498            .lock()
1499            .now_or_never()
1500            .context("failed to acquire balloon lock")?;
1501        AnySnapshot::to_any(BalloonSnapshot {
1502            features: self.features,
1503            acked_features: self.acked_features,
1504            state: state.clone(),
1505            ws_num_bins: self.ws_num_bins,
1506        })
1507        .context("failed to serialize balloon state")
1508    }
1509
1510    fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1511        let snap: BalloonSnapshot = AnySnapshot::from_any(data).context("error deserializing")?;
1512        if snap.features != self.features {
1513            anyhow::bail!(
1514                "balloon: expected features to match, but they did not. Live: {:?}, snapshot {:?}",
1515                self.features,
1516                snap.features,
1517            );
1518        }
1519
1520        let mut state = self
1521            .state
1522            .lock()
1523            .now_or_never()
1524            .context("failed to acquire balloon lock")?;
1525        *state = snap.state;
1526        self.ws_num_bins = snap.ws_num_bins;
1527        self.acked_features = snap.acked_features;
1528        Ok(())
1529    }
1530}
1531
1532#[cfg(test)]
1533mod tests {
1534    use super::*;
1535    use crate::suspendable_virtio_tests;
1536    use crate::virtio::descriptor_utils::create_descriptor_chain;
1537    use crate::virtio::descriptor_utils::DescriptorType;
1538
1539    #[test]
1540    fn desc_parsing_inflate() {
1541        // Check that the memory addresses are parsed correctly by 'handle_address_chain' and passed
1542        // to the closure.
1543        let memory_start_addr = GuestAddress(0x0);
1544        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1545        memory
1546            .write_obj_at_addr(0x10u32, GuestAddress(0x100))
1547            .unwrap();
1548        memory
1549            .write_obj_at_addr(0xaa55aa55u32, GuestAddress(0x104))
1550            .unwrap();
1551
1552        let mut chain = create_descriptor_chain(
1553            &memory,
1554            GuestAddress(0x0),
1555            GuestAddress(0x100),
1556            vec![(DescriptorType::Readable, 8)],
1557            0,
1558        )
1559        .expect("create_descriptor_chain failed");
1560
1561        let mut addrs = Vec::new();
1562        let res = handle_address_chain(None, &mut chain, &mut |mut ranges| {
1563            addrs.append(&mut ranges)
1564        });
1565        assert!(res.is_ok());
1566        assert_eq!(addrs.len(), 2);
1567        assert_eq!(
1568            addrs[0].0,
1569            GuestAddress(0x10u64 << VIRTIO_BALLOON_PFN_SHIFT)
1570        );
1571        assert_eq!(
1572            addrs[1].0,
1573            GuestAddress(0xaa55aa55u64 << VIRTIO_BALLOON_PFN_SHIFT)
1574        );
1575    }
1576
1577    struct BalloonContext {
1578        _ctrl_tube: Tube,
1579        _mem_client_tube: Tube,
1580    }
1581
1582    fn modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon) {
1583        balloon.ws_num_bins = !balloon.ws_num_bins;
1584    }
1585
1586    fn create_device() -> (BalloonContext, Balloon) {
1587        let (_ctrl_tube, ctrl_tube_device) = Tube::pair().unwrap();
1588        let (_mem_client_tube, mem_client_tube_device) = Tube::pair().unwrap();
1589        (
1590            BalloonContext {
1591                _ctrl_tube,
1592                _mem_client_tube,
1593            },
1594            Balloon::new(
1595                0,
1596                ctrl_tube_device,
1597                VmMemoryClient::new(mem_client_tube_device),
1598                None,
1599                1024,
1600                0,
1601                #[cfg(feature = "registered_events")]
1602                None,
1603                0,
1604            )
1605            .unwrap(),
1606        )
1607    }
1608
1609    suspendable_virtio_tests!(balloon, create_device, 2, modify_device);
1610}