devices/virtio/vhost_user_backend/
snd.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5pub mod sys;
6
7use std::borrow::Borrow;
8use std::rc::Rc;
9
10use anyhow::anyhow;
11use anyhow::bail;
12use anyhow::Context;
13use base::error;
14use base::warn;
15use cros_async::sync::RwLock as AsyncRwLock;
16use cros_async::EventAsync;
17use cros_async::Executor;
18use futures::channel::mpsc;
19use futures::FutureExt;
20use hypervisor::ProtectionType;
21use serde::Deserialize;
22use serde::Serialize;
23use snapshot::AnySnapshot;
24pub use sys::run_snd_device;
25pub use sys::Options;
26use vm_memory::GuestMemory;
27use vmm_vhost::message::VhostUserProtocolFeatures;
28use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
29use zerocopy::IntoBytes;
30
31use crate::virtio;
32use crate::virtio::copy_config;
33use crate::virtio::device_constants::snd::virtio_snd_config;
34use crate::virtio::snd::common_backend::async_funcs::handle_ctrl_queue;
35use crate::virtio::snd::common_backend::async_funcs::handle_pcm_queue;
36use crate::virtio::snd::common_backend::async_funcs::send_pcm_response_worker;
37use crate::virtio::snd::common_backend::create_stream_info_builders;
38use crate::virtio::snd::common_backend::hardcoded_snd_data;
39use crate::virtio::snd::common_backend::hardcoded_virtio_snd_config;
40use crate::virtio::snd::common_backend::stream_info::StreamInfo;
41use crate::virtio::snd::common_backend::stream_info::StreamInfoBuilder;
42use crate::virtio::snd::common_backend::stream_info::StreamInfoSnapshot;
43use crate::virtio::snd::common_backend::Error;
44use crate::virtio::snd::common_backend::PcmResponse;
45use crate::virtio::snd::common_backend::SndData;
46use crate::virtio::snd::common_backend::MAX_QUEUE_NUM;
47use crate::virtio::snd::constants::VIRTIO_SND_R_PCM_PREPARE;
48use crate::virtio::snd::constants::VIRTIO_SND_R_PCM_START;
49use crate::virtio::snd::parameters::Parameters;
50use crate::virtio::vhost_user_backend::handler::DeviceRequestHandler;
51use crate::virtio::vhost_user_backend::handler::Error as DeviceError;
52use crate::virtio::vhost_user_backend::handler::VhostUserDevice;
53use crate::virtio::vhost_user_backend::handler::WorkerState;
54use crate::virtio::vhost_user_backend::VhostUserDeviceBuilder;
55use crate::virtio::Queue;
56
57// Async workers:
58// 0 - ctrl
59// 1 - event
60// 2 - tx
61// 3 - rx
62const PCM_RESPONSE_WORKER_IDX_OFFSET: usize = 2;
63struct SndBackend {
64    ex: Executor,
65    cfg: virtio_snd_config,
66    avail_features: u64,
67    workers: [Option<WorkerState<Rc<AsyncRwLock<Queue>>, Result<(), Error>>>; MAX_QUEUE_NUM],
68    // tx and rx
69    response_workers: [Option<WorkerState<Rc<AsyncRwLock<Queue>>, Result<(), Error>>>; 2],
70    snd_data: Rc<SndData>,
71    streams: Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
72    tx_send: mpsc::UnboundedSender<PcmResponse>,
73    rx_send: mpsc::UnboundedSender<PcmResponse>,
74    tx_recv: Option<mpsc::UnboundedReceiver<PcmResponse>>,
75    rx_recv: Option<mpsc::UnboundedReceiver<PcmResponse>>,
76    // Appended to logs for when there are mutliple audio devices.
77    card_index: usize,
78}
79
80#[derive(Serialize, Deserialize)]
81struct SndBackendSnapshot {
82    avail_features: u64,
83    stream_infos: Option<Vec<StreamInfoSnapshot>>,
84    snd_data: SndData,
85}
86
87impl SndBackend {
88    pub fn new(
89        ex: &Executor,
90        params: Parameters,
91        #[cfg(windows)] audio_client_guid: Option<String>,
92        card_index: usize,
93    ) -> anyhow::Result<Self> {
94        let cfg = hardcoded_virtio_snd_config(&params);
95        let avail_features = virtio::base_features(ProtectionType::Unprotected)
96            | 1 << VHOST_USER_F_PROTOCOL_FEATURES;
97
98        let snd_data = hardcoded_snd_data(&params);
99        let mut keep_rds = Vec::new();
100        let builders = create_stream_info_builders(&params, &snd_data, &mut keep_rds, card_index)?;
101
102        if snd_data.pcm_info_len() != builders.len() {
103            error!(
104                "[Card {}] snd: expected {} stream info builders, got {}",
105                card_index,
106                snd_data.pcm_info_len(),
107                builders.len(),
108            )
109        }
110
111        let streams = builders.into_iter();
112
113        #[cfg(windows)]
114        let streams = streams
115            .map(|stream_builder| stream_builder.audio_client_guid(audio_client_guid.clone()));
116
117        let streams = streams
118            .map(StreamInfoBuilder::build)
119            .map(AsyncRwLock::new)
120            .collect();
121        let streams = Rc::new(AsyncRwLock::new(streams));
122
123        let (tx_send, tx_recv) = mpsc::unbounded();
124        let (rx_send, rx_recv) = mpsc::unbounded();
125
126        Ok(SndBackend {
127            ex: ex.clone(),
128            cfg,
129            avail_features,
130            workers: Default::default(),
131            response_workers: Default::default(),
132            snd_data: Rc::new(snd_data),
133            streams,
134            tx_send,
135            rx_send,
136            tx_recv: Some(tx_recv),
137            rx_recv: Some(rx_recv),
138            card_index,
139        })
140    }
141}
142
143impl VhostUserDeviceBuilder for SndBackend {
144    fn build(self: Box<Self>, _ex: &Executor) -> anyhow::Result<Box<dyn vmm_vhost::Backend>> {
145        let handler = DeviceRequestHandler::new(*self);
146        Ok(Box::new(handler))
147    }
148}
149
150impl VhostUserDevice for SndBackend {
151    fn max_queue_num(&self) -> usize {
152        MAX_QUEUE_NUM
153    }
154
155    fn features(&self) -> u64 {
156        self.avail_features
157    }
158
159    fn protocol_features(&self) -> VhostUserProtocolFeatures {
160        VhostUserProtocolFeatures::CONFIG
161            | VhostUserProtocolFeatures::MQ
162            | VhostUserProtocolFeatures::DEVICE_STATE
163    }
164
165    fn read_config(&self, offset: u64, data: &mut [u8]) {
166        copy_config(data, 0, self.cfg.as_bytes(), offset)
167    }
168
169    fn reset(&mut self) {
170        for worker in self.workers.iter_mut().filter_map(Option::take) {
171            let _ = self.ex.run_until(worker.queue_task.cancel());
172        }
173    }
174
175    fn start_queue(
176        &mut self,
177        idx: usize,
178        queue: virtio::Queue,
179        _mem: GuestMemory,
180    ) -> anyhow::Result<()> {
181        if self.workers[idx].is_some() {
182            warn!(
183                "[Card {}] Starting new queue handler without stopping old handler",
184                self.card_index
185            );
186            self.stop_queue(idx)?;
187        }
188
189        let kick_evt = queue
190            .event()
191            .try_clone()
192            .with_context(|| format!("[Card {}] failed to clone queue event", self.card_index))?;
193        let mut kick_evt = EventAsync::new(kick_evt, &self.ex).with_context(|| {
194            format!(
195                "[Card {}] failed to create EventAsync for kick_evt",
196                self.card_index
197            )
198        })?;
199        let queue = Rc::new(AsyncRwLock::new(queue));
200        let card_index = self.card_index;
201        let queue_task = match idx {
202            0 => {
203                // ctrl queue
204                let streams = self.streams.clone();
205                let snd_data = self.snd_data.clone();
206                let tx_send = self.tx_send.clone();
207                let rx_send = self.rx_send.clone();
208                let ctrl_queue = queue.clone();
209
210                let ex_clone = self.ex.clone();
211                Some(self.ex.spawn_local(async move {
212                    handle_ctrl_queue(
213                        &ex_clone,
214                        &streams,
215                        &snd_data,
216                        ctrl_queue,
217                        &mut kick_evt,
218                        tx_send,
219                        rx_send,
220                        card_index,
221                        None,
222                    )
223                    .await
224                }))
225            }
226            // TODO(woodychow): Add event queue support
227            //
228            // Note: Even though we don't support the event queue, we still need to keep track of
229            // the Queue so we can return it back in stop_queue. As such, we create a do nothing
230            // future to "run" this queue so that we track a WorkerState for it (which is how
231            // we return the Queue back).
232            1 => Some(self.ex.spawn_local(async move { Ok(()) })),
233            2 | 3 => {
234                let (send, recv) = if idx == 2 {
235                    (self.tx_send.clone(), self.tx_recv.take())
236                } else {
237                    (self.rx_send.clone(), self.rx_recv.take())
238                };
239                let mut recv = recv.ok_or_else(|| {
240                    anyhow!("[Card {}] queue restart is not supported", self.card_index)
241                })?;
242                let streams = Rc::clone(&self.streams);
243                let queue_pcm_queue = queue.clone();
244                let queue_task = self.ex.spawn_local(async move {
245                    handle_pcm_queue(&streams, send, queue_pcm_queue, &kick_evt, card_index, None)
246                        .await
247                });
248
249                let queue_response_queue = queue.clone();
250                let response_queue_task = self.ex.spawn_local(async move {
251                    send_pcm_response_worker(queue_response_queue, &mut recv, None).await
252                });
253
254                self.response_workers[idx - PCM_RESPONSE_WORKER_IDX_OFFSET] = Some(WorkerState {
255                    queue_task: response_queue_task,
256                    queue: queue.clone(),
257                });
258
259                Some(queue_task)
260            }
261            _ => bail!(
262                "[Card {}] attempted to start unknown queue: {}",
263                self.card_index,
264                idx
265            ),
266        };
267
268        if let Some(queue_task) = queue_task {
269            self.workers[idx] = Some(WorkerState { queue_task, queue });
270        }
271        Ok(())
272    }
273
274    fn stop_queue(&mut self, idx: usize) -> anyhow::Result<virtio::Queue> {
275        let worker_queue_rc = self
276            .workers
277            .get_mut(idx)
278            .and_then(Option::take)
279            .map(|worker| {
280                // Wait for queue_task to be aborted.
281                let _ = self.ex.run_until(worker.queue_task.cancel());
282                worker.queue
283            });
284
285        if idx == 2 || idx == 3 {
286            if let Some(worker) = self
287                .response_workers
288                .get_mut(idx - PCM_RESPONSE_WORKER_IDX_OFFSET)
289                .and_then(Option::take)
290            {
291                // Wait for queue_task to be aborted.
292                let _ = self.ex.run_until(worker.queue_task.cancel());
293            }
294        }
295
296        if let Some(queue_rc) = worker_queue_rc {
297            match Rc::try_unwrap(queue_rc) {
298                Ok(queue_mutex) => Ok(queue_mutex.into_inner()),
299                Err(_) => panic!(
300                    "[Card {}] failed to recover queue from worker",
301                    self.card_index
302                ),
303            }
304        } else {
305            Err(anyhow::Error::new(DeviceError::WorkerNotFound))
306        }
307    }
308
309    fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
310        // now_or_never will succeed here because no workers are running.
311        let stream_info_snaps = if let Some(stream_infos) = &self.streams.lock().now_or_never() {
312            let mut snaps = Vec::new();
313            for stream_info in stream_infos.iter() {
314                snaps.push(
315                    stream_info
316                        .lock()
317                        .now_or_never()
318                        .unwrap_or_else(|| {
319                            panic!(
320                                "[Card {}] failed to lock audio state during snapshot",
321                                self.card_index
322                            )
323                        })
324                        .snapshot(),
325                );
326            }
327            Some(snaps)
328        } else {
329            None
330        };
331        let snd_data_ref: &SndData = self.snd_data.borrow();
332        AnySnapshot::to_any(SndBackendSnapshot {
333            avail_features: self.avail_features,
334            stream_infos: stream_info_snaps,
335            snd_data: snd_data_ref.clone(),
336        })
337        .with_context(|| {
338            format!(
339                "[Card {}] Failed to serialize SndBackendSnapshot",
340                self.card_index
341            )
342        })
343    }
344
345    fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
346        let deser: SndBackendSnapshot = AnySnapshot::from_any(data).with_context(|| {
347            format!(
348                "[Card {}] Failed to deserialize SndBackendSnapshot",
349                self.card_index
350            )
351        })?;
352        anyhow::ensure!(
353            deser.avail_features == self.avail_features,
354            "[Card {}] avail features doesn't match on restore: expected: {}, got: {}",
355            self.card_index,
356            deser.avail_features,
357            self.avail_features
358        );
359        let snd_data = self.snd_data.borrow();
360        anyhow::ensure!(
361            &deser.snd_data == snd_data,
362            "[Card {}] snd data doesn't match on restore: expected: {:?}, got: {:?}",
363            self.card_index,
364            deser.snd_data,
365            snd_data,
366        );
367
368        let ex_clone = self.ex.clone();
369        let streams_rc = self.streams.clone();
370        let tx_send_clone = self.tx_send.clone();
371        let rx_send_clone = self.rx_send.clone();
372
373        let card_index = self.card_index;
374        let restore_task =
375            self.ex.spawn_local(async move {
376                if let Some(stream_infos) = &deser.stream_infos {
377                    for (stream, stream_info) in
378                        streams_rc.lock().await.iter().zip(stream_infos.iter())
379                    {
380                        stream.lock().await.restore(stream_info);
381                        if stream_info.state == VIRTIO_SND_R_PCM_START
382                            || stream_info.state == VIRTIO_SND_R_PCM_PREPARE
383                        {
384                            stream
385                                .lock()
386                                .await
387                                .prepare(&ex_clone, &tx_send_clone, &rx_send_clone)
388                                .await
389                                .unwrap_or_else(|_| {
390                                    panic!("[Card {card_index}] failed to prepare PCM")
391                                });
392                        }
393                        if stream_info.state == VIRTIO_SND_R_PCM_START {
394                            stream.lock().await.start().await.unwrap_or_else(|_| {
395                                panic!("[Card {card_index}] failed to start PCM")
396                            });
397                        }
398                    }
399                }
400            });
401        self.ex
402            .run_until(restore_task)
403            .unwrap_or_else(|_| panic!("[Card {}] failed to restore streams", self.card_index));
404        Ok(())
405    }
406
407    fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
408        // This device has no non-queue workers to stop.
409        Ok(())
410    }
411}