devices/virtio/snd/common_backend/
async_funcs.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::fmt;
6use std::io;
7use std::io::Read;
8use std::io::Write;
9use std::rc::Rc;
10use std::sync::atomic::AtomicBool;
11use std::sync::atomic::Ordering;
12use std::time::Duration;
13
14use async_trait::async_trait;
15use audio_streams::capture::AsyncCaptureBuffer;
16use audio_streams::AsyncPlaybackBuffer;
17use audio_streams::BoxError;
18use base::debug;
19use base::error;
20use base::info;
21use cros_async::sync::Condvar;
22use cros_async::sync::RwLock as AsyncRwLock;
23use cros_async::AsyncTube;
24use cros_async::EventAsync;
25use cros_async::Executor;
26use cros_async::TimerAsync;
27use futures::channel::mpsc;
28use futures::channel::oneshot;
29use futures::pin_mut;
30use futures::select;
31use futures::FutureExt;
32use futures::SinkExt;
33use futures::StreamExt;
34use thiserror::Error as ThisError;
35use vm_control::SndControlCommand;
36use vm_control::VmResponse;
37use zerocopy::IntoBytes;
38
39use super::Error;
40use super::SndData;
41use super::WorkerStatus;
42use crate::virtio::snd::common::*;
43use crate::virtio::snd::common_backend::stream_info::SetParams;
44use crate::virtio::snd::common_backend::stream_info::StreamInfo;
45use crate::virtio::snd::common_backend::DirectionalStream;
46use crate::virtio::snd::common_backend::PcmResponse;
47use crate::virtio::snd::constants::*;
48use crate::virtio::snd::layout::*;
49use crate::virtio::DescriptorChain;
50use crate::virtio::Queue;
51use crate::virtio::Reader;
52use crate::virtio::Writer;
53
54/// Trait to wrap system specific helpers for reading from the start point capture buffer.
55#[async_trait(?Send)]
56pub trait CaptureBufferReader {
57    async fn get_next_capture_period(
58        &mut self,
59        ex: &Executor,
60    ) -> Result<AsyncCaptureBuffer, BoxError>;
61}
62
63/// Trait to wrap system specific helpers for writing to endpoint playback buffers.
64#[async_trait(?Send)]
65pub trait PlaybackBufferWriter {
66    fn new(guest_period_bytes: usize) -> Self
67    where
68        Self: Sized;
69
70    /// Returns the period of the endpoint device.
71    fn endpoint_period_bytes(&self) -> usize;
72
73    /// Read audio samples from the tx virtqueue.
74    fn copy_to_buffer(
75        &mut self,
76        dst_buf: &mut AsyncPlaybackBuffer<'_>,
77        reader: &mut Reader,
78    ) -> Result<usize, Error> {
79        dst_buf.copy_from(reader).map_err(Error::Io)
80    }
81}
82
83#[derive(Debug)]
84enum VirtioSndPcmCmd {
85    SetParams { set_params: SetParams },
86    Prepare,
87    Start,
88    Stop,
89    Release,
90}
91
92impl fmt::Display for VirtioSndPcmCmd {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        let cmd_code = match self {
95            VirtioSndPcmCmd::SetParams { set_params: _ } => VIRTIO_SND_R_PCM_SET_PARAMS,
96            VirtioSndPcmCmd::Prepare => VIRTIO_SND_R_PCM_PREPARE,
97            VirtioSndPcmCmd::Start => VIRTIO_SND_R_PCM_START,
98            VirtioSndPcmCmd::Stop => VIRTIO_SND_R_PCM_STOP,
99            VirtioSndPcmCmd::Release => VIRTIO_SND_R_PCM_RELEASE,
100        };
101        f.write_str(get_virtio_snd_r_pcm_cmd_name(cmd_code))
102    }
103}
104
105#[derive(ThisError, Debug)]
106enum VirtioSndPcmCmdError {
107    #[error("SetParams requires additional parameters")]
108    SetParams,
109    #[error("Invalid virtio snd command code")]
110    InvalidCode,
111}
112
113impl TryFrom<u32> for VirtioSndPcmCmd {
114    type Error = VirtioSndPcmCmdError;
115
116    fn try_from(code: u32) -> Result<Self, Self::Error> {
117        match code {
118            VIRTIO_SND_R_PCM_PREPARE => Ok(VirtioSndPcmCmd::Prepare),
119            VIRTIO_SND_R_PCM_START => Ok(VirtioSndPcmCmd::Start),
120            VIRTIO_SND_R_PCM_STOP => Ok(VirtioSndPcmCmd::Stop),
121            VIRTIO_SND_R_PCM_RELEASE => Ok(VirtioSndPcmCmd::Release),
122            VIRTIO_SND_R_PCM_SET_PARAMS => Err(VirtioSndPcmCmdError::SetParams),
123            _ => Err(VirtioSndPcmCmdError::InvalidCode),
124        }
125    }
126}
127
128impl VirtioSndPcmCmd {
129    fn with_set_params_and_direction(
130        set_params: virtio_snd_pcm_set_params,
131        dir: u8,
132    ) -> VirtioSndPcmCmd {
133        let buffer_bytes: u32 = set_params.buffer_bytes.into();
134        let period_bytes: u32 = set_params.period_bytes.into();
135        VirtioSndPcmCmd::SetParams {
136            set_params: SetParams {
137                channels: set_params.channels,
138                format: from_virtio_sample_format(set_params.format).unwrap(),
139                frame_rate: from_virtio_frame_rate(set_params.rate).unwrap(),
140                buffer_bytes: buffer_bytes as usize,
141                period_bytes: period_bytes as usize,
142                dir,
143            },
144        }
145    }
146}
147
148// Returns true if the operation is successful. Returns error if there is
149// a runtime/internal error
150async fn process_pcm_ctrl(
151    ex: &Executor,
152    tx_send: &mpsc::UnboundedSender<PcmResponse>,
153    rx_send: &mpsc::UnboundedSender<PcmResponse>,
154    streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
155    cmd: VirtioSndPcmCmd,
156    writer: &mut Writer,
157    stream_id: usize,
158    card_index: usize,
159) -> Result<(), Error> {
160    let streams = streams.read_lock().await;
161    let mut stream = match streams.get(stream_id) {
162        Some(stream_info) => stream_info.lock().await,
163        None => {
164            error!(
165                "[Card {}] Stream id={} not found for {}. Error code: VIRTIO_SND_S_BAD_MSG",
166                card_index, stream_id, cmd
167            );
168            return writer
169                .write_obj(VIRTIO_SND_S_BAD_MSG)
170                .map_err(Error::WriteResponse);
171        }
172    };
173
174    debug!("[Card {}] {} for stream id={}", card_index, cmd, stream_id);
175
176    let result = match cmd {
177        VirtioSndPcmCmd::SetParams { set_params } => {
178            let result = stream.set_params(set_params).await;
179            if result.is_ok() {
180                debug!(
181                    "[Card {}] VIRTIO_SND_R_PCM_SET_PARAMS for stream id={}. Stream info: {:#?}",
182                    card_index, stream_id, *stream
183                );
184            }
185            result
186        }
187        VirtioSndPcmCmd::Prepare => stream.prepare(ex, tx_send, rx_send).await,
188        VirtioSndPcmCmd::Start => stream.start().await,
189        VirtioSndPcmCmd::Stop => stream.stop().await,
190        VirtioSndPcmCmd::Release => stream.release().await,
191    };
192    match result {
193        Ok(_) => writer
194            .write_obj(VIRTIO_SND_S_OK)
195            .map_err(Error::WriteResponse),
196        Err(Error::OperationNotSupported) => {
197            error!(
198                "[Card {}] {} for stream id={} failed. Error code: VIRTIO_SND_S_NOT_SUPP.",
199                card_index, cmd, stream_id
200            );
201
202            writer
203                .write_obj(VIRTIO_SND_S_NOT_SUPP)
204                .map_err(Error::WriteResponse)
205        }
206        Err(e) => {
207            // Runtime/internal error would be more appropriate, but there's
208            // no such error type
209            error!(
210                "[Card {}] {} for stream id={} failed. Error code: VIRTIO_SND_S_IO_ERR. Actual error: {}",
211                card_index, cmd, stream_id, e
212            );
213            writer
214                .write_obj(VIRTIO_SND_S_IO_ERR)
215                .map_err(Error::WriteResponse)
216        }
217    }
218}
219
220async fn write_data(
221    mut dst_buf: AsyncPlaybackBuffer<'_>,
222    reader: Option<&mut Reader>,
223    buffer_writer: &mut Box<dyn PlaybackBufferWriter>,
224) -> Result<u32, Error> {
225    let transferred = match reader {
226        Some(reader) => buffer_writer.copy_to_buffer(&mut dst_buf, reader)?,
227        None => dst_buf
228            .copy_from(&mut io::repeat(0).take(buffer_writer.endpoint_period_bytes() as u64))
229            .map_err(Error::Io)?,
230    };
231
232    if transferred != buffer_writer.endpoint_period_bytes() {
233        error!(
234            "Bytes written {} != period_bytes {}",
235            transferred,
236            buffer_writer.endpoint_period_bytes()
237        );
238        Err(Error::InvalidBufferSize)
239    } else {
240        dst_buf.commit().await;
241        Ok(dst_buf.latency_bytes())
242    }
243}
244
245async fn read_data(
246    mut src_buf: AsyncCaptureBuffer<'_>,
247    writer: Option<&mut Writer>,
248    period_bytes: usize,
249) -> Result<u32, Error> {
250    let transferred = match writer {
251        Some(writer) => src_buf.copy_to(writer),
252        None => src_buf.copy_to(&mut io::sink()),
253    }
254    .map_err(Error::Io)?;
255    if transferred != period_bytes {
256        error!(
257            "Bytes written {} != period_bytes {}",
258            transferred, period_bytes
259        );
260        Err(Error::InvalidBufferSize)
261    } else {
262        src_buf.commit().await;
263        Ok(src_buf.latency_bytes())
264    }
265}
266
267impl From<Result<u32, Error>> for virtio_snd_pcm_status {
268    fn from(res: Result<u32, Error>) -> Self {
269        match res {
270            Ok(latency_bytes) => virtio_snd_pcm_status::new(StatusCode::OK, latency_bytes),
271            Err(e) => {
272                error!("PCM I/O message failed: {}", e);
273                virtio_snd_pcm_status::new(StatusCode::IoErr, 0)
274            }
275        }
276    }
277}
278
279// Drain all DescriptorChain in desc_receiver during WorkerStatus::Quit process.
280async fn drain_desc_receiver(
281    desc_receiver: &mut mpsc::UnboundedReceiver<DescriptorChain>,
282    sender: &mut mpsc::UnboundedSender<PcmResponse>,
283) -> Result<(), Error> {
284    let mut o_desc_chain = desc_receiver.next().await;
285    while let Some(desc_chain) = o_desc_chain {
286        // From the virtio-snd spec:
287        // The device MUST complete all pending I/O messages for the specified stream ID.
288        let status = virtio_snd_pcm_status::new(StatusCode::OK, 0);
289        // Fetch next DescriptorChain to see if the current one is the last one.
290        o_desc_chain = desc_receiver.next().await;
291        let (done, future) = if o_desc_chain.is_none() {
292            let (done, future) = oneshot::channel();
293            (Some(done), Some(future))
294        } else {
295            (None, None)
296        };
297        sender
298            .send(PcmResponse {
299                desc_chain,
300                status,
301                done,
302            })
303            .await
304            .map_err(Error::MpscSend)?;
305
306        if let Some(f) = future {
307            // From the virtio-snd spec:
308            // The device MUST NOT complete the control request (VIRTIO_SND_R_PCM_RELEASE)
309            // while there are pending I/O messages for the specified stream ID.
310            f.await.map_err(Error::DoneNotTriggered)?;
311        };
312    }
313    Ok(())
314}
315
316/// Start a pcm worker that receives descriptors containing PCM frames (audio data) from the tx/rx
317/// queue, and forward them to CRAS. One pcm worker per stream.
318///
319/// This worker is started when VIRTIO_SND_R_PCM_PREPARE is called, and returned before
320/// VIRTIO_SND_R_PCM_RELEASE is completed for the stream.
321pub async fn start_pcm_worker(
322    ex: Executor,
323    dstream: DirectionalStream,
324    mut desc_receiver: mpsc::UnboundedReceiver<DescriptorChain>,
325    status_mutex: Rc<AsyncRwLock<WorkerStatus>>,
326    mut sender: mpsc::UnboundedSender<PcmResponse>,
327    period_dur: Duration,
328    card_index: usize,
329    muted: Rc<AtomicBool>,
330    release_signal: Rc<(AsyncRwLock<bool>, Condvar)>,
331) -> Result<(), Error> {
332    let res = pcm_worker_loop(
333        ex,
334        dstream,
335        &mut desc_receiver,
336        &status_mutex,
337        &mut sender,
338        period_dur,
339        card_index,
340        muted,
341        release_signal,
342    )
343    .await;
344    *status_mutex.lock().await = WorkerStatus::Quit;
345    if res.is_err() {
346        error!(
347            "[Card {}] pcm_worker error: {:#?}. Draining desc_receiver",
348            card_index,
349            res.as_ref().err()
350        );
351        // On error, guaranteed that desc_receiver has not been drained, so drain it here.
352        // Note that drain blocks until the stream is release.
353        drain_desc_receiver(&mut desc_receiver, &mut sender).await?;
354    }
355    res
356}
357
358async fn pcm_worker_loop(
359    ex: Executor,
360    dstream: DirectionalStream,
361    desc_receiver: &mut mpsc::UnboundedReceiver<DescriptorChain>,
362    status_mutex: &Rc<AsyncRwLock<WorkerStatus>>,
363    sender: &mut mpsc::UnboundedSender<PcmResponse>,
364    period_dur: Duration,
365    card_index: usize,
366    muted: Rc<AtomicBool>,
367    release_signal: Rc<(AsyncRwLock<bool>, Condvar)>,
368) -> Result<(), Error> {
369    let on_release = async {
370        await_reset_signal(Some(&*release_signal)).await;
371        // After receiving release signal, wait for up to 2 periods,
372        // giving it a chance to respond to the last buffer.
373        if let Err(e) = TimerAsync::sleep(&ex, period_dur * 2).await {
374            error!(
375                "[Card {}] Error on sleep after receiving reset signal: {}",
376                card_index, e
377            )
378        }
379    }
380    .fuse();
381    pin_mut!(on_release);
382
383    match dstream {
384        DirectionalStream::Output(mut sys_direction_output) => loop {
385            #[cfg(windows)]
386            let (mut stream, mut buffer_writer_lock) = (
387                sys_direction_output
388                    .async_playback_buffer_stream
389                    .lock()
390                    .await,
391                sys_direction_output.buffer_writer.lock().await,
392            );
393            #[cfg(windows)]
394            let buffer_writer = &mut buffer_writer_lock;
395            #[cfg(any(target_os = "android", target_os = "linux"))]
396            let (stream, buffer_writer) = (
397                &mut sys_direction_output.async_playback_buffer_stream,
398                &mut sys_direction_output.buffer_writer,
399            );
400
401            let next_buf = stream.next_playback_buffer(&ex).fuse();
402            pin_mut!(next_buf);
403
404            let dst_buf = select! {
405                _ = on_release => {
406                    drain_desc_receiver(desc_receiver, sender).await?;
407                    break Ok(());
408                },
409                buf = next_buf => buf.map_err(Error::FetchBuffer)?,
410            };
411            let worker_status = status_mutex.lock().await;
412            match *worker_status {
413                WorkerStatus::Quit => {
414                    drain_desc_receiver(desc_receiver, sender).await?;
415                    if let Err(e) = write_data(dst_buf, None, buffer_writer).await {
416                        error!(
417                            "[Card {}] Error on write_data after worker quit: {}",
418                            card_index, e
419                        )
420                    }
421                    break Ok(());
422                }
423                WorkerStatus::Pause => {
424                    write_data(dst_buf, None, buffer_writer).await?;
425                }
426                WorkerStatus::Running => match desc_receiver.try_next() {
427                    Err(e) => {
428                        error!(
429                            "[Card {}] Underrun. No new DescriptorChain while running: {}",
430                            card_index, e
431                        );
432                        write_data(dst_buf, None, buffer_writer).await?;
433                    }
434                    Ok(None) => {
435                        error!("[Card {}] Unreachable. status should be Quit when the channel is closed", card_index);
436                        write_data(dst_buf, None, buffer_writer).await?;
437                        return Err(Error::InvalidPCMWorkerState);
438                    }
439                    Ok(Some(mut desc_chain)) => {
440                        let reader = if muted.load(Ordering::Relaxed) {
441                            None
442                        } else {
443                            // stream_id was already read in handle_pcm_queue
444                            Some(&mut desc_chain.reader)
445                        };
446                        let status = write_data(dst_buf, reader, buffer_writer).await.into();
447                        sender
448                            .send(PcmResponse {
449                                desc_chain,
450                                status,
451                                done: None,
452                            })
453                            .await
454                            .map_err(Error::MpscSend)?;
455                    }
456                },
457            }
458        },
459        DirectionalStream::Input(period_bytes, mut buffer_reader) => loop {
460            let next_buf = buffer_reader.get_next_capture_period(&ex).fuse();
461            pin_mut!(next_buf);
462
463            let src_buf = select! {
464                _ = on_release => {
465                    drain_desc_receiver(desc_receiver, sender).await?;
466                    break Ok(());
467                },
468                buf = next_buf => buf.map_err(Error::FetchBuffer)?,
469            };
470
471            let worker_status = status_mutex.lock().await;
472            match *worker_status {
473                WorkerStatus::Quit => {
474                    drain_desc_receiver(desc_receiver, sender).await?;
475                    if let Err(e) = read_data(src_buf, None, period_bytes).await {
476                        error!(
477                            "[Card {}] Error on read_data after worker quit: {}",
478                            card_index, e
479                        )
480                    }
481                    break Ok(());
482                }
483                WorkerStatus::Pause => {
484                    read_data(src_buf, None, period_bytes).await?;
485                }
486                WorkerStatus::Running => match desc_receiver.try_next() {
487                    Err(e) => {
488                        error!(
489                            "[Card {}] Overrun. No new DescriptorChain while running: {}",
490                            card_index, e
491                        );
492                        read_data(src_buf, None, period_bytes).await?;
493                    }
494                    Ok(None) => {
495                        error!("[Card {}] Unreachable. status should be Quit when the channel is closed", card_index);
496                        read_data(src_buf, None, period_bytes).await?;
497                        return Err(Error::InvalidPCMWorkerState);
498                    }
499                    Ok(Some(mut desc_chain)) => {
500                        let writer = if muted.load(Ordering::Relaxed) {
501                            None
502                        } else {
503                            Some(&mut desc_chain.writer)
504                        };
505                        let status = read_data(src_buf, writer, period_bytes).await.into();
506                        sender
507                            .send(PcmResponse {
508                                desc_chain,
509                                status,
510                                done: None,
511                            })
512                            .await
513                            .map_err(Error::MpscSend)?;
514                    }
515                },
516            }
517        },
518    }
519}
520
521// Defer pcm message response to the pcm response worker
522async fn defer_pcm_response_to_worker(
523    desc_chain: DescriptorChain,
524    status: virtio_snd_pcm_status,
525    response_sender: &mut mpsc::UnboundedSender<PcmResponse>,
526) -> Result<(), Error> {
527    response_sender
528        .send(PcmResponse {
529            desc_chain,
530            status,
531            done: None,
532        })
533        .await
534        .map_err(Error::MpscSend)
535}
536
537fn send_pcm_response(
538    mut desc_chain: DescriptorChain,
539    queue: &mut Queue,
540    status: virtio_snd_pcm_status,
541) -> Result<(), Error> {
542    let writer = &mut desc_chain.writer;
543
544    // For rx queue only. Fast forward the unused audio data buffer.
545    if writer.available_bytes() > std::mem::size_of::<virtio_snd_pcm_status>() {
546        writer
547            .consume_bytes(writer.available_bytes() - std::mem::size_of::<virtio_snd_pcm_status>());
548    }
549    writer.write_obj(status).map_err(Error::WriteResponse)?;
550    queue.add_used(desc_chain);
551    queue.trigger_interrupt();
552    Ok(())
553}
554
555// Await until reset_signal has been released
556async fn await_reset_signal(reset_signal_option: Option<&(AsyncRwLock<bool>, Condvar)>) {
557    match reset_signal_option {
558        Some((lock, cvar)) => {
559            let mut reset = lock.lock().await;
560            while !*reset {
561                reset = cvar.wait(reset).await;
562            }
563        }
564        None => futures::future::pending().await,
565    };
566}
567
568pub async fn send_pcm_response_worker(
569    queue: Rc<AsyncRwLock<Queue>>,
570    recv: &mut mpsc::UnboundedReceiver<PcmResponse>,
571    reset_signal: Option<&(AsyncRwLock<bool>, Condvar)>,
572) -> Result<(), Error> {
573    let on_reset = await_reset_signal(reset_signal).fuse();
574    pin_mut!(on_reset);
575
576    loop {
577        let next_async = recv.next().fuse();
578        pin_mut!(next_async);
579
580        let res = select! {
581            _ = on_reset => break,
582            res = next_async => res,
583        };
584
585        if let Some(r) = res {
586            send_pcm_response(r.desc_chain, &mut *queue.lock().await, r.status)?;
587
588            // Resume pcm_worker
589            if let Some(done) = r.done {
590                done.send(()).map_err(Error::OneshotSend)?;
591            }
592        } else {
593            debug!("PcmResponse channel is closed.");
594            break;
595        }
596    }
597    Ok(())
598}
599
600/// Handle messages from the control tube. This one is not related to virtio spec.
601pub async fn handle_ctrl_tube(
602    streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
603    control_tube: &AsyncTube,
604    reset_signal: Option<&(AsyncRwLock<bool>, Condvar)>,
605) -> Result<(), Error> {
606    let on_reset = await_reset_signal(reset_signal).fuse();
607    pin_mut!(on_reset);
608
609    loop {
610        let next_async = control_tube.next().fuse();
611        pin_mut!(next_async);
612
613        let cmd = select! {
614            _ = on_reset => break,
615            res = next_async => res,
616        };
617
618        match cmd {
619            Ok(cmd) => {
620                let resp = match cmd {
621                    SndControlCommand::MuteAll(muted) => {
622                        let streams = streams.read_lock().await;
623                        for stream in streams.iter() {
624                            let stream_info = stream.lock().await;
625                            stream_info.muted.store(muted, Ordering::Relaxed);
626                            info!("Stream muted = {:?}", muted);
627                        }
628                        VmResponse::Ok
629                    }
630                };
631                control_tube
632                    .send(resp)
633                    .await
634                    .map_err(Error::ControlTubeError)?;
635            }
636            Err(e) => {
637                error!("Failed to read the command: {}", e);
638                return Err(Error::ControlTubeError(e));
639            }
640        }
641    }
642
643    Ok(())
644}
645
646/// Handle messages from the tx or the rx queue. One invocation is needed for
647/// each queue.
648pub async fn handle_pcm_queue(
649    streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
650    mut response_sender: mpsc::UnboundedSender<PcmResponse>,
651    queue: Rc<AsyncRwLock<Queue>>,
652    queue_event: &EventAsync,
653    card_index: usize,
654    reset_signal: Option<&(AsyncRwLock<bool>, Condvar)>,
655) -> Result<(), Error> {
656    let on_reset = await_reset_signal(reset_signal).fuse();
657    pin_mut!(on_reset);
658
659    loop {
660        // Manual queue.next_async() to avoid holding the mutex
661        let next_async = async {
662            loop {
663                // Check if there are more descriptors available.
664                if let Some(chain) = queue.lock().await.pop() {
665                    return Ok(chain);
666                }
667                queue_event.next_val().await?;
668            }
669        }
670        .fuse();
671        pin_mut!(next_async);
672
673        let mut desc_chain = select! {
674            _ = on_reset => break,
675            res = next_async => res.map_err(Error::Async)?,
676        };
677
678        let pcm_xfer: virtio_snd_pcm_xfer =
679            desc_chain.reader.read_obj().map_err(Error::ReadMessage)?;
680        let stream_id: usize = u32::from(pcm_xfer.stream_id) as usize;
681
682        let streams = streams.read_lock().await;
683        let stream_info = match streams.get(stream_id) {
684            Some(stream_info) => stream_info.read_lock().await,
685            None => {
686                error!(
687                    "[Card {}] stream_id ({}) >= num_streams ({})",
688                    card_index,
689                    stream_id,
690                    streams.len()
691                );
692                defer_pcm_response_to_worker(
693                    desc_chain,
694                    virtio_snd_pcm_status::new(StatusCode::IoErr, 0),
695                    &mut response_sender,
696                )
697                .await?;
698                continue;
699            }
700        };
701
702        match stream_info.sender.as_ref() {
703            Some(mut s) => {
704                s.send(desc_chain).await.map_err(Error::MpscSend)?;
705                if *stream_info.status_mutex.lock().await == WorkerStatus::Quit {
706                    // If sender channel is still intact but worker status is quit,
707                    // the worker quitted unexpectedly. Return error to request a reset.
708                    return Err(Error::PCMWorkerQuittedUnexpectedly);
709                }
710            }
711            None => {
712                if !stream_info.just_reset {
713                    error!(
714                        "[Card {}] stream {} is not ready. state: {}",
715                        card_index,
716                        stream_id,
717                        get_virtio_snd_r_pcm_cmd_name(stream_info.state)
718                    );
719                }
720                defer_pcm_response_to_worker(
721                    desc_chain,
722                    virtio_snd_pcm_status::new(StatusCode::IoErr, 0),
723                    &mut response_sender,
724                )
725                .await?;
726            }
727        };
728    }
729    Ok(())
730}
731
732/// Handle all the control messages from the ctrl queue.
733pub async fn handle_ctrl_queue(
734    ex: &Executor,
735    streams: &Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
736    snd_data: &SndData,
737    queue: Rc<AsyncRwLock<Queue>>,
738    queue_event: &mut EventAsync,
739    tx_send: mpsc::UnboundedSender<PcmResponse>,
740    rx_send: mpsc::UnboundedSender<PcmResponse>,
741    card_index: usize,
742    reset_signal: Option<&(AsyncRwLock<bool>, Condvar)>,
743) -> Result<(), Error> {
744    let on_reset = await_reset_signal(reset_signal).fuse();
745    pin_mut!(on_reset);
746
747    let mut queue = queue.lock().await;
748    loop {
749        let mut desc_chain = {
750            let next_async = queue.next_async(queue_event).fuse();
751            pin_mut!(next_async);
752
753            select! {
754                _ = on_reset => break,
755                res = next_async => res.map_err(Error::Async)?,
756            }
757        };
758
759        let reader = &mut desc_chain.reader;
760        let writer = &mut desc_chain.writer;
761        // Don't advance the reader
762        let code = reader
763            .peek_obj::<virtio_snd_hdr>()
764            .map_err(Error::ReadMessage)?
765            .code
766            .into();
767
768        let handle_ctrl_msg = async {
769            match code {
770                VIRTIO_SND_R_JACK_INFO => {
771                    let query_info: virtio_snd_query_info =
772                        reader.read_obj().map_err(Error::ReadMessage)?;
773                    let start_id: usize = u32::from(query_info.start_id) as usize;
774                    let count: usize = u32::from(query_info.count) as usize;
775                    if start_id + count > snd_data.jack_info.len() {
776                        error!(
777                            "[Card {}] start_id({}) + count({}) must be smaller than \
778                            the number of jacks ({})",
779                            card_index,
780                            start_id,
781                            count,
782                            snd_data.jack_info.len()
783                        );
784                        return writer
785                            .write_obj(VIRTIO_SND_S_BAD_MSG)
786                            .map_err(Error::WriteResponse);
787                    }
788                    // The response consists of the virtio_snd_hdr structure (contains the request
789                    // status code), followed by the device-writable information structures of the
790                    // item. Each information structure begins with the following common header
791                    writer
792                        .write_obj(VIRTIO_SND_S_OK)
793                        .map_err(Error::WriteResponse)?;
794                    for i in start_id..(start_id + count) {
795                        writer
796                            .write_all(snd_data.jack_info[i].as_bytes())
797                            .map_err(Error::WriteResponse)?;
798                    }
799                    Ok(())
800                }
801                VIRTIO_SND_R_PCM_INFO => {
802                    let query_info: virtio_snd_query_info =
803                        reader.read_obj().map_err(Error::ReadMessage)?;
804                    let start_id: usize = u32::from(query_info.start_id) as usize;
805                    let count: usize = u32::from(query_info.count) as usize;
806                    if start_id + count > snd_data.pcm_info.len() {
807                        error!(
808                            "[Card {}] start_id({}) + count({}) must be smaller than \
809                            the number of streams ({})",
810                            card_index,
811                            start_id,
812                            count,
813                            snd_data.pcm_info.len()
814                        );
815                        return writer
816                            .write_obj(VIRTIO_SND_S_BAD_MSG)
817                            .map_err(Error::WriteResponse);
818                    }
819                    // The response consists of the virtio_snd_hdr structure (contains the request
820                    // status code), followed by the device-writable information structures of the
821                    // item. Each information structure begins with the following common header
822                    writer
823                        .write_obj(VIRTIO_SND_S_OK)
824                        .map_err(Error::WriteResponse)?;
825                    for i in start_id..(start_id + count) {
826                        writer
827                            .write_all(snd_data.pcm_info[i].as_bytes())
828                            .map_err(Error::WriteResponse)?;
829                    }
830                    Ok(())
831                }
832                VIRTIO_SND_R_CHMAP_INFO => {
833                    let query_info: virtio_snd_query_info =
834                        reader.read_obj().map_err(Error::ReadMessage)?;
835                    let start_id: usize = u32::from(query_info.start_id) as usize;
836                    let count: usize = u32::from(query_info.count) as usize;
837                    if start_id + count > snd_data.chmap_info.len() {
838                        error!(
839                            "[Card {}] start_id({}) + count({}) must be smaller than \
840                            the number of chmaps ({})",
841                            card_index,
842                            start_id,
843                            count,
844                            snd_data.chmap_info.len()
845                        );
846                        return writer
847                            .write_obj(VIRTIO_SND_S_BAD_MSG)
848                            .map_err(Error::WriteResponse);
849                    }
850                    // The response consists of the virtio_snd_hdr structure (contains the request
851                    // status code), followed by the device-writable information structures of the
852                    // item. Each information structure begins with the following common header
853                    writer
854                        .write_obj(VIRTIO_SND_S_OK)
855                        .map_err(Error::WriteResponse)?;
856                    for i in start_id..(start_id + count) {
857                        writer
858                            .write_all(snd_data.chmap_info[i].as_bytes())
859                            .map_err(Error::WriteResponse)?;
860                    }
861                    Ok(())
862                }
863                VIRTIO_SND_R_JACK_REMAP => {
864                    unreachable!("remap is unsupported");
865                }
866                VIRTIO_SND_R_PCM_SET_PARAMS => {
867                    // Raise VIRTIO_SND_S_BAD_MSG or IO error?
868                    let set_params: virtio_snd_pcm_set_params =
869                        reader.read_obj().map_err(Error::ReadMessage)?;
870                    let stream_id: usize = u32::from(set_params.hdr.stream_id) as usize;
871                    let buffer_bytes: u32 = set_params.buffer_bytes.into();
872                    let period_bytes: u32 = set_params.period_bytes.into();
873
874                    let dir = match snd_data.pcm_info.get(stream_id) {
875                        Some(pcm_info) => {
876                            if set_params.channels < pcm_info.channels_min
877                                || set_params.channels > pcm_info.channels_max
878                            {
879                                error!(
880                                    "[Card {}] Number of channels ({}) must be between {} and {}",
881                                    card_index,
882                                    set_params.channels,
883                                    pcm_info.channels_min,
884                                    pcm_info.channels_max
885                                );
886                                return writer
887                                    .write_obj(VIRTIO_SND_S_NOT_SUPP)
888                                    .map_err(Error::WriteResponse);
889                            }
890                            if (u64::from(pcm_info.formats) & (1 << set_params.format)) == 0 {
891                                error!(
892                                    "[Card {}] PCM format {} is not supported.",
893                                    card_index, set_params.format
894                                );
895                                return writer
896                                    .write_obj(VIRTIO_SND_S_NOT_SUPP)
897                                    .map_err(Error::WriteResponse);
898                            }
899                            if (u64::from(pcm_info.rates) & (1 << set_params.rate)) == 0 {
900                                error!(
901                                    "[Card {}] PCM frame rate {} is not supported.",
902                                    card_index, set_params.rate
903                                );
904                                return writer
905                                    .write_obj(VIRTIO_SND_S_NOT_SUPP)
906                                    .map_err(Error::WriteResponse);
907                            }
908
909                            pcm_info.direction
910                        }
911                        None => {
912                            error!(
913                                "[Card {}] stream_id {} < streams {}",
914                                card_index,
915                                stream_id,
916                                snd_data.pcm_info.len()
917                            );
918                            return writer
919                                .write_obj(VIRTIO_SND_S_BAD_MSG)
920                                .map_err(Error::WriteResponse);
921                        }
922                    };
923
924                    if set_params.features != 0 {
925                        error!("[Card {}] No feature is supported", card_index);
926                        return writer
927                            .write_obj(VIRTIO_SND_S_NOT_SUPP)
928                            .map_err(Error::WriteResponse);
929                    }
930
931                    if buffer_bytes % period_bytes != 0 {
932                        error!(
933                            "[Card {}] buffer_bytes({}) must be dividable by period_bytes({})",
934                            card_index, buffer_bytes, period_bytes
935                        );
936                        return writer
937                            .write_obj(VIRTIO_SND_S_BAD_MSG)
938                            .map_err(Error::WriteResponse);
939                    }
940
941                    process_pcm_ctrl(
942                        ex,
943                        &tx_send,
944                        &rx_send,
945                        streams,
946                        VirtioSndPcmCmd::with_set_params_and_direction(set_params, dir),
947                        writer,
948                        stream_id,
949                        card_index,
950                    )
951                    .await
952                }
953                VIRTIO_SND_R_PCM_PREPARE
954                | VIRTIO_SND_R_PCM_START
955                | VIRTIO_SND_R_PCM_STOP
956                | VIRTIO_SND_R_PCM_RELEASE => {
957                    let hdr: virtio_snd_pcm_hdr = reader.read_obj().map_err(Error::ReadMessage)?;
958                    let stream_id: usize = u32::from(hdr.stream_id) as usize;
959                    let cmd = match VirtioSndPcmCmd::try_from(code) {
960                        Ok(cmd) => cmd,
961                        Err(err) => {
962                            error!(
963                                "[Card {}] Error converting code to command: {}",
964                                card_index, err
965                            );
966                            return writer
967                                .write_obj(VIRTIO_SND_S_BAD_MSG)
968                                .map_err(Error::WriteResponse);
969                        }
970                    };
971                    process_pcm_ctrl(
972                        ex, &tx_send, &rx_send, streams, cmd, writer, stream_id, card_index,
973                    )
974                    .await
975                    .and(Ok(()))?;
976                    Ok(())
977                }
978                c => {
979                    error!("[Card {}] Unrecognized code: {}", card_index, c);
980                    writer
981                        .write_obj(VIRTIO_SND_S_BAD_MSG)
982                        .map_err(Error::WriteResponse)
983                }
984            }
985        };
986
987        handle_ctrl_msg.await?;
988        queue.add_used(desc_chain);
989        queue.trigger_interrupt();
990    }
991    Ok(())
992}
993
994/// Send events to the audio driver.
995pub async fn handle_event_queue(
996    mut queue: Queue,
997    mut queue_event: EventAsync,
998) -> Result<(), Error> {
999    loop {
1000        let desc_chain = queue
1001            .next_async(&mut queue_event)
1002            .await
1003            .map_err(Error::Async)?;
1004
1005        // TODO(woodychow): Poll and forward events from cras asynchronously (API to be added)
1006        queue.add_used(desc_chain);
1007        queue.trigger_interrupt();
1008    }
1009}
1010
1011#[cfg(test)]
1012mod tests {
1013    use std::sync::Arc;
1014
1015    use audio_streams::NoopStreamSourceGenerator;
1016    use base::Tube;
1017
1018    use super::*;
1019    use crate::virtio::snd::common_backend::notify_reset_signal;
1020
1021    #[test]
1022    fn test_handle_ctrl_tube_reset_signal() {
1023        let ex = Executor::new().expect("Failed to create an executor");
1024        let result = ex.run_until(async {
1025            let streams: Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>> = Default::default();
1026            let (t0, _t1) = Tube::pair().expect("Failed to create tube pairs");
1027            let t0 = AsyncTube::new(&ex, t0).expect("Failed to create async tube");
1028            let reset_signal = (AsyncRwLock::new(false), Condvar::new());
1029
1030            let handle_future = handle_ctrl_tube(&streams, &t0, Some(&reset_signal));
1031            let notify_future = notify_reset_signal(&reset_signal);
1032            let (result, _) = futures::join!(handle_future, notify_future);
1033
1034            assert!(
1035                result.is_ok(),
1036                "handle_ctrl_tube returns an error after reset signal"
1037            );
1038        });
1039
1040        assert!(result.is_ok(), "ex.run_until returns an error");
1041    }
1042
1043    fn new_stream() -> StreamInfo {
1044        let card_index = 0;
1045        StreamInfo::builder(
1046            Arc::new(Box::new(NoopStreamSourceGenerator::new())),
1047            card_index,
1048        )
1049        .build()
1050    }
1051
1052    #[test]
1053    fn test_handle_ctrl_tube_receive_mute_cmd() {
1054        let ex = Executor::new().expect("Failed to create an executor");
1055        let result = ex.run_until(async {
1056            let streams: Vec<AsyncRwLock<StreamInfo>> = vec![AsyncRwLock::new(new_stream())];
1057            let streams = Rc::new(AsyncRwLock::new(streams));
1058
1059            let (t0, t1) = Tube::pair().expect("Failed to create tube pairs");
1060            let t0 = AsyncTube::new(&ex, t0).expect("Failed to create an async tube");
1061            let t1 = AsyncTube::new(&ex, t1).expect("Failed to create an async tube");
1062            let reset_signal = (AsyncRwLock::new(false), Condvar::new());
1063
1064            let handle_future = handle_ctrl_tube(&streams, &t0, Some(&reset_signal));
1065            let tube_future = async {
1066                let _ = t1.send(&SndControlCommand::MuteAll(true)).await;
1067                let recv_result = t1.next::<VmResponse>().await;
1068                notify_reset_signal(&reset_signal).await;
1069                recv_result
1070            };
1071            let (handle_result, tube_result) = futures::join!(handle_future, tube_future);
1072
1073            assert!(
1074                handle_result.is_ok(),
1075                "handle_ctrl_tube returns an error after reset signal"
1076            );
1077            assert!(tube_result.is_ok(), "Failed to receive data from the tube");
1078            assert!(
1079                matches!(tube_result.unwrap(), VmResponse::Ok),
1080                "tube_result is not Ok",
1081            );
1082
1083            let streams = streams.read_lock().await;
1084            let stream = streams.first().unwrap().lock().await;
1085            assert!(stream.muted.load(Ordering::Relaxed), "Stream is not muted");
1086        });
1087
1088        assert!(result.is_ok(), "ex.run_until returns an error");
1089    }
1090}