devices/virtio/vhost_user_backend/
snd.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5pub mod sys;
6
7use std::borrow::Borrow;
8use std::rc::Rc;
9
10use anyhow::anyhow;
11use anyhow::bail;
12use anyhow::Context;
13use base::error;
14use base::warn;
15use cros_async::sync::RwLock as AsyncRwLock;
16use cros_async::EventAsync;
17use cros_async::Executor;
18use futures::channel::mpsc;
19use futures::FutureExt;
20use hypervisor::ProtectionType;
21use serde::Deserialize;
22use serde::Serialize;
23use snapshot::AnySnapshot;
24pub use sys::run_snd_device;
25pub use sys::Options;
26use vm_memory::GuestMemory;
27use vmm_vhost::message::VhostUserProtocolFeatures;
28use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
29use zerocopy::IntoBytes;
30
31use crate::virtio;
32use crate::virtio::copy_config;
33use crate::virtio::device_constants::snd::virtio_snd_config;
34use crate::virtio::snd::common_backend::async_funcs::handle_ctrl_queue;
35use crate::virtio::snd::common_backend::async_funcs::handle_pcm_queue;
36use crate::virtio::snd::common_backend::async_funcs::send_pcm_response_worker;
37use crate::virtio::snd::common_backend::create_stream_info_builders;
38use crate::virtio::snd::common_backend::hardcoded_snd_data;
39use crate::virtio::snd::common_backend::hardcoded_virtio_snd_config;
40use crate::virtio::snd::common_backend::stream_info::StreamInfo;
41use crate::virtio::snd::common_backend::stream_info::StreamInfoBuilder;
42use crate::virtio::snd::common_backend::stream_info::StreamInfoSnapshot;
43use crate::virtio::snd::common_backend::Error;
44use crate::virtio::snd::common_backend::PcmResponse;
45use crate::virtio::snd::common_backend::SndData;
46use crate::virtio::snd::common_backend::MAX_QUEUE_NUM;
47use crate::virtio::snd::constants::VIRTIO_SND_R_PCM_PREPARE;
48use crate::virtio::snd::constants::VIRTIO_SND_R_PCM_START;
49use crate::virtio::snd::parameters::Parameters;
50use crate::virtio::vhost_user_backend::handler::DeviceRequestHandler;
51use crate::virtio::vhost_user_backend::handler::Error as DeviceError;
52use crate::virtio::vhost_user_backend::handler::VhostUserDevice;
53use crate::virtio::vhost_user_backend::handler::WorkerState;
54use crate::virtio::vhost_user_backend::VhostUserDeviceBuilder;
55use crate::virtio::Queue;
56
57// Async workers:
58// 0 - ctrl
59// 1 - event
60// 2 - tx
61// 3 - rx
62const PCM_RESPONSE_WORKER_IDX_OFFSET: usize = 2;
63struct SndBackend {
64    ex: Executor,
65    cfg: virtio_snd_config,
66    avail_features: u64,
67    workers: [Option<WorkerState<Rc<AsyncRwLock<Queue>>, Result<(), Error>>>; MAX_QUEUE_NUM],
68    // tx and rx
69    response_workers: [Option<WorkerState<Rc<AsyncRwLock<Queue>>, Result<(), Error>>>; 2],
70    snd_data: Rc<SndData>,
71    streams: Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
72    tx_send: mpsc::UnboundedSender<PcmResponse>,
73    rx_send: mpsc::UnboundedSender<PcmResponse>,
74    tx_recv: Option<mpsc::UnboundedReceiver<PcmResponse>>,
75    rx_recv: Option<mpsc::UnboundedReceiver<PcmResponse>>,
76    // Appended to logs for when there are mutliple audio devices.
77    card_index: usize,
78    unmap_guest_memory_on_fork: bool,
79}
80
81#[derive(Serialize, Deserialize)]
82struct SndBackendSnapshot {
83    avail_features: u64,
84    stream_infos: Option<Vec<StreamInfoSnapshot>>,
85    snd_data: SndData,
86}
87
88impl SndBackend {
89    pub fn new(
90        ex: &Executor,
91        params: Parameters,
92        #[cfg(windows)] audio_client_guid: Option<String>,
93        card_index: usize,
94    ) -> anyhow::Result<Self> {
95        let cfg = hardcoded_virtio_snd_config(&params);
96        let avail_features = virtio::base_features(ProtectionType::Unprotected)
97            | 1 << VHOST_USER_F_PROTOCOL_FEATURES;
98
99        let snd_data = hardcoded_snd_data(&params);
100        let mut keep_rds = Vec::new();
101        let builders = create_stream_info_builders(&params, &snd_data, &mut keep_rds, card_index)?;
102
103        if snd_data.pcm_info_len() != builders.len() {
104            error!(
105                "[Card {}] snd: expected {} stream info builders, got {}",
106                card_index,
107                snd_data.pcm_info_len(),
108                builders.len(),
109            )
110        }
111
112        let streams = builders.into_iter();
113
114        #[cfg(windows)]
115        let streams = streams
116            .map(|stream_builder| stream_builder.audio_client_guid(audio_client_guid.clone()));
117
118        let streams = streams
119            .map(StreamInfoBuilder::build)
120            .map(AsyncRwLock::new)
121            .collect();
122        let streams = Rc::new(AsyncRwLock::new(streams));
123
124        let (tx_send, tx_recv) = mpsc::unbounded();
125        let (rx_send, rx_recv) = mpsc::unbounded();
126
127        #[cfg(any(target_os = "android", target_os = "linux"))]
128        let unmap_guest_memory_on_fork = params.unmap_guest_memory_on_fork;
129        #[cfg(not(any(target_os = "android", target_os = "linux")))]
130        let unmap_guest_memory_on_fork = false;
131
132        Ok(SndBackend {
133            ex: ex.clone(),
134            cfg,
135            avail_features,
136            workers: Default::default(),
137            response_workers: Default::default(),
138            snd_data: Rc::new(snd_data),
139            streams,
140            tx_send,
141            rx_send,
142            tx_recv: Some(tx_recv),
143            rx_recv: Some(rx_recv),
144            card_index,
145            unmap_guest_memory_on_fork,
146        })
147    }
148}
149
150impl VhostUserDeviceBuilder for SndBackend {
151    fn build(self: Box<Self>, _ex: &Executor) -> anyhow::Result<Box<dyn vmm_vhost::Backend>> {
152        let handler = DeviceRequestHandler::new(*self);
153        Ok(Box::new(handler))
154    }
155}
156
157impl VhostUserDevice for SndBackend {
158    fn max_queue_num(&self) -> usize {
159        MAX_QUEUE_NUM
160    }
161
162    fn features(&self) -> u64 {
163        self.avail_features
164    }
165
166    fn protocol_features(&self) -> VhostUserProtocolFeatures {
167        VhostUserProtocolFeatures::CONFIG
168            | VhostUserProtocolFeatures::MQ
169            | VhostUserProtocolFeatures::DEVICE_STATE
170    }
171
172    fn unmap_guest_memory_on_fork(&self) -> bool {
173        self.unmap_guest_memory_on_fork
174    }
175
176    fn read_config(&self, offset: u64, data: &mut [u8]) {
177        copy_config(data, 0, self.cfg.as_bytes(), offset)
178    }
179
180    fn reset(&mut self) {
181        for worker in self.workers.iter_mut().filter_map(Option::take) {
182            let _ = self.ex.run_until(worker.queue_task.cancel());
183        }
184    }
185
186    fn start_queue(
187        &mut self,
188        idx: usize,
189        queue: virtio::Queue,
190        _mem: GuestMemory,
191    ) -> anyhow::Result<()> {
192        if self.workers[idx].is_some() {
193            warn!(
194                "[Card {}] Starting new queue handler without stopping old handler",
195                self.card_index
196            );
197            self.stop_queue(idx)?;
198        }
199
200        let kick_evt = queue
201            .event()
202            .try_clone()
203            .with_context(|| format!("[Card {}] failed to clone queue event", self.card_index))?;
204        let mut kick_evt = EventAsync::new(kick_evt, &self.ex).with_context(|| {
205            format!(
206                "[Card {}] failed to create EventAsync for kick_evt",
207                self.card_index
208            )
209        })?;
210        let queue = Rc::new(AsyncRwLock::new(queue));
211        let card_index = self.card_index;
212        let queue_task = match idx {
213            0 => {
214                // ctrl queue
215                let streams = self.streams.clone();
216                let snd_data = self.snd_data.clone();
217                let tx_send = self.tx_send.clone();
218                let rx_send = self.rx_send.clone();
219                let ctrl_queue = queue.clone();
220
221                let ex_clone = self.ex.clone();
222                Some(self.ex.spawn_local(async move {
223                    handle_ctrl_queue(
224                        &ex_clone,
225                        &streams,
226                        &snd_data,
227                        ctrl_queue,
228                        &mut kick_evt,
229                        tx_send,
230                        rx_send,
231                        card_index,
232                        None,
233                    )
234                    .await
235                }))
236            }
237            // TODO(woodychow): Add event queue support
238            //
239            // Note: Even though we don't support the event queue, we still need to keep track of
240            // the Queue so we can return it back in stop_queue. As such, we create a do nothing
241            // future to "run" this queue so that we track a WorkerState for it (which is how
242            // we return the Queue back).
243            1 => Some(self.ex.spawn_local(async move { Ok(()) })),
244            2 | 3 => {
245                let (send, recv) = if idx == 2 {
246                    (self.tx_send.clone(), self.tx_recv.take())
247                } else {
248                    (self.rx_send.clone(), self.rx_recv.take())
249                };
250                let mut recv = recv.ok_or_else(|| {
251                    anyhow!("[Card {}] queue restart is not supported", self.card_index)
252                })?;
253                let streams = Rc::clone(&self.streams);
254                let queue_pcm_queue = queue.clone();
255                let queue_task = self.ex.spawn_local(async move {
256                    handle_pcm_queue(&streams, send, queue_pcm_queue, &kick_evt, card_index, None)
257                        .await
258                });
259
260                let queue_response_queue = queue.clone();
261                let response_queue_task = self.ex.spawn_local(async move {
262                    send_pcm_response_worker(queue_response_queue, &mut recv, None).await
263                });
264
265                self.response_workers[idx - PCM_RESPONSE_WORKER_IDX_OFFSET] = Some(WorkerState {
266                    queue_task: response_queue_task,
267                    queue: queue.clone(),
268                });
269
270                Some(queue_task)
271            }
272            _ => bail!(
273                "[Card {}] attempted to start unknown queue: {}",
274                self.card_index,
275                idx
276            ),
277        };
278
279        if let Some(queue_task) = queue_task {
280            self.workers[idx] = Some(WorkerState { queue_task, queue });
281        }
282        Ok(())
283    }
284
285    fn stop_queue(&mut self, idx: usize) -> anyhow::Result<virtio::Queue> {
286        let worker_queue_rc = self
287            .workers
288            .get_mut(idx)
289            .and_then(Option::take)
290            .map(|worker| {
291                // Wait for queue_task to be aborted.
292                let _ = self.ex.run_until(worker.queue_task.cancel());
293                worker.queue
294            });
295
296        if idx == 2 || idx == 3 {
297            if let Some(worker) = self
298                .response_workers
299                .get_mut(idx - PCM_RESPONSE_WORKER_IDX_OFFSET)
300                .and_then(Option::take)
301            {
302                // Wait for queue_task to be aborted.
303                let _ = self.ex.run_until(worker.queue_task.cancel());
304            }
305        }
306
307        if let Some(queue_rc) = worker_queue_rc {
308            match Rc::try_unwrap(queue_rc) {
309                Ok(queue_mutex) => Ok(queue_mutex.into_inner()),
310                Err(_) => panic!(
311                    "[Card {}] failed to recover queue from worker",
312                    self.card_index
313                ),
314            }
315        } else {
316            Err(anyhow::Error::new(DeviceError::WorkerNotFound))
317        }
318    }
319
320    fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
321        // now_or_never will succeed here because no workers are running.
322        let stream_info_snaps = if let Some(stream_infos) = &self.streams.lock().now_or_never() {
323            let mut snaps = Vec::new();
324            for stream_info in stream_infos.iter() {
325                snaps.push(
326                    stream_info
327                        .lock()
328                        .now_or_never()
329                        .unwrap_or_else(|| {
330                            panic!(
331                                "[Card {}] failed to lock audio state during snapshot",
332                                self.card_index
333                            )
334                        })
335                        .snapshot(),
336                );
337            }
338            Some(snaps)
339        } else {
340            None
341        };
342        let snd_data_ref: &SndData = self.snd_data.borrow();
343        AnySnapshot::to_any(SndBackendSnapshot {
344            avail_features: self.avail_features,
345            stream_infos: stream_info_snaps,
346            snd_data: snd_data_ref.clone(),
347        })
348        .with_context(|| {
349            format!(
350                "[Card {}] Failed to serialize SndBackendSnapshot",
351                self.card_index
352            )
353        })
354    }
355
356    fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
357        let deser: SndBackendSnapshot = AnySnapshot::from_any(data).with_context(|| {
358            format!(
359                "[Card {}] Failed to deserialize SndBackendSnapshot",
360                self.card_index
361            )
362        })?;
363        anyhow::ensure!(
364            deser.avail_features == self.avail_features,
365            "[Card {}] avail features doesn't match on restore: expected: {}, got: {}",
366            self.card_index,
367            deser.avail_features,
368            self.avail_features
369        );
370        let snd_data = self.snd_data.borrow();
371        anyhow::ensure!(
372            &deser.snd_data == snd_data,
373            "[Card {}] snd data doesn't match on restore: expected: {:?}, got: {:?}",
374            self.card_index,
375            deser.snd_data,
376            snd_data,
377        );
378
379        let ex_clone = self.ex.clone();
380        let streams_rc = self.streams.clone();
381        let tx_send_clone = self.tx_send.clone();
382        let rx_send_clone = self.rx_send.clone();
383
384        let card_index = self.card_index;
385        let restore_task =
386            self.ex.spawn_local(async move {
387                if let Some(stream_infos) = &deser.stream_infos {
388                    for (stream, stream_info) in
389                        streams_rc.lock().await.iter().zip(stream_infos.iter())
390                    {
391                        stream.lock().await.restore(stream_info);
392                        if stream_info.state == VIRTIO_SND_R_PCM_START
393                            || stream_info.state == VIRTIO_SND_R_PCM_PREPARE
394                        {
395                            stream
396                                .lock()
397                                .await
398                                .prepare(&ex_clone, &tx_send_clone, &rx_send_clone)
399                                .await
400                                .unwrap_or_else(|_| {
401                                    panic!("[Card {card_index}] failed to prepare PCM")
402                                });
403                        }
404                        if stream_info.state == VIRTIO_SND_R_PCM_START {
405                            stream.lock().await.start().await.unwrap_or_else(|_| {
406                                panic!("[Card {card_index}] failed to start PCM")
407                            });
408                        }
409                    }
410                }
411            });
412        self.ex
413            .run_until(restore_task)
414            .unwrap_or_else(|_| panic!("[Card {}] failed to restore streams", self.card_index));
415        Ok(())
416    }
417
418    fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
419        // This device has no non-queue workers to stop.
420        Ok(())
421    }
422}