devices/virtio/snd/vios_backend/
shm_vios.rs

1// Copyright 2020 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::collections::HashMap;
6use std::collections::VecDeque;
7use std::fs::File;
8use std::io::Error as IOError;
9use std::io::ErrorKind as IOErrorKind;
10use std::io::Seek;
11use std::io::SeekFrom;
12use std::path::Path;
13use std::path::PathBuf;
14use std::sync::mpsc::channel;
15use std::sync::mpsc::Receiver;
16use std::sync::mpsc::RecvError;
17use std::sync::mpsc::Sender;
18use std::sync::Arc;
19
20use base::error;
21use base::AsRawDescriptor;
22use base::Error as BaseError;
23use base::Event;
24use base::EventToken;
25use base::FromRawDescriptor;
26use base::IntoRawDescriptor;
27use base::MemoryMapping;
28use base::MemoryMappingBuilder;
29use base::MmapError;
30use base::RawDescriptor;
31use base::SafeDescriptor;
32use base::ScmSocket;
33use base::UnixSeqpacket;
34use base::VolatileMemory;
35use base::VolatileMemoryError;
36use base::VolatileSlice;
37use base::WaitContext;
38use base::WorkerThread;
39use remain::sorted;
40use serde::Deserialize;
41use serde::Serialize;
42use sync::Mutex;
43use thiserror::Error as ThisError;
44use zerocopy::FromBytes;
45use zerocopy::Immutable;
46use zerocopy::IntoBytes;
47use zerocopy::KnownLayout;
48
49use crate::virtio::snd::constants::*;
50use crate::virtio::snd::layout::*;
51use crate::virtio::snd::vios_backend::streams::StreamState;
52
53pub type Result<T> = std::result::Result<T, Error>;
54
55#[sorted]
56#[derive(ThisError, Debug)]
57pub enum Error {
58    #[error("Error memory mapping client_shm: {0}")]
59    BaseMmapError(BaseError),
60    #[error("Sender was dropped without sending buffer status, the recv thread may have exited")]
61    BufferStatusSenderLost(RecvError),
62    #[error("Command failed with status {0}")]
63    CommandFailed(u32),
64    #[error("Expected {0} controls, received {1}")]
65    ControlsMismatch(u32, u32),
66    #[error("Error duplicating file descriptor: {0}")]
67    DupError(BaseError),
68    #[error("Duplicated control index: {0} name: '{1}'")]
69    DuplicatedControlId(u32, String),
70    #[error("Failed to create Recv event: {0}")]
71    EventCreateError(BaseError),
72    #[error("Failed to dup Recv event: {0}")]
73    EventDupError(BaseError),
74    #[error("Failed to signal event: {0}")]
75    EventWriteError(BaseError),
76    #[error("Failed to get size of tx shared memory: {0}")]
77    FileSizeError(IOError),
78    #[error("Error accessing guest's shared memory: {0}")]
79    GuestMmapError(MmapError),
80    #[error("No jack with id {0}")]
81    InvalidJackId(u32),
82    #[error("No stream with id {0}")]
83    InvalidStreamId(u32),
84    #[error("IO buffer operation failed: status = {0}")]
85    IOBufferError(u32),
86    #[error("No PCM streams available")]
87    NoStreamsAvailable,
88    #[error("Insuficient space for the new buffer in the queue's buffer area")]
89    OutOfSpace,
90    #[error("Platform not supported")]
91    PlatformNotSupported,
92    #[error("{0}")]
93    ProtocolError(ProtocolErrorKind),
94    #[error("Failed to connect to VioS server {1}: {0:?}")]
95    ServerConnectionError(IOError, PathBuf),
96    #[error("Failed to communicate with VioS server: {0:?}")]
97    ServerError(IOError),
98    #[error("Failed to communicate with VioS server: {0:?}")]
99    ServerIOError(IOError),
100    #[error("Error accessing VioS server's shared memory: {0}")]
101    ServerMmapError(MmapError),
102    #[error("Failed to duplicate UnixSeqpacket: {0}")]
103    UnixSeqpacketDupError(IOError),
104    #[error("Unsupported frame rate: {0}")]
105    UnsupportedFrameRate(u32),
106    #[error("Error accessing volatile memory: {0}")]
107    VolatileMemoryError(VolatileMemoryError),
108    #[error("Failed to create Recv thread's WaitContext: {0}")]
109    WaitContextCreateError(BaseError),
110    #[error("Error waiting for events")]
111    WaitError(BaseError),
112    #[error("Invalid operation for stream direction: {0}")]
113    WrongDirection(u8),
114    #[error("Set saved params should only be used while restoring the device")]
115    WrongSetParams,
116}
117
118#[derive(ThisError, Debug)]
119pub enum ProtocolErrorKind {
120    #[error("The server sent a config of the wrong size: {0}")]
121    UnexpectedConfigSize(usize),
122    #[error("Received {1} file descriptors from the server, expected {0}")]
123    UnexpectedNumberOfFileDescriptors(usize, usize), // expected, received
124    #[error("Server's version ({0}) doesn't match client's")]
125    VersionMismatch(u32),
126    #[error("Received a msg with an unexpected size: expected {0}, received {1}")]
127    UnexpectedMessageSize(usize, usize), // expected, received
128}
129
130/// The client for the VioS backend
131///
132/// Uses a protocol equivalent to virtio-snd over a shared memory file and a unix socket for
133/// notifications. It's thread safe, it can be encapsulated in an Arc smart pointer and shared
134/// between threads.
135pub struct VioSClient {
136    // These mutexes should almost never be held simultaneously. If at some point they have to the
137    // locking order should match the order in which they are declared here.
138    config: VioSConfig,
139    jacks: Vec<virtio_snd_jack_info>,
140    streams: Vec<virtio_snd_pcm_info>,
141    chmaps: Vec<virtio_snd_chmap_info>,
142    controls: Vec<virtio_snd_ctl_info>,
143    // The control socket is used from multiple threads to send and wait for a reply, which needs
144    // to happen atomically, hence the need for a mutex instead of just sharing clones of the
145    // socket.
146    control_socket: Mutex<UnixSeqpacket>,
147    event_socket: UnixSeqpacket,
148    // These are thread safe and don't require locking
149    tx: IoBufferQueue,
150    rx: IoBufferQueue,
151    // This is accessed by the recv_thread and whatever thread processes the events
152    events: Arc<Mutex<VecDeque<virtio_snd_event>>>,
153    event_notifier: Event,
154    // These are accessed by the recv_thread and the stream threads
155    tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
156    rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
157    recv_thread_state: Arc<Mutex<ThreadFlags>>,
158    recv_thread: Mutex<Option<WorkerThread<()>>>,
159    // Params are required to be stored for snapshot/restore. On restore, we don't have the params
160    // locally available as the VM is started anew, so they need to be restored.
161    params: HashMap<u32, virtio_snd_pcm_set_params>,
162}
163
164#[derive(Serialize, Deserialize)]
165pub struct VioSClientSnapshot {
166    config: VioSConfig,
167    jacks: Vec<virtio_snd_jack_info>,
168    streams: Vec<virtio_snd_pcm_info>,
169    chmaps: Vec<virtio_snd_chmap_info>,
170    controls: Vec<virtio_snd_ctl_info>,
171    params: HashMap<u32, virtio_snd_pcm_set_params>,
172}
173
174impl VioSClient {
175    /// Create a new client given the path to the audio server's socket.
176    pub fn try_new<P: AsRef<Path>>(server: P) -> Result<VioSClient> {
177        let client_socket = ScmSocket::try_from(
178            UnixSeqpacket::connect(server.as_ref())
179                .map_err(|e| Error::ServerConnectionError(e, server.as_ref().into()))?,
180        )
181        .map_err(|e| Error::ServerConnectionError(e, server.as_ref().into()))?;
182
183        let mut config: VioSConfig = Default::default();
184        const NUM_FDS: usize = 5;
185        let (recv_size, mut safe_fds) = client_socket
186            .recv_with_fds(config.as_mut_bytes(), NUM_FDS)
187            .map_err(Error::ServerError)?;
188
189        match config.version {
190            2 => {
191                if recv_size != VIOS_SIZE_V2 {
192                    return Err(Error::ProtocolError(
193                        ProtocolErrorKind::UnexpectedConfigSize(recv_size),
194                    ));
195                }
196            }
197            3 => {
198                if recv_size != VIOS_SIZE_V3 {
199                    return Err(Error::ProtocolError(
200                        ProtocolErrorKind::UnexpectedConfigSize(recv_size),
201                    ));
202                }
203            }
204            _ => {
205                return Err(Error::ProtocolError(ProtocolErrorKind::VersionMismatch(
206                    config.version,
207                )));
208            }
209        }
210
211        fn pop<T: FromRawDescriptor>(
212            safe_fds: &mut Vec<SafeDescriptor>,
213            expected: usize,
214            received: usize,
215        ) -> Result<T> {
216            // SAFETY:
217            // Safe because we transfer ownership from the SafeDescriptor to T
218            unsafe {
219                Ok(T::from_raw_descriptor(
220                    safe_fds
221                        .pop()
222                        .ok_or(Error::ProtocolError(
223                            ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(
224                                expected, received,
225                            ),
226                        ))?
227                        .into_raw_descriptor(),
228                ))
229            }
230        }
231
232        let fd_count = safe_fds.len();
233        let rx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
234        let tx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
235        let rx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
236        let tx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
237        let event_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
238
239        if !safe_fds.is_empty() {
240            return Err(Error::ProtocolError(
241                ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(NUM_FDS, fd_count),
242            ));
243        }
244
245        let tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
246            Arc::new(Mutex::new(HashMap::new()));
247        let rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
248            Arc::new(Mutex::new(HashMap::new()));
249        let recv_thread_state = Arc::new(Mutex::new(ThreadFlags {
250            reporting_events: false,
251        }));
252
253        let mut client = VioSClient {
254            config,
255            jacks: Vec::new(),
256            streams: Vec::new(),
257            chmaps: Vec::new(),
258            controls: Vec::new(),
259            control_socket: Mutex::new(client_socket.into_inner()),
260            event_socket,
261            tx: IoBufferQueue::new(tx_socket, tx_shm_file)?,
262            rx: IoBufferQueue::new(rx_socket, rx_shm_file)?,
263            events: Arc::new(Mutex::new(VecDeque::new())),
264            event_notifier: Event::new().map_err(Error::EventCreateError)?,
265            tx_subscribers,
266            rx_subscribers,
267            recv_thread_state,
268            recv_thread: Mutex::new(None),
269            params: HashMap::new(),
270        };
271        client.request_and_cache_info()?;
272        Ok(client)
273    }
274
275    /// Get the number of jacks
276    pub fn num_jacks(&self) -> u32 {
277        self.config.jacks
278    }
279
280    /// Get the number of pcm streams
281    pub fn num_streams(&self) -> u32 {
282        self.config.streams
283    }
284
285    /// Get the number of channel maps
286    pub fn num_chmaps(&self) -> u32 {
287        self.config.chmaps
288    }
289
290    /// Get the configuration information on a jack
291    pub fn jack_info(&self, idx: u32) -> Option<virtio_snd_jack_info> {
292        self.jacks.get(idx as usize).copied()
293    }
294
295    /// Get the configuration information on a pcm stream
296    pub fn stream_info(&self, idx: u32) -> Option<virtio_snd_pcm_info> {
297        self.streams.get(idx as usize).cloned()
298    }
299
300    /// Get the configuration information on a channel map
301    pub fn chmap_info(&self, idx: u32) -> Option<virtio_snd_chmap_info> {
302        self.chmaps.get(idx as usize).copied()
303    }
304
305    /// Get the number of controls
306    pub fn num_controls(&self) -> u32 {
307        self.config.controls
308    }
309
310    /// Get the configuration information on a control
311    pub fn control_info(&self, idx: u32) -> Option<virtio_snd_ctl_info> {
312        self.controls.get(idx as usize).copied()
313    }
314
315    /// Set the value of a control
316    pub fn set_control(&self, control_id: u32, value: virtio_snd_ctl_value) -> Result<()> {
317        // Define a combined struct locally to ensure atomic transmission
318        #[repr(C)]
319        #[derive(Copy, Clone, Immutable, IntoBytes, KnownLayout)]
320        struct virtio_snd_ctl_set_value {
321            hdr: virtio_snd_ctl_hdr,
322            value: virtio_snd_ctl_value,
323        }
324
325        let control_socket_lock = self.control_socket.lock();
326        send_cmd(
327            &control_socket_lock,
328            virtio_snd_ctl_set_value {
329                hdr: virtio_snd_ctl_hdr {
330                    hdr: virtio_snd_hdr {
331                        code: VIRTIO_SND_R_CTL_WRITE.into(),
332                    },
333                    control_id: control_id.into(),
334                },
335                value,
336            },
337        )
338    }
339
340    /// Get the value of a control
341    pub fn get_control(&self, control_id: u32) -> Result<virtio_snd_ctl_value> {
342        let msg = virtio_snd_ctl_hdr {
343            hdr: virtio_snd_hdr {
344                code: VIRTIO_SND_R_CTL_READ.into(),
345            },
346            control_id: control_id.into(),
347        };
348        let control_socket_lock = self.control_socket.lock();
349        seq_socket_send(&control_socket_lock, msg.as_bytes())?;
350        let reply = control_socket_lock
351            .recv_as_vec()
352            .map_err(Error::ServerIOError)?;
353        let status_size = std::mem::size_of::<virtio_snd_hdr>();
354        let mut status: virtio_snd_hdr = Default::default();
355        if reply.len() < status_size {
356            return Err(Error::ProtocolError(
357                ProtocolErrorKind::UnexpectedMessageSize(status_size, reply.len()),
358            ));
359        }
360        status
361            .as_mut_bytes()
362            .copy_from_slice(&reply[0..status_size]);
363        if status.code.to_native() != VIRTIO_SND_S_OK {
364            return Err(Error::CommandFailed(status.code.to_native()));
365        }
366        if reply.len() != status_size + std::mem::size_of::<virtio_snd_ctl_value>() {
367            return Err(Error::ProtocolError(
368                ProtocolErrorKind::UnexpectedMessageSize(
369                    status_size + std::mem::size_of::<virtio_snd_ctl_value>(),
370                    reply.len(),
371                ),
372            ));
373        }
374        let mut value: virtio_snd_ctl_value = Default::default();
375        value.as_mut_bytes().copy_from_slice(&reply[status_size..]);
376        Ok(value)
377    }
378
379    /// Starts the background thread that receives release messages from the server. If the thread
380    /// was already started this function does nothing.
381    /// This thread must be started prior to attempting any stream IO operation or the calling
382    /// thread would block.
383    pub fn start_bg_thread(&self) -> Result<()> {
384        if self.recv_thread.lock().is_some() {
385            return Ok(());
386        }
387        let tx_socket = self.tx.try_clone_socket()?;
388        let rx_socket = self.rx.try_clone_socket()?;
389        let event_socket = self
390            .event_socket
391            .try_clone()
392            .map_err(Error::UnixSeqpacketDupError)?;
393        let mut opt = self.recv_thread.lock();
394        // The lock on recv_thread was released above to avoid holding more than one lock at a time
395        // while duplicating the fds. So we have to check the condition again.
396        if opt.is_none() {
397            let tx_subscribers = self.tx_subscribers.clone();
398            let rx_subscribers = self.rx_subscribers.clone();
399            let event_notifier = self
400                .event_notifier
401                .try_clone()
402                .map_err(Error::EventDupError)?;
403            let events = self.events.clone();
404            let recv_thread_state = self.recv_thread_state.clone();
405            *opt = Some(WorkerThread::start("shm_vios", move |kill_event| {
406                if let Err(e) = run_recv_thread(
407                    kill_event,
408                    tx_subscribers,
409                    rx_subscribers,
410                    event_notifier,
411                    events,
412                    recv_thread_state,
413                    tx_socket,
414                    rx_socket,
415                    event_socket,
416                ) {
417                    error!("virtio-snd shm_vios worker failed: {e:#}");
418                }
419            }));
420        }
421        Ok(())
422    }
423
424    /// Stops the background thread.
425    pub fn stop_bg_thread(&self) -> Result<()> {
426        if let Some(recv_thread) = self.recv_thread.lock().take() {
427            recv_thread.stop();
428        }
429        Ok(())
430    }
431
432    /// Gets an Event object that will trigger every time an event is received from the server
433    pub fn get_event_notifier(&self) -> Result<Event> {
434        // Let the background thread know that there is at least one consumer of events
435        self.recv_thread_state.lock().reporting_events = true;
436        self.event_notifier
437            .try_clone()
438            .map_err(Error::EventDupError)
439    }
440
441    /// Retrieves one event. Callers should have received a notification through the event notifier
442    /// before calling this function.
443    pub fn pop_event(&self) -> Option<virtio_snd_event> {
444        self.events.lock().pop_front()
445    }
446
447    /// Remap a jack. This should only be called if the jack announces support for the operation
448    /// through the features field in the corresponding virtio_snd_jack_info struct.
449    pub fn remap_jack(&self, jack_id: u32, association: u32, sequence: u32) -> Result<()> {
450        if jack_id >= self.config.jacks {
451            return Err(Error::InvalidJackId(jack_id));
452        }
453        let msg = virtio_snd_jack_remap {
454            hdr: virtio_snd_jack_hdr {
455                hdr: virtio_snd_hdr {
456                    code: VIRTIO_SND_R_JACK_REMAP.into(),
457                },
458                jack_id: jack_id.into(),
459            },
460            association: association.into(),
461            sequence: sequence.into(),
462        };
463        let control_socket_lock = self.control_socket.lock();
464        send_cmd(&control_socket_lock, msg)
465    }
466
467    /// Configures a stream with the given parameters.
468    pub fn set_stream_parameters(
469        &mut self,
470        stream_id: u32,
471        params: VioSStreamParams,
472    ) -> Result<()> {
473        self.streams
474            .get(stream_id as usize)
475            .ok_or(Error::InvalidStreamId(stream_id))?;
476        let raw_params: virtio_snd_pcm_set_params = (stream_id, params).into();
477        // Old value is not needed and is dropped
478        let _ = self.params.insert(stream_id, raw_params);
479        let control_socket_lock = self.control_socket.lock();
480        send_cmd(&control_socket_lock, raw_params)
481    }
482
483    /// Configures a stream with the given parameters.
484    pub fn set_stream_parameters_raw(
485        &mut self,
486        raw_params: virtio_snd_pcm_set_params,
487    ) -> Result<()> {
488        let stream_id = raw_params.hdr.stream_id.to_native();
489        // Old value is not needed and is dropped
490        let _ = self.params.insert(stream_id, raw_params);
491        self.streams
492            .get(stream_id as usize)
493            .ok_or(Error::InvalidStreamId(stream_id))?;
494        let control_socket_lock = self.control_socket.lock();
495        send_cmd(&control_socket_lock, raw_params)
496    }
497
498    /// Send the PREPARE_STREAM command to the server.
499    pub fn prepare_stream(&self, stream_id: u32) -> Result<()> {
500        self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_PREPARE)
501    }
502
503    /// Send the RELEASE_STREAM command to the server.
504    pub fn release_stream(&self, stream_id: u32) -> Result<()> {
505        self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_RELEASE)
506    }
507
508    /// Send the START_STREAM command to the server.
509    pub fn start_stream(&self, stream_id: u32) -> Result<()> {
510        self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_START)
511    }
512
513    /// Send the STOP_STREAM command to the server.
514    pub fn stop_stream(&self, stream_id: u32) -> Result<()> {
515        self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_STOP)
516    }
517
518    /// Send audio frames to the server. Blocks the calling thread until the server acknowledges
519    /// the data.
520    pub fn inject_audio_data<R, Cb: FnOnce(VolatileSlice) -> R>(
521        &self,
522        stream_id: u32,
523        size: usize,
524        callback: Cb,
525    ) -> Result<(u32, R)> {
526        if self
527            .streams
528            .get(stream_id as usize)
529            .ok_or(Error::InvalidStreamId(stream_id))?
530            .direction
531            != VIRTIO_SND_D_OUTPUT
532        {
533            return Err(Error::WrongDirection(VIRTIO_SND_D_OUTPUT));
534        }
535        self.streams
536            .get(stream_id as usize)
537            .ok_or(Error::InvalidStreamId(stream_id))?;
538        let dst_offset = self.tx.allocate_buffer(size)?;
539        let buffer_slice = self.tx.buffer_at(dst_offset, size)?;
540        let ret = callback(buffer_slice);
541        // Register to receive the status before sending the buffer to the server
542        let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
543        self.tx_subscribers.lock().insert(dst_offset, sender);
544        self.tx.send_buffer(stream_id, dst_offset, size)?;
545        let (_, latency) = await_status(receiver)?;
546        Ok((latency, ret))
547    }
548
549    /// Request audio frames from the server. It blocks until the data is available.
550    pub fn request_audio_data<R, Cb: FnOnce(&VolatileSlice) -> R>(
551        &self,
552        stream_id: u32,
553        size: usize,
554        callback: Cb,
555    ) -> Result<(u32, R)> {
556        if self
557            .streams
558            .get(stream_id as usize)
559            .ok_or(Error::InvalidStreamId(stream_id))?
560            .direction
561            != VIRTIO_SND_D_INPUT
562        {
563            return Err(Error::WrongDirection(VIRTIO_SND_D_INPUT));
564        }
565        let src_offset = self.rx.allocate_buffer(size)?;
566        // Register to receive the status before sending the buffer to the server
567        let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
568        self.rx_subscribers.lock().insert(src_offset, sender);
569        self.rx.send_buffer(stream_id, src_offset, size)?;
570        // Make sure no mutexes are held while awaiting for the buffer to be written to
571        let (recv_size, latency) = await_status(receiver)?;
572        let buffer_slice = self.rx.buffer_at(src_offset, recv_size)?;
573        Ok((latency, callback(&buffer_slice)))
574    }
575
576    /// Get a list of file descriptors used by the implementation.
577    pub fn keep_rds(&self) -> Vec<RawDescriptor> {
578        let control_desc = self.control_socket.lock().as_raw_descriptor();
579        let event_desc = self.event_socket.as_raw_descriptor();
580        let event_notifier = self.event_notifier.as_raw_descriptor();
581        let mut ret = vec![control_desc, event_desc, event_notifier];
582        ret.append(&mut self.tx.keep_rds());
583        ret.append(&mut self.rx.keep_rds());
584        ret
585    }
586
587    fn common_stream_op(&self, stream_id: u32, op: u32) -> Result<()> {
588        self.streams
589            .get(stream_id as usize)
590            .ok_or(Error::InvalidStreamId(stream_id))?;
591        let msg = virtio_snd_pcm_hdr {
592            hdr: virtio_snd_hdr { code: op.into() },
593            stream_id: stream_id.into(),
594        };
595        let control_socket_lock = self.control_socket.lock();
596        send_cmd(&control_socket_lock, msg)
597    }
598
599    fn request_and_cache_info(&mut self) -> Result<()> {
600        self.request_and_cache_jacks_info()?;
601        self.request_and_cache_streams_info()?;
602        self.request_and_cache_chmaps_info()?;
603        self.request_and_cache_controls_info()?;
604        Ok(())
605    }
606
607    fn request_info<T: Default + Copy + Clone>(
608        &self,
609        req_code: u32,
610        count: usize,
611    ) -> Result<Vec<T>> {
612        let info_size = std::mem::size_of::<T>();
613        let status_size = std::mem::size_of::<virtio_snd_hdr>();
614        let req = virtio_snd_query_info {
615            hdr: virtio_snd_hdr {
616                code: req_code.into(),
617            },
618            start_id: 0u32.into(),
619            count: (count as u32).into(),
620            size: (std::mem::size_of::<virtio_snd_query_info>() as u32).into(),
621        };
622        let control_socket_lock = self.control_socket.lock();
623        seq_socket_send(&control_socket_lock, req.as_bytes())?;
624        let reply = control_socket_lock
625            .recv_as_vec()
626            .map_err(Error::ServerIOError)?;
627        let mut status: virtio_snd_hdr = Default::default();
628        status
629            .as_mut_bytes()
630            .copy_from_slice(&reply[0..status_size]);
631        if status.code.to_native() != VIRTIO_SND_S_OK {
632            return Err(Error::CommandFailed(status.code.to_native()));
633        }
634        if reply.len() != status_size + count * info_size {
635            return Err(Error::ProtocolError(
636                ProtocolErrorKind::UnexpectedMessageSize(
637                    status_size + count * info_size,
638                    reply.len(),
639                ),
640            ));
641        }
642        Ok(reply[status_size..]
643            .chunks(info_size)
644            .map(|info_buffer| {
645                // SAFETY:
646                // `info_buffer` is guaranteed to be `info_size` bytes long by `chunks()`.
647                // `info_size` is exactly `std::mem::size_of::<T>()`.
648                // We use `read_unaligned` because the byte slice may not be properly aligned for
649                // `T`.
650                unsafe { std::ptr::read_unaligned(info_buffer.as_ptr() as *const T) }
651            })
652            .collect())
653    }
654
655    fn request_and_cache_jacks_info(&mut self) -> Result<()> {
656        let num_jacks = self.config.jacks as usize;
657        if num_jacks == 0 {
658            return Ok(());
659        }
660        self.jacks = self.request_info(VIRTIO_SND_R_JACK_INFO, num_jacks)?;
661        Ok(())
662    }
663
664    fn request_and_cache_streams_info(&mut self) -> Result<()> {
665        let num_streams = self.config.streams as usize;
666        if num_streams == 0 {
667            return Ok(());
668        }
669        self.streams = self.request_info(VIRTIO_SND_R_PCM_INFO, num_streams)?;
670        Ok(())
671    }
672
673    fn request_and_cache_chmaps_info(&mut self) -> Result<()> {
674        let num_chmaps = self.config.chmaps as usize;
675        if num_chmaps == 0 {
676            return Ok(());
677        }
678        self.chmaps = self.request_info(VIRTIO_SND_R_CHMAP_INFO, num_chmaps)?;
679        Ok(())
680    }
681
682    fn request_and_cache_controls_info(&mut self) -> Result<()> {
683        let num_controls = self.config.controls as usize;
684        if num_controls == 0 {
685            return Ok(());
686        }
687        let raw_controls: Vec<virtio_snd_ctl_info> =
688            self.request_info(VIRTIO_SND_R_CTL_INFO, num_controls)?;
689
690        let mut seen = std::collections::HashSet::new();
691
692        self.controls.clear();
693        for (i, info) in raw_controls.into_iter().enumerate() {
694            let name_bytes = info.name; // Copy name bytes
695            let index = info.index.to_native();
696            let key = (name_bytes, index);
697
698            if seen.contains(&key) {
699                let name = std::str::from_utf8(&info.name)
700                    .unwrap_or("invalid utf8")
701                    .trim_matches(char::from(0));
702                base::error!(
703                    "CTLS: encountered a duplicate vec_idx: {}, index: {}, name: {}",
704                    i,
705                    index,
706                    name
707                );
708                return Err(Error::DuplicatedControlId(index, name.to_owned()));
709            }
710            seen.insert(key);
711            self.controls.push(info);
712            let name = std::str::from_utf8(&info.name)
713                .unwrap_or("invalid utf8")
714                .trim_matches(char::from(0));
715            base::info!(
716                "CTLS: kept vec_idx: {}, index: {}, name: {}",
717                i,
718                index,
719                name
720            );
721        }
722
723        if self.controls.len() != self.config.controls as usize {
724            return Err(Error::ControlsMismatch(
725                self.config.controls,
726                self.controls.len() as u32,
727            ));
728        }
729        Ok(())
730    }
731
732    pub fn snapshot(&self) -> VioSClientSnapshot {
733        VioSClientSnapshot {
734            config: self.config,
735            jacks: self.jacks.clone(),
736            streams: self.streams.clone(),
737            chmaps: self.chmaps.clone(),
738            controls: self.controls.clone(),
739            params: self.params.clone(),
740        }
741    }
742
743    // Function called `restore` to signify it will happen as part of the snapshot/restore flow. No
744    // data is actually restored in the case of VioSClient.
745    pub fn restore(&mut self, data: VioSClientSnapshot) -> anyhow::Result<()> {
746        anyhow::ensure!(
747            data.config == self.config,
748            "config doesn't match on restore: expected: {:?}, got: {:?}",
749            data.config,
750            self.config
751        );
752        self.jacks = data.jacks;
753        self.streams = data.streams;
754        self.chmaps = data.chmaps;
755        self.controls = data.controls;
756        self.params = data.params;
757        Ok(())
758    }
759
760    pub fn restore_stream(&mut self, stream_id: u32, state: StreamState) -> Result<()> {
761        if let Some(params) = self.params.get(&stream_id).cloned() {
762            self.set_stream_parameters_raw(params)?;
763        }
764        match state {
765            StreamState::Started => {
766                // If state != prepared, start will always fail.
767                // As such, it is fine to only print the first error without returning, as the
768                // second action will then fail.
769                if let Err(e) = self.prepare_stream(stream_id) {
770                    error!("failed to prepare stream: {}", e);
771                };
772                self.start_stream(stream_id)
773            }
774            StreamState::Prepared => self.prepare_stream(stream_id),
775            // Nothing to do here
776            _ => Ok(()),
777        }
778    }
779}
780
781#[derive(Clone, Copy)]
782struct ThreadFlags {
783    reporting_events: bool,
784}
785
786#[derive(EventToken)]
787enum Token {
788    Notification,
789    TxBufferMsg,
790    RxBufferMsg,
791    EventMsg,
792}
793
794fn recv_buffer_status_msg(
795    socket: &UnixSeqpacket,
796    subscribers: &Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
797) -> Result<()> {
798    let mut msg: IoStatusMsg = Default::default();
799    let size = socket
800        .recv(msg.as_mut_bytes())
801        .map_err(Error::ServerIOError)?;
802    if size != std::mem::size_of::<IoStatusMsg>() {
803        return Err(Error::ProtocolError(
804            ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<IoStatusMsg>(), size),
805        ));
806    }
807    let mut status = msg.status.status.into();
808    if status == u32::MAX {
809        // Anyone waiting for this would continue to wait for as long as status is
810        // u32::MAX
811        status -= 1;
812    }
813    let latency = msg.status.latency_bytes.into();
814    let offset = msg.buffer_offset as usize;
815    let consumed_len = msg.consumed_len as usize;
816    let promise_opt = subscribers.lock().remove(&offset);
817    match promise_opt {
818        None => error!(
819            "Received an unexpected buffer status message: {}. This is a BUG!!",
820            offset
821        ),
822        Some(sender) => {
823            if let Err(e) = sender.send(BufferReleaseMsg {
824                status,
825                latency,
826                consumed_len,
827            }) {
828                error!("Failed to notify waiting thread: {:?}", e);
829            }
830        }
831    }
832    Ok(())
833}
834
835fn recv_event(socket: &UnixSeqpacket) -> Result<virtio_snd_event> {
836    let mut msg: virtio_snd_event = Default::default();
837    let size = socket
838        .recv(msg.as_mut_bytes())
839        .map_err(Error::ServerIOError)?;
840    if size != std::mem::size_of::<virtio_snd_event>() {
841        return Err(Error::ProtocolError(
842            ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<virtio_snd_event>(), size),
843        ));
844    }
845    Ok(msg)
846}
847
848fn run_recv_thread(
849    kill_event: Event,
850    tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
851    rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
852    event_notifier: Event,
853    event_queue: Arc<Mutex<VecDeque<virtio_snd_event>>>,
854    state: Arc<Mutex<ThreadFlags>>,
855    tx_socket: UnixSeqpacket,
856    rx_socket: UnixSeqpacket,
857    event_socket: UnixSeqpacket,
858) -> Result<()> {
859    let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[
860        (&tx_socket, Token::TxBufferMsg),
861        (&rx_socket, Token::RxBufferMsg),
862        (&event_socket, Token::EventMsg),
863        (&kill_event, Token::Notification),
864    ])
865    .map_err(Error::WaitContextCreateError)?;
866    let mut running = true;
867    while running {
868        let events = wait_ctx.wait().map_err(Error::WaitError)?;
869        for evt in events {
870            match evt.token {
871                Token::TxBufferMsg => recv_buffer_status_msg(&tx_socket, &tx_subscribers)?,
872                Token::RxBufferMsg => recv_buffer_status_msg(&rx_socket, &rx_subscribers)?,
873                Token::EventMsg => {
874                    let evt = recv_event(&event_socket)?;
875                    let state_cpy = *state.lock();
876                    if state_cpy.reporting_events {
877                        event_queue.lock().push_back(evt);
878                        event_notifier.signal().map_err(Error::EventWriteError)?;
879                    } // else just drop the events
880                }
881                Token::Notification => {
882                    // Just consume the notification and check for termination on the next
883                    // iteration
884                    if let Err(e) = kill_event.wait() {
885                        error!("Failed to consume notification from recv thread: {:?}", e);
886                    }
887                    running = false;
888                }
889            }
890        }
891    }
892    Ok(())
893}
894
895fn await_status(promise: Receiver<BufferReleaseMsg>) -> Result<(usize, u32)> {
896    let BufferReleaseMsg {
897        status,
898        latency,
899        consumed_len,
900    } = promise.recv().map_err(Error::BufferStatusSenderLost)?;
901    if status == VIRTIO_SND_S_OK {
902        Ok((consumed_len, latency))
903    } else {
904        Err(Error::IOBufferError(status))
905    }
906}
907
908struct IoBufferQueue {
909    socket: UnixSeqpacket,
910    file: File,
911    mmap: MemoryMapping,
912    size: usize,
913    next: Mutex<usize>,
914}
915
916impl IoBufferQueue {
917    fn new(socket: UnixSeqpacket, mut file: File) -> Result<IoBufferQueue> {
918        let size = file.seek(SeekFrom::End(0)).map_err(Error::FileSizeError)? as usize;
919
920        let mmap = MemoryMappingBuilder::new(size)
921            .from_file(&file)
922            .build()
923            .map_err(Error::ServerMmapError)?;
924
925        Ok(IoBufferQueue {
926            socket,
927            file,
928            mmap,
929            size,
930            next: Mutex::new(0),
931        })
932    }
933
934    fn allocate_buffer(&self, size: usize) -> Result<usize> {
935        if size > self.size {
936            return Err(Error::OutOfSpace);
937        }
938        let mut next_lock = self.next.lock();
939        let offset = if size > self.size - *next_lock {
940            // Can't fit the new buffer at the end of the area, so put it at the beginning
941            0
942        } else {
943            *next_lock
944        };
945        *next_lock = offset + size;
946        Ok(offset)
947    }
948
949    fn buffer_at(&self, offset: usize, len: usize) -> Result<VolatileSlice> {
950        self.mmap
951            .get_slice(offset, len)
952            .map_err(Error::VolatileMemoryError)
953    }
954
955    fn try_clone_socket(&self) -> Result<UnixSeqpacket> {
956        self.socket
957            .try_clone()
958            .map_err(Error::UnixSeqpacketDupError)
959    }
960
961    fn send_buffer(&self, stream_id: u32, offset: usize, size: usize) -> Result<()> {
962        let msg = IoTransferMsg::new(stream_id, offset, size);
963        seq_socket_send(&self.socket, msg.as_bytes())
964    }
965
966    fn keep_rds(&self) -> Vec<RawDescriptor> {
967        vec![
968            self.file.as_raw_descriptor(),
969            self.socket.as_raw_descriptor(),
970        ]
971    }
972}
973
974/// Groups the parameters used to configure a stream prior to using it.
975pub struct VioSStreamParams {
976    pub buffer_bytes: u32,
977    pub period_bytes: u32,
978    pub features: u32,
979    pub channels: u8,
980    pub format: u8,
981    pub rate: u8,
982}
983
984impl From<(u32, VioSStreamParams)> for virtio_snd_pcm_set_params {
985    fn from(val: (u32, VioSStreamParams)) -> Self {
986        virtio_snd_pcm_set_params {
987            hdr: virtio_snd_pcm_hdr {
988                hdr: virtio_snd_hdr {
989                    code: VIRTIO_SND_R_PCM_SET_PARAMS.into(),
990                },
991                stream_id: val.0.into(),
992            },
993            buffer_bytes: val.1.buffer_bytes.into(),
994            period_bytes: val.1.period_bytes.into(),
995            features: val.1.features.into(),
996            channels: val.1.channels,
997            format: val.1.format,
998            rate: val.1.rate,
999            padding: 0u8,
1000        }
1001    }
1002}
1003
1004fn send_cmd<T: Immutable + IntoBytes>(control_socket: &UnixSeqpacket, data: T) -> Result<()> {
1005    seq_socket_send(control_socket, data.as_bytes())?;
1006    recv_cmd_status(control_socket)
1007}
1008
1009fn recv_cmd_status(control_socket: &UnixSeqpacket) -> Result<()> {
1010    let mut status: virtio_snd_hdr = Default::default();
1011    control_socket
1012        .recv(status.as_mut_bytes())
1013        .map_err(Error::ServerIOError)?;
1014    if status.code.to_native() == VIRTIO_SND_S_OK {
1015        Ok(())
1016    } else {
1017        Err(Error::CommandFailed(status.code.to_native()))
1018    }
1019}
1020
1021fn seq_socket_send(socket: &UnixSeqpacket, data: &[u8]) -> Result<()> {
1022    loop {
1023        let send_res = socket.send(data);
1024        if let Err(e) = send_res {
1025            match e.kind() {
1026                // Retry if interrupted
1027                IOErrorKind::Interrupted => continue,
1028                _ => return Err(Error::ServerIOError(e)),
1029            }
1030        }
1031        // Success
1032        break;
1033    }
1034    Ok(())
1035}
1036
1037#[repr(C)]
1038#[derive(
1039    Copy,
1040    Clone,
1041    Default,
1042    FromBytes,
1043    Immutable,
1044    IntoBytes,
1045    KnownLayout,
1046    Serialize,
1047    Deserialize,
1048    PartialEq,
1049    Eq,
1050    Debug,
1051)]
1052struct VioSConfig {
1053    // Fields below belong to version 2
1054    version: u32,
1055    jacks: u32,
1056    streams: u32,
1057    chmaps: u32,
1058
1059    // Fields below belong to version 3
1060    controls: u32,
1061}
1062
1063const VIOS_SIZE_V2: usize = 4 * std::mem::size_of::<u32>();
1064const VIOS_SIZE_V3: usize = std::mem::size_of::<VioSConfig>();
1065
1066struct BufferReleaseMsg {
1067    status: u32,
1068    latency: u32,
1069    consumed_len: usize,
1070}
1071
1072#[repr(C)]
1073#[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
1074struct IoTransferMsg {
1075    io_xfer: virtio_snd_pcm_xfer,
1076    buffer_offset: u32,
1077    buffer_len: u32,
1078}
1079
1080impl IoTransferMsg {
1081    fn new(stream_id: u32, buffer_offset: usize, buffer_len: usize) -> IoTransferMsg {
1082        IoTransferMsg {
1083            io_xfer: virtio_snd_pcm_xfer {
1084                stream_id: stream_id.into(),
1085            },
1086            buffer_offset: buffer_offset as u32,
1087            buffer_len: buffer_len as u32,
1088        }
1089    }
1090}
1091
1092#[repr(C)]
1093#[derive(Copy, Clone, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
1094struct IoStatusMsg {
1095    status: virtio_snd_pcm_status,
1096    buffer_offset: u32,
1097    consumed_len: u32,
1098}