devices/virtio/
balloon.rs

1// Copyright 2017 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::collections::BTreeMap;
6use std::collections::VecDeque;
7use std::io::Write;
8use std::sync::Arc;
9
10use anyhow::anyhow;
11use anyhow::Context;
12use balloon_control::BalloonStats;
13use balloon_control::BalloonTubeCommand;
14use balloon_control::BalloonTubeResult;
15use balloon_control::BalloonWS;
16use balloon_control::WSBucket;
17use balloon_control::VIRTIO_BALLOON_WS_MAX_NUM_BINS;
18use balloon_control::VIRTIO_BALLOON_WS_MIN_NUM_BINS;
19use base::debug;
20use base::error;
21use base::warn;
22use base::AsRawDescriptor;
23use base::Event;
24use base::RawDescriptor;
25#[cfg(feature = "registered_events")]
26use base::SendTube;
27use base::Tube;
28use base::WorkerThread;
29use cros_async::block_on;
30use cros_async::sync::RwLock as AsyncRwLock;
31use cros_async::AsyncTube;
32use cros_async::EventAsync;
33use cros_async::Executor;
34#[cfg(feature = "registered_events")]
35use cros_async::SendTubeAsync;
36use data_model::Le16;
37use data_model::Le32;
38use data_model::Le64;
39use futures::channel::mpsc;
40use futures::channel::oneshot;
41use futures::pin_mut;
42use futures::select;
43use futures::select_biased;
44use futures::FutureExt;
45use futures::StreamExt;
46use remain::sorted;
47use serde::Deserialize;
48use serde::Serialize;
49use snapshot::AnySnapshot;
50use thiserror::Error as ThisError;
51use vm_control::api::VmMemoryClient;
52#[cfg(feature = "registered_events")]
53use vm_control::RegisteredEventWithData;
54use vm_memory::GuestAddress;
55use vm_memory::GuestMemory;
56use zerocopy::FromBytes;
57use zerocopy::Immutable;
58use zerocopy::IntoBytes;
59use zerocopy::KnownLayout;
60
61use super::async_utils;
62use super::copy_config;
63use super::create_stop_oneshot;
64use super::DescriptorChain;
65use super::DeviceType;
66use super::Interrupt;
67use super::Queue;
68use super::Reader;
69use super::StoppedWorker;
70use super::VirtioDevice;
71use crate::UnpinRequest;
72use crate::UnpinResponse;
73
74#[sorted]
75#[derive(ThisError, Debug)]
76pub enum BalloonError {
77    /// Failed an async await
78    #[error("failed async await: {0}")]
79    AsyncAwait(cros_async::AsyncError),
80    /// Failed an async await
81    #[error("failed async await: {0}")]
82    AsyncAwaitAnyhow(anyhow::Error),
83    /// Failed to create event.
84    #[error("failed to create event: {0}")]
85    CreatingEvent(base::Error),
86    /// Failed to create async message receiver.
87    #[error("failed to create async message receiver: {0}")]
88    CreatingMessageReceiver(base::TubeError),
89    /// Failed to receive command message.
90    #[error("failed to receive command message: {0}")]
91    ReceivingCommand(base::TubeError),
92    /// Failed to send command response.
93    #[error("failed to send command response: {0}")]
94    SendResponse(base::TubeError),
95    /// Error while writing to virtqueue
96    #[error("failed to write to virtqueue: {0}")]
97    WriteQueue(std::io::Error),
98    /// Failed to write config event.
99    #[error("failed to write config event: {0}")]
100    WritingConfigEvent(base::Error),
101}
102pub type Result<T> = std::result::Result<T, BalloonError>;
103
104const QUEUE_SIZE: u16 = 128;
105
106// Virtqueue indexes that do not depend on advertised features
107const INFLATEQ: usize = 0;
108const DEFLATEQ: usize = 1;
109
110const VIRTIO_BALLOON_PFN_SHIFT: u32 = 12;
111const VIRTIO_BALLOON_PF_SIZE: u64 = 1 << VIRTIO_BALLOON_PFN_SHIFT;
112
113// The feature bitmap for virtio balloon
114const VIRTIO_BALLOON_F_MUST_TELL_HOST: u32 = 0; // Tell before reclaiming pages
115const VIRTIO_BALLOON_F_STATS_VQ: u32 = 1; // Stats reporting enabled
116const VIRTIO_BALLOON_F_DEFLATE_ON_OOM: u32 = 2; // Deflate balloon on OOM
117const VIRTIO_BALLOON_F_PAGE_REPORTING: u32 = 5; // Page reporting virtqueue
118                                                // TODO(b/273973298): this should maybe be bit 6? to be changed later
119const VIRTIO_BALLOON_F_WS_REPORTING: u32 = 8; // Working Set Reporting virtqueues
120
121#[derive(Copy, Clone)]
122#[repr(u32)]
123// Balloon virtqueues
124pub enum BalloonFeatures {
125    // Page Reporting enabled
126    PageReporting = VIRTIO_BALLOON_F_PAGE_REPORTING,
127    // WS Reporting enabled
128    WSReporting = VIRTIO_BALLOON_F_WS_REPORTING,
129}
130
131// virtio_balloon_config is the balloon device configuration space defined by the virtio spec.
132#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
133#[repr(C)]
134struct virtio_balloon_config {
135    num_pages: Le32,
136    actual: Le32,
137    free_page_hint_cmd_id: Le32,
138    poison_val: Le32,
139    // WS field is part of proposed spec extension (b/273973298).
140    ws_num_bins: u8,
141    _reserved: [u8; 3],
142}
143
144// BalloonState is shared by the worker and device thread.
145#[derive(Clone, Default, Serialize, Deserialize)]
146struct BalloonState {
147    num_pages: u32,
148    actual_pages: u32,
149    expecting_ws: bool,
150    // Flag indicating that the balloon is in the process of a failable update. This
151    // is set by an Adjust command that has allow_failure set, and is cleared when the
152    // Adjusted success/failure response is sent.
153    failable_update: bool,
154    pending_adjusted_responses: VecDeque<u32>,
155}
156
157// The constants defining stats types in virtio_baloon_stat
158const VIRTIO_BALLOON_S_SWAP_IN: u16 = 0;
159const VIRTIO_BALLOON_S_SWAP_OUT: u16 = 1;
160const VIRTIO_BALLOON_S_MAJFLT: u16 = 2;
161const VIRTIO_BALLOON_S_MINFLT: u16 = 3;
162const VIRTIO_BALLOON_S_MEMFREE: u16 = 4;
163const VIRTIO_BALLOON_S_MEMTOT: u16 = 5;
164const VIRTIO_BALLOON_S_AVAIL: u16 = 6;
165const VIRTIO_BALLOON_S_CACHES: u16 = 7;
166const VIRTIO_BALLOON_S_HTLB_PGALLOC: u16 = 8;
167const VIRTIO_BALLOON_S_HTLB_PGFAIL: u16 = 9;
168const VIRTIO_BALLOON_S_NONSTANDARD_SHMEM: u16 = 65534;
169const VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE: u16 = 65535;
170
171// BalloonStat is used to deserialize stats from the stats_queue.
172#[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
173#[repr(C, packed)]
174struct BalloonStat {
175    tag: Le16,
176    val: Le64,
177}
178
179impl BalloonStat {
180    fn update_stats(&self, stats: &mut BalloonStats) {
181        let val = Some(self.val.to_native());
182        match self.tag.to_native() {
183            VIRTIO_BALLOON_S_SWAP_IN => stats.swap_in = val,
184            VIRTIO_BALLOON_S_SWAP_OUT => stats.swap_out = val,
185            VIRTIO_BALLOON_S_MAJFLT => stats.major_faults = val,
186            VIRTIO_BALLOON_S_MINFLT => stats.minor_faults = val,
187            VIRTIO_BALLOON_S_MEMFREE => stats.free_memory = val,
188            VIRTIO_BALLOON_S_MEMTOT => stats.total_memory = val,
189            VIRTIO_BALLOON_S_AVAIL => stats.available_memory = val,
190            VIRTIO_BALLOON_S_CACHES => stats.disk_caches = val,
191            VIRTIO_BALLOON_S_HTLB_PGALLOC => stats.hugetlb_allocations = val,
192            VIRTIO_BALLOON_S_HTLB_PGFAIL => stats.hugetlb_failures = val,
193            VIRTIO_BALLOON_S_NONSTANDARD_SHMEM => stats.shared_memory = val,
194            VIRTIO_BALLOON_S_NONSTANDARD_UNEVICTABLE => stats.unevictable_memory = val,
195            _ => (),
196        }
197    }
198}
199
200// virtio_balloon_ws is used to deserialize from the ws data vq.
201#[repr(C)]
202#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
203struct virtio_balloon_ws {
204    tag: Le16,
205    node_id: Le16,
206    // virtio prefers field members to align on a word boundary so we must pad. see:
207    // https://crsrc.org/o/src/third_party/kernel/v5.15/include/uapi/linux/virtio_balloon.h;l=105
208    _reserved: [u8; 4],
209    idle_age_ms: Le64,
210    // TODO(b/273973298): these should become separate fields - bytes for ANON and FILE
211    memory_size_bytes: [Le64; 2],
212}
213
214impl virtio_balloon_ws {
215    fn update_ws(&self, ws: &mut BalloonWS) {
216        let bucket = WSBucket {
217            age: self.idle_age_ms.to_native(),
218            bytes: [
219                self.memory_size_bytes[0].to_native(),
220                self.memory_size_bytes[1].to_native(),
221            ],
222        };
223        ws.ws.push(bucket);
224    }
225}
226
227const _VIRTIO_BALLOON_WS_OP_INVALID: u16 = 0;
228const VIRTIO_BALLOON_WS_OP_REQUEST: u16 = 1;
229const VIRTIO_BALLOON_WS_OP_CONFIG: u16 = 2;
230const _VIRTIO_BALLOON_WS_OP_DISCARD: u16 = 3;
231
232// virtio_balloon_op is used to serialize to the ws cmd vq.
233#[repr(C, packed)]
234#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
235struct virtio_balloon_op {
236    type_: Le16,
237}
238
239fn invoke_desc_handler<F>(ranges: Vec<(u64, u64)>, desc_handler: &mut F)
240where
241    F: FnMut(Vec<(GuestAddress, u64)>),
242{
243    desc_handler(
244        ranges
245            .into_iter()
246            .map(|range| (GuestAddress(range.0), range.1))
247            .collect(),
248    );
249}
250
251// Release a list of guest memory ranges back to the host system.
252// Unpin requests for each inflate range will be sent via `release_memory_tube`
253// if provided, and then `desc_handler` will be called for each inflate range.
254fn release_ranges<F>(
255    release_memory_tube: Option<&Tube>,
256    inflate_ranges: Vec<(u64, u64)>,
257    desc_handler: &mut F,
258) -> anyhow::Result<()>
259where
260    F: FnMut(Vec<(GuestAddress, u64)>),
261{
262    if let Some(tube) = release_memory_tube {
263        let unpin_ranges = inflate_ranges
264            .iter()
265            .map(|v| {
266                (
267                    v.0 >> VIRTIO_BALLOON_PFN_SHIFT,
268                    v.1 / VIRTIO_BALLOON_PF_SIZE,
269                )
270            })
271            .collect();
272        let req = UnpinRequest {
273            ranges: unpin_ranges,
274        };
275        if let Err(e) = tube.send(&req) {
276            error!("failed to send unpin request: {}", e);
277        } else {
278            match tube.recv() {
279                Ok(resp) => match resp {
280                    UnpinResponse::Success => invoke_desc_handler(inflate_ranges, desc_handler),
281                    UnpinResponse::Failed => error!("failed to handle unpin request"),
282                },
283                Err(e) => error!("failed to handle get unpin response: {}", e),
284            }
285        }
286    } else {
287        invoke_desc_handler(inflate_ranges, desc_handler);
288    }
289
290    Ok(())
291}
292
293// Processes one message's list of addresses.
294fn handle_address_chain<F>(
295    release_memory_tube: Option<&Tube>,
296    avail_desc: &mut DescriptorChain,
297    desc_handler: &mut F,
298) -> anyhow::Result<()>
299where
300    F: FnMut(Vec<(GuestAddress, u64)>),
301{
302    // In a long-running system, there is no reason to expect that
303    // a significant number of freed pages are consecutive. However,
304    // batching is relatively simple and can result in significant
305    // gains in a newly booted system, so it's worth attempting.
306    let mut range_start = 0;
307    let mut range_size = 0;
308    let mut inflate_ranges: Vec<(u64, u64)> = Vec::new();
309    for res in avail_desc.reader.iter::<Le32>() {
310        let pfn = match res {
311            Ok(pfn) => pfn,
312            Err(e) => {
313                error!("error while reading unused pages: {}", e);
314                break;
315            }
316        };
317        let guest_address = (u64::from(pfn.to_native())) << VIRTIO_BALLOON_PFN_SHIFT;
318        if range_start + range_size == guest_address {
319            range_size += VIRTIO_BALLOON_PF_SIZE;
320        } else if range_start == guest_address + VIRTIO_BALLOON_PF_SIZE {
321            range_start = guest_address;
322            range_size += VIRTIO_BALLOON_PF_SIZE;
323        } else {
324            // Discontinuity, so flush the previous range. Note range_size
325            // will be 0 on the first iteration, so skip that.
326            if range_size != 0 {
327                inflate_ranges.push((range_start, range_size));
328            }
329            range_start = guest_address;
330            range_size = VIRTIO_BALLOON_PF_SIZE;
331        }
332    }
333    if range_size != 0 {
334        inflate_ranges.push((range_start, range_size));
335    }
336
337    release_ranges(release_memory_tube, inflate_ranges, desc_handler)
338}
339
340// Async task that handles the main balloon inflate and deflate queues.
341async fn handle_queue<F>(
342    mut queue: Queue,
343    mut queue_event: EventAsync,
344    release_memory_tube: Option<&Tube>,
345    mut desc_handler: F,
346    mut stop_rx: oneshot::Receiver<()>,
347) -> Queue
348where
349    F: FnMut(Vec<(GuestAddress, u64)>),
350{
351    loop {
352        let mut avail_desc = match queue
353            .next_async_interruptable(&mut queue_event, &mut stop_rx)
354            .await
355        {
356            Ok(Some(res)) => res,
357            Ok(None) => return queue,
358            Err(e) => {
359                error!("Failed to read descriptor {}", e);
360                return queue;
361            }
362        };
363        if let Err(e) =
364            handle_address_chain(release_memory_tube, &mut avail_desc, &mut desc_handler)
365        {
366            error!("balloon: failed to process inflate addresses: {}", e);
367        }
368        queue.add_used(avail_desc);
369        queue.trigger_interrupt();
370    }
371}
372
373// Processes one page-reporting descriptor.
374fn handle_reported_buffer<F>(
375    release_memory_tube: Option<&Tube>,
376    avail_desc: &DescriptorChain,
377    desc_handler: &mut F,
378) -> anyhow::Result<()>
379where
380    F: FnMut(Vec<(GuestAddress, u64)>),
381{
382    let reported_ranges: Vec<(u64, u64)> = avail_desc
383        .reader
384        .get_remaining_regions()
385        .chain(avail_desc.writer.get_remaining_regions())
386        .map(|r| (r.offset, r.len as u64))
387        .collect();
388
389    release_ranges(release_memory_tube, reported_ranges, desc_handler)
390}
391
392// Async task that handles the page reporting queue.
393async fn handle_reporting_queue<F>(
394    mut queue: Queue,
395    mut queue_event: EventAsync,
396    release_memory_tube: Option<&Tube>,
397    mut desc_handler: F,
398    mut stop_rx: oneshot::Receiver<()>,
399) -> Queue
400where
401    F: FnMut(Vec<(GuestAddress, u64)>),
402{
403    loop {
404        let avail_desc = match queue
405            .next_async_interruptable(&mut queue_event, &mut stop_rx)
406            .await
407        {
408            Ok(Some(res)) => res,
409            Ok(None) => return queue,
410            Err(e) => {
411                error!("Failed to read descriptor {}", e);
412                return queue;
413            }
414        };
415        if let Err(e) = handle_reported_buffer(release_memory_tube, &avail_desc, &mut desc_handler)
416        {
417            error!("balloon: failed to process reported buffer: {}", e);
418        }
419        queue.add_used(avail_desc);
420        queue.trigger_interrupt();
421    }
422}
423
424fn parse_balloon_stats(reader: &mut Reader) -> BalloonStats {
425    let mut stats: BalloonStats = Default::default();
426    for res in reader.iter::<BalloonStat>() {
427        match res {
428            Ok(stat) => stat.update_stats(&mut stats),
429            Err(e) => {
430                error!("error while reading stats: {}", e);
431                break;
432            }
433        };
434    }
435    stats
436}
437
438// Async task that handles the stats queue. Note that the cadence of this is driven by requests for
439// balloon stats from the control pipe.
440// The guests queues an initial buffer on boot, which is read and then this future will block until
441// signaled from the command socket that stats should be collected again.
442async fn handle_stats_queue(
443    mut queue: Queue,
444    mut queue_event: EventAsync,
445    mut stats_rx: mpsc::Receiver<()>,
446    command_tube: &AsyncTube,
447    #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
448    state: Arc<AsyncRwLock<BalloonState>>,
449    mut stop_rx: oneshot::Receiver<()>,
450) -> Queue {
451    let mut avail_desc = match queue
452        .next_async_interruptable(&mut queue_event, &mut stop_rx)
453        .await
454    {
455        // Consume the first stats buffer sent from the guest at startup. It was not
456        // requested by anyone, and the stats are stale.
457        Ok(Some(res)) => res,
458        Ok(None) => return queue,
459        Err(e) => {
460            error!("Failed to read descriptor {}", e);
461            return queue;
462        }
463    };
464
465    loop {
466        select_biased! {
467            msg = stats_rx.next() => {
468                // Wait for a request to read the stats.
469                match msg {
470                    Some(()) => (),
471                    None => {
472                        error!("stats signal channel was closed");
473                        return queue;
474                    }
475                }
476            }
477            _ = stop_rx => return queue,
478        };
479
480        // Request a new stats_desc to the guest.
481        queue.add_used(avail_desc);
482        queue.trigger_interrupt();
483
484        avail_desc = match queue.next_async(&mut queue_event).await {
485            Err(e) => {
486                error!("Failed to read descriptor {}", e);
487                return queue;
488            }
489            Ok(d) => d,
490        };
491        let stats = parse_balloon_stats(&mut avail_desc.reader);
492
493        let actual_pages = state.lock().await.actual_pages as u64;
494        let result = BalloonTubeResult::Stats {
495            balloon_actual: actual_pages << VIRTIO_BALLOON_PFN_SHIFT,
496            stats,
497        };
498        let send_result = command_tube.send(result).await;
499        if let Err(e) = send_result {
500            error!("failed to send stats result: {}", e);
501        }
502
503        #[cfg(feature = "registered_events")]
504        if let Some(registered_evt_q) = registered_evt_q {
505            if let Err(e) = registered_evt_q
506                .send(&RegisteredEventWithData::VirtioBalloonResize)
507                .await
508            {
509                error!("failed to send VirtioBalloonResize event: {}", e);
510            }
511        }
512    }
513}
514
515async fn send_adjusted_response(
516    tube: &AsyncTube,
517    num_pages: u32,
518) -> std::result::Result<(), base::TubeError> {
519    let num_bytes = (num_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
520    let result = BalloonTubeResult::Adjusted { num_bytes };
521    tube.send(result).await
522}
523
524enum WSOp {
525    WSReport,
526    WSConfig {
527        bins: Vec<u32>,
528        refresh_threshold: u32,
529        report_threshold: u32,
530    },
531}
532
533async fn handle_ws_op_queue(
534    mut queue: Queue,
535    mut queue_event: EventAsync,
536    mut ws_op_rx: mpsc::Receiver<WSOp>,
537    state: Arc<AsyncRwLock<BalloonState>>,
538    mut stop_rx: oneshot::Receiver<()>,
539) -> Result<Queue> {
540    loop {
541        let op = select_biased! {
542            next_op = ws_op_rx.next().fuse() => {
543                match next_op {
544                    Some(op) => op,
545                    None => {
546                        error!("ws op tube was closed");
547                        break;
548                    }
549                }
550            }
551            _ = stop_rx => {
552                break;
553            }
554        };
555        let mut avail_desc = queue
556            .next_async(&mut queue_event)
557            .await
558            .map_err(BalloonError::AsyncAwait)?;
559        let writer = &mut avail_desc.writer;
560
561        match op {
562            WSOp::WSReport => {
563                {
564                    let mut state = state.lock().await;
565                    state.expecting_ws = true;
566                }
567
568                let ws_r = virtio_balloon_op {
569                    type_: VIRTIO_BALLOON_WS_OP_REQUEST.into(),
570                };
571
572                writer.write_obj(ws_r).map_err(BalloonError::WriteQueue)?;
573            }
574            WSOp::WSConfig {
575                bins,
576                refresh_threshold,
577                report_threshold,
578            } => {
579                let cmd = virtio_balloon_op {
580                    type_: VIRTIO_BALLOON_WS_OP_CONFIG.into(),
581                };
582
583                writer.write_obj(cmd).map_err(BalloonError::WriteQueue)?;
584                writer
585                    .write_all(bins.as_bytes())
586                    .map_err(BalloonError::WriteQueue)?;
587                writer
588                    .write_obj(refresh_threshold)
589                    .map_err(BalloonError::WriteQueue)?;
590                writer
591                    .write_obj(report_threshold)
592                    .map_err(BalloonError::WriteQueue)?;
593            }
594        }
595
596        queue.add_used(avail_desc);
597        queue.trigger_interrupt();
598    }
599
600    Ok(queue)
601}
602
603fn parse_balloon_ws(reader: &mut Reader) -> BalloonWS {
604    let mut ws = BalloonWS::new();
605    for res in reader.iter::<virtio_balloon_ws>() {
606        match res {
607            Ok(ws_msg) => {
608                ws_msg.update_ws(&mut ws);
609            }
610            Err(e) => {
611                error!("error while reading ws: {}", e);
612                break;
613            }
614        }
615    }
616    if ws.ws.len() < VIRTIO_BALLOON_WS_MIN_NUM_BINS || ws.ws.len() > VIRTIO_BALLOON_WS_MAX_NUM_BINS
617    {
618        error!("unexpected number of WS buckets: {}", ws.ws.len());
619    }
620    ws
621}
622
623// Async task that handles the stats queue. Note that the arrival of events on
624// the WS vq may be the result of either a WS request (WS-R) command having
625// been sent to the guest, or an unprompted send due to memory pressue in the
626// guest. If the data was requested, we should also send that back on the
627// command tube.
628async fn handle_ws_data_queue(
629    mut queue: Queue,
630    mut queue_event: EventAsync,
631    command_tube: &AsyncTube,
632    #[cfg(feature = "registered_events")] registered_evt_q: Option<&SendTubeAsync>,
633    state: Arc<AsyncRwLock<BalloonState>>,
634    mut stop_rx: oneshot::Receiver<()>,
635) -> Result<Queue> {
636    loop {
637        let mut avail_desc = match queue
638            .next_async_interruptable(&mut queue_event, &mut stop_rx)
639            .await
640            .map_err(BalloonError::AsyncAwait)?
641        {
642            Some(res) => res,
643            None => return Ok(queue),
644        };
645
646        let ws = parse_balloon_ws(&mut avail_desc.reader);
647
648        let mut state = state.lock().await;
649
650        // update ws report with balloon pages now that we have a lock on state
651        let balloon_actual = (state.actual_pages as u64) << VIRTIO_BALLOON_PFN_SHIFT;
652
653        if state.expecting_ws {
654            let result = BalloonTubeResult::WorkingSet { ws, balloon_actual };
655            let send_result = command_tube.send(result).await;
656            if let Err(e) = send_result {
657                error!("failed to send ws result: {}", e);
658            }
659
660            state.expecting_ws = false;
661        } else {
662            #[cfg(feature = "registered_events")]
663            if let Some(registered_evt_q) = registered_evt_q {
664                if let Err(e) = registered_evt_q
665                    .send(RegisteredEventWithData::from_ws(&ws, balloon_actual))
666                    .await
667                {
668                    error!("failed to send VirtioBalloonWSReport event: {}", e);
669                }
670            }
671        }
672
673        queue.add_used(avail_desc);
674        queue.trigger_interrupt();
675    }
676}
677
678// Async task that handles the command socket. The command socket handles messages from the host
679// requesting that the guest balloon be adjusted or to report guest memory statistics.
680async fn handle_command_tube(
681    command_tube: &AsyncTube,
682    interrupt: Interrupt,
683    state: Arc<AsyncRwLock<BalloonState>>,
684    mut stats_tx: mpsc::Sender<()>,
685    mut ws_op_tx: mpsc::Sender<WSOp>,
686    mut stop_rx: oneshot::Receiver<()>,
687) -> Result<()> {
688    loop {
689        let cmd_res = select_biased! {
690            res = command_tube.next().fuse() => res,
691            _ = stop_rx => return Ok(())
692        };
693        match cmd_res {
694            Ok(command) => match command {
695                BalloonTubeCommand::Adjust {
696                    num_bytes,
697                    allow_failure,
698                } => {
699                    let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as u32;
700                    let mut state = state.lock().await;
701
702                    state.num_pages = num_pages;
703                    interrupt.signal_config_changed();
704
705                    if allow_failure {
706                        if num_pages == state.actual_pages {
707                            send_adjusted_response(command_tube, num_pages)
708                                .await
709                                .map_err(BalloonError::SendResponse)?;
710                        } else {
711                            state.failable_update = true;
712                        }
713                    }
714                }
715                BalloonTubeCommand::WorkingSetConfig {
716                    bins,
717                    refresh_threshold,
718                    report_threshold,
719                } => {
720                    if let Err(e) = ws_op_tx.try_send(WSOp::WSConfig {
721                        bins,
722                        refresh_threshold,
723                        report_threshold,
724                    }) {
725                        error!("failed to send config to ws handler: {}", e);
726                    }
727                }
728                BalloonTubeCommand::Stats => {
729                    if let Err(e) = stats_tx.try_send(()) {
730                        error!("failed to signal the stat handler: {}", e);
731                    }
732                }
733                BalloonTubeCommand::WorkingSet => {
734                    if let Err(e) = ws_op_tx.try_send(WSOp::WSReport) {
735                        error!("failed to send report request to ws handler: {}", e);
736                    }
737                }
738            },
739            #[cfg(windows)]
740            Err(base::TubeError::Recv(e)) if e.kind() == std::io::ErrorKind::TimedOut => {
741                // On Windows, async IO tasks like the next/recv above are cancelled as the VM is
742                // shutting down. For the sake of consistency with unix, we can't *just* return
743                // here; instead, we wait for the stop request to arrive, *and then* return.
744                //
745                // The real fix is to get rid of the global unblock pool, since then we won't
746                // cancel the tasks early (b/196911556).
747                let _ = stop_rx.await;
748                return Ok(());
749            }
750            Err(e) => {
751                return Err(BalloonError::ReceivingCommand(e));
752            }
753        }
754    }
755}
756
757async fn handle_pending_adjusted_responses(
758    pending_adjusted_response_event: EventAsync,
759    command_tube: &AsyncTube,
760    state: Arc<AsyncRwLock<BalloonState>>,
761) -> Result<()> {
762    loop {
763        pending_adjusted_response_event
764            .next_val()
765            .await
766            .map_err(BalloonError::AsyncAwait)?;
767        while let Some(num_pages) = state.lock().await.pending_adjusted_responses.pop_front() {
768            send_adjusted_response(command_tube, num_pages)
769                .await
770                .map_err(BalloonError::SendResponse)?;
771        }
772    }
773}
774
775/// Represents queues & events for the balloon device.
776struct BalloonQueues {
777    inflate: Queue,
778    deflate: Queue,
779    stats: Option<Queue>,
780    reporting: Option<Queue>,
781    ws_data: Option<Queue>,
782    ws_op: Option<Queue>,
783}
784
785impl BalloonQueues {
786    fn new(inflate: Queue, deflate: Queue) -> Self {
787        BalloonQueues {
788            inflate,
789            deflate,
790            stats: None,
791            reporting: None,
792            ws_data: None,
793            ws_op: None,
794        }
795    }
796}
797
798/// When the worker is stopped, the queues are preserved here.
799struct PausedQueues {
800    inflate: Queue,
801    deflate: Queue,
802    stats: Option<Queue>,
803    reporting: Option<Queue>,
804    ws_data: Option<Queue>,
805    ws_op: Option<Queue>,
806}
807
808impl PausedQueues {
809    fn new(inflate: Queue, deflate: Queue) -> Self {
810        PausedQueues {
811            inflate,
812            deflate,
813            stats: None,
814            reporting: None,
815            ws_data: None,
816            ws_op: None,
817        }
818    }
819}
820
821fn apply_if_some<F, R>(queue_opt: Option<Queue>, mut func: F)
822where
823    F: FnMut(Queue) -> R,
824{
825    if let Some(queue) = queue_opt {
826        func(queue);
827    }
828}
829
830impl From<Box<PausedQueues>> for BTreeMap<usize, Queue> {
831    fn from(queues: Box<PausedQueues>) -> BTreeMap<usize, Queue> {
832        let mut ret = Vec::new();
833        ret.push(queues.inflate);
834        ret.push(queues.deflate);
835        apply_if_some(queues.stats, |stats| ret.push(stats));
836        apply_if_some(queues.reporting, |reporting| ret.push(reporting));
837        apply_if_some(queues.ws_data, |ws_data| ret.push(ws_data));
838        apply_if_some(queues.ws_op, |ws_op| ret.push(ws_op));
839        // WARNING: We don't use the indices from the virito spec on purpose, see comment in
840        // get_queues_from_map for the rationale.
841        ret.into_iter().enumerate().collect()
842    }
843}
844
845fn free_memory(
846    vm_memory_client: &VmMemoryClient,
847    mem: &GuestMemory,
848    ranges: Vec<(GuestAddress, u64)>,
849) {
850    // If the memory is locked and device sandboxing is enabled (inferred from
851    // `use_punchhole_locked() == false`), then free the memory directly from the sandboxed
852    // process.
853    //
854    // This used to be necessary, but isn't anymore. Instead we keep it just to avoid disrupting
855    // the behavior of crosvm on ChromeOS (possible perf loss). It is likely incompatible with
856    // non-KVM hypervisors. For now it is OK because we don't have known users that do all of (1)
857    // use non-KVM hypervisor, (2) lock guest memory, and (3) enable device sandboxing.
858    #[cfg(any(target_os = "android", target_os = "linux"))]
859    if mem.locked() && !mem.use_punchhole_locked() {
860        for (guest_address, len) in ranges {
861            if let Err(e) = mem.remove_range(guest_address, len) {
862                warn!("Marking pages unused failed: {}, addr={}", e, guest_address);
863            }
864        }
865        return;
866    }
867    if let Err(e) = vm_memory_client.dynamically_free_memory_ranges(ranges) {
868        warn!("Failed to dynamically free memory ranges: {e:#}");
869    }
870}
871
872fn reclaim_memory(vm_memory_client: &VmMemoryClient, ranges: Vec<(GuestAddress, u64)>) {
873    if let Err(e) = vm_memory_client.dynamically_reclaim_memory_ranges(ranges) {
874        warn!("Failed to dynamically reclaim memory range: {e:#}");
875    }
876}
877
878/// Stores data from the worker when it stops so that data can be re-used when
879/// the worker is restarted.
880struct WorkerReturn {
881    release_memory_tube: Option<Tube>,
882    command_tube: Tube,
883    #[cfg(feature = "registered_events")]
884    registered_evt_q: Option<SendTube>,
885    paused_queues: Option<PausedQueues>,
886    vm_memory_client: VmMemoryClient,
887}
888
889// The main worker thread. Initialized the asynchronous worker tasks and passes them to the executor
890// to be processed.
891fn run_worker(
892    inflate_queue: Queue,
893    deflate_queue: Queue,
894    stats_queue: Option<Queue>,
895    reporting_queue: Option<Queue>,
896    ws_data_queue: Option<Queue>,
897    ws_op_queue: Option<Queue>,
898    command_tube: Tube,
899    vm_memory_client: VmMemoryClient,
900    mem: GuestMemory,
901    release_memory_tube: Option<Tube>,
902    interrupt: Interrupt,
903    kill_evt: Event,
904    target_reached_evt: Event,
905    pending_adjusted_response_event: Event,
906    state: Arc<AsyncRwLock<BalloonState>>,
907    #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
908) -> WorkerReturn {
909    let ex = Executor::new().unwrap();
910    let command_tube = AsyncTube::new(&ex, command_tube).unwrap();
911    #[cfg(feature = "registered_events")]
912    let registered_evt_q_async = registered_evt_q
913        .as_ref()
914        .map(|q| SendTubeAsync::new(q.try_clone().unwrap(), &ex).unwrap());
915
916    let mut stop_queue_oneshots = Vec::new();
917
918    // We need a block to release all references to command_tube at the end before returning it.
919    let paused_queues = {
920        // The first queue is used for inflate messages
921        let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
922        let inflate_queue_evt = inflate_queue
923            .event()
924            .try_clone()
925            .expect("failed to clone queue event");
926        let inflate = handle_queue(
927            inflate_queue,
928            EventAsync::new(inflate_queue_evt, &ex).expect("failed to create async event"),
929            release_memory_tube.as_ref(),
930            |ranges| free_memory(&vm_memory_client, &mem, ranges),
931            stop_rx,
932        );
933        let inflate = inflate.fuse();
934        pin_mut!(inflate);
935
936        // The second queue is used for deflate messages
937        let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
938        let deflate_queue_evt = deflate_queue
939            .event()
940            .try_clone()
941            .expect("failed to clone queue event");
942        let deflate = handle_queue(
943            deflate_queue,
944            EventAsync::new(deflate_queue_evt, &ex).expect("failed to create async event"),
945            None,
946            |ranges| reclaim_memory(&vm_memory_client, ranges),
947            stop_rx,
948        );
949        let deflate = deflate.fuse();
950        pin_mut!(deflate);
951
952        // The next queue is used for stats messages if VIRTIO_BALLOON_F_STATS_VQ is negotiated.
953        let (stats_tx, stats_rx) = mpsc::channel::<()>(1);
954        let has_stats_queue = stats_queue.is_some();
955        let stats = if let Some(stats_queue) = stats_queue {
956            let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
957            let stats_queue_evt = stats_queue
958                .event()
959                .try_clone()
960                .expect("failed to clone queue event");
961            handle_stats_queue(
962                stats_queue,
963                EventAsync::new(stats_queue_evt, &ex).expect("failed to create async event"),
964                stats_rx,
965                &command_tube,
966                #[cfg(feature = "registered_events")]
967                registered_evt_q_async.as_ref(),
968                state.clone(),
969                stop_rx,
970            )
971            .left_future()
972        } else {
973            std::future::pending().right_future()
974        };
975        let stats = stats.fuse();
976        pin_mut!(stats);
977
978        // The next queue is used for reporting messages
979        let has_reporting_queue = reporting_queue.is_some();
980        let reporting = if let Some(reporting_queue) = reporting_queue {
981            let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
982            let reporting_queue_evt = reporting_queue
983                .event()
984                .try_clone()
985                .expect("failed to clone queue event");
986            handle_reporting_queue(
987                reporting_queue,
988                EventAsync::new(reporting_queue_evt, &ex).expect("failed to create async event"),
989                release_memory_tube.as_ref(),
990                |ranges| free_memory(&vm_memory_client, &mem, ranges),
991                stop_rx,
992            )
993            .left_future()
994        } else {
995            std::future::pending().right_future()
996        };
997        let reporting = reporting.fuse();
998        pin_mut!(reporting);
999
1000        // If VIRTIO_BALLOON_F_WS_REPORTING is set 2 queues must handled - one for WS data and one
1001        // for WS notifications.
1002        let has_ws_data_queue = ws_data_queue.is_some();
1003        let ws_data = if let Some(ws_data_queue) = ws_data_queue {
1004            let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1005            let ws_data_queue_evt = ws_data_queue
1006                .event()
1007                .try_clone()
1008                .expect("failed to clone queue event");
1009            handle_ws_data_queue(
1010                ws_data_queue,
1011                EventAsync::new(ws_data_queue_evt, &ex).expect("failed to create async event"),
1012                &command_tube,
1013                #[cfg(feature = "registered_events")]
1014                registered_evt_q_async.as_ref(),
1015                state.clone(),
1016                stop_rx,
1017            )
1018            .left_future()
1019        } else {
1020            std::future::pending().right_future()
1021        };
1022        let ws_data = ws_data.fuse();
1023        pin_mut!(ws_data);
1024
1025        let (ws_op_tx, ws_op_rx) = mpsc::channel::<WSOp>(1);
1026        let has_ws_op_queue = ws_op_queue.is_some();
1027        let ws_op = if let Some(ws_op_queue) = ws_op_queue {
1028            let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1029            let ws_op_queue_evt = ws_op_queue
1030                .event()
1031                .try_clone()
1032                .expect("failed to clone queue event");
1033            handle_ws_op_queue(
1034                ws_op_queue,
1035                EventAsync::new(ws_op_queue_evt, &ex).expect("failed to create async event"),
1036                ws_op_rx,
1037                state.clone(),
1038                stop_rx,
1039            )
1040            .left_future()
1041        } else {
1042            std::future::pending().right_future()
1043        };
1044        let ws_op = ws_op.fuse();
1045        pin_mut!(ws_op);
1046
1047        // Future to handle command messages that resize the balloon.
1048        let stop_rx = create_stop_oneshot(&mut stop_queue_oneshots);
1049        let command = handle_command_tube(
1050            &command_tube,
1051            interrupt.clone(),
1052            state.clone(),
1053            stats_tx,
1054            ws_op_tx,
1055            stop_rx,
1056        );
1057        pin_mut!(command);
1058
1059        // Send a message if balloon target reached event is triggered.
1060        let target_reached = handle_target_reached(&ex, target_reached_evt, &vm_memory_client);
1061        pin_mut!(target_reached);
1062
1063        // Exit if the kill event is triggered.
1064        let kill = async_utils::await_and_exit(&ex, kill_evt);
1065        pin_mut!(kill);
1066
1067        let pending_adjusted = handle_pending_adjusted_responses(
1068            EventAsync::new(pending_adjusted_response_event, &ex)
1069                .expect("failed to create async event"),
1070            &command_tube,
1071            state,
1072        );
1073        pin_mut!(pending_adjusted);
1074
1075        let res = ex.run_until(async {
1076            select! {
1077                _ = kill.fuse() => (),
1078                _ = inflate => return Err(anyhow!("inflate stopped unexpectedly")),
1079                _ = deflate => return Err(anyhow!("deflate stopped unexpectedly")),
1080                _ = stats => return Err(anyhow!("stats stopped unexpectedly")),
1081                _ = reporting => return Err(anyhow!("reporting stopped unexpectedly")),
1082                _ = command.fuse() => return Err(anyhow!("command stopped unexpectedly")),
1083                _ = ws_op => return Err(anyhow!("ws_op stopped unexpectedly")),
1084                _ = pending_adjusted.fuse() => return Err(anyhow!("pending_adjusted stopped unexpectedly")),
1085                _ = ws_data => return Err(anyhow!("ws_data stopped unexpectedly")),
1086                _ = target_reached.fuse() => return Err(anyhow!("target_reached stopped unexpectedly")),
1087            }
1088
1089            // Worker is shutting down. To recover the queues, we have to signal
1090            // all the queue futures to exit.
1091            for stop_tx in stop_queue_oneshots {
1092                if stop_tx.send(()).is_err() {
1093                    return Err(anyhow!("failed to request stop for queue future"));
1094                }
1095            }
1096
1097            // Collect all the queues (awaiting any queue future should now
1098            // return its Queue immediately).
1099            let mut paused_queues = PausedQueues::new(
1100                inflate.await,
1101                deflate.await,
1102            );
1103            if has_reporting_queue {
1104                paused_queues.reporting = Some(reporting.await);
1105            }
1106            if has_stats_queue {
1107                paused_queues.stats = Some(stats.await);
1108            }
1109            if has_ws_data_queue {
1110                paused_queues.ws_data = Some(ws_data.await.context("failed to stop ws_data queue")?);
1111            }
1112            if has_ws_op_queue {
1113                paused_queues.ws_op = Some(ws_op.await.context("failed to stop ws_op queue")?);
1114            }
1115            Ok(paused_queues)
1116        });
1117
1118        match res {
1119            Err(e) => {
1120                error!("error happened in executor: {}", e);
1121                None
1122            }
1123            Ok(main_future_res) => match main_future_res {
1124                Ok(paused_queues) => Some(paused_queues),
1125                Err(e) => {
1126                    error!("error happened in main balloon future: {}", e);
1127                    None
1128                }
1129            },
1130        }
1131    };
1132
1133    WorkerReturn {
1134        command_tube: command_tube.into(),
1135        paused_queues,
1136        release_memory_tube,
1137        #[cfg(feature = "registered_events")]
1138        registered_evt_q,
1139        vm_memory_client,
1140    }
1141}
1142
1143async fn handle_target_reached(
1144    ex: &Executor,
1145    target_reached_evt: Event,
1146    vm_memory_client: &VmMemoryClient,
1147) -> anyhow::Result<()> {
1148    let event_async =
1149        EventAsync::new(target_reached_evt, ex).context("failed to create EventAsync")?;
1150    loop {
1151        // Wait for target reached trigger.
1152        let _ = event_async.next_val().await;
1153        // Send the message to vm_control on the event. We don't have to read the current
1154        // size yet.
1155        if let Err(e) = vm_memory_client.balloon_target_reached(0) {
1156            warn!("Failed to send or receive allocation complete request: {e:#}");
1157        }
1158    }
1159    // The above loop will never terminate and there is no reason to terminate it either. However,
1160    // the function is used in an executor that expects a Result<> return. Make sure that clippy
1161    // doesn't enforce the unreachable_code condition.
1162    #[allow(unreachable_code)]
1163    Ok(())
1164}
1165
1166/// Virtio device for memory balloon inflation/deflation.
1167pub struct Balloon {
1168    command_tube: Option<Tube>,
1169    vm_memory_client: Option<VmMemoryClient>,
1170    release_memory_tube: Option<Tube>,
1171    pending_adjusted_response_event: Event,
1172    state: Arc<AsyncRwLock<BalloonState>>,
1173    features: u64,
1174    acked_features: u64,
1175    worker_thread: Option<WorkerThread<WorkerReturn>>,
1176    #[cfg(feature = "registered_events")]
1177    registered_evt_q: Option<SendTube>,
1178    ws_num_bins: u8,
1179    target_reached_evt: Option<Event>,
1180    queue_sizes: Vec<u16>,
1181}
1182
1183/// Snapshot of the [Balloon] state.
1184#[derive(Serialize, Deserialize)]
1185struct BalloonSnapshot {
1186    state: BalloonState,
1187    features: u64,
1188    acked_features: u64,
1189    ws_num_bins: u8,
1190}
1191
1192impl Balloon {
1193    /// Creates a new virtio balloon device.
1194    /// To let Balloon able to successfully release the memory which are pinned
1195    /// by CoIOMMU to host, the release_memory_tube will be used to send the inflate
1196    /// ranges to CoIOMMU with UnpinRequest/UnpinResponse messages, so that The
1197    /// memory in the inflate range can be unpinned first.
1198    pub fn new(
1199        base_features: u64,
1200        command_tube: Tube,
1201        vm_memory_client: VmMemoryClient,
1202        release_memory_tube: Option<Tube>,
1203        init_balloon_size: u64,
1204        enabled_features: u64,
1205        #[cfg(feature = "registered_events")] registered_evt_q: Option<SendTube>,
1206        ws_num_bins: u8,
1207    ) -> Result<Balloon> {
1208        let features = base_features
1209            | 1 << VIRTIO_BALLOON_F_MUST_TELL_HOST
1210            | 1 << VIRTIO_BALLOON_F_STATS_VQ
1211            | 1 << VIRTIO_BALLOON_F_DEFLATE_ON_OOM
1212            | enabled_features;
1213
1214        let mut queue_sizes = Vec::new();
1215        queue_sizes.push(QUEUE_SIZE); // inflateq
1216        queue_sizes.push(QUEUE_SIZE); // deflateq
1217        if features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1218            queue_sizes.push(QUEUE_SIZE); // statsq
1219        }
1220        if features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1221            queue_sizes.push(QUEUE_SIZE); // reporting_vq
1222        }
1223        if features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1224            queue_sizes.push(QUEUE_SIZE); // ws_data
1225            queue_sizes.push(QUEUE_SIZE); // ws_cmd
1226        }
1227
1228        Ok(Balloon {
1229            command_tube: Some(command_tube),
1230            vm_memory_client: Some(vm_memory_client),
1231            release_memory_tube,
1232            pending_adjusted_response_event: Event::new().map_err(BalloonError::CreatingEvent)?,
1233            state: Arc::new(AsyncRwLock::new(BalloonState {
1234                num_pages: (init_balloon_size >> VIRTIO_BALLOON_PFN_SHIFT) as u32,
1235                actual_pages: 0,
1236                failable_update: false,
1237                pending_adjusted_responses: VecDeque::new(),
1238                expecting_ws: false,
1239            })),
1240            worker_thread: None,
1241            features,
1242            acked_features: 0,
1243            #[cfg(feature = "registered_events")]
1244            registered_evt_q,
1245            ws_num_bins,
1246            target_reached_evt: None,
1247            queue_sizes,
1248        })
1249    }
1250
1251    fn get_config(&self) -> virtio_balloon_config {
1252        let state = block_on(self.state.lock());
1253        virtio_balloon_config {
1254            num_pages: state.num_pages.into(),
1255            actual: state.actual_pages.into(),
1256            // crosvm does not (currently) use free_page_hint_cmd_id or
1257            // poison_val, but they must be present in the right order and size
1258            // for the virtio-balloon driver in the guest to deserialize the
1259            // config correctly.
1260            free_page_hint_cmd_id: 0.into(),
1261            poison_val: 0.into(),
1262            ws_num_bins: self.ws_num_bins,
1263            _reserved: [0, 0, 0],
1264        }
1265    }
1266
1267    fn stop_worker(&mut self) -> StoppedWorker<PausedQueues> {
1268        if let Some(worker_thread) = self.worker_thread.take() {
1269            let worker_ret = worker_thread.stop();
1270            self.release_memory_tube = worker_ret.release_memory_tube;
1271            self.command_tube = Some(worker_ret.command_tube);
1272            #[cfg(feature = "registered_events")]
1273            {
1274                self.registered_evt_q = worker_ret.registered_evt_q;
1275            }
1276            self.vm_memory_client = Some(worker_ret.vm_memory_client);
1277
1278            if let Some(queues) = worker_ret.paused_queues {
1279                StoppedWorker::WithQueues(Box::new(queues))
1280            } else {
1281                StoppedWorker::MissingQueues
1282            }
1283        } else {
1284            StoppedWorker::AlreadyStopped
1285        }
1286    }
1287
1288    /// Given a filtered queue vector from [VirtioDevice::activate], extract
1289    /// the queues (accounting for queues that are missing because the features
1290    /// are not negotiated) into a structure that is easier to work with.
1291    fn get_queues_from_map(
1292        &self,
1293        mut queues: BTreeMap<usize, Queue>,
1294    ) -> anyhow::Result<BalloonQueues> {
1295        let inflate_queue = queues.remove(&INFLATEQ).context("missing inflateq")?;
1296        let deflate_queue = queues.remove(&DEFLATEQ).context("missing deflateq")?;
1297        let mut queue_struct = BalloonQueues::new(inflate_queue, deflate_queue);
1298
1299        // Queues whose existence depends on advertised features start at queue index 2.
1300        let mut next_queue_index = 2;
1301        let mut next_queue = || {
1302            let idx = next_queue_index;
1303            next_queue_index += 1;
1304            idx
1305        };
1306
1307        if self.features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1308            let statsq = next_queue();
1309            if self.acked_features & (1 << VIRTIO_BALLOON_F_STATS_VQ) != 0 {
1310                queue_struct.stats = Some(queues.remove(&statsq).context("missing statsq")?);
1311            }
1312        }
1313
1314        if self.features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1315            let reporting_vq = next_queue();
1316            if self.acked_features & (1 << VIRTIO_BALLOON_F_PAGE_REPORTING) != 0 {
1317                queue_struct.reporting = Some(
1318                    queues
1319                        .remove(&reporting_vq)
1320                        .context("missing reporting_vq")?,
1321                );
1322            }
1323        }
1324
1325        if self.features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1326            let ws_data_vq = next_queue();
1327            let ws_op_vq = next_queue();
1328            if self.acked_features & (1 << VIRTIO_BALLOON_F_WS_REPORTING) != 0 {
1329                queue_struct.ws_data =
1330                    Some(queues.remove(&ws_data_vq).context("missing ws_data_vq")?);
1331                queue_struct.ws_op = Some(queues.remove(&ws_op_vq).context("missing ws_op_vq")?);
1332            }
1333        }
1334
1335        if !queues.is_empty() {
1336            return Err(anyhow!("unexpected queues {:?}", queues.into_keys()));
1337        }
1338
1339        Ok(queue_struct)
1340    }
1341
1342    fn start_worker(
1343        &mut self,
1344        mem: GuestMemory,
1345        interrupt: Interrupt,
1346        queues: BalloonQueues,
1347    ) -> anyhow::Result<()> {
1348        let (self_target_reached_evt, target_reached_evt) = Event::new()
1349            .and_then(|e| Ok((e.try_clone()?, e)))
1350            .context("failed to create target_reached Event pair: {}")?;
1351        self.target_reached_evt = Some(self_target_reached_evt);
1352
1353        let state = self.state.clone();
1354
1355        let command_tube = self.command_tube.take().unwrap();
1356
1357        let vm_memory_client = self.vm_memory_client.take().unwrap();
1358        let release_memory_tube = self.release_memory_tube.take();
1359        #[cfg(feature = "registered_events")]
1360        let registered_evt_q = self.registered_evt_q.take();
1361        let pending_adjusted_response_event = self
1362            .pending_adjusted_response_event
1363            .try_clone()
1364            .context("failed to clone Event")?;
1365
1366        self.worker_thread = Some(WorkerThread::start("v_balloon", move |kill_evt| {
1367            run_worker(
1368                queues.inflate,
1369                queues.deflate,
1370                queues.stats,
1371                queues.reporting,
1372                queues.ws_data,
1373                queues.ws_op,
1374                command_tube,
1375                vm_memory_client,
1376                mem,
1377                release_memory_tube,
1378                interrupt,
1379                kill_evt,
1380                target_reached_evt,
1381                pending_adjusted_response_event,
1382                state,
1383                #[cfg(feature = "registered_events")]
1384                registered_evt_q,
1385            )
1386        }));
1387
1388        Ok(())
1389    }
1390}
1391
1392impl VirtioDevice for Balloon {
1393    fn keep_rds(&self) -> Vec<RawDescriptor> {
1394        let mut rds = Vec::new();
1395        if let Some(command_tube) = &self.command_tube {
1396            rds.push(command_tube.as_raw_descriptor());
1397        }
1398        if let Some(vm_memory_client) = &self.vm_memory_client {
1399            rds.push(vm_memory_client.as_raw_descriptor());
1400        }
1401        if let Some(release_memory_tube) = &self.release_memory_tube {
1402            rds.push(release_memory_tube.as_raw_descriptor());
1403        }
1404        #[cfg(feature = "registered_events")]
1405        if let Some(registered_evt_q) = &self.registered_evt_q {
1406            rds.push(registered_evt_q.as_raw_descriptor());
1407        }
1408        rds.push(self.pending_adjusted_response_event.as_raw_descriptor());
1409        rds
1410    }
1411
1412    fn device_type(&self) -> DeviceType {
1413        DeviceType::Balloon
1414    }
1415
1416    fn queue_max_sizes(&self) -> &[u16] {
1417        &self.queue_sizes
1418    }
1419
1420    fn read_config(&self, offset: u64, data: &mut [u8]) {
1421        copy_config(data, 0, self.get_config().as_bytes(), offset);
1422    }
1423
1424    fn write_config(&mut self, offset: u64, data: &[u8]) {
1425        let mut config = self.get_config();
1426        copy_config(config.as_mut_bytes(), offset, data, 0);
1427        let mut state = block_on(self.state.lock());
1428        state.actual_pages = config.actual.to_native();
1429
1430        // If balloon has updated to the requested memory, let the hypervisor know.
1431        if config.num_pages == config.actual {
1432            debug!(
1433                "sending target reached event at {}",
1434                u32::from(config.num_pages)
1435            );
1436            self.target_reached_evt.as_ref().map(|e| e.signal());
1437        }
1438        if state.failable_update && state.actual_pages == state.num_pages {
1439            state.failable_update = false;
1440            let num_pages = state.num_pages;
1441            state.pending_adjusted_responses.push_back(num_pages);
1442            let _ = self.pending_adjusted_response_event.signal();
1443        }
1444    }
1445
1446    fn features(&self) -> u64 {
1447        self.features
1448    }
1449
1450    fn ack_features(&mut self, mut value: u64) {
1451        if value & !self.features != 0 {
1452            warn!("virtio_balloon got unknown feature ack {:x}", value);
1453            value &= self.features;
1454        }
1455        self.acked_features |= value;
1456    }
1457
1458    fn activate(
1459        &mut self,
1460        mem: GuestMemory,
1461        interrupt: Interrupt,
1462        queues: BTreeMap<usize, Queue>,
1463    ) -> anyhow::Result<()> {
1464        let queues = self.get_queues_from_map(queues)?;
1465        self.start_worker(mem, interrupt, queues)
1466    }
1467
1468    fn reset(&mut self) -> anyhow::Result<()> {
1469        let _worker = self.stop_worker();
1470        Ok(())
1471    }
1472
1473    fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
1474        match self.stop_worker() {
1475            StoppedWorker::WithQueues(paused_queues) => Ok(Some(paused_queues.into())),
1476            StoppedWorker::MissingQueues => {
1477                anyhow::bail!("balloon queue workers did not stop cleanly.")
1478            }
1479            StoppedWorker::AlreadyStopped => {
1480                // Device hasn't been activated.
1481                Ok(None)
1482            }
1483        }
1484    }
1485
1486    fn virtio_wake(
1487        &mut self,
1488        queues_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
1489    ) -> anyhow::Result<()> {
1490        if let Some((mem, interrupt, queues)) = queues_state {
1491            if queues.len() < 2 {
1492                anyhow::bail!("{} queues were found, but an activated balloon must have at least 2 active queues.", queues.len());
1493            }
1494
1495            let balloon_queues = self.get_queues_from_map(queues)?;
1496            self.start_worker(mem, interrupt, balloon_queues)?;
1497        }
1498        Ok(())
1499    }
1500
1501    fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1502        let state = self
1503            .state
1504            .lock()
1505            .now_or_never()
1506            .context("failed to acquire balloon lock")?;
1507        AnySnapshot::to_any(BalloonSnapshot {
1508            features: self.features,
1509            acked_features: self.acked_features,
1510            state: state.clone(),
1511            ws_num_bins: self.ws_num_bins,
1512        })
1513        .context("failed to serialize balloon state")
1514    }
1515
1516    fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1517        let snap: BalloonSnapshot = AnySnapshot::from_any(data).context("error deserializing")?;
1518        if snap.features != self.features {
1519            anyhow::bail!(
1520                "balloon: expected features to match, but they did not. Live: {:?}, snapshot {:?}",
1521                self.features,
1522                snap.features,
1523            );
1524        }
1525
1526        let mut state = self
1527            .state
1528            .lock()
1529            .now_or_never()
1530            .context("failed to acquire balloon lock")?;
1531        *state = snap.state;
1532        self.ws_num_bins = snap.ws_num_bins;
1533        self.acked_features = snap.acked_features;
1534        Ok(())
1535    }
1536}
1537
1538#[cfg(test)]
1539mod tests {
1540    use super::*;
1541    use crate::suspendable_virtio_tests;
1542    use crate::virtio::descriptor_utils::create_descriptor_chain;
1543    use crate::virtio::descriptor_utils::DescriptorType;
1544
1545    #[test]
1546    fn desc_parsing_inflate() {
1547        // Check that the memory addresses are parsed correctly by 'handle_address_chain' and passed
1548        // to the closure.
1549        let memory_start_addr = GuestAddress(0x0);
1550        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1551        memory
1552            .write_obj_at_addr(0x10u32, GuestAddress(0x100))
1553            .unwrap();
1554        memory
1555            .write_obj_at_addr(0xaa55aa55u32, GuestAddress(0x104))
1556            .unwrap();
1557
1558        let mut chain = create_descriptor_chain(
1559            &memory,
1560            GuestAddress(0x0),
1561            GuestAddress(0x100),
1562            vec![(DescriptorType::Readable, 8)],
1563            0,
1564        )
1565        .expect("create_descriptor_chain failed");
1566
1567        let mut addrs = Vec::new();
1568        let res = handle_address_chain(None, &mut chain, &mut |mut ranges| {
1569            addrs.append(&mut ranges)
1570        });
1571        assert!(res.is_ok());
1572        assert_eq!(addrs.len(), 2);
1573        assert_eq!(
1574            addrs[0].0,
1575            GuestAddress(0x10u64 << VIRTIO_BALLOON_PFN_SHIFT)
1576        );
1577        assert_eq!(
1578            addrs[1].0,
1579            GuestAddress(0xaa55aa55u64 << VIRTIO_BALLOON_PFN_SHIFT)
1580        );
1581    }
1582
1583    struct BalloonContext {
1584        _ctrl_tube: Tube,
1585        _mem_client_tube: Tube,
1586    }
1587
1588    fn modify_device(_balloon_context: &mut BalloonContext, balloon: &mut Balloon) {
1589        balloon.ws_num_bins = !balloon.ws_num_bins;
1590    }
1591
1592    fn create_device() -> (BalloonContext, Balloon) {
1593        let (_ctrl_tube, ctrl_tube_device) = Tube::pair().unwrap();
1594        let (_mem_client_tube, mem_client_tube_device) = Tube::pair().unwrap();
1595        (
1596            BalloonContext {
1597                _ctrl_tube,
1598                _mem_client_tube,
1599            },
1600            Balloon::new(
1601                0,
1602                ctrl_tube_device,
1603                VmMemoryClient::new(mem_client_tube_device),
1604                None,
1605                1024,
1606                0,
1607                #[cfg(feature = "registered_events")]
1608                None,
1609                0,
1610            )
1611            .unwrap(),
1612        )
1613    }
1614
1615    suspendable_virtio_tests!(balloon, create_device, 2, modify_device);
1616}