devices/virtio/snd/vios_backend/
shm_vios.rs

1// Copyright 2020 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::collections::HashMap;
6use std::collections::VecDeque;
7use std::fs::File;
8use std::io::Error as IOError;
9use std::io::ErrorKind as IOErrorKind;
10use std::io::Seek;
11use std::io::SeekFrom;
12use std::path::Path;
13use std::path::PathBuf;
14use std::sync::mpsc::channel;
15use std::sync::mpsc::Receiver;
16use std::sync::mpsc::RecvError;
17use std::sync::mpsc::Sender;
18use std::sync::Arc;
19
20use base::error;
21use base::AsRawDescriptor;
22use base::Error as BaseError;
23use base::Event;
24use base::EventToken;
25use base::FromRawDescriptor;
26use base::IntoRawDescriptor;
27use base::MemoryMapping;
28use base::MemoryMappingBuilder;
29use base::MmapError;
30use base::RawDescriptor;
31use base::SafeDescriptor;
32use base::ScmSocket;
33use base::UnixSeqpacket;
34use base::VolatileMemory;
35use base::VolatileMemoryError;
36use base::VolatileSlice;
37use base::WaitContext;
38use base::WorkerThread;
39use remain::sorted;
40use serde::Deserialize;
41use serde::Serialize;
42use sync::Mutex;
43use thiserror::Error as ThisError;
44use zerocopy::FromBytes;
45use zerocopy::Immutable;
46use zerocopy::IntoBytes;
47use zerocopy::KnownLayout;
48
49use crate::virtio::snd::constants::*;
50use crate::virtio::snd::layout::*;
51use crate::virtio::snd::vios_backend::streams::StreamState;
52
53pub type Result<T> = std::result::Result<T, Error>;
54
55#[sorted]
56#[derive(ThisError, Debug)]
57pub enum Error {
58    #[error("Error memory mapping client_shm: {0}")]
59    BaseMmapError(BaseError),
60    #[error("Sender was dropped without sending buffer status, the recv thread may have exited")]
61    BufferStatusSenderLost(RecvError),
62    #[error("Command failed with status {0}")]
63    CommandFailed(u32),
64    #[error("Error duplicating file descriptor: {0}")]
65    DupError(BaseError),
66    #[error("Failed to create Recv event: {0}")]
67    EventCreateError(BaseError),
68    #[error("Failed to dup Recv event: {0}")]
69    EventDupError(BaseError),
70    #[error("Failed to signal event: {0}")]
71    EventWriteError(BaseError),
72    #[error("Failed to get size of tx shared memory: {0}")]
73    FileSizeError(IOError),
74    #[error("Error accessing guest's shared memory: {0}")]
75    GuestMmapError(MmapError),
76    #[error("No jack with id {0}")]
77    InvalidJackId(u32),
78    #[error("No stream with id {0}")]
79    InvalidStreamId(u32),
80    #[error("IO buffer operation failed: status = {0}")]
81    IOBufferError(u32),
82    #[error("No PCM streams available")]
83    NoStreamsAvailable,
84    #[error("Insuficient space for the new buffer in the queue's buffer area")]
85    OutOfSpace,
86    #[error("Platform not supported")]
87    PlatformNotSupported,
88    #[error("{0}")]
89    ProtocolError(ProtocolErrorKind),
90    #[error("Failed to connect to VioS server {1}: {0:?}")]
91    ServerConnectionError(IOError, PathBuf),
92    #[error("Failed to communicate with VioS server: {0:?}")]
93    ServerError(IOError),
94    #[error("Failed to communicate with VioS server: {0:?}")]
95    ServerIOError(IOError),
96    #[error("Error accessing VioS server's shared memory: {0}")]
97    ServerMmapError(MmapError),
98    #[error("Failed to duplicate UnixSeqpacket: {0}")]
99    UnixSeqpacketDupError(IOError),
100    #[error("Unsupported frame rate: {0}")]
101    UnsupportedFrameRate(u32),
102    #[error("Error accessing volatile memory: {0}")]
103    VolatileMemoryError(VolatileMemoryError),
104    #[error("Failed to create Recv thread's WaitContext: {0}")]
105    WaitContextCreateError(BaseError),
106    #[error("Error waiting for events")]
107    WaitError(BaseError),
108    #[error("Invalid operation for stream direction: {0}")]
109    WrongDirection(u8),
110    #[error("Set saved params should only be used while restoring the device")]
111    WrongSetParams,
112}
113
114#[derive(ThisError, Debug)]
115pub enum ProtocolErrorKind {
116    #[error("The server sent a config of the wrong size: {0}")]
117    UnexpectedConfigSize(usize),
118    #[error("Received {1} file descriptors from the server, expected {0}")]
119    UnexpectedNumberOfFileDescriptors(usize, usize), // expected, received
120    #[error("Server's version ({0}) doesn't match client's")]
121    VersionMismatch(u32),
122    #[error("Received a msg with an unexpected size: expected {0}, received {1}")]
123    UnexpectedMessageSize(usize, usize), // expected, received
124}
125
126/// The client for the VioS backend
127///
128/// Uses a protocol equivalent to virtio-snd over a shared memory file and a unix socket for
129/// notifications. It's thread safe, it can be encapsulated in an Arc smart pointer and shared
130/// between threads.
131pub struct VioSClient {
132    // These mutexes should almost never be held simultaneously. If at some point they have to the
133    // locking order should match the order in which they are declared here.
134    config: VioSConfig,
135    jacks: Vec<virtio_snd_jack_info>,
136    streams: Vec<virtio_snd_pcm_info>,
137    chmaps: Vec<virtio_snd_chmap_info>,
138    // The control socket is used from multiple threads to send and wait for a reply, which needs
139    // to happen atomically, hence the need for a mutex instead of just sharing clones of the
140    // socket.
141    control_socket: Mutex<UnixSeqpacket>,
142    event_socket: UnixSeqpacket,
143    // These are thread safe and don't require locking
144    tx: IoBufferQueue,
145    rx: IoBufferQueue,
146    // This is accessed by the recv_thread and whatever thread processes the events
147    events: Arc<Mutex<VecDeque<virtio_snd_event>>>,
148    event_notifier: Event,
149    // These are accessed by the recv_thread and the stream threads
150    tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
151    rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
152    recv_thread_state: Arc<Mutex<ThreadFlags>>,
153    recv_thread: Mutex<Option<WorkerThread<()>>>,
154    // Params are required to be stored for snapshot/restore. On restore, we don't have the params
155    // locally available as the VM is started anew, so they need to be restored.
156    params: HashMap<u32, virtio_snd_pcm_set_params>,
157}
158
159#[derive(Serialize, Deserialize)]
160pub struct VioSClientSnapshot {
161    config: VioSConfig,
162    jacks: Vec<virtio_snd_jack_info>,
163    streams: Vec<virtio_snd_pcm_info>,
164    chmaps: Vec<virtio_snd_chmap_info>,
165    params: HashMap<u32, virtio_snd_pcm_set_params>,
166}
167
168impl VioSClient {
169    /// Create a new client given the path to the audio server's socket.
170    pub fn try_new<P: AsRef<Path>>(server: P) -> Result<VioSClient> {
171        let client_socket = ScmSocket::try_from(
172            UnixSeqpacket::connect(server.as_ref())
173                .map_err(|e| Error::ServerConnectionError(e, server.as_ref().into()))?,
174        )
175        .map_err(|e| Error::ServerConnectionError(e, server.as_ref().into()))?;
176        let mut config: VioSConfig = Default::default();
177        const NUM_FDS: usize = 5;
178        let (recv_size, mut safe_fds) = client_socket
179            .recv_with_fds(config.as_mut_bytes(), NUM_FDS)
180            .map_err(Error::ServerError)?;
181
182        if recv_size != std::mem::size_of::<VioSConfig>() {
183            return Err(Error::ProtocolError(
184                ProtocolErrorKind::UnexpectedConfigSize(recv_size),
185            ));
186        }
187
188        if config.version != VIOS_VERSION {
189            return Err(Error::ProtocolError(ProtocolErrorKind::VersionMismatch(
190                config.version,
191            )));
192        }
193
194        fn pop<T: FromRawDescriptor>(
195            safe_fds: &mut Vec<SafeDescriptor>,
196            expected: usize,
197            received: usize,
198        ) -> Result<T> {
199            // SAFETY:
200            // Safe because we transfer ownership from the SafeDescriptor to T
201            unsafe {
202                Ok(T::from_raw_descriptor(
203                    safe_fds
204                        .pop()
205                        .ok_or(Error::ProtocolError(
206                            ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(
207                                expected, received,
208                            ),
209                        ))?
210                        .into_raw_descriptor(),
211                ))
212            }
213        }
214
215        let fd_count = safe_fds.len();
216        let rx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
217        let tx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
218        let rx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
219        let tx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
220        let event_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
221
222        if !safe_fds.is_empty() {
223            return Err(Error::ProtocolError(
224                ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(NUM_FDS, fd_count),
225            ));
226        }
227
228        let tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
229            Arc::new(Mutex::new(HashMap::new()));
230        let rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
231            Arc::new(Mutex::new(HashMap::new()));
232        let recv_thread_state = Arc::new(Mutex::new(ThreadFlags {
233            reporting_events: false,
234        }));
235
236        let mut client = VioSClient {
237            config,
238            jacks: Vec::new(),
239            streams: Vec::new(),
240            chmaps: Vec::new(),
241            control_socket: Mutex::new(client_socket.into_inner()),
242            event_socket,
243            tx: IoBufferQueue::new(tx_socket, tx_shm_file)?,
244            rx: IoBufferQueue::new(rx_socket, rx_shm_file)?,
245            events: Arc::new(Mutex::new(VecDeque::new())),
246            event_notifier: Event::new().map_err(Error::EventCreateError)?,
247            tx_subscribers,
248            rx_subscribers,
249            recv_thread_state,
250            recv_thread: Mutex::new(None),
251            params: HashMap::new(),
252        };
253        client.request_and_cache_info()?;
254        Ok(client)
255    }
256
257    /// Get the number of jacks
258    pub fn num_jacks(&self) -> u32 {
259        self.config.jacks
260    }
261
262    /// Get the number of pcm streams
263    pub fn num_streams(&self) -> u32 {
264        self.config.streams
265    }
266
267    /// Get the number of channel maps
268    pub fn num_chmaps(&self) -> u32 {
269        self.config.chmaps
270    }
271
272    /// Get the configuration information on a jack
273    pub fn jack_info(&self, idx: u32) -> Option<virtio_snd_jack_info> {
274        self.jacks.get(idx as usize).copied()
275    }
276
277    /// Get the configuration information on a pcm stream
278    pub fn stream_info(&self, idx: u32) -> Option<virtio_snd_pcm_info> {
279        self.streams.get(idx as usize).cloned()
280    }
281
282    /// Get the configuration information on a channel map
283    pub fn chmap_info(&self, idx: u32) -> Option<virtio_snd_chmap_info> {
284        self.chmaps.get(idx as usize).copied()
285    }
286
287    /// Starts the background thread that receives release messages from the server. If the thread
288    /// was already started this function does nothing.
289    /// This thread must be started prior to attempting any stream IO operation or the calling
290    /// thread would block.
291    pub fn start_bg_thread(&self) -> Result<()> {
292        if self.recv_thread.lock().is_some() {
293            return Ok(());
294        }
295        let tx_socket = self.tx.try_clone_socket()?;
296        let rx_socket = self.rx.try_clone_socket()?;
297        let event_socket = self
298            .event_socket
299            .try_clone()
300            .map_err(Error::UnixSeqpacketDupError)?;
301        let mut opt = self.recv_thread.lock();
302        // The lock on recv_thread was released above to avoid holding more than one lock at a time
303        // while duplicating the fds. So we have to check the condition again.
304        if opt.is_none() {
305            let tx_subscribers = self.tx_subscribers.clone();
306            let rx_subscribers = self.rx_subscribers.clone();
307            let event_notifier = self
308                .event_notifier
309                .try_clone()
310                .map_err(Error::EventDupError)?;
311            let events = self.events.clone();
312            let recv_thread_state = self.recv_thread_state.clone();
313            *opt = Some(WorkerThread::start("shm_vios", move |kill_event| {
314                if let Err(e) = run_recv_thread(
315                    kill_event,
316                    tx_subscribers,
317                    rx_subscribers,
318                    event_notifier,
319                    events,
320                    recv_thread_state,
321                    tx_socket,
322                    rx_socket,
323                    event_socket,
324                ) {
325                    error!("virtio-snd shm_vios worker failed: {e:#}");
326                }
327            }));
328        }
329        Ok(())
330    }
331
332    /// Stops the background thread.
333    pub fn stop_bg_thread(&self) -> Result<()> {
334        if let Some(recv_thread) = self.recv_thread.lock().take() {
335            recv_thread.stop();
336        }
337        Ok(())
338    }
339
340    /// Gets an Event object that will trigger every time an event is received from the server
341    pub fn get_event_notifier(&self) -> Result<Event> {
342        // Let the background thread know that there is at least one consumer of events
343        self.recv_thread_state.lock().reporting_events = true;
344        self.event_notifier
345            .try_clone()
346            .map_err(Error::EventDupError)
347    }
348
349    /// Retrieves one event. Callers should have received a notification through the event notifier
350    /// before calling this function.
351    pub fn pop_event(&self) -> Option<virtio_snd_event> {
352        self.events.lock().pop_front()
353    }
354
355    /// Remap a jack. This should only be called if the jack announces support for the operation
356    /// through the features field in the corresponding virtio_snd_jack_info struct.
357    pub fn remap_jack(&self, jack_id: u32, association: u32, sequence: u32) -> Result<()> {
358        if jack_id >= self.config.jacks {
359            return Err(Error::InvalidJackId(jack_id));
360        }
361        let msg = virtio_snd_jack_remap {
362            hdr: virtio_snd_jack_hdr {
363                hdr: virtio_snd_hdr {
364                    code: VIRTIO_SND_R_JACK_REMAP.into(),
365                },
366                jack_id: jack_id.into(),
367            },
368            association: association.into(),
369            sequence: sequence.into(),
370        };
371        let control_socket_lock = self.control_socket.lock();
372        send_cmd(&control_socket_lock, msg)
373    }
374
375    /// Configures a stream with the given parameters.
376    pub fn set_stream_parameters(
377        &mut self,
378        stream_id: u32,
379        params: VioSStreamParams,
380    ) -> Result<()> {
381        self.streams
382            .get(stream_id as usize)
383            .ok_or(Error::InvalidStreamId(stream_id))?;
384        let raw_params: virtio_snd_pcm_set_params = (stream_id, params).into();
385        // Old value is not needed and is dropped
386        let _ = self.params.insert(stream_id, raw_params);
387        let control_socket_lock = self.control_socket.lock();
388        send_cmd(&control_socket_lock, raw_params)
389    }
390
391    /// Configures a stream with the given parameters.
392    pub fn set_stream_parameters_raw(
393        &mut self,
394        raw_params: virtio_snd_pcm_set_params,
395    ) -> Result<()> {
396        let stream_id = raw_params.hdr.stream_id.to_native();
397        // Old value is not needed and is dropped
398        let _ = self.params.insert(stream_id, raw_params);
399        self.streams
400            .get(stream_id as usize)
401            .ok_or(Error::InvalidStreamId(stream_id))?;
402        let control_socket_lock = self.control_socket.lock();
403        send_cmd(&control_socket_lock, raw_params)
404    }
405
406    /// Send the PREPARE_STREAM command to the server.
407    pub fn prepare_stream(&self, stream_id: u32) -> Result<()> {
408        self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_PREPARE)
409    }
410
411    /// Send the RELEASE_STREAM command to the server.
412    pub fn release_stream(&self, stream_id: u32) -> Result<()> {
413        self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_RELEASE)
414    }
415
416    /// Send the START_STREAM command to the server.
417    pub fn start_stream(&self, stream_id: u32) -> Result<()> {
418        self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_START)
419    }
420
421    /// Send the STOP_STREAM command to the server.
422    pub fn stop_stream(&self, stream_id: u32) -> Result<()> {
423        self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_STOP)
424    }
425
426    /// Send audio frames to the server. Blocks the calling thread until the server acknowledges
427    /// the data.
428    pub fn inject_audio_data<R, Cb: FnOnce(VolatileSlice) -> R>(
429        &self,
430        stream_id: u32,
431        size: usize,
432        callback: Cb,
433    ) -> Result<(u32, R)> {
434        if self
435            .streams
436            .get(stream_id as usize)
437            .ok_or(Error::InvalidStreamId(stream_id))?
438            .direction
439            != VIRTIO_SND_D_OUTPUT
440        {
441            return Err(Error::WrongDirection(VIRTIO_SND_D_OUTPUT));
442        }
443        self.streams
444            .get(stream_id as usize)
445            .ok_or(Error::InvalidStreamId(stream_id))?;
446        let dst_offset = self.tx.allocate_buffer(size)?;
447        let buffer_slice = self.tx.buffer_at(dst_offset, size)?;
448        let ret = callback(buffer_slice);
449        // Register to receive the status before sending the buffer to the server
450        let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
451        self.tx_subscribers.lock().insert(dst_offset, sender);
452        self.tx.send_buffer(stream_id, dst_offset, size)?;
453        let (_, latency) = await_status(receiver)?;
454        Ok((latency, ret))
455    }
456
457    /// Request audio frames from the server. It blocks until the data is available.
458    pub fn request_audio_data<R, Cb: FnOnce(&VolatileSlice) -> R>(
459        &self,
460        stream_id: u32,
461        size: usize,
462        callback: Cb,
463    ) -> Result<(u32, R)> {
464        if self
465            .streams
466            .get(stream_id as usize)
467            .ok_or(Error::InvalidStreamId(stream_id))?
468            .direction
469            != VIRTIO_SND_D_INPUT
470        {
471            return Err(Error::WrongDirection(VIRTIO_SND_D_INPUT));
472        }
473        let src_offset = self.rx.allocate_buffer(size)?;
474        // Register to receive the status before sending the buffer to the server
475        let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
476        self.rx_subscribers.lock().insert(src_offset, sender);
477        self.rx.send_buffer(stream_id, src_offset, size)?;
478        // Make sure no mutexes are held while awaiting for the buffer to be written to
479        let (recv_size, latency) = await_status(receiver)?;
480        let buffer_slice = self.rx.buffer_at(src_offset, recv_size)?;
481        Ok((latency, callback(&buffer_slice)))
482    }
483
484    /// Get a list of file descriptors used by the implementation.
485    pub fn keep_rds(&self) -> Vec<RawDescriptor> {
486        let control_desc = self.control_socket.lock().as_raw_descriptor();
487        let event_desc = self.event_socket.as_raw_descriptor();
488        let event_notifier = self.event_notifier.as_raw_descriptor();
489        let mut ret = vec![control_desc, event_desc, event_notifier];
490        ret.append(&mut self.tx.keep_rds());
491        ret.append(&mut self.rx.keep_rds());
492        ret
493    }
494
495    fn common_stream_op(&self, stream_id: u32, op: u32) -> Result<()> {
496        self.streams
497            .get(stream_id as usize)
498            .ok_or(Error::InvalidStreamId(stream_id))?;
499        let msg = virtio_snd_pcm_hdr {
500            hdr: virtio_snd_hdr { code: op.into() },
501            stream_id: stream_id.into(),
502        };
503        let control_socket_lock = self.control_socket.lock();
504        send_cmd(&control_socket_lock, msg)
505    }
506
507    fn request_and_cache_info(&mut self) -> Result<()> {
508        self.request_and_cache_jacks_info()?;
509        self.request_and_cache_streams_info()?;
510        self.request_and_cache_chmaps_info()?;
511        Ok(())
512    }
513
514    fn request_info<T: IntoBytes + FromBytes + Default + Copy + Clone>(
515        &self,
516        req_code: u32,
517        count: usize,
518    ) -> Result<Vec<T>> {
519        let info_size = std::mem::size_of::<T>();
520        let status_size = std::mem::size_of::<virtio_snd_hdr>();
521        let req = virtio_snd_query_info {
522            hdr: virtio_snd_hdr {
523                code: req_code.into(),
524            },
525            start_id: 0u32.into(),
526            count: (count as u32).into(),
527            size: (std::mem::size_of::<virtio_snd_query_info>() as u32).into(),
528        };
529        let control_socket_lock = self.control_socket.lock();
530        seq_socket_send(&control_socket_lock, req.as_bytes())?;
531        let reply = control_socket_lock
532            .recv_as_vec()
533            .map_err(Error::ServerIOError)?;
534        let mut status: virtio_snd_hdr = Default::default();
535        status
536            .as_mut_bytes()
537            .copy_from_slice(&reply[0..status_size]);
538        if status.code.to_native() != VIRTIO_SND_S_OK {
539            return Err(Error::CommandFailed(status.code.to_native()));
540        }
541        if reply.len() != status_size + count * info_size {
542            return Err(Error::ProtocolError(
543                ProtocolErrorKind::UnexpectedMessageSize(count * info_size, reply.len()),
544            ));
545        }
546        Ok(reply[status_size..]
547            .chunks(info_size)
548            .map(|info_buffer| T::read_from_bytes(info_buffer).unwrap())
549            .collect())
550    }
551
552    fn request_and_cache_jacks_info(&mut self) -> Result<()> {
553        let num_jacks = self.config.jacks as usize;
554        if num_jacks == 0 {
555            return Ok(());
556        }
557        self.jacks = self.request_info(VIRTIO_SND_R_JACK_INFO, num_jacks)?;
558        Ok(())
559    }
560
561    fn request_and_cache_streams_info(&mut self) -> Result<()> {
562        let num_streams = self.config.streams as usize;
563        if num_streams == 0 {
564            return Ok(());
565        }
566        self.streams = self.request_info(VIRTIO_SND_R_PCM_INFO, num_streams)?;
567        Ok(())
568    }
569
570    fn request_and_cache_chmaps_info(&mut self) -> Result<()> {
571        let num_chmaps = self.config.chmaps as usize;
572        if num_chmaps == 0 {
573            return Ok(());
574        }
575        self.chmaps = self.request_info(VIRTIO_SND_R_CHMAP_INFO, num_chmaps)?;
576        Ok(())
577    }
578
579    pub fn snapshot(&self) -> VioSClientSnapshot {
580        VioSClientSnapshot {
581            config: self.config,
582            jacks: self.jacks.clone(),
583            streams: self.streams.clone(),
584            chmaps: self.chmaps.clone(),
585            params: self.params.clone(),
586        }
587    }
588
589    // Function called `restore` to signify it will happen as part of the snapshot/restore flow. No
590    // data is actually restored in the case of VioSClient.
591    pub fn restore(&mut self, data: VioSClientSnapshot) -> anyhow::Result<()> {
592        anyhow::ensure!(
593            data.config == self.config,
594            "config doesn't match on restore: expected: {:?}, got: {:?}",
595            data.config,
596            self.config
597        );
598        self.jacks = data.jacks;
599        self.streams = data.streams;
600        self.chmaps = data.chmaps;
601        self.params = data.params;
602        Ok(())
603    }
604
605    pub fn restore_stream(&mut self, stream_id: u32, state: StreamState) -> Result<()> {
606        if let Some(params) = self.params.get(&stream_id).cloned() {
607            self.set_stream_parameters_raw(params)?;
608        }
609        match state {
610            StreamState::Started => {
611                // If state != prepared, start will always fail.
612                // As such, it is fine to only print the first error without returning, as the
613                // second action will then fail.
614                if let Err(e) = self.prepare_stream(stream_id) {
615                    error!("failed to prepare stream: {}", e);
616                };
617                self.start_stream(stream_id)
618            }
619            StreamState::Prepared => self.prepare_stream(stream_id),
620            // Nothing to do here
621            _ => Ok(()),
622        }
623    }
624}
625
626#[derive(Clone, Copy)]
627struct ThreadFlags {
628    reporting_events: bool,
629}
630
631#[derive(EventToken)]
632enum Token {
633    Notification,
634    TxBufferMsg,
635    RxBufferMsg,
636    EventMsg,
637}
638
639fn recv_buffer_status_msg(
640    socket: &UnixSeqpacket,
641    subscribers: &Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
642) -> Result<()> {
643    let mut msg: IoStatusMsg = Default::default();
644    let size = socket
645        .recv(msg.as_mut_bytes())
646        .map_err(Error::ServerIOError)?;
647    if size != std::mem::size_of::<IoStatusMsg>() {
648        return Err(Error::ProtocolError(
649            ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<IoStatusMsg>(), size),
650        ));
651    }
652    let mut status = msg.status.status.into();
653    if status == u32::MAX {
654        // Anyone waiting for this would continue to wait for as long as status is
655        // u32::MAX
656        status -= 1;
657    }
658    let latency = msg.status.latency_bytes.into();
659    let offset = msg.buffer_offset as usize;
660    let consumed_len = msg.consumed_len as usize;
661    let promise_opt = subscribers.lock().remove(&offset);
662    match promise_opt {
663        None => error!(
664            "Received an unexpected buffer status message: {}. This is a BUG!!",
665            offset
666        ),
667        Some(sender) => {
668            if let Err(e) = sender.send(BufferReleaseMsg {
669                status,
670                latency,
671                consumed_len,
672            }) {
673                error!("Failed to notify waiting thread: {:?}", e);
674            }
675        }
676    }
677    Ok(())
678}
679
680fn recv_event(socket: &UnixSeqpacket) -> Result<virtio_snd_event> {
681    let mut msg: virtio_snd_event = Default::default();
682    let size = socket
683        .recv(msg.as_mut_bytes())
684        .map_err(Error::ServerIOError)?;
685    if size != std::mem::size_of::<virtio_snd_event>() {
686        return Err(Error::ProtocolError(
687            ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<virtio_snd_event>(), size),
688        ));
689    }
690    Ok(msg)
691}
692
693fn run_recv_thread(
694    kill_event: Event,
695    tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
696    rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
697    event_notifier: Event,
698    event_queue: Arc<Mutex<VecDeque<virtio_snd_event>>>,
699    state: Arc<Mutex<ThreadFlags>>,
700    tx_socket: UnixSeqpacket,
701    rx_socket: UnixSeqpacket,
702    event_socket: UnixSeqpacket,
703) -> Result<()> {
704    let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[
705        (&tx_socket, Token::TxBufferMsg),
706        (&rx_socket, Token::RxBufferMsg),
707        (&event_socket, Token::EventMsg),
708        (&kill_event, Token::Notification),
709    ])
710    .map_err(Error::WaitContextCreateError)?;
711    let mut running = true;
712    while running {
713        let events = wait_ctx.wait().map_err(Error::WaitError)?;
714        for evt in events {
715            match evt.token {
716                Token::TxBufferMsg => recv_buffer_status_msg(&tx_socket, &tx_subscribers)?,
717                Token::RxBufferMsg => recv_buffer_status_msg(&rx_socket, &rx_subscribers)?,
718                Token::EventMsg => {
719                    let evt = recv_event(&event_socket)?;
720                    let state_cpy = *state.lock();
721                    if state_cpy.reporting_events {
722                        event_queue.lock().push_back(evt);
723                        event_notifier.signal().map_err(Error::EventWriteError)?;
724                    } // else just drop the events
725                }
726                Token::Notification => {
727                    // Just consume the notification and check for termination on the next
728                    // iteration
729                    if let Err(e) = kill_event.wait() {
730                        error!("Failed to consume notification from recv thread: {:?}", e);
731                    }
732                    running = false;
733                }
734            }
735        }
736    }
737    Ok(())
738}
739
740fn await_status(promise: Receiver<BufferReleaseMsg>) -> Result<(usize, u32)> {
741    let BufferReleaseMsg {
742        status,
743        latency,
744        consumed_len,
745    } = promise.recv().map_err(Error::BufferStatusSenderLost)?;
746    if status == VIRTIO_SND_S_OK {
747        Ok((consumed_len, latency))
748    } else {
749        Err(Error::IOBufferError(status))
750    }
751}
752
753struct IoBufferQueue {
754    socket: UnixSeqpacket,
755    file: File,
756    mmap: MemoryMapping,
757    size: usize,
758    next: Mutex<usize>,
759}
760
761impl IoBufferQueue {
762    fn new(socket: UnixSeqpacket, mut file: File) -> Result<IoBufferQueue> {
763        let size = file.seek(SeekFrom::End(0)).map_err(Error::FileSizeError)? as usize;
764
765        let mmap = MemoryMappingBuilder::new(size)
766            .from_file(&file)
767            .build()
768            .map_err(Error::ServerMmapError)?;
769
770        Ok(IoBufferQueue {
771            socket,
772            file,
773            mmap,
774            size,
775            next: Mutex::new(0),
776        })
777    }
778
779    fn allocate_buffer(&self, size: usize) -> Result<usize> {
780        if size > self.size {
781            return Err(Error::OutOfSpace);
782        }
783        let mut next_lock = self.next.lock();
784        let offset = if size > self.size - *next_lock {
785            // Can't fit the new buffer at the end of the area, so put it at the beginning
786            0
787        } else {
788            *next_lock
789        };
790        *next_lock = offset + size;
791        Ok(offset)
792    }
793
794    fn buffer_at(&self, offset: usize, len: usize) -> Result<VolatileSlice> {
795        self.mmap
796            .get_slice(offset, len)
797            .map_err(Error::VolatileMemoryError)
798    }
799
800    fn try_clone_socket(&self) -> Result<UnixSeqpacket> {
801        self.socket
802            .try_clone()
803            .map_err(Error::UnixSeqpacketDupError)
804    }
805
806    fn send_buffer(&self, stream_id: u32, offset: usize, size: usize) -> Result<()> {
807        let msg = IoTransferMsg::new(stream_id, offset, size);
808        seq_socket_send(&self.socket, msg.as_bytes())
809    }
810
811    fn keep_rds(&self) -> Vec<RawDescriptor> {
812        vec![
813            self.file.as_raw_descriptor(),
814            self.socket.as_raw_descriptor(),
815        ]
816    }
817}
818
819/// Groups the parameters used to configure a stream prior to using it.
820pub struct VioSStreamParams {
821    pub buffer_bytes: u32,
822    pub period_bytes: u32,
823    pub features: u32,
824    pub channels: u8,
825    pub format: u8,
826    pub rate: u8,
827}
828
829impl From<(u32, VioSStreamParams)> for virtio_snd_pcm_set_params {
830    fn from(val: (u32, VioSStreamParams)) -> Self {
831        virtio_snd_pcm_set_params {
832            hdr: virtio_snd_pcm_hdr {
833                hdr: virtio_snd_hdr {
834                    code: VIRTIO_SND_R_PCM_SET_PARAMS.into(),
835                },
836                stream_id: val.0.into(),
837            },
838            buffer_bytes: val.1.buffer_bytes.into(),
839            period_bytes: val.1.period_bytes.into(),
840            features: val.1.features.into(),
841            channels: val.1.channels,
842            format: val.1.format,
843            rate: val.1.rate,
844            padding: 0u8,
845        }
846    }
847}
848
849fn send_cmd<T: Immutable + IntoBytes>(control_socket: &UnixSeqpacket, data: T) -> Result<()> {
850    seq_socket_send(control_socket, data.as_bytes())?;
851    recv_cmd_status(control_socket)
852}
853
854fn recv_cmd_status(control_socket: &UnixSeqpacket) -> Result<()> {
855    let mut status: virtio_snd_hdr = Default::default();
856    control_socket
857        .recv(status.as_mut_bytes())
858        .map_err(Error::ServerIOError)?;
859    if status.code.to_native() == VIRTIO_SND_S_OK {
860        Ok(())
861    } else {
862        Err(Error::CommandFailed(status.code.to_native()))
863    }
864}
865
866fn seq_socket_send(socket: &UnixSeqpacket, data: &[u8]) -> Result<()> {
867    loop {
868        let send_res = socket.send(data);
869        if let Err(e) = send_res {
870            match e.kind() {
871                // Retry if interrupted
872                IOErrorKind::Interrupted => continue,
873                _ => return Err(Error::ServerIOError(e)),
874            }
875        }
876        // Success
877        break;
878    }
879    Ok(())
880}
881
882const VIOS_VERSION: u32 = 2;
883
884#[repr(C)]
885#[derive(
886    Copy,
887    Clone,
888    Default,
889    FromBytes,
890    Immutable,
891    IntoBytes,
892    KnownLayout,
893    Serialize,
894    Deserialize,
895    PartialEq,
896    Eq,
897    Debug,
898)]
899struct VioSConfig {
900    version: u32,
901    jacks: u32,
902    streams: u32,
903    chmaps: u32,
904}
905
906struct BufferReleaseMsg {
907    status: u32,
908    latency: u32,
909    consumed_len: usize,
910}
911
912#[repr(C)]
913#[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
914struct IoTransferMsg {
915    io_xfer: virtio_snd_pcm_xfer,
916    buffer_offset: u32,
917    buffer_len: u32,
918}
919
920impl IoTransferMsg {
921    fn new(stream_id: u32, buffer_offset: usize, buffer_len: usize) -> IoTransferMsg {
922        IoTransferMsg {
923            io_xfer: virtio_snd_pcm_xfer {
924                stream_id: stream_id.into(),
925            },
926            buffer_offset: buffer_offset as u32,
927            buffer_len: buffer_len as u32,
928        }
929    }
930}
931
932#[repr(C)]
933#[derive(Copy, Clone, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
934struct IoStatusMsg {
935    status: virtio_snd_pcm_status,
936    buffer_offset: u32,
937    consumed_len: u32,
938}