devices/virtio/snd/common_backend/
mod.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// virtio-sound spec: https://github.com/oasis-tcs/virtio-spec/blob/master/virtio-sound.tex
6
7use std::collections::BTreeMap;
8use std::io;
9use std::rc::Rc;
10use std::sync::Arc;
11
12use anyhow::anyhow;
13use anyhow::Context;
14use audio_streams::BoxError;
15use base::debug;
16use base::error;
17use base::warn;
18use base::AsRawDescriptor;
19use base::Descriptor;
20use base::Error as SysError;
21use base::Event;
22use base::RawDescriptor;
23use base::Tube;
24use base::WorkerThread;
25use cros_async::block_on;
26use cros_async::sync::Condvar;
27use cros_async::sync::RwLock as AsyncRwLock;
28use cros_async::AsyncError;
29use cros_async::AsyncTube;
30use cros_async::EventAsync;
31use cros_async::Executor;
32use futures::channel::mpsc;
33use futures::channel::oneshot;
34use futures::channel::oneshot::Canceled;
35use futures::future::FusedFuture;
36use futures::join;
37use futures::pin_mut;
38use futures::select;
39use futures::FutureExt;
40use serde::Deserialize;
41use serde::Serialize;
42use snapshot::AnySnapshot;
43use thiserror::Error as ThisError;
44use vm_memory::GuestMemory;
45use zerocopy::IntoBytes;
46
47use crate::virtio::async_utils;
48use crate::virtio::copy_config;
49use crate::virtio::device_constants::snd::virtio_snd_config;
50use crate::virtio::snd::common_backend::async_funcs::*;
51use crate::virtio::snd::common_backend::stream_info::StreamInfo;
52use crate::virtio::snd::common_backend::stream_info::StreamInfoBuilder;
53use crate::virtio::snd::common_backend::stream_info::StreamInfoSnapshot;
54use crate::virtio::snd::constants::*;
55use crate::virtio::snd::file_backend::create_file_stream_source_generators;
56use crate::virtio::snd::file_backend::Error as FileError;
57use crate::virtio::snd::layout::*;
58use crate::virtio::snd::null_backend::create_null_stream_source_generators;
59use crate::virtio::snd::parameters::Parameters;
60use crate::virtio::snd::parameters::StreamSourceBackend;
61use crate::virtio::snd::sys::create_stream_source_generators as sys_create_stream_source_generators;
62use crate::virtio::snd::sys::set_audio_thread_priority;
63use crate::virtio::snd::sys::SysAsyncStreamObjects;
64use crate::virtio::snd::sys::SysAudioStreamSourceGenerator;
65use crate::virtio::snd::sys::SysDirectionOutput;
66use crate::virtio::DescriptorChain;
67use crate::virtio::DeviceType;
68use crate::virtio::Interrupt;
69use crate::virtio::Queue;
70use crate::virtio::VirtioDevice;
71
72pub mod async_funcs;
73pub mod stream_info;
74
75// control + event + tx + rx queue
76pub const MAX_QUEUE_NUM: usize = 4;
77pub const MAX_VRING_LEN: u16 = 1024;
78
79#[derive(ThisError, Debug)]
80pub enum Error {
81    /// next_async failed.
82    #[error("Failed to read descriptor asynchronously: {0}")]
83    Async(AsyncError),
84    /// Creating stream failed.
85    #[error("Failed to create stream: {0}")]
86    CreateStream(BoxError),
87    /// Creating stream failed.
88    #[error("No stream source found.")]
89    EmptyStreamSource,
90    /// Creating kill event failed.
91    #[error("Failed to create kill event: {0}")]
92    CreateKillEvent(SysError),
93    /// Creating WaitContext failed.
94    #[error("Failed to create wait context: {0}")]
95    CreateWaitContext(SysError),
96    #[error("Failed to create file stream source generator")]
97    CreateFileStreamSourceGenerator(FileError),
98    /// Cloning kill event failed.
99    #[error("Failed to clone kill event: {0}")]
100    CloneKillEvent(SysError),
101    // Future error.
102    #[error("Unexpected error. Done was not triggered before dropped: {0}")]
103    DoneNotTriggered(Canceled),
104    /// Error reading message from queue.
105    #[error("Failed to read message: {0}")]
106    ReadMessage(io::Error),
107    /// Failed writing a response to a control message.
108    #[error("Failed to write message response: {0}")]
109    WriteResponse(io::Error),
110    // Mpsc read error.
111    #[error("Error in mpsc: {0}")]
112    MpscSend(futures::channel::mpsc::SendError),
113    // Oneshot send error.
114    #[error("Error in oneshot send")]
115    OneshotSend(()),
116    /// Failure in communicating with the host
117    #[error("Failed to send/receive to/from control tube")]
118    ControlTubeError(base::TubeError),
119    /// Stream not found.
120    #[error("stream id ({0}) < num_streams ({1})")]
121    StreamNotFound(usize, usize),
122    /// Fetch buffer error
123    #[error("Failed to get buffer from CRAS: {0}")]
124    FetchBuffer(BoxError),
125    /// Invalid buffer size
126    #[error("Invalid buffer size")]
127    InvalidBufferSize,
128    /// IoError
129    #[error("I/O failed: {0}")]
130    Io(io::Error),
131    /// Operation not supported.
132    #[error("Operation not supported")]
133    OperationNotSupported,
134    /// Writing to a buffer in the guest failed.
135    #[error("failed to write to buffer: {0}")]
136    WriteBuffer(io::Error),
137    // Invalid PCM worker state.
138    #[error("Invalid PCM worker state")]
139    InvalidPCMWorkerState,
140    // Invalid backend.
141    #[error("Backend is not implemented")]
142    InvalidBackend,
143    // Failed to generate StreamSource
144    #[error("Failed to generate stream source: {0}")]
145    GenerateStreamSource(BoxError),
146    // PCM worker unexpectedly quitted.
147    #[error("PCM worker quitted unexpectedly")]
148    PCMWorkerQuittedUnexpectedly,
149}
150
151pub enum DirectionalStream {
152    Input(
153        usize, // `period_size` in `usize`
154        Box<dyn CaptureBufferReader>,
155    ),
156    Output(SysDirectionOutput),
157}
158
159#[derive(Copy, Clone, std::cmp::PartialEq, Eq)]
160pub enum WorkerStatus {
161    Pause = 0,
162    Running = 1,
163    Quit = 2,
164}
165
166// Stores constant data
167#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Debug)]
168pub struct SndData {
169    pub(crate) jack_info: Vec<virtio_snd_jack_info>,
170    pub(crate) pcm_info: Vec<virtio_snd_pcm_info>,
171    pub(crate) chmap_info: Vec<virtio_snd_chmap_info>,
172}
173
174impl SndData {
175    pub fn pcm_info_len(&self) -> usize {
176        self.pcm_info.len()
177    }
178
179    pub fn pcm_info_iter(&self) -> std::slice::Iter<'_, virtio_snd_pcm_info> {
180        self.pcm_info.iter()
181    }
182}
183
184const SUPPORTED_FORMATS: u64 = 1 << VIRTIO_SND_PCM_FMT_U8
185    | 1 << VIRTIO_SND_PCM_FMT_S16
186    | 1 << VIRTIO_SND_PCM_FMT_S24
187    | 1 << VIRTIO_SND_PCM_FMT_S32;
188const SUPPORTED_FRAME_RATES: u64 = 1 << VIRTIO_SND_PCM_RATE_8000
189    | 1 << VIRTIO_SND_PCM_RATE_11025
190    | 1 << VIRTIO_SND_PCM_RATE_16000
191    | 1 << VIRTIO_SND_PCM_RATE_22050
192    | 1 << VIRTIO_SND_PCM_RATE_32000
193    | 1 << VIRTIO_SND_PCM_RATE_44100
194    | 1 << VIRTIO_SND_PCM_RATE_48000;
195
196// Response from pcm_worker to pcm_queue
197pub struct PcmResponse {
198    pub(crate) desc_chain: DescriptorChain,
199    pub(crate) status: virtio_snd_pcm_status, // response to the pcm message
200    pub(crate) done: Option<oneshot::Sender<()>>, // when pcm response is written to the queue
201}
202
203pub struct VirtioSnd {
204    control_tube: Option<Tube>,
205    cfg: virtio_snd_config,
206    snd_data: SndData,
207    stream_info_builders: Vec<StreamInfoBuilder>,
208    avail_features: u64,
209    acked_features: u64,
210    queue_sizes: Box<[u16]>,
211    worker_thread: Option<WorkerThread<WorkerReturn>>,
212    keep_rds: Vec<Descriptor>,
213    streams_state: Option<Vec<StreamInfoSnapshot>>,
214    card_index: usize,
215}
216
217#[derive(Serialize, Deserialize)]
218struct VirtioSndSnapshot {
219    avail_features: u64,
220    acked_features: u64,
221    queue_sizes: Vec<u16>,
222    streams_state: Option<Vec<StreamInfoSnapshot>>,
223    snd_data: SndData,
224}
225
226impl VirtioSnd {
227    pub fn new(
228        base_features: u64,
229        params: Parameters,
230        control_tube: Tube,
231    ) -> Result<VirtioSnd, Error> {
232        let params = resize_parameters_pcm_device_config(params);
233        let cfg = hardcoded_virtio_snd_config(&params);
234        let snd_data = hardcoded_snd_data(&params);
235        let avail_features = base_features;
236        let mut keep_rds: Vec<RawDescriptor> = Vec::new();
237        keep_rds.push(control_tube.as_raw_descriptor());
238
239        let stream_info_builders =
240            create_stream_info_builders(&params, &snd_data, &mut keep_rds, params.card_index)?;
241
242        Ok(VirtioSnd {
243            control_tube: Some(control_tube),
244            cfg,
245            snd_data,
246            stream_info_builders,
247            avail_features,
248            acked_features: 0,
249            queue_sizes: vec![MAX_VRING_LEN; MAX_QUEUE_NUM].into_boxed_slice(),
250            worker_thread: None,
251            keep_rds: keep_rds.iter().map(|rd| Descriptor(*rd)).collect(),
252            streams_state: None,
253            card_index: params.card_index,
254        })
255    }
256}
257
258fn create_stream_source_generators(
259    params: &Parameters,
260    snd_data: &SndData,
261    keep_rds: &mut Vec<RawDescriptor>,
262) -> Result<Vec<SysAudioStreamSourceGenerator>, Error> {
263    let generators = match params.backend {
264        StreamSourceBackend::NULL => create_null_stream_source_generators(snd_data),
265        StreamSourceBackend::FILE => {
266            create_file_stream_source_generators(params, snd_data, keep_rds)
267                .map_err(Error::CreateFileStreamSourceGenerator)?
268        }
269        StreamSourceBackend::Sys(backend) => {
270            sys_create_stream_source_generators(backend, params, snd_data)
271        }
272    };
273    Ok(generators)
274}
275
276/// Creates [`StreamInfoBuilder`]s by calling [`create_stream_source_generators()`] then zip
277/// them with [`crate::virtio::snd::parameters::PCMDeviceParameters`] from the params to set
278/// the parameters on each [`StreamInfoBuilder`] (e.g. effects).
279pub(crate) fn create_stream_info_builders(
280    params: &Parameters,
281    snd_data: &SndData,
282    keep_rds: &mut Vec<RawDescriptor>,
283    card_index: usize,
284) -> Result<Vec<StreamInfoBuilder>, Error> {
285    Ok(create_stream_source_generators(params, snd_data, keep_rds)?
286        .into_iter()
287        .map(Arc::new)
288        .zip(snd_data.pcm_info_iter())
289        .map(|(generator, pcm_info)| {
290            let device_params = params.get_device_params(pcm_info).unwrap_or_default();
291            StreamInfo::builder(generator, card_index)
292                .effects(device_params.effects.unwrap_or_default())
293        })
294        .collect())
295}
296
297// To be used with hardcoded_snd_data
298pub fn hardcoded_virtio_snd_config(params: &Parameters) -> virtio_snd_config {
299    virtio_snd_config {
300        jacks: 0.into(),
301        streams: params.get_total_streams().into(),
302        chmaps: (params.num_output_devices * 3 + params.num_input_devices).into(),
303        controls: 0.into(),
304    }
305}
306
307// To be used with hardcoded_virtio_snd_config
308pub fn hardcoded_snd_data(params: &Parameters) -> SndData {
309    let jack_info: Vec<virtio_snd_jack_info> = Vec::new();
310    let mut pcm_info: Vec<virtio_snd_pcm_info> = Vec::new();
311    let mut chmap_info: Vec<virtio_snd_chmap_info> = Vec::new();
312
313    for dev in 0..params.num_output_devices {
314        for _ in 0..params.num_output_streams {
315            pcm_info.push(virtio_snd_pcm_info {
316                hdr: virtio_snd_info {
317                    hda_fn_nid: dev.into(),
318                },
319                features: 0.into(), /* 1 << VIRTIO_SND_PCM_F_XXX */
320                formats: SUPPORTED_FORMATS.into(),
321                rates: SUPPORTED_FRAME_RATES.into(),
322                direction: VIRTIO_SND_D_OUTPUT,
323                channels_min: 1,
324                channels_max: 6,
325                padding: [0; 5],
326            });
327        }
328    }
329    for dev in 0..params.num_input_devices {
330        for _ in 0..params.num_input_streams {
331            pcm_info.push(virtio_snd_pcm_info {
332                hdr: virtio_snd_info {
333                    hda_fn_nid: dev.into(),
334                },
335                features: 0.into(), /* 1 << VIRTIO_SND_PCM_F_XXX */
336                formats: SUPPORTED_FORMATS.into(),
337                rates: SUPPORTED_FRAME_RATES.into(),
338                direction: VIRTIO_SND_D_INPUT,
339                channels_min: 1,
340                channels_max: 2,
341                padding: [0; 5],
342            });
343        }
344    }
345    // Use stereo channel map.
346    let mut positions = [VIRTIO_SND_CHMAP_NONE; VIRTIO_SND_CHMAP_MAX_SIZE];
347    positions[0] = VIRTIO_SND_CHMAP_FL;
348    positions[1] = VIRTIO_SND_CHMAP_FR;
349    for dev in 0..params.num_output_devices {
350        chmap_info.push(virtio_snd_chmap_info {
351            hdr: virtio_snd_info {
352                hda_fn_nid: dev.into(),
353            },
354            direction: VIRTIO_SND_D_OUTPUT,
355            channels: 2,
356            positions,
357        });
358    }
359    for dev in 0..params.num_input_devices {
360        chmap_info.push(virtio_snd_chmap_info {
361            hdr: virtio_snd_info {
362                hda_fn_nid: dev.into(),
363            },
364            direction: VIRTIO_SND_D_INPUT,
365            channels: 2,
366            positions,
367        });
368    }
369    positions[2] = VIRTIO_SND_CHMAP_RL;
370    positions[3] = VIRTIO_SND_CHMAP_RR;
371    for dev in 0..params.num_output_devices {
372        chmap_info.push(virtio_snd_chmap_info {
373            hdr: virtio_snd_info {
374                hda_fn_nid: dev.into(),
375            },
376            direction: VIRTIO_SND_D_OUTPUT,
377            channels: 4,
378            positions,
379        });
380    }
381    positions[2] = VIRTIO_SND_CHMAP_FC;
382    positions[3] = VIRTIO_SND_CHMAP_LFE;
383    positions[4] = VIRTIO_SND_CHMAP_RL;
384    positions[5] = VIRTIO_SND_CHMAP_RR;
385    for dev in 0..params.num_output_devices {
386        chmap_info.push(virtio_snd_chmap_info {
387            hdr: virtio_snd_info {
388                hda_fn_nid: dev.into(),
389            },
390            direction: VIRTIO_SND_D_OUTPUT,
391            channels: 6,
392            positions,
393        });
394    }
395
396    SndData {
397        jack_info,
398        pcm_info,
399        chmap_info,
400    }
401}
402
403fn resize_parameters_pcm_device_config(mut params: Parameters) -> Parameters {
404    if params.output_device_config.len() > params.num_output_devices as usize {
405        warn!("Truncating output device config due to length > number of output devices");
406    }
407    params
408        .output_device_config
409        .resize_with(params.num_output_devices as usize, Default::default);
410
411    if params.input_device_config.len() > params.num_input_devices as usize {
412        warn!("Truncating input device config due to length > number of input devices");
413    }
414    params
415        .input_device_config
416        .resize_with(params.num_input_devices as usize, Default::default);
417
418    params
419}
420
421impl VirtioDevice for VirtioSnd {
422    fn keep_rds(&self) -> Vec<RawDescriptor> {
423        self.keep_rds
424            .iter()
425            .map(|descr| descr.as_raw_descriptor())
426            .collect()
427    }
428
429    fn device_type(&self) -> DeviceType {
430        DeviceType::Sound
431    }
432
433    fn queue_max_sizes(&self) -> &[u16] {
434        &self.queue_sizes
435    }
436
437    fn features(&self) -> u64 {
438        self.avail_features
439    }
440
441    fn ack_features(&mut self, mut v: u64) {
442        // Check if the guest is ACK'ing a feature that we didn't claim to have.
443        let unrequested_features = v & !self.avail_features;
444        if unrequested_features != 0 {
445            warn!("virtio_fs got unknown feature ack: {:x}", v);
446
447            // Don't count these features as acked.
448            v &= !unrequested_features;
449        }
450        self.acked_features |= v;
451    }
452
453    fn read_config(&self, offset: u64, data: &mut [u8]) {
454        copy_config(data, 0, self.cfg.as_bytes(), offset)
455    }
456
457    fn activate(
458        &mut self,
459        _guest_mem: GuestMemory,
460        _interrupt: Interrupt,
461        queues: BTreeMap<usize, Queue>,
462    ) -> anyhow::Result<()> {
463        if queues.len() != self.queue_sizes.len() {
464            return Err(anyhow!(
465                "snd: expected {} queues, got {}",
466                self.queue_sizes.len(),
467                queues.len()
468            ));
469        }
470
471        let snd_data = self.snd_data.clone();
472        let stream_info_builders = self.stream_info_builders.to_vec();
473        let streams_state = self.streams_state.take();
474        let card_index = self.card_index;
475        let control_tube = self.control_tube.take().unwrap();
476        self.worker_thread = Some(WorkerThread::start("v_snd_common", move |kill_evt| {
477            let _thread_priority_handle = set_audio_thread_priority();
478            if let Err(e) = _thread_priority_handle {
479                warn!("Failed to set audio thread to real time: {}", e);
480            };
481            run_worker(
482                queues,
483                snd_data,
484                kill_evt,
485                stream_info_builders,
486                streams_state,
487                card_index,
488                control_tube,
489            )
490        }));
491
492        Ok(())
493    }
494
495    fn reset(&mut self) -> anyhow::Result<()> {
496        if let Some(worker_thread) = self.worker_thread.take() {
497            let worker = worker_thread.stop();
498            self.control_tube = Some(worker.control_tube);
499        }
500
501        Ok(())
502    }
503
504    fn virtio_sleep(&mut self) -> anyhow::Result<Option<BTreeMap<usize, Queue>>> {
505        if let Some(worker_thread) = self.worker_thread.take() {
506            let worker = worker_thread.stop();
507            self.control_tube = Some(worker.control_tube);
508            self.snd_data = worker.snd_data;
509            self.streams_state = Some(worker.streams_state);
510            return Ok(Some(BTreeMap::from_iter(
511                worker.queues.into_iter().enumerate(),
512            )));
513        }
514        Ok(None)
515    }
516
517    fn virtio_wake(
518        &mut self,
519        device_state: Option<(GuestMemory, Interrupt, BTreeMap<usize, Queue>)>,
520    ) -> anyhow::Result<()> {
521        match device_state {
522            None => Ok(()),
523            Some((mem, interrupt, queues)) => {
524                // TODO: activate is just what we want at the moment, but we should probably move
525                // it into a "start workers" function to make it obvious that it isn't strictly
526                // used for activate events.
527                self.activate(mem, interrupt, queues)?;
528                Ok(())
529            }
530        }
531    }
532
533    fn virtio_snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
534        let streams_state = if let Some(states) = &self.streams_state {
535            let mut state_vec = Vec::new();
536            for state in states {
537                state_vec.push(state.clone());
538            }
539            Some(state_vec)
540        } else {
541            None
542        };
543        AnySnapshot::to_any(VirtioSndSnapshot {
544            avail_features: self.avail_features,
545            acked_features: self.acked_features,
546            queue_sizes: self.queue_sizes.to_vec(),
547            streams_state,
548            snd_data: self.snd_data.clone(),
549        })
550        .context("failed to Serialize Sound device")
551    }
552
553    fn virtio_restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
554        let mut deser: VirtioSndSnapshot =
555            AnySnapshot::from_any(data).context("failed to Deserialize Sound device")?;
556        anyhow::ensure!(
557            deser.avail_features == self.avail_features,
558            "avail features doesn't match on restore: expected: {}, got: {}",
559            deser.avail_features,
560            self.avail_features
561        );
562        anyhow::ensure!(
563            deser.queue_sizes == self.queue_sizes.to_vec(),
564            "queue sizes doesn't match on restore: expected: {:?}, got: {:?}",
565            deser.queue_sizes,
566            self.queue_sizes.to_vec()
567        );
568        self.acked_features = deser.acked_features;
569        anyhow::ensure!(
570            deser.snd_data == self.snd_data,
571            "snd data doesn't match on restore: expected: {:?}, got: {:?}",
572            deser.snd_data,
573            self.snd_data
574        );
575        self.acked_features = deser.acked_features;
576        self.streams_state = deser.streams_state.take();
577        Ok(())
578    }
579}
580
581#[derive(PartialEq)]
582enum LoopState {
583    Continue,
584    Break,
585}
586
587fn run_worker(
588    queues: BTreeMap<usize, Queue>,
589    snd_data: SndData,
590    kill_evt: Event,
591    stream_info_builders: Vec<StreamInfoBuilder>,
592    streams_state: Option<Vec<StreamInfoSnapshot>>,
593    card_index: usize,
594    control_tube: Tube,
595) -> WorkerReturn {
596    let ex = Executor::new().expect("Failed to create an executor");
597    let control_tube = AsyncTube::new(&ex, control_tube).expect("failed to create async snd tube");
598
599    if snd_data.pcm_info_len() != stream_info_builders.len() {
600        error!(
601            "snd: expected {} streams, got {}",
602            snd_data.pcm_info_len(),
603            stream_info_builders.len(),
604        );
605    }
606    let streams: Vec<AsyncRwLock<StreamInfo>> = stream_info_builders
607        .into_iter()
608        .map(StreamInfoBuilder::build)
609        .map(AsyncRwLock::new)
610        .collect();
611
612    let (tx_send, mut tx_recv) = mpsc::unbounded();
613    let (rx_send, mut rx_recv) = mpsc::unbounded();
614    let tx_send_clone = tx_send.clone();
615    let rx_send_clone = rx_send.clone();
616    let restore_task = ex.spawn_local(async move {
617        if let Some(states) = &streams_state {
618            let ex = Executor::new().expect("Failed to create an executor");
619            for (stream, state) in streams.iter().zip(states.iter()) {
620                stream.lock().await.restore(state);
621                if state.state == VIRTIO_SND_R_PCM_START || state.state == VIRTIO_SND_R_PCM_PREPARE
622                {
623                    stream
624                        .lock()
625                        .await
626                        .prepare(&ex, &tx_send_clone, &rx_send_clone)
627                        .await
628                        .expect("failed to prepare PCM");
629                }
630                if state.state == VIRTIO_SND_R_PCM_START {
631                    stream
632                        .lock()
633                        .await
634                        .start()
635                        .await
636                        .expect("failed to start PCM");
637                }
638            }
639        }
640        streams
641    });
642    let streams = ex
643        .run_until(restore_task)
644        .expect("failed to restore streams");
645    let streams = Rc::new(AsyncRwLock::new(streams));
646
647    let mut queues: Vec<(Queue, EventAsync)> = queues
648        .into_values()
649        .map(|q| {
650            let e = q.event().try_clone().expect("Failed to clone queue event");
651            (
652                q,
653                EventAsync::new(e, &ex).expect("Failed to create async event for queue"),
654            )
655        })
656        .collect();
657
658    let (ctrl_queue, mut ctrl_queue_evt) = queues.remove(0);
659    let ctrl_queue = Rc::new(AsyncRwLock::new(ctrl_queue));
660    let (_event_queue, _event_queue_evt) = queues.remove(0);
661    let (tx_queue, tx_queue_evt) = queues.remove(0);
662    let (rx_queue, rx_queue_evt) = queues.remove(0);
663
664    let tx_queue = Rc::new(AsyncRwLock::new(tx_queue));
665    let rx_queue = Rc::new(AsyncRwLock::new(rx_queue));
666
667    // Exit if the kill event is triggered.
668    let f_kill = async_utils::await_and_exit(&ex, kill_evt).fuse();
669
670    pin_mut!(f_kill);
671
672    loop {
673        if run_worker_once(
674            &ex,
675            &streams,
676            &snd_data,
677            &mut f_kill,
678            ctrl_queue.clone(),
679            &mut ctrl_queue_evt,
680            tx_queue.clone(),
681            &tx_queue_evt,
682            tx_send.clone(),
683            &mut tx_recv,
684            rx_queue.clone(),
685            &rx_queue_evt,
686            rx_send.clone(),
687            &mut rx_recv,
688            card_index,
689            &control_tube,
690        ) == LoopState::Break
691        {
692            break;
693        }
694
695        if let Err(e) = reset_streams(
696            &ex,
697            &streams,
698            &tx_queue,
699            &mut tx_recv,
700            &rx_queue,
701            &mut rx_recv,
702        ) {
703            error!("Error reset streams: {}", e);
704            break;
705        }
706    }
707    let streams_state_task = ex.spawn_local(async move {
708        let mut v = Vec::new();
709        for stream in streams.read_lock().await.iter() {
710            v.push(stream.read_lock().await.snapshot());
711        }
712        v
713    });
714    let streams_state = ex
715        .run_until(streams_state_task)
716        .expect("failed to save streams state");
717    let ctrl_queue = match Rc::try_unwrap(ctrl_queue) {
718        Ok(q) => q.into_inner(),
719        Err(_) => panic!("Too many refs to ctrl_queue"),
720    };
721    let tx_queue = match Rc::try_unwrap(tx_queue) {
722        Ok(q) => q.into_inner(),
723        Err(_) => panic!("Too many refs to tx_queue"),
724    };
725    let rx_queue = match Rc::try_unwrap(rx_queue) {
726        Ok(q) => q.into_inner(),
727        Err(_) => panic!("Too many refs to rx_queue"),
728    };
729    let queues = vec![ctrl_queue, _event_queue, tx_queue, rx_queue];
730
731    WorkerReturn {
732        control_tube: control_tube.into(),
733        queues,
734        snd_data,
735        streams_state,
736    }
737}
738
739struct WorkerReturn {
740    control_tube: Tube,
741    queues: Vec<Queue>,
742    snd_data: SndData,
743    streams_state: Vec<StreamInfoSnapshot>,
744}
745
746async fn notify_reset_signal(reset_signal: &(AsyncRwLock<bool>, Condvar)) {
747    let (lock, cvar) = reset_signal;
748    *lock.lock().await = true;
749    cvar.notify_all();
750}
751
752/// Runs all workers once and exit if any worker exit.
753///
754/// Returns [`LoopState::Break`] if the worker `f_kill` exits, or something went
755/// wrong on shutdown process. The caller should not run the worker again and should exit the main
756/// loop.
757///
758/// If this function returns [`LoopState::Continue`], the caller can continue the main loop by
759/// resetting the streams and run the worker again.
760fn run_worker_once(
761    ex: &Executor,
762    streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
763    snd_data: &SndData,
764    mut f_kill: &mut (impl FusedFuture<Output = anyhow::Result<()>> + Unpin),
765    ctrl_queue: Rc<AsyncRwLock<Queue>>,
766    ctrl_queue_evt: &mut EventAsync,
767    tx_queue: Rc<AsyncRwLock<Queue>>,
768    tx_queue_evt: &EventAsync,
769    tx_send: mpsc::UnboundedSender<PcmResponse>,
770    tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
771    rx_queue: Rc<AsyncRwLock<Queue>>,
772    rx_queue_evt: &EventAsync,
773    rx_send: mpsc::UnboundedSender<PcmResponse>,
774    rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
775    card_index: usize,
776    control_tube: &AsyncTube,
777) -> LoopState {
778    let tx_send2 = tx_send.clone();
779    let rx_send2 = rx_send.clone();
780
781    let reset_signal = (AsyncRwLock::new(false), Condvar::new());
782
783    let f_host_ctrl = handle_ctrl_tube(streams, control_tube, Some(&reset_signal)).fuse();
784
785    let f_ctrl = handle_ctrl_queue(
786        ex,
787        streams,
788        snd_data,
789        ctrl_queue,
790        ctrl_queue_evt,
791        tx_send,
792        rx_send,
793        card_index,
794        Some(&reset_signal),
795    )
796    .fuse();
797
798    // TODO(woodychow): Enable this when libcras sends jack connect/disconnect evts
799    // let f_event = handle_event_queue(
800    //     snd_state,
801    //     event_queue,
802    //     event_queue_evt,
803    // );
804    let f_tx = handle_pcm_queue(
805        streams,
806        tx_send2,
807        tx_queue.clone(),
808        tx_queue_evt,
809        card_index,
810        Some(&reset_signal),
811    )
812    .fuse();
813    let f_tx_response = send_pcm_response_worker(tx_queue, tx_recv, Some(&reset_signal)).fuse();
814    let f_rx = handle_pcm_queue(
815        streams,
816        rx_send2,
817        rx_queue.clone(),
818        rx_queue_evt,
819        card_index,
820        Some(&reset_signal),
821    )
822    .fuse();
823    let f_rx_response = send_pcm_response_worker(rx_queue, rx_recv, Some(&reset_signal)).fuse();
824
825    pin_mut!(
826        f_host_ctrl,
827        f_ctrl,
828        f_tx,
829        f_tx_response,
830        f_rx,
831        f_rx_response
832    );
833
834    let done = async {
835        select! {
836            res = f_host_ctrl => (res.context("error in handling host control command"), LoopState::Continue),
837            res = f_ctrl => (res.context("error in handling ctrl queue"), LoopState::Continue),
838            res = f_tx => (res.context("error in handling tx queue"), LoopState::Continue),
839            res = f_tx_response => (res.context("error in handling tx response"), LoopState::Continue),
840            res = f_rx => (res.context("error in handling rx queue"), LoopState::Continue),
841            res = f_rx_response => (res.context("error in handling rx response"), LoopState::Continue),
842
843            // For following workers, do not continue the loop
844            res = f_kill => (res.context("error in await_and_exit"), LoopState::Break),
845        }
846    };
847
848    match ex.run_until(done) {
849        Ok((res, loop_state)) => {
850            if let Err(e) = res {
851                error!("Error in worker: {:#}", e);
852            }
853            if loop_state == LoopState::Break {
854                return LoopState::Break;
855            }
856        }
857        Err(e) => {
858            error!("Error happened in executor: {}", e);
859        }
860    }
861
862    warn!("Shutting down all workers for reset procedure");
863    block_on(notify_reset_signal(&reset_signal));
864
865    let shutdown = async {
866        loop {
867            let (res, worker_name) = select!(
868                res = f_ctrl => (res, "f_ctrl"),
869                res = f_tx => (res, "f_tx"),
870                res = f_tx_response => (res, "f_tx_response"),
871                res = f_rx => (res, "f_rx"),
872                res = f_rx_response => (res, "f_rx_response"),
873                complete => break,
874            );
875            match res {
876                Ok(_) => debug!("Worker {} stopped", worker_name),
877                Err(e) => error!("Worker {} stopped with error {}", worker_name, e),
878            };
879        }
880    };
881
882    if let Err(e) = ex.run_until(shutdown) {
883        error!("Error happened in executor while shutdown: {}", e);
884        return LoopState::Break;
885    }
886
887    LoopState::Continue
888}
889
890fn reset_streams(
891    ex: &Executor,
892    streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
893    tx_queue: &Rc<AsyncRwLock<Queue>>,
894    tx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
895    rx_queue: &Rc<AsyncRwLock<Queue>>,
896    rx_recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
897) -> Result<(), AsyncError> {
898    let reset_signal = (AsyncRwLock::new(false), Condvar::new());
899
900    let do_reset = async {
901        let streams = streams.read_lock().await;
902        for stream_info in &*streams {
903            let mut stream_info = stream_info.lock().await;
904            if stream_info.state == VIRTIO_SND_R_PCM_START {
905                if let Err(e) = stream_info.stop().await {
906                    error!("Error on stop while resetting stream: {}", e);
907                }
908            }
909            if stream_info.state == VIRTIO_SND_R_PCM_STOP
910                || stream_info.state == VIRTIO_SND_R_PCM_PREPARE
911            {
912                if let Err(e) = stream_info.release().await {
913                    error!("Error on release while resetting stream: {}", e);
914                }
915            }
916            stream_info.just_reset = true;
917        }
918
919        notify_reset_signal(&reset_signal).await;
920    };
921
922    // Run these in a loop to ensure that they will survive until do_reset is finished
923    let f_tx_response = async {
924        while send_pcm_response_worker(tx_queue.clone(), tx_recv, Some(&reset_signal))
925            .await
926            .is_err()
927        {}
928    };
929
930    let f_rx_response = async {
931        while send_pcm_response_worker(rx_queue.clone(), rx_recv, Some(&reset_signal))
932            .await
933            .is_err()
934        {}
935    };
936
937    let reset = async {
938        join!(f_tx_response, f_rx_response, do_reset);
939    };
940
941    ex.run_until(reset)
942}
943
944#[cfg(test)]
945#[allow(clippy::needless_update)]
946mod tests {
947    use audio_streams::StreamEffect;
948
949    use super::*;
950    use crate::virtio::snd::parameters::PCMDeviceParameters;
951
952    #[test]
953    fn test_virtio_snd_new() {
954        let params = Parameters {
955            num_output_devices: 3,
956            num_input_devices: 2,
957            num_output_streams: 3,
958            num_input_streams: 2,
959            output_device_config: vec![PCMDeviceParameters {
960                effects: Some(vec![StreamEffect::EchoCancellation]),
961                ..PCMDeviceParameters::default()
962            }],
963            input_device_config: vec![PCMDeviceParameters {
964                effects: Some(vec![StreamEffect::EchoCancellation]),
965                ..PCMDeviceParameters::default()
966            }],
967            ..Default::default()
968        };
969
970        let (t0, _t1) = Tube::pair().expect("failed to create tube");
971        let res = VirtioSnd::new(123, params, t0).unwrap();
972
973        // Default values
974        assert_eq!(res.snd_data.jack_info.len(), 0);
975        assert_eq!(res.acked_features, 0);
976        assert_eq!(res.worker_thread.is_none(), true);
977
978        assert_eq!(res.avail_features, 123); // avail_features must be equal to the input
979        assert_eq!(res.cfg.jacks.to_native(), 0);
980        assert_eq!(res.cfg.streams.to_native(), 13); // (Output = 3*3) + (Input = 2*2)
981        assert_eq!(res.cfg.chmaps.to_native(), 11); // (Output = 3*3) + (Input = 2*1)
982
983        // Check snd_data.pcm_info
984        assert_eq!(res.snd_data.pcm_info.len(), 13);
985        // Check hda_fn_nid (PCM Device number)
986        let expected_hda_fn_nid = [0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 1, 1];
987        for (i, pcm_info) in res.snd_data.pcm_info.iter().enumerate() {
988            assert_eq!(
989                pcm_info.hdr.hda_fn_nid.to_native(),
990                expected_hda_fn_nid[i],
991                "pcm_info index {i} incorrect hda_fn_nid"
992            );
993        }
994        // First 9 devices must be OUTPUT
995        for i in 0..9 {
996            assert_eq!(
997                res.snd_data.pcm_info[i].direction, VIRTIO_SND_D_OUTPUT,
998                "pcm_info index {i} incorrect direction"
999            );
1000        }
1001        // Next 4 devices must be INPUT
1002        for i in 9..13 {
1003            assert_eq!(
1004                res.snd_data.pcm_info[i].direction, VIRTIO_SND_D_INPUT,
1005                "pcm_info index {i} incorrect direction"
1006            );
1007        }
1008
1009        // Check snd_data.chmap_info
1010        assert_eq!(res.snd_data.chmap_info.len(), 11);
1011        let expected_hda_fn_nid = [0, 1, 2, 0, 1, 0, 1, 2, 0, 1, 2];
1012        // Check hda_fn_nid (PCM Device number)
1013        for (i, chmap_info) in res.snd_data.chmap_info.iter().enumerate() {
1014            assert_eq!(
1015                chmap_info.hdr.hda_fn_nid.to_native(),
1016                expected_hda_fn_nid[i],
1017                "chmap_info index {i} incorrect hda_fn_nid"
1018            );
1019        }
1020    }
1021
1022    #[test]
1023    fn test_resize_parameters_pcm_device_config_truncate() {
1024        // If pcm_device_config is larger than number of devices, it will be truncated
1025        let params = Parameters {
1026            num_output_devices: 1,
1027            num_input_devices: 1,
1028            output_device_config: vec![PCMDeviceParameters::default(); 3],
1029            input_device_config: vec![PCMDeviceParameters::default(); 3],
1030            ..Parameters::default()
1031        };
1032        let params = resize_parameters_pcm_device_config(params);
1033        assert_eq!(params.output_device_config.len(), 1);
1034        assert_eq!(params.input_device_config.len(), 1);
1035    }
1036
1037    #[test]
1038    fn test_resize_parameters_pcm_device_config_extend() {
1039        let params = Parameters {
1040            num_output_devices: 3,
1041            num_input_devices: 2,
1042            num_output_streams: 3,
1043            num_input_streams: 2,
1044            output_device_config: vec![PCMDeviceParameters {
1045                effects: Some(vec![StreamEffect::EchoCancellation]),
1046                ..PCMDeviceParameters::default()
1047            }],
1048            input_device_config: vec![PCMDeviceParameters {
1049                effects: Some(vec![StreamEffect::EchoCancellation]),
1050                ..PCMDeviceParameters::default()
1051            }],
1052            ..Default::default()
1053        };
1054
1055        let params = resize_parameters_pcm_device_config(params);
1056
1057        // Check output_device_config correctly extended
1058        assert_eq!(
1059            params.output_device_config,
1060            vec![
1061                PCMDeviceParameters {
1062                    // Keep from the parameters
1063                    effects: Some(vec![StreamEffect::EchoCancellation]),
1064                    ..PCMDeviceParameters::default()
1065                },
1066                PCMDeviceParameters::default(), // Extended with default
1067                PCMDeviceParameters::default(), // Extended with default
1068            ]
1069        );
1070
1071        // Check input_device_config correctly extended
1072        assert_eq!(
1073            params.input_device_config,
1074            vec![
1075                PCMDeviceParameters {
1076                    // Keep from the parameters
1077                    effects: Some(vec![StreamEffect::EchoCancellation]),
1078                    ..PCMDeviceParameters::default()
1079                },
1080                PCMDeviceParameters::default(), // Extended with default
1081            ]
1082        );
1083    }
1084}