devices/virtio/snd/common_backend/
mod.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// virtio-sound spec: https://github.com/oasis-tcs/virtio-spec/blob/master/virtio-sound.tex
6
7use std::collections::BTreeMap;
8use std::io;
9use std::rc::Rc;
10use std::sync::Arc;
11
12use anyhow::anyhow;
13use anyhow::Context;
14use audio_streams::BoxError;
15use base::debug;
16use base::error;
17use base::warn;
18use base::AsRawDescriptor;
19use base::Descriptor;
20use base::Error as SysError;
21use base::Event;
22use base::RawDescriptor;
23use base::Tube;
24use base::WorkerThread;
25use cros_async::block_on;
26use cros_async::sync::Condvar;
27use cros_async::sync::RwLock as AsyncRwLock;
28use cros_async::AsyncError;
29use cros_async::AsyncTube;
30use cros_async::EventAsync;
31use cros_async::Executor;
32use futures::channel::mpsc;
33use futures::channel::oneshot;
34use futures::channel::oneshot::Canceled;
35use futures::future::FusedFuture;
36use futures::join;
37use futures::pin_mut;
38use futures::select;
39use futures::FutureExt;
40use serde::Deserialize;
41use serde::Serialize;
42use snapshot::AnySnapshot;
43use thiserror::Error as ThisError;
44use vm_memory::GuestMemory;
45use zerocopy::IntoBytes;
46
47use crate::virtio::async_utils;
48use crate::virtio::copy_config;
49use crate::virtio::device_constants::snd::virtio_snd_config;
50use crate::virtio::snd::common_backend::async_funcs::*;
51use crate::virtio::snd::common_backend::stream_info::StreamInfo;
52use crate::virtio::snd::common_backend::stream_info::StreamInfoBuilder;
53use crate::virtio::snd::common_backend::stream_info::StreamInfoSnapshot;
54use crate::virtio::snd::constants::*;
55use crate::virtio::snd::file_backend::create_file_stream_source_generators;
56use crate::virtio::snd::file_backend::Error as FileError;
57use crate::virtio::snd::layout::*;
58use crate::virtio::snd::null_backend::create_null_stream_source_generators;
59use crate::virtio::snd::parameters::Parameters;
60use crate::virtio::snd::parameters::StreamSourceBackend;
61use crate::virtio::snd::sys::create_stream_source_generators as sys_create_stream_source_generators;
62use crate::virtio::snd::sys::set_audio_thread_priority;
63use crate::virtio::snd::sys::SysAsyncStreamObjects;
64use crate::virtio::snd::sys::SysAudioStreamSourceGenerator;
65use crate::virtio::snd::sys::SysDirectionOutput;
66use crate::virtio::DescriptorChain;
67use crate::virtio::DeviceType;
68use crate::virtio::Interrupt;
69use crate::virtio::Queue;
70use crate::virtio::VirtioDevice;
71
72pub mod async_funcs;
73pub mod stream_info;
74
75// control + event + tx + rx queue
76pub const MAX_QUEUE_NUM: usize = 4;
77pub const MAX_VRING_LEN: u16 = 1024;
78
79#[derive(ThisError, Debug)]
80pub enum Error {
81    /// next_async failed.
82    #[error("Failed to read descriptor asynchronously: {0}")]
83    Async(AsyncError),
84    /// Creating stream failed.
85    #[error("Failed to create stream: {0}")]
86    CreateStream(BoxError),
87    /// Creating stream failed.
88    #[error("No stream source found.")]
89    EmptyStreamSource,
90    /// Creating kill event failed.
91    #[error("Failed to create kill event: {0}")]
92    CreateKillEvent(SysError),
93    /// Creating WaitContext failed.
94    #[error("Failed to create wait context: {0}")]
95    CreateWaitContext(SysError),
96    #[error("Failed to create file stream source generator")]
97    CreateFileStreamSourceGenerator(FileError),
98    /// Cloning kill event failed.
99    #[error("Failed to clone kill event: {0}")]
100    CloneKillEvent(SysError),
101    // Future error.
102    #[error("Unexpected error. Done was not triggered before dropped: {0}")]
103    DoneNotTriggered(Canceled),
104    /// Error reading message from queue.
105    #[error("Failed to read message: {0}")]
106    ReadMessage(io::Error),
107    /// Failed writing a response to a control message.
108    #[error("Failed to write message response: {0}")]
109    WriteResponse(io::Error),
110    // Mpsc read error.
111    #[error("Error in mpsc: {0}")]
112    MpscSend(futures::channel::mpsc::SendError),
113    // Oneshot send error.
114    #[error("Error in oneshot send")]
115    OneshotSend(()),
116    /// Failure in communicating with the host
117    #[error("Failed to send/receive to/from control tube")]
118    ControlTubeError(base::TubeError),
119    /// Stream not found.
120    #[error("stream id ({0}) < num_streams ({1})")]
121    StreamNotFound(usize, usize),
122    /// Fetch buffer error
123    #[error("Failed to get buffer from CRAS: {0}")]
124    FetchBuffer(BoxError),
125    /// Invalid buffer size
126    #[error("Invalid buffer size")]
127    InvalidBufferSize,
128    /// IoError
129    #[error("I/O failed: {0}")]
130    Io(io::Error),
131    /// Operation not supported.
132    #[error("Operation not supported")]
133    OperationNotSupported,
134    /// Writing to a buffer in the guest failed.
135    #[error("failed to write to buffer: {0}")]
136    WriteBuffer(io::Error),
137    // Invalid PCM worker state.
138    #[error("Invalid PCM worker state")]
139    InvalidPCMWorkerState,
140    // Invalid backend.
141    #[error("Backend is not implemented")]
142    InvalidBackend,
143    // Failed to generate StreamSource
144    #[error("Failed to generate stream source: {0}")]
145    GenerateStreamSource(BoxError),
146    // PCM worker unexpectedly quitted.
147    #[error("PCM worker quitted unexpectedly")]
148    PCMWorkerQuittedUnexpectedly,
149}
150
151pub enum DirectionalStream {
152    Input(
153        usize, // `period_size` in `usize`
154        Box<dyn CaptureBufferReader>,
155    ),
156    Output(SysDirectionOutput),
157}
158
159#[derive(Copy, Clone, std::cmp::PartialEq, Eq)]
160pub enum WorkerStatus {
161    Pause = 0,
162    Running = 1,
163    Quit = 2,
164}
165
166// Stores constant data
167#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Debug)]
168pub struct SndData {
169    pub(crate) jack_info: Vec<virtio_snd_jack_info>,
170    pub(crate) pcm_info: Vec<virtio_snd_pcm_info>,
171    pub(crate) chmap_info: Vec<virtio_snd_chmap_info>,
172}
173
174impl SndData {
175    pub fn pcm_info_len(&self) -> usize {
176        self.pcm_info.len()
177    }
178
179    pub fn pcm_info_iter(&self) -> std::slice::Iter<'_, virtio_snd_pcm_info> {
180        self.pcm_info.iter()
181    }
182}
183
184const SUPPORTED_FORMATS: u64 = 1 << VIRTIO_SND_PCM_FMT_U8
185    | 1 << VIRTIO_SND_PCM_FMT_S16
186    | 1 << VIRTIO_SND_PCM_FMT_S24
187    | 1 << VIRTIO_SND_PCM_FMT_S32;
188const SUPPORTED_FRAME_RATES: u64 = 1 << VIRTIO_SND_PCM_RATE_8000
189    | 1 << VIRTIO_SND_PCM_RATE_11025
190    | 1 << VIRTIO_SND_PCM_RATE_16000
191    | 1 << VIRTIO_SND_PCM_RATE_22050
192    | 1 << VIRTIO_SND_PCM_RATE_32000
193    | 1 << VIRTIO_SND_PCM_RATE_44100
194    | 1 << VIRTIO_SND_PCM_RATE_48000;
195
196// Response from pcm_worker to pcm_queue
197pub struct PcmResponse {
198    pub(crate) desc_chain: DescriptorChain,
199    pub(crate) status: virtio_snd_pcm_status, // response to the pcm message
200    pub(crate) done: Option<oneshot::Sender<()>>, // when pcm response is written to the queue
201}
202
203pub struct VirtioSnd {
204    control_tube: Option<Tube>,
205    cfg: virtio_snd_config,
206    snd_data: SndData,
207    stream_info_builders: Vec<StreamInfoBuilder>,
208    avail_features: u64,
209    acked_features: u64,
210    queue_sizes: Box<[u16]>,
211    worker_thread: Option<WorkerThread<WorkerReturn>>,
212    keep_rds: Vec<Descriptor>,
213    streams_state: Option<Vec<StreamInfoSnapshot>>,
214    card_index: usize,
215}
216
217#[derive(Serialize, Deserialize)]
218struct VirtioSndSnapshot {
219    avail_features: u64,
220    acked_features: u64,
221    queue_sizes: Vec<u16>,
222    streams_state: Option<Vec<StreamInfoSnapshot>>,
223    snd_data: SndData,
224}
225
226impl VirtioSnd {
227    pub fn new(
228        base_features: u64,
229        params: Parameters,
230        control_tube: Tube,
231    ) -> Result<VirtioSnd, Error> {
232        let params = resize_parameters_pcm_device_config(params);
233        let cfg = hardcoded_virtio_snd_config(&params);
234        let snd_data = hardcoded_snd_data(&params);
235        let avail_features = base_features;
236        let mut keep_rds: Vec<RawDescriptor> = Vec::new();
237        keep_rds.push(control_tube.as_raw_descriptor());
238
239        let stream_info_builders =
240            create_stream_info_builders(&params, &snd_data, &mut keep_rds, params.card_index)?;
241
242        Ok(VirtioSnd {
243            control_tube: Some(control_tube),
244            cfg,
245            snd_data,
246            stream_info_builders,
247            avail_features,
248            acked_features: 0,
249            queue_sizes: vec![MAX_VRING_LEN; MAX_QUEUE_NUM].into_boxed_slice(),
250            worker_thread: None,
251            keep_rds: keep_rds.iter().map(|rd| Descriptor(*rd)).collect(),
252            streams_state: None,
253            card_index: params.card_index,
254        })
255    }
256}
257
258fn create_stream_source_generators(
259    params: &Parameters,
260    snd_data: &SndData,
261    keep_rds: &mut Vec<RawDescriptor>,
262) -> Result<Vec<SysAudioStreamSourceGenerator>, Error> {
263    let generators = match params.backend {
264        StreamSourceBackend::NULL => create_null_stream_source_generators(snd_data),
265        StreamSourceBackend::FILE => {
266            create_file_stream_source_generators(params, snd_data, keep_rds)
267                .map_err(Error::CreateFileStreamSourceGenerator)?
268        }
269        StreamSourceBackend::Sys(backend) => {
270            sys_create_stream_source_generators(backend, params, snd_data)
271        }
272    };
273    Ok(generators)
274}
275
276/// Creates [`StreamInfoBuilder`]s by calling [`create_stream_source_generators()`] then zip
277/// them with [`crate::virtio::snd::parameters::PCMDeviceParameters`] from the params to set
278/// the parameters on each [`StreamInfoBuilder`] (e.g. effects).
279pub(crate) fn create_stream_info_builders(
280    params: &Parameters,
281    snd_data: &SndData,
282    keep_rds: &mut Vec<RawDescriptor>,
283    card_index: usize,
284) -> Result<Vec<StreamInfoBuilder>, Error> {
285    Ok(create_stream_source_generators(params, snd_data, keep_rds)?
286        .into_iter()
287        .map(Arc::new)
288        .zip(snd_data.pcm_info_iter())
289        .map(|(generator, pcm_info)| {
290            let device_params = params.get_device_params(pcm_info).unwrap_or_default();
291            StreamInfo::builder(generator, card_index)
292                .effects(device_params.effects.unwrap_or_default())
293        })
294        .collect())
295}
296
297// To be used with hardcoded_snd_data
298pub fn hardcoded_virtio_snd_config(params: &Parameters) -> virtio_snd_config {
299    virtio_snd_config {
300        jacks: 0.into(),
301        streams: params.get_total_streams().into(),
302        chmaps: (params.num_output_devices * 3 + params.num_input_devices).into(),
303    }
304}
305
306// To be used with hardcoded_virtio_snd_config
307pub fn hardcoded_snd_data(params: &Parameters) -> SndData {
308    let jack_info: Vec<virtio_snd_jack_info> = Vec::new();
309    let mut pcm_info: Vec<virtio_snd_pcm_info> = Vec::new();
310    let mut chmap_info: Vec<virtio_snd_chmap_info> = Vec::new();
311
312    for dev in 0..params.num_output_devices {
313        for _ in 0..params.num_output_streams {
314            pcm_info.push(virtio_snd_pcm_info {
315                hdr: virtio_snd_info {
316                    hda_fn_nid: dev.into(),
317                },
318                features: 0.into(), /* 1 << VIRTIO_SND_PCM_F_XXX */
319                formats: SUPPORTED_FORMATS.into(),
320                rates: SUPPORTED_FRAME_RATES.into(),
321                direction: VIRTIO_SND_D_OUTPUT,
322                channels_min: 1,
323                channels_max: 6,
324                padding: [0; 5],
325            });
326        }
327    }
328    for dev in 0..params.num_input_devices {
329        for _ in 0..params.num_input_streams {
330            pcm_info.push(virtio_snd_pcm_info {
331                hdr: virtio_snd_info {
332                    hda_fn_nid: dev.into(),
333                },
334                features: 0.into(), /* 1 << VIRTIO_SND_PCM_F_XXX */
335                formats: SUPPORTED_FORMATS.into(),
336                rates: SUPPORTED_FRAME_RATES.into(),
337                direction: VIRTIO_SND_D_INPUT,
338                channels_min: 1,
339                channels_max: 2,
340                padding: [0; 5],
341            });
342        }
343    }
344    // Use stereo channel map.
345    let mut positions = [VIRTIO_SND_CHMAP_NONE; VIRTIO_SND_CHMAP_MAX_SIZE];
346    positions[0] = VIRTIO_SND_CHMAP_FL;
347    positions[1] = VIRTIO_SND_CHMAP_FR;
348    for dev in 0..params.num_output_devices {
349        chmap_info.push(virtio_snd_chmap_info {
350            hdr: virtio_snd_info {
351                hda_fn_nid: dev.into(),
352            },
353            direction: VIRTIO_SND_D_OUTPUT,
354            channels: 2,
355            positions,
356        });
357    }
358    for dev in 0..params.num_input_devices {
359        chmap_info.push(virtio_snd_chmap_info {
360            hdr: virtio_snd_info {
361                hda_fn_nid: dev.into(),
362            },
363            direction: VIRTIO_SND_D_INPUT,
364            channels: 2,
365            positions,
366        });
367    }
368    positions[2] = VIRTIO_SND_CHMAP_RL;
369    positions[3] = VIRTIO_SND_CHMAP_RR;
370    for dev in 0..params.num_output_devices {
371        chmap_info.push(virtio_snd_chmap_info {
372            hdr: virtio_snd_info {
373                hda_fn_nid: dev.into(),
374            },
375            direction: VIRTIO_SND_D_OUTPUT,
376            channels: 4,
377            positions,
378        });
379    }
380    positions[2] = VIRTIO_SND_CHMAP_FC;
381    positions[3] = VIRTIO_SND_CHMAP_LFE;
382    positions[4] = VIRTIO_SND_CHMAP_RL;
383    positions[5] = VIRTIO_SND_CHMAP_RR;
384    for dev in 0..params.num_output_devices {
385        chmap_info.push(virtio_snd_chmap_info {
386            hdr: virtio_snd_info {
387                hda_fn_nid: dev.into(),
388            },
389            direction: VIRTIO_SND_D_OUTPUT,
390            channels: 6,
391            positions,
392        });
393    }
394
395    SndData {
396        jack_info,
397        pcm_info,
398        chmap_info,
399    }
400}
401
402fn resize_parameters_pcm_device_config(mut params: Parameters) -> Parameters {
403    if params.output_device_config.len() > params.num_output_devices as usize {
404        warn!("Truncating output device config due to length > number of output devices");
405    }
406    params
407        .output_device_config
408        .resize_with(params.num_output_devices as usize, Default::default);
409
410    if params.input_device_config.len() > params.num_input_devices as usize {
411        warn!("Truncating input device config due to length > number of input devices");
412    }
413    params
414        .input_device_config
415        .resize_with(params.num_input_devices as usize, Default::default);
416
417    params
418}
419
420impl VirtioDevice for VirtioSnd {
421    fn keep_rds(&self) -> Vec<RawDescriptor> {
422        self.keep_rds
423            .iter()
424            .map(|descr| descr.as_raw_descriptor())
425            .collect()
426    }
427
428    fn device_type(&self) -> DeviceType {
429        DeviceType::Sound
430    }
431
432    fn queue_max_sizes(&self) -> &[u16] {
433        &self.queue_sizes
434    }
435
436    fn features(&self) -> u64 {
437        self.avail_features
438    }
439
440    fn ack_features(&mut self, mut v: u64) {
441        // Check if the guest is ACK'ing a feature that we didn't claim to have.
442        let unrequested_features = v & !self.avail_features;
443        if unrequested_features != 0 {
444            warn!("virtio_fs got unknown feature ack: {:x}", v);
445
446            // Don't count these features as acked.
447            v &= !unrequested_features;
448        }
449        self.acked_features |= v;
450    }
451
452    fn read_config(&self, offset: u64, data: &mut [u8]) {
453        copy_config(data, 0, self.cfg.as_bytes(), offset)
454    }
455
456    fn activate(
457        &mut self,
458        _guest_mem: GuestMemory,
459        _interrupt: Interrupt,
460        queues: BTreeMap<usize, Queue>,
461    ) -> anyhow::Result<()> {
462        if queues.len() != self.queue_sizes.len() {
463            return Err(anyhow!(
464                "snd: expected {} queues, got {}",
465                self.queue_sizes.len(),
466                queues.len()
467            ));
468        }
469
470        let snd_data = self.snd_data.clone();
471        let stream_info_builders = self.stream_info_builders.to_vec();
472        let streams_state = self.streams_state.take();
473        let card_index = self.card_index;
474        let control_tube = self.control_tube.take().unwrap();
475        self.worker_thread = Some(WorkerThread::start("v_snd_common", move |kill_evt| {
476            let _thread_priority_handle = set_audio_thread_priority();
477            if let Err(e) = _thread_priority_handle {
478                warn!("Failed to set audio thread to real time: {}", e);
479            };
480            run_worker(
481                queues,
482                snd_data,
483                kill_evt,
484                stream_info_builders,
485                streams_state,
486                card_index,
487                control_tube,
488            )
489        }));
490
491        Ok(())
492    }
493
494    fn reset(&mut self) -> anyhow::Result<()> {
495        if let Some(worker_thread) = self.worker_thread.take() {
496            let worker = worker_thread.stop();
497            self.control_tube = Some(worker.control_tube);
498        }
499
500        Ok(())
501    }
502
503    fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
504        if let Some(worker_thread) = self.worker_thread.take() {
505            let worker = worker_thread.stop();
506            self.control_tube = Some(worker.control_tube);
507            self.snd_data = worker.snd_data;
508            self.streams_state = Some(worker.streams_state);
509            return Ok(Some(BTreeMap::from_iter(
510                worker.queues.into_iter().enumerate(),
511            )));
512        }
513        Ok(None)
514    }
515
516    fn virtio_wake(
517        &mut self,
518        device_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
519    ) -> anyhow::Result<()> {
520        match device_state {
521            None => Ok(()),
522            Some((mem, interrupt, queues)) => {
523                // TODO: activate is just what we want at the moment, but we should probably move
524                // it into a "start workers" function to make it obvious that it isn't strictly
525                // used for activate events.
526                self.activate(mem, interrupt, queues)?;
527                Ok(())
528            }
529        }
530    }
531
532    fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
533        let streams_state = if let Some(states) = &self.streams_state {
534            let mut state_vec = Vec::new();
535            for state in states {
536                state_vec.push(state.clone());
537            }
538            Some(state_vec)
539        } else {
540            None
541        };
542        AnySnapshot::to_any(VirtioSndSnapshot {
543            avail_features: self.avail_features,
544            acked_features: self.acked_features,
545            queue_sizes: self.queue_sizes.to_vec(),
546            streams_state,
547            snd_data: self.snd_data.clone(),
548        })
549        .context("failed to Serialize Sound device")
550    }
551
552    fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
553        let mut deser: VirtioSndSnapshot =
554            AnySnapshot::from_any(data).context("failed to Deserialize Sound device")?;
555        anyhow::ensure!(
556            deser.avail_features == self.avail_features,
557            "avail features doesn't match on restore: expected: {}, got: {}",
558            deser.avail_features,
559            self.avail_features
560        );
561        anyhow::ensure!(
562            deser.queue_sizes == self.queue_sizes.to_vec(),
563            "queue sizes doesn't match on restore: expected: {:?}, got: {:?}",
564            deser.queue_sizes,
565            self.queue_sizes.to_vec()
566        );
567        self.acked_features = deser.acked_features;
568        anyhow::ensure!(
569            deser.snd_data == self.snd_data,
570            "snd data doesn't match on restore: expected: {:?}, got: {:?}",
571            deser.snd_data,
572            self.snd_data
573        );
574        self.acked_features = deser.acked_features;
575        self.streams_state = deser.streams_state.take();
576        Ok(())
577    }
578}
579
580#[derive(PartialEq)]
581enum LoopState {
582    Continue,
583    Break,
584}
585
586fn run_worker(
587    queues: BTreeMap<usize, Queue>,
588    snd_data: SndData,
589    kill_evt: Event,
590    stream_info_builders: Vec<StreamInfoBuilder>,
591    streams_state: Option<Vec<StreamInfoSnapshot>>,
592    card_index: usize,
593    control_tube: Tube,
594) -> WorkerReturn {
595    let ex = Executor::new().expect("Failed to create an executor");
596    let control_tube = AsyncTube::new(&ex, control_tube).expect("failed to create async snd tube");
597
598    if snd_data.pcm_info_len() != stream_info_builders.len() {
599        error!(
600            "snd: expected {} streams, got {}",
601            snd_data.pcm_info_len(),
602            stream_info_builders.len(),
603        );
604    }
605    let streams: Vec<AsyncRwLock<StreamInfo>> = stream_info_builders
606        .into_iter()
607        .map(StreamInfoBuilder::build)
608        .map(AsyncRwLock::new)
609        .collect();
610
611    let (tx_send, mut tx_recv) = mpsc::unbounded();
612    let (rx_send, mut rx_recv) = mpsc::unbounded();
613    let tx_send_clone = tx_send.clone();
614    let rx_send_clone = rx_send.clone();
615    let restore_task = ex.spawn_local(async move {
616        if let Some(states) = &streams_state {
617            let ex = Executor::new().expect("Failed to create an executor");
618            for (stream, state) in streams.iter().zip(states.iter()) {
619                stream.lock().await.restore(state);
620                if state.state == VIRTIO_SND_R_PCM_START || state.state == VIRTIO_SND_R_PCM_PREPARE
621                {
622                    stream
623                        .lock()
624                        .await
625                        .prepare(&ex, &tx_send_clone, &rx_send_clone)
626                        .await
627                        .expect("failed to prepare PCM");
628                }
629                if state.state == VIRTIO_SND_R_PCM_START {
630                    stream
631                        .lock()
632                        .await
633                        .start()
634                        .await
635                        .expect("failed to start PCM");
636                }
637            }
638        }
639        streams
640    });
641    let streams = ex
642        .run_until(restore_task)
643        .expect("failed to restore streams");
644    let streams = Rc::new(AsyncRwLock::new(streams));
645
646    let mut queues: Vec<(Queue, EventAsync)> = queues
647        .into_values()
648        .map(|q| {
649            let e = q.event().try_clone().expect("Failed to clone queue event");
650            (
651                q,
652                EventAsync::new(e, &ex).expect("Failed to create async event for queue"),
653            )
654        })
655        .collect();
656
657    let (ctrl_queue, mut ctrl_queue_evt) = queues.remove(0);
658    let ctrl_queue = Rc::new(AsyncRwLock::new(ctrl_queue));
659    let (_event_queue, _event_queue_evt) = queues.remove(0);
660    let (tx_queue, tx_queue_evt) = queues.remove(0);
661    let (rx_queue, rx_queue_evt) = queues.remove(0);
662
663    let tx_queue = Rc::new(AsyncRwLock::new(tx_queue));
664    let rx_queue = Rc::new(AsyncRwLock::new(rx_queue));
665
666    // Exit if the kill event is triggered.
667    let f_kill = async_utils::await_and_exit(&ex, kill_evt).fuse();
668
669    pin_mut!(f_kill);
670
671    loop {
672        if run_worker_once(
673            &ex,
674            &streams,
675            &snd_data,
676            &mut f_kill,
677            ctrl_queue.clone(),
678            &mut ctrl_queue_evt,
679            tx_queue.clone(),
680            &tx_queue_evt,
681            tx_send.clone(),
682            &mut tx_recv,
683            rx_queue.clone(),
684            &rx_queue_evt,
685            rx_send.clone(),
686            &mut rx_recv,
687            card_index,
688            &control_tube,
689        ) == LoopState::Break
690        {
691            break;
692        }
693
694        if let Err(e) = reset_streams(
695            &ex,
696            &streams,
697            &tx_queue,
698            &mut tx_recv,
699            &rx_queue,
700            &mut rx_recv,
701        ) {
702            error!("Error reset streams: {}", e);
703            break;
704        }
705    }
706    let streams_state_task = ex.spawn_local(async move {
707        let mut v = Vec::new();
708        for stream in streams.read_lock().await.iter() {
709            v.push(stream.read_lock().await.snapshot());
710        }
711        v
712    });
713    let streams_state = ex
714        .run_until(streams_state_task)
715        .expect("failed to save streams state");
716    let ctrl_queue = match Rc::try_unwrap(ctrl_queue) {
717        Ok(q) => q.into_inner(),
718        Err(_) => panic!("Too many refs to ctrl_queue"),
719    };
720    let tx_queue = match Rc::try_unwrap(tx_queue) {
721        Ok(q) => q.into_inner(),
722        Err(_) => panic!("Too many refs to tx_queue"),
723    };
724    let rx_queue = match Rc::try_unwrap(rx_queue) {
725        Ok(q) => q.into_inner(),
726        Err(_) => panic!("Too many refs to rx_queue"),
727    };
728    let queues = vec![ctrl_queue, _event_queue, tx_queue, rx_queue];
729
730    WorkerReturn {
731        control_tube: control_tube.into(),
732        queues,
733        snd_data,
734        streams_state,
735    }
736}
737
738struct WorkerReturn {
739    control_tube: Tube,
740    queues: Vec<Queue>,
741    snd_data: SndData,
742    streams_state: Vec<StreamInfoSnapshot>,
743}
744
745async fn notify_reset_signal(reset_signal: &(AsyncRwLock<bool>, Condvar)) {
746    let (lock, cvar) = reset_signal;
747    *lock.lock().await = true;
748    cvar.notify_all();
749}
750
751/// Runs all workers once and exit if any worker exit.
752///
753/// Returns [`LoopState::Break`] if the worker `f_kill` exits, or something went
754/// wrong on shutdown process. The caller should not run the worker again and should exit the main
755/// loop.
756///
757/// If this function returns [`LoopState::Continue`], the caller can continue the main loop by
758/// resetting the streams and run the worker again.
759fn run_worker_once(
760    ex: &Executor,
761    streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
762    snd_data: &SndData,
763    mut f_kill: &mut (impl FusedFuture<Output = anyhow::Result<()>> + Unpin),
764    ctrl_queue: Rc<AsyncRwLock<Queue>>,
765    ctrl_queue_evt: &mut EventAsync,
766    tx_queue: Rc<AsyncRwLock<Queue>>,
767    tx_queue_evt: &EventAsync,
768    tx_send: mpsc::UnboundedSender<PcmResponse>,
769    tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
770    rx_queue: Rc<AsyncRwLock<Queue>>,
771    rx_queue_evt: &EventAsync,
772    rx_send: mpsc::UnboundedSender<PcmResponse>,
773    rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
774    card_index: usize,
775    control_tube: &AsyncTube,
776) -> LoopState {
777    let tx_send2 = tx_send.clone();
778    let rx_send2 = rx_send.clone();
779
780    let reset_signal = (AsyncRwLock::new(false), Condvar::new());
781
782    let f_host_ctrl = handle_ctrl_tube(streams, control_tube, Some(&reset_signal)).fuse();
783
784    let f_ctrl = handle_ctrl_queue(
785        ex,
786        streams,
787        snd_data,
788        ctrl_queue,
789        ctrl_queue_evt,
790        tx_send,
791        rx_send,
792        card_index,
793        Some(&reset_signal),
794    )
795    .fuse();
796
797    // TODO(woodychow): Enable this when libcras sends jack connect/disconnect evts
798    // let f_event = handle_event_queue(
799    //     snd_state,
800    //     event_queue,
801    //     event_queue_evt,
802    // );
803    let f_tx = handle_pcm_queue(
804        streams,
805        tx_send2,
806        tx_queue.clone(),
807        tx_queue_evt,
808        card_index,
809        Some(&reset_signal),
810    )
811    .fuse();
812    let f_tx_response = send_pcm_response_worker(tx_queue, tx_recv, Some(&reset_signal)).fuse();
813    let f_rx = handle_pcm_queue(
814        streams,
815        rx_send2,
816        rx_queue.clone(),
817        rx_queue_evt,
818        card_index,
819        Some(&reset_signal),
820    )
821    .fuse();
822    let f_rx_response = send_pcm_response_worker(rx_queue, rx_recv, Some(&reset_signal)).fuse();
823
824    pin_mut!(
825        f_host_ctrl,
826        f_ctrl,
827        f_tx,
828        f_tx_response,
829        f_rx,
830        f_rx_response
831    );
832
833    let done = async {
834        select! {
835            res = f_host_ctrl => (res.context("error in handling host control command"), LoopState::Continue),
836            res = f_ctrl => (res.context("error in handling ctrl queue"), LoopState::Continue),
837            res = f_tx => (res.context("error in handling tx queue"), LoopState::Continue),
838            res = f_tx_response => (res.context("error in handling tx response"), LoopState::Continue),
839            res = f_rx => (res.context("error in handling rx queue"), LoopState::Continue),
840            res = f_rx_response => (res.context("error in handling rx response"), LoopState::Continue),
841
842            // For following workers, do not continue the loop
843            res = f_kill => (res.context("error in await_and_exit"), LoopState::Break),
844        }
845    };
846
847    match ex.run_until(done) {
848        Ok((res, loop_state)) => {
849            if let Err(e) = res {
850                error!("Error in worker: {:#}", e);
851            }
852            if loop_state == LoopState::Break {
853                return LoopState::Break;
854            }
855        }
856        Err(e) => {
857            error!("Error happened in executor: {}", e);
858        }
859    }
860
861    warn!("Shutting down all workers for reset procedure");
862    block_on(notify_reset_signal(&reset_signal));
863
864    let shutdown = async {
865        loop {
866            let (res, worker_name) = select!(
867                res = f_ctrl => (res, "f_ctrl"),
868                res = f_tx => (res, "f_tx"),
869                res = f_tx_response => (res, "f_tx_response"),
870                res = f_rx => (res, "f_rx"),
871                res = f_rx_response => (res, "f_rx_response"),
872                complete => break,
873            );
874            match res {
875                Ok(_) => debug!("Worker {} stopped", worker_name),
876                Err(e) => error!("Worker {} stopped with error {}", worker_name, e),
877            };
878        }
879    };
880
881    if let Err(e) = ex.run_until(shutdown) {
882        error!("Error happened in executor while shutdown: {}", e);
883        return LoopState::Break;
884    }
885
886    LoopState::Continue
887}
888
889fn reset_streams(
890    ex: &Executor,
891    streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
892    tx_queue: &Rc<AsyncRwLock<Queue>>,
893    tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
894    rx_queue: &Rc<AsyncRwLock<Queue>>,
895    rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
896) -> Result<(), AsyncError> {
897    let reset_signal = (AsyncRwLock::new(false), Condvar::new());
898
899    let do_reset = async {
900        let streams = streams.read_lock().await;
901        for stream_info in &*streams {
902            let mut stream_info = stream_info.lock().await;
903            if stream_info.state == VIRTIO_SND_R_PCM_START {
904                if let Err(e) = stream_info.stop().await {
905                    error!("Error on stop while resetting stream: {}", e);
906                }
907            }
908            if stream_info.state == VIRTIO_SND_R_PCM_STOP
909                || stream_info.state == VIRTIO_SND_R_PCM_PREPARE
910            {
911                if let Err(e) = stream_info.release().await {
912                    error!("Error on release while resetting stream: {}", e);
913                }
914            }
915            stream_info.just_reset = true;
916        }
917
918        notify_reset_signal(&reset_signal).await;
919    };
920
921    // Run these in a loop to ensure that they will survive until do_reset is finished
922    let f_tx_response = async {
923        while send_pcm_response_worker(tx_queue.clone(), tx_recv, Some(&reset_signal))
924            .await
925            .is_err()
926        {}
927    };
928
929    let f_rx_response = async {
930        while send_pcm_response_worker(rx_queue.clone(), rx_recv, Some(&reset_signal))
931            .await
932            .is_err()
933        {}
934    };
935
936    let reset = async {
937        join!(f_tx_response, f_rx_response, do_reset);
938    };
939
940    ex.run_until(reset)
941}
942
943#[cfg(test)]
944#[allow(clippy::needless_update)]
945mod tests {
946    use audio_streams::StreamEffect;
947
948    use super::*;
949    use crate::virtio::snd::parameters::PCMDeviceParameters;
950
951    #[test]
952    fn test_virtio_snd_new() {
953        let params = Parameters {
954            num_output_devices: 3,
955            num_input_devices: 2,
956            num_output_streams: 3,
957            num_input_streams: 2,
958            output_device_config: vec![PCMDeviceParameters {
959                effects: Some(vec![StreamEffect::EchoCancellation]),
960                ..PCMDeviceParameters::default()
961            }],
962            input_device_config: vec![PCMDeviceParameters {
963                effects: Some(vec![StreamEffect::EchoCancellation]),
964                ..PCMDeviceParameters::default()
965            }],
966            ..Default::default()
967        };
968
969        let (t0, _t1) = Tube::pair().expect("failed to create tube");
970        let res = VirtioSnd::new(123, params, t0).unwrap();
971
972        // Default values
973        assert_eq!(res.snd_data.jack_info.len(), 0);
974        assert_eq!(res.acked_features, 0);
975        assert_eq!(res.worker_thread.is_none(), true);
976
977        assert_eq!(res.avail_features, 123); // avail_features must be equal to the input
978        assert_eq!(res.cfg.jacks.to_native(), 0);
979        assert_eq!(res.cfg.streams.to_native(), 13); // (Output = 3*3) + (Input = 2*2)
980        assert_eq!(res.cfg.chmaps.to_native(), 11); // (Output = 3*3) + (Input = 2*1)
981
982        // Check snd_data.pcm_info
983        assert_eq!(res.snd_data.pcm_info.len(), 13);
984        // Check hda_fn_nid (PCM Device number)
985        let expected_hda_fn_nid = [0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 1, 1];
986        for (i, pcm_info) in res.snd_data.pcm_info.iter().enumerate() {
987            assert_eq!(
988                pcm_info.hdr.hda_fn_nid.to_native(),
989                expected_hda_fn_nid[i],
990                "pcm_info index {i} incorrect hda_fn_nid"
991            );
992        }
993        // First 9 devices must be OUTPUT
994        for i in 0..9 {
995            assert_eq!(
996                res.snd_data.pcm_info[i].direction, VIRTIO_SND_D_OUTPUT,
997                "pcm_info index {i} incorrect direction"
998            );
999        }
1000        // Next 4 devices must be INPUT
1001        for i in 9..13 {
1002            assert_eq!(
1003                res.snd_data.pcm_info[i].direction, VIRTIO_SND_D_INPUT,
1004                "pcm_info index {i} incorrect direction"
1005            );
1006        }
1007
1008        // Check snd_data.chmap_info
1009        assert_eq!(res.snd_data.chmap_info.len(), 11);
1010        let expected_hda_fn_nid = [0, 1, 2, 0, 1, 0, 1, 2, 0, 1, 2];
1011        // Check hda_fn_nid (PCM Device number)
1012        for (i, chmap_info) in res.snd_data.chmap_info.iter().enumerate() {
1013            assert_eq!(
1014                chmap_info.hdr.hda_fn_nid.to_native(),
1015                expected_hda_fn_nid[i],
1016                "chmap_info index {i} incorrect hda_fn_nid"
1017            );
1018        }
1019    }
1020
1021    #[test]
1022    fn test_resize_parameters_pcm_device_config_truncate() {
1023        // If pcm_device_config is larger than number of devices, it will be truncated
1024        let params = Parameters {
1025            num_output_devices: 1,
1026            num_input_devices: 1,
1027            output_device_config: vec![PCMDeviceParameters::default(); 3],
1028            input_device_config: vec![PCMDeviceParameters::default(); 3],
1029            ..Parameters::default()
1030        };
1031        let params = resize_parameters_pcm_device_config(params);
1032        assert_eq!(params.output_device_config.len(), 1);
1033        assert_eq!(params.input_device_config.len(), 1);
1034    }
1035
1036    #[test]
1037    fn test_resize_parameters_pcm_device_config_extend() {
1038        let params = Parameters {
1039            num_output_devices: 3,
1040            num_input_devices: 2,
1041            num_output_streams: 3,
1042            num_input_streams: 2,
1043            output_device_config: vec![PCMDeviceParameters {
1044                effects: Some(vec![StreamEffect::EchoCancellation]),
1045                ..PCMDeviceParameters::default()
1046            }],
1047            input_device_config: vec![PCMDeviceParameters {
1048                effects: Some(vec![StreamEffect::EchoCancellation]),
1049                ..PCMDeviceParameters::default()
1050            }],
1051            ..Default::default()
1052        };
1053
1054        let params = resize_parameters_pcm_device_config(params);
1055
1056        // Check output_device_config correctly extended
1057        assert_eq!(
1058            params.output_device_config,
1059            vec![
1060                PCMDeviceParameters {
1061                    // Keep from the parameters
1062                    effects: Some(vec![StreamEffect::EchoCancellation]),
1063                    ..PCMDeviceParameters::default()
1064                },
1065                PCMDeviceParameters::default(), // Extended with default
1066                PCMDeviceParameters::default(), // Extended with default
1067            ]
1068        );
1069
1070        // Check input_device_config correctly extended
1071        assert_eq!(
1072            params.input_device_config,
1073            vec![
1074                PCMDeviceParameters {
1075                    // Keep from the parameters
1076                    effects: Some(vec![StreamEffect::EchoCancellation]),
1077                    ..PCMDeviceParameters::default()
1078                },
1079                PCMDeviceParameters::default(), // Extended with default
1080            ]
1081        );
1082    }
1083}