audio_streams/
shm_streams.rs

1// Copyright 2019 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#[cfg(any(target_os = "android", target_os = "linux"))]
6use std::os::unix::io::RawFd;
7use std::sync::Arc;
8use std::sync::Condvar;
9use std::sync::Mutex;
10use std::time::Duration;
11use std::time::Instant;
12
13use remain::sorted;
14use thiserror::Error;
15
16use crate::BoxError;
17use crate::SampleFormat;
18use crate::StreamDirection;
19use crate::StreamEffect;
20
21type GenericResult<T> = std::result::Result<T, BoxError>;
22
23/// `BufferSet` is used as a callback mechanism for `ServerRequest` objects.
24/// It is meant to be implemented by the audio stream, allowing arbitrary code
25/// to be run after a buffer offset and length is set.
26pub trait BufferSet {
27    /// Called when the client sets a buffer offset and length.
28    ///
29    /// `offset` is the offset within shared memory of the buffer and `frames`
30    /// indicates the number of audio frames that can be read from or written to
31    /// the buffer.
32    fn callback(&mut self, offset: usize, frames: usize) -> GenericResult<()>;
33
34    /// Called when the client ignores a request from the server.
35    fn ignore(&mut self) -> GenericResult<()>;
36}
37
38#[sorted]
39#[derive(Error, Debug)]
40pub enum Error {
41    #[error("Provided number of frames {0} exceeds requested number of frames {1}")]
42    TooManyFrames(usize, usize),
43}
44
45/// `ServerRequest` represents an active request from the server for the client
46/// to provide a buffer in shared memory to playback from or capture to.
47pub struct ServerRequest<'a> {
48    requested_frames: usize,
49    buffer_set: &'a mut dyn BufferSet,
50}
51
52impl<'a> ServerRequest<'a> {
53    /// Create a new ServerRequest object
54    ///
55    /// Create a ServerRequest object representing a request from the server
56    /// for a buffer `requested_frames` in size.
57    ///
58    /// When the client responds to this request by calling
59    /// [`set_buffer_offset_and_frames`](ServerRequest::set_buffer_offset_and_frames),
60    /// BufferSet::callback will be called on `buffer_set`.
61    ///
62    /// # Arguments
63    /// * `requested_frames` - The requested buffer size in frames.
64    /// * `buffer_set` - The object implementing the callback for when a buffer is provided.
65    pub fn new<D: BufferSet>(requested_frames: usize, buffer_set: &'a mut D) -> Self {
66        Self {
67            requested_frames,
68            buffer_set,
69        }
70    }
71
72    /// Get the number of frames of audio data requested by the server.
73    ///
74    /// The returned value should never be greater than the `buffer_size`
75    /// given in [`new_stream`](ShmStreamSource::new_stream).
76    pub fn requested_frames(&self) -> usize {
77        self.requested_frames
78    }
79
80    /// Sets the buffer offset and length for the requested buffer.
81    ///
82    /// Sets the buffer offset and length of the buffer that fulfills this
83    /// server request to `offset` and `length`, respectively. This means that
84    /// `length` bytes of audio samples may be read from/written to that
85    /// location in `client_shm` for a playback/capture stream, respectively.
86    /// This function may only be called once for a `ServerRequest`, at which
87    /// point the ServerRequest is dropped and no further calls are possible.
88    ///
89    /// # Arguments
90    ///
91    /// * `offset` - The value to use as the new buffer offset for the next buffer.
92    /// * `frames` - The length of the next buffer in frames.
93    ///
94    /// # Errors
95    ///
96    /// * If `frames` is greater than `requested_frames`.
97    pub fn set_buffer_offset_and_frames(self, offset: usize, frames: usize) -> GenericResult<()> {
98        if frames > self.requested_frames {
99            return Err(Box::new(Error::TooManyFrames(
100                frames,
101                self.requested_frames,
102            )));
103        }
104
105        self.buffer_set.callback(offset, frames)
106    }
107
108    /// Ignore this request
109    ///
110    /// If the client does not intend to respond to this ServerRequest with a
111    /// buffer, they should call this function. The stream will be notified that
112    /// the request has been ignored and will handle it properly.
113    pub fn ignore_request(self) -> GenericResult<()> {
114        self.buffer_set.ignore()
115    }
116}
117
118/// `ShmStream` allows a client to interact with an active CRAS stream.
119pub trait ShmStream: Send {
120    /// Get the size of a frame of audio data for this stream.
121    fn frame_size(&self) -> usize;
122
123    /// Get the number of channels of audio data for this stream.
124    fn num_channels(&self) -> usize;
125
126    /// Get the frame rate of audio data for this stream.
127    fn frame_rate(&self) -> u32;
128
129    /// Waits until the next server message indicating action is required.
130    ///
131    /// For playback streams, this will be `AUDIO_MESSAGE_REQUEST_DATA`, meaning
132    /// that we must set the buffer offset to the next location where playback
133    /// data can be found.
134    /// For capture streams, this will be `AUDIO_MESSAGE_DATA_READY`, meaning
135    /// that we must set the buffer offset to the next location where captured
136    /// data can be written to.
137    /// Will return early if `timeout` elapses before a message is received.
138    ///
139    /// # Arguments
140    ///
141    /// * `timeout` - The amount of time to wait until a message is received.
142    ///
143    /// # Return value
144    ///
145    /// Returns `Some(request)` where `request` is an object that implements the
146    /// [`ServerRequest`] trait and which can be used to get the
147    /// number of bytes requested for playback streams or that have already been
148    /// written to shm for capture streams.
149    ///
150    /// If the timeout occurs before a message is received, returns `None`.
151    ///
152    /// # Errors
153    ///
154    /// * If an invalid message type is received for the stream.
155    fn wait_for_next_action_with_timeout(
156        &mut self,
157        timeout: Duration,
158    ) -> GenericResult<Option<ServerRequest>>;
159}
160
161/// `SharedMemory` specifies features of shared memory areas passed on to `ShmStreamSource`.
162pub trait SharedMemory {
163    type Error: std::error::Error;
164
165    /// Creates a new shared memory file descriptor without specifying a name.
166    fn anon(size: u64) -> Result<Self, Self::Error>
167    where
168        Self: Sized;
169
170    /// Gets the size in bytes of the shared memory.
171    ///
172    /// The size returned here does not reflect changes by other interfaces or users of the shared
173    /// memory file descriptor..
174    fn size(&self) -> u64;
175
176    /// Returns the underlying raw fd.
177    #[cfg(any(target_os = "android", target_os = "linux"))]
178    fn as_raw_fd(&self) -> RawFd;
179}
180
181/// `ShmStreamSource` creates streams for playback or capture of audio.
182pub trait ShmStreamSource<E: std::error::Error>: Send {
183    /// Creates a new [`ShmStream`]
184    ///
185    /// Creates a new `ShmStream` object, which allows:
186    /// * Waiting until the server has communicated that data is ready or requested that we make
187    ///   more data available.
188    /// * Setting the location and length of buffers for reading/writing audio data.
189    ///
190    /// # Arguments
191    ///
192    /// * `direction` - The direction of the stream, either `Playback` or `Capture`.
193    /// * `num_channels` - The number of audio channels for the stream.
194    /// * `format` - The audio format to use for audio samples.
195    /// * `frame_rate` - The stream's frame rate in Hz.
196    /// * `buffer_size` - The maximum size of an audio buffer. This will be the size used for
197    ///   transfers of audio data between client and server.
198    /// * `effects` - Audio effects to use for the stream, such as echo-cancellation.
199    /// * `client_shm` - The shared memory area that will contain samples.
200    /// * `buffer_offsets` - The two initial values to use as buffer offsets for streams. This way,
201    ///   the server will not write audio data to an arbitrary offset in `client_shm` if the client
202    ///   fails to update offsets in time.
203    ///
204    /// # Errors
205    ///
206    /// * If sending the connect stream message to the server fails.
207    #[allow(clippy::too_many_arguments)]
208    fn new_stream(
209        &mut self,
210        direction: StreamDirection,
211        num_channels: usize,
212        format: SampleFormat,
213        frame_rate: u32,
214        buffer_size: usize,
215        effects: &[StreamEffect],
216        client_shm: &dyn SharedMemory<Error = E>,
217        buffer_offsets: [u64; 2],
218    ) -> GenericResult<Box<dyn ShmStream>>;
219
220    /// Get a list of file descriptors used by the implementation.
221    ///
222    /// Returns any open file descriptors needed by the implementation.
223    /// This list helps users of the ShmStreamSource enter Linux jails without
224    /// closing needed file descriptors.
225    #[cfg(any(target_os = "android", target_os = "linux"))]
226    fn keep_fds(&self) -> Vec<RawFd> {
227        Vec::new()
228    }
229}
230
231/// Class that implements ShmStream trait but does nothing with the samples
232pub struct NullShmStream {
233    num_channels: usize,
234    frame_rate: u32,
235    buffer_size: usize,
236    frame_size: usize,
237    interval: Duration,
238    next_frame: Duration,
239    start_time: Instant,
240}
241
242impl NullShmStream {
243    /// Attempt to create a new NullShmStream with the given number of channels,
244    /// format, frame_rate, and buffer_size.
245    pub fn new(
246        buffer_size: usize,
247        num_channels: usize,
248        format: SampleFormat,
249        frame_rate: u32,
250    ) -> Self {
251        let interval = Duration::from_millis(buffer_size as u64 * 1000 / frame_rate as u64);
252        Self {
253            num_channels,
254            frame_rate,
255            buffer_size,
256            frame_size: format.sample_bytes() * num_channels,
257            interval,
258            next_frame: interval,
259            start_time: Instant::now(),
260        }
261    }
262}
263
264impl BufferSet for NullShmStream {
265    fn callback(&mut self, _offset: usize, _frames: usize) -> GenericResult<()> {
266        Ok(())
267    }
268
269    fn ignore(&mut self) -> GenericResult<()> {
270        Ok(())
271    }
272}
273
274impl ShmStream for NullShmStream {
275    fn frame_size(&self) -> usize {
276        self.frame_size
277    }
278
279    fn num_channels(&self) -> usize {
280        self.num_channels
281    }
282
283    fn frame_rate(&self) -> u32 {
284        self.frame_rate
285    }
286
287    fn wait_for_next_action_with_timeout(
288        &mut self,
289        timeout: Duration,
290    ) -> GenericResult<Option<ServerRequest>> {
291        let elapsed = self.start_time.elapsed();
292        if elapsed < self.next_frame {
293            if timeout < self.next_frame - elapsed {
294                std::thread::sleep(timeout);
295                return Ok(None);
296            } else {
297                std::thread::sleep(self.next_frame - elapsed);
298            }
299        }
300        self.next_frame += self.interval;
301        Ok(Some(ServerRequest::new(self.buffer_size, self)))
302    }
303}
304
305/// Source of `NullShmStream` objects.
306#[derive(Default)]
307pub struct NullShmStreamSource;
308
309impl NullShmStreamSource {
310    pub fn new() -> Self {
311        NullShmStreamSource
312    }
313}
314
315impl<E: std::error::Error> ShmStreamSource<E> for NullShmStreamSource {
316    fn new_stream(
317        &mut self,
318        _direction: StreamDirection,
319        num_channels: usize,
320        format: SampleFormat,
321        frame_rate: u32,
322        buffer_size: usize,
323        _effects: &[StreamEffect],
324        _client_shm: &dyn SharedMemory<Error = E>,
325        _buffer_offsets: [u64; 2],
326    ) -> GenericResult<Box<dyn ShmStream>> {
327        let new_stream = NullShmStream::new(buffer_size, num_channels, format, frame_rate);
328        Ok(Box::new(new_stream))
329    }
330}
331
332#[derive(Clone)]
333pub struct MockShmStream {
334    num_channels: usize,
335    frame_rate: u32,
336    request_size: usize,
337    frame_size: usize,
338    request_notifier: Arc<(Mutex<bool>, Condvar)>,
339}
340
341impl MockShmStream {
342    /// Attempt to create a new MockShmStream with the given number of
343    /// channels, frame_rate, format, and buffer_size.
344    pub fn new(
345        num_channels: usize,
346        frame_rate: u32,
347        format: SampleFormat,
348        buffer_size: usize,
349    ) -> Self {
350        #[allow(clippy::mutex_atomic)]
351        Self {
352            num_channels,
353            frame_rate,
354            request_size: buffer_size,
355            frame_size: format.sample_bytes() * num_channels,
356            request_notifier: Arc::new((Mutex::new(false), Condvar::new())),
357        }
358    }
359
360    /// Call to request data from the stream, causing it to return from
361    /// `wait_for_next_action_with_timeout`. Will block until
362    /// `set_buffer_offset_and_frames` is called on the ServerRequest returned
363    /// from `wait_for_next_action_with_timeout`, or until `timeout` elapses.
364    /// Returns true if a response was successfully received.
365    pub fn trigger_callback_with_timeout(&mut self, timeout: Duration) -> bool {
366        let (lock, cvar) = &*self.request_notifier;
367        let mut requested = lock.lock().unwrap();
368        *requested = true;
369        cvar.notify_one();
370        let start_time = Instant::now();
371        while *requested {
372            requested = cvar.wait_timeout(requested, timeout).unwrap().0;
373            if start_time.elapsed() > timeout {
374                // We failed to get a callback in time, mark this as false.
375                *requested = false;
376                return false;
377            }
378        }
379
380        true
381    }
382
383    fn notify_request(&mut self) {
384        let (lock, cvar) = &*self.request_notifier;
385        let mut requested = lock.lock().unwrap();
386        *requested = false;
387        cvar.notify_one();
388    }
389}
390
391impl BufferSet for MockShmStream {
392    fn callback(&mut self, _offset: usize, _frames: usize) -> GenericResult<()> {
393        self.notify_request();
394        Ok(())
395    }
396
397    fn ignore(&mut self) -> GenericResult<()> {
398        self.notify_request();
399        Ok(())
400    }
401}
402
403impl ShmStream for MockShmStream {
404    fn frame_size(&self) -> usize {
405        self.frame_size
406    }
407
408    fn num_channels(&self) -> usize {
409        self.num_channels
410    }
411
412    fn frame_rate(&self) -> u32 {
413        self.frame_rate
414    }
415
416    fn wait_for_next_action_with_timeout(
417        &mut self,
418        timeout: Duration,
419    ) -> GenericResult<Option<ServerRequest>> {
420        {
421            let start_time = Instant::now();
422            let (lock, cvar) = &*self.request_notifier;
423            let mut requested = lock.lock().unwrap();
424            while !*requested {
425                requested = cvar.wait_timeout(requested, timeout).unwrap().0;
426                if start_time.elapsed() > timeout {
427                    return Ok(None);
428                }
429            }
430        }
431
432        Ok(Some(ServerRequest::new(self.request_size, self)))
433    }
434}
435
436/// Source of `MockShmStream` objects.
437#[derive(Clone, Default)]
438pub struct MockShmStreamSource {
439    last_stream: Arc<(Mutex<Option<MockShmStream>>, Condvar)>,
440}
441
442impl MockShmStreamSource {
443    pub fn new() -> Self {
444        Default::default()
445    }
446
447    /// Get the last stream that has been created from this source. If no stream
448    /// has been created, block until one has.
449    pub fn get_last_stream(&self) -> MockShmStream {
450        let (last_stream, cvar) = &*self.last_stream;
451        let mut stream = last_stream.lock().unwrap();
452        loop {
453            match &*stream {
454                None => stream = cvar.wait(stream).unwrap(),
455                Some(ref s) => return s.clone(),
456            };
457        }
458    }
459}
460
461impl<E: std::error::Error> ShmStreamSource<E> for MockShmStreamSource {
462    fn new_stream(
463        &mut self,
464        _direction: StreamDirection,
465        num_channels: usize,
466        format: SampleFormat,
467        frame_rate: u32,
468        buffer_size: usize,
469        _effects: &[StreamEffect],
470        _client_shm: &dyn SharedMemory<Error = E>,
471        _buffer_offsets: [u64; 2],
472    ) -> GenericResult<Box<dyn ShmStream>> {
473        let (last_stream, cvar) = &*self.last_stream;
474        let mut stream = last_stream.lock().unwrap();
475
476        let new_stream = MockShmStream::new(num_channels, frame_rate, format, buffer_size);
477        *stream = Some(new_stream.clone());
478        cvar.notify_one();
479        Ok(Box::new(new_stream))
480    }
481}
482
483// Tests that run only for Unix, where `base::SharedMemory` is used.
484#[cfg(all(test, unix))]
485pub mod tests {
486    use super::*;
487
488    struct MockSharedMemory {}
489
490    impl SharedMemory for MockSharedMemory {
491        type Error = super::Error;
492
493        fn anon(_: u64) -> Result<Self, Self::Error> {
494            Ok(MockSharedMemory {})
495        }
496
497        fn size(&self) -> u64 {
498            0
499        }
500
501        #[cfg(any(target_os = "android", target_os = "linux"))]
502        fn as_raw_fd(&self) -> RawFd {
503            0
504        }
505    }
506
507    #[test]
508    fn mock_trigger_callback() {
509        let stream_source = MockShmStreamSource::new();
510        let mut thread_stream_source = stream_source.clone();
511
512        let buffer_size = 480;
513        let num_channels = 2;
514        let format = SampleFormat::S24LE;
515        let shm = MockSharedMemory {};
516
517        let handle = std::thread::spawn(move || {
518            let mut stream = thread_stream_source
519                .new_stream(
520                    StreamDirection::Playback,
521                    num_channels,
522                    format,
523                    44100,
524                    buffer_size,
525                    &[],
526                    &shm,
527                    [400, 8000],
528                )
529                .expect("Failed to create stream");
530
531            let request = stream
532                .wait_for_next_action_with_timeout(Duration::from_secs(5))
533                .expect("Failed to wait for next action");
534            match request {
535                Some(r) => {
536                    let requested = r.requested_frames();
537                    r.set_buffer_offset_and_frames(872, requested)
538                        .expect("Failed to set buffer offset and frames");
539                    requested
540                }
541                None => 0,
542            }
543        });
544
545        let mut stream = stream_source.get_last_stream();
546        assert!(stream.trigger_callback_with_timeout(Duration::from_secs(1)));
547
548        let requested_frames = handle.join().expect("Failed to join thread");
549        assert_eq!(requested_frames, buffer_size);
550    }
551
552    #[test]
553    fn null_consumption_rate() {
554        let frame_rate = 44100;
555        let buffer_size = 480;
556        let interval = Duration::from_millis(buffer_size as u64 * 1000 / frame_rate as u64);
557
558        let shm = MockSharedMemory {};
559
560        let start = Instant::now();
561
562        let mut stream_source = NullShmStreamSource::new();
563        let mut stream = stream_source
564            .new_stream(
565                StreamDirection::Playback,
566                2,
567                SampleFormat::S24LE,
568                frame_rate,
569                buffer_size,
570                &[],
571                &shm,
572                [400, 8000],
573            )
574            .expect("Failed to create stream");
575
576        let timeout = Duration::from_secs(5);
577        let request = stream
578            .wait_for_next_action_with_timeout(timeout)
579            .expect("Failed to wait for first request")
580            .expect("First request should not have timed out");
581        request
582            .set_buffer_offset_and_frames(276, 480)
583            .expect("Failed to set buffer offset and length");
584
585        // The second call should block until the first buffer is consumed.
586        let _request = stream
587            .wait_for_next_action_with_timeout(timeout)
588            .expect("Failed to wait for second request");
589        let elapsed = start.elapsed();
590        assert!(
591            elapsed > interval,
592            "wait_for_next_action_with_timeout didn't block long enough: {elapsed:?}"
593        );
594
595        assert!(
596            elapsed < timeout,
597            "wait_for_next_action_with_timeout blocked for too long: {elapsed:?}"
598        );
599    }
600}