1use std::collections::HashMap;
6use std::collections::VecDeque;
7use std::fs::File;
8use std::io::Error as IOError;
9use std::io::ErrorKind as IOErrorKind;
10use std::io::Seek;
11use std::io::SeekFrom;
12use std::path::Path;
13use std::path::PathBuf;
14use std::sync::mpsc::channel;
15use std::sync::mpsc::Receiver;
16use std::sync::mpsc::RecvError;
17use std::sync::mpsc::Sender;
18use std::sync::Arc;
19
20use base::error;
21use base::AsRawDescriptor;
22use base::Error as BaseError;
23use base::Event;
24use base::EventToken;
25use base::FromRawDescriptor;
26use base::IntoRawDescriptor;
27use base::MemoryMapping;
28use base::MemoryMappingBuilder;
29use base::MmapError;
30use base::RawDescriptor;
31use base::SafeDescriptor;
32use base::ScmSocket;
33use base::UnixSeqpacket;
34use base::VolatileMemory;
35use base::VolatileMemoryError;
36use base::VolatileSlice;
37use base::WaitContext;
38use base::WorkerThread;
39use remain::sorted;
40use serde::Deserialize;
41use serde::Serialize;
42use sync::Mutex;
43use thiserror::Error as ThisError;
44use zerocopy::FromBytes;
45use zerocopy::Immutable;
46use zerocopy::IntoBytes;
47use zerocopy::KnownLayout;
48
49use crate::virtio::snd::constants::*;
50use crate::virtio::snd::layout::*;
51use crate::virtio::snd::vios_backend::streams::StreamState;
52
53pub type Result<T> = std::result::Result<T, Error>;
54
55#[sorted]
56#[derive(ThisError, Debug)]
57pub enum Error {
58 #[error("Error memory mapping client_shm: {0}")]
59 BaseMmapError(BaseError),
60 #[error("Sender was dropped without sending buffer status, the recv thread may have exited")]
61 BufferStatusSenderLost(RecvError),
62 #[error("Command failed with status {0}")]
63 CommandFailed(u32),
64 #[error("Expected {0} controls, received {1}")]
65 ControlsMismatch(u32, u32),
66 #[error("Error duplicating file descriptor: {0}")]
67 DupError(BaseError),
68 #[error("Duplicated control index: {0} name: '{1}'")]
69 DuplicatedControlId(u32, String),
70 #[error("Failed to create Recv event: {0}")]
71 EventCreateError(BaseError),
72 #[error("Failed to dup Recv event: {0}")]
73 EventDupError(BaseError),
74 #[error("Failed to signal event: {0}")]
75 EventWriteError(BaseError),
76 #[error("Failed to get size of tx shared memory: {0}")]
77 FileSizeError(IOError),
78 #[error("Error accessing guest's shared memory: {0}")]
79 GuestMmapError(MmapError),
80 #[error("No jack with id {0}")]
81 InvalidJackId(u32),
82 #[error("No stream with id {0}")]
83 InvalidStreamId(u32),
84 #[error("IO buffer operation failed: status = {0}")]
85 IOBufferError(u32),
86 #[error("No PCM streams available")]
87 NoStreamsAvailable,
88 #[error("Insuficient space for the new buffer in the queue's buffer area")]
89 OutOfSpace,
90 #[error("Platform not supported")]
91 PlatformNotSupported,
92 #[error("{0}")]
93 ProtocolError(ProtocolErrorKind),
94 #[error("Failed to connect to VioS server {1}: {0:?}")]
95 ServerConnectionError(IOError, PathBuf),
96 #[error("Failed to communicate with VioS server: {0:?}")]
97 ServerError(IOError),
98 #[error("Failed to communicate with VioS server: {0:?}")]
99 ServerIOError(IOError),
100 #[error("Error accessing VioS server's shared memory: {0}")]
101 ServerMmapError(MmapError),
102 #[error("Failed to duplicate UnixSeqpacket: {0}")]
103 UnixSeqpacketDupError(IOError),
104 #[error("Unsupported frame rate: {0}")]
105 UnsupportedFrameRate(u32),
106 #[error("Error accessing volatile memory: {0}")]
107 VolatileMemoryError(VolatileMemoryError),
108 #[error("Failed to create Recv thread's WaitContext: {0}")]
109 WaitContextCreateError(BaseError),
110 #[error("Error waiting for events")]
111 WaitError(BaseError),
112 #[error("Invalid operation for stream direction: {0}")]
113 WrongDirection(u8),
114 #[error("Set saved params should only be used while restoring the device")]
115 WrongSetParams,
116}
117
118#[derive(ThisError, Debug)]
119pub enum ProtocolErrorKind {
120 #[error("The server sent a config of the wrong size: {0}")]
121 UnexpectedConfigSize(usize),
122 #[error("Received {1} file descriptors from the server, expected {0}")]
123 UnexpectedNumberOfFileDescriptors(usize, usize), #[error("Server's version ({0}) doesn't match client's")]
125 VersionMismatch(u32),
126 #[error("Received a msg with an unexpected size: expected {0}, received {1}")]
127 UnexpectedMessageSize(usize, usize), }
129
130pub struct VioSClient {
136 config: VioSConfig,
139 jacks: Vec<virtio_snd_jack_info>,
140 streams: Vec<virtio_snd_pcm_info>,
141 chmaps: Vec<virtio_snd_chmap_info>,
142 controls: Vec<virtio_snd_ctl_info>,
143 control_socket: Mutex<UnixSeqpacket>,
147 event_socket: UnixSeqpacket,
148 tx: IoBufferQueue,
150 rx: IoBufferQueue,
151 events: Arc<Mutex<VecDeque<virtio_snd_event>>>,
153 event_notifier: Event,
154 tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
156 rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
157 recv_thread_state: Arc<Mutex<ThreadFlags>>,
158 recv_thread: Mutex<Option<WorkerThread<()>>>,
159 params: HashMap<u32, virtio_snd_pcm_set_params>,
162}
163
164#[derive(Serialize, Deserialize)]
165pub struct VioSClientSnapshot {
166 config: VioSConfig,
167 jacks: Vec<virtio_snd_jack_info>,
168 streams: Vec<virtio_snd_pcm_info>,
169 chmaps: Vec<virtio_snd_chmap_info>,
170 controls: Vec<virtio_snd_ctl_info>,
171 params: HashMap<u32, virtio_snd_pcm_set_params>,
172}
173
174impl VioSClient {
175 pub fn try_new<P: AsRef<Path>>(server: P) -> Result<VioSClient> {
177 let client_socket = ScmSocket::try_from(
178 UnixSeqpacket::connect(server.as_ref())
179 .map_err(|e| Error::ServerConnectionError(e, server.as_ref().into()))?,
180 )
181 .map_err(|e| Error::ServerConnectionError(e, server.as_ref().into()))?;
182
183 let mut config: VioSConfig = Default::default();
184 const NUM_FDS: usize = 5;
185 let (recv_size, mut safe_fds) = client_socket
186 .recv_with_fds(config.as_mut_bytes(), NUM_FDS)
187 .map_err(Error::ServerError)?;
188
189 match config.version {
190 2 => {
191 if recv_size != VIOS_SIZE_V2 {
192 return Err(Error::ProtocolError(
193 ProtocolErrorKind::UnexpectedConfigSize(recv_size),
194 ));
195 }
196 }
197 3 => {
198 if recv_size != VIOS_SIZE_V3 {
199 return Err(Error::ProtocolError(
200 ProtocolErrorKind::UnexpectedConfigSize(recv_size),
201 ));
202 }
203 }
204 _ => {
205 return Err(Error::ProtocolError(ProtocolErrorKind::VersionMismatch(
206 config.version,
207 )));
208 }
209 }
210
211 fn pop<T: FromRawDescriptor>(
212 safe_fds: &mut Vec<SafeDescriptor>,
213 expected: usize,
214 received: usize,
215 ) -> Result<T> {
216 unsafe {
219 Ok(T::from_raw_descriptor(
220 safe_fds
221 .pop()
222 .ok_or(Error::ProtocolError(
223 ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(
224 expected, received,
225 ),
226 ))?
227 .into_raw_descriptor(),
228 ))
229 }
230 }
231
232 let fd_count = safe_fds.len();
233 let rx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
234 let tx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
235 let rx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
236 let tx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
237 let event_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
238
239 if !safe_fds.is_empty() {
240 return Err(Error::ProtocolError(
241 ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(NUM_FDS, fd_count),
242 ));
243 }
244
245 let tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
246 Arc::new(Mutex::new(HashMap::new()));
247 let rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
248 Arc::new(Mutex::new(HashMap::new()));
249 let recv_thread_state = Arc::new(Mutex::new(ThreadFlags {
250 reporting_events: false,
251 }));
252
253 let mut client = VioSClient {
254 config,
255 jacks: Vec::new(),
256 streams: Vec::new(),
257 chmaps: Vec::new(),
258 controls: Vec::new(),
259 control_socket: Mutex::new(client_socket.into_inner()),
260 event_socket,
261 tx: IoBufferQueue::new(tx_socket, tx_shm_file)?,
262 rx: IoBufferQueue::new(rx_socket, rx_shm_file)?,
263 events: Arc::new(Mutex::new(VecDeque::new())),
264 event_notifier: Event::new().map_err(Error::EventCreateError)?,
265 tx_subscribers,
266 rx_subscribers,
267 recv_thread_state,
268 recv_thread: Mutex::new(None),
269 params: HashMap::new(),
270 };
271 client.request_and_cache_info()?;
272 Ok(client)
273 }
274
275 pub fn num_jacks(&self) -> u32 {
277 self.config.jacks
278 }
279
280 pub fn num_streams(&self) -> u32 {
282 self.config.streams
283 }
284
285 pub fn num_chmaps(&self) -> u32 {
287 self.config.chmaps
288 }
289
290 pub fn jack_info(&self, idx: u32) -> Option<virtio_snd_jack_info> {
292 self.jacks.get(idx as usize).copied()
293 }
294
295 pub fn stream_info(&self, idx: u32) -> Option<virtio_snd_pcm_info> {
297 self.streams.get(idx as usize).cloned()
298 }
299
300 pub fn chmap_info(&self, idx: u32) -> Option<virtio_snd_chmap_info> {
302 self.chmaps.get(idx as usize).copied()
303 }
304
305 pub fn num_controls(&self) -> u32 {
307 self.config.controls
308 }
309
310 pub fn control_info(&self, idx: u32) -> Option<virtio_snd_ctl_info> {
312 self.controls.get(idx as usize).copied()
313 }
314
315 pub fn set_control(&self, control_id: u32, value: virtio_snd_ctl_value) -> Result<()> {
317 #[repr(C)]
319 #[derive(Copy, Clone, Immutable, IntoBytes, KnownLayout)]
320 struct virtio_snd_ctl_set_value {
321 hdr: virtio_snd_ctl_hdr,
322 value: virtio_snd_ctl_value,
323 }
324
325 let control_socket_lock = self.control_socket.lock();
326 send_cmd(
327 &control_socket_lock,
328 virtio_snd_ctl_set_value {
329 hdr: virtio_snd_ctl_hdr {
330 hdr: virtio_snd_hdr {
331 code: VIRTIO_SND_R_CTL_WRITE.into(),
332 },
333 control_id: control_id.into(),
334 },
335 value,
336 },
337 )
338 }
339
340 pub fn get_control(&self, control_id: u32) -> Result<virtio_snd_ctl_value> {
342 let msg = virtio_snd_ctl_hdr {
343 hdr: virtio_snd_hdr {
344 code: VIRTIO_SND_R_CTL_READ.into(),
345 },
346 control_id: control_id.into(),
347 };
348 let control_socket_lock = self.control_socket.lock();
349 seq_socket_send(&control_socket_lock, msg.as_bytes())?;
350 let reply = control_socket_lock
351 .recv_as_vec()
352 .map_err(Error::ServerIOError)?;
353 let status_size = std::mem::size_of::<virtio_snd_hdr>();
354 let mut status: virtio_snd_hdr = Default::default();
355 if reply.len() < status_size {
356 return Err(Error::ProtocolError(
357 ProtocolErrorKind::UnexpectedMessageSize(status_size, reply.len()),
358 ));
359 }
360 status
361 .as_mut_bytes()
362 .copy_from_slice(&reply[0..status_size]);
363 if status.code.to_native() != VIRTIO_SND_S_OK {
364 return Err(Error::CommandFailed(status.code.to_native()));
365 }
366 if reply.len() != status_size + std::mem::size_of::<virtio_snd_ctl_value>() {
367 return Err(Error::ProtocolError(
368 ProtocolErrorKind::UnexpectedMessageSize(
369 status_size + std::mem::size_of::<virtio_snd_ctl_value>(),
370 reply.len(),
371 ),
372 ));
373 }
374 let mut value: virtio_snd_ctl_value = Default::default();
375 value.as_mut_bytes().copy_from_slice(&reply[status_size..]);
376 Ok(value)
377 }
378
379 pub fn start_bg_thread(&self) -> Result<()> {
384 if self.recv_thread.lock().is_some() {
385 return Ok(());
386 }
387 let tx_socket = self.tx.try_clone_socket()?;
388 let rx_socket = self.rx.try_clone_socket()?;
389 let event_socket = self
390 .event_socket
391 .try_clone()
392 .map_err(Error::UnixSeqpacketDupError)?;
393 let mut opt = self.recv_thread.lock();
394 if opt.is_none() {
397 let tx_subscribers = self.tx_subscribers.clone();
398 let rx_subscribers = self.rx_subscribers.clone();
399 let event_notifier = self
400 .event_notifier
401 .try_clone()
402 .map_err(Error::EventDupError)?;
403 let events = self.events.clone();
404 let recv_thread_state = self.recv_thread_state.clone();
405 *opt = Some(WorkerThread::start("shm_vios", move |kill_event| {
406 if let Err(e) = run_recv_thread(
407 kill_event,
408 tx_subscribers,
409 rx_subscribers,
410 event_notifier,
411 events,
412 recv_thread_state,
413 tx_socket,
414 rx_socket,
415 event_socket,
416 ) {
417 error!("virtio-snd shm_vios worker failed: {e:#}");
418 }
419 }));
420 }
421 Ok(())
422 }
423
424 pub fn stop_bg_thread(&self) -> Result<()> {
426 if let Some(recv_thread) = self.recv_thread.lock().take() {
427 recv_thread.stop();
428 }
429 Ok(())
430 }
431
432 pub fn get_event_notifier(&self) -> Result<Event> {
434 self.recv_thread_state.lock().reporting_events = true;
436 self.event_notifier
437 .try_clone()
438 .map_err(Error::EventDupError)
439 }
440
441 pub fn pop_event(&self) -> Option<virtio_snd_event> {
444 self.events.lock().pop_front()
445 }
446
447 pub fn remap_jack(&self, jack_id: u32, association: u32, sequence: u32) -> Result<()> {
450 if jack_id >= self.config.jacks {
451 return Err(Error::InvalidJackId(jack_id));
452 }
453 let msg = virtio_snd_jack_remap {
454 hdr: virtio_snd_jack_hdr {
455 hdr: virtio_snd_hdr {
456 code: VIRTIO_SND_R_JACK_REMAP.into(),
457 },
458 jack_id: jack_id.into(),
459 },
460 association: association.into(),
461 sequence: sequence.into(),
462 };
463 let control_socket_lock = self.control_socket.lock();
464 send_cmd(&control_socket_lock, msg)
465 }
466
467 pub fn set_stream_parameters(
469 &mut self,
470 stream_id: u32,
471 params: VioSStreamParams,
472 ) -> Result<()> {
473 self.streams
474 .get(stream_id as usize)
475 .ok_or(Error::InvalidStreamId(stream_id))?;
476 let raw_params: virtio_snd_pcm_set_params = (stream_id, params).into();
477 let _ = self.params.insert(stream_id, raw_params);
479 let control_socket_lock = self.control_socket.lock();
480 send_cmd(&control_socket_lock, raw_params)
481 }
482
483 pub fn set_stream_parameters_raw(
485 &mut self,
486 raw_params: virtio_snd_pcm_set_params,
487 ) -> Result<()> {
488 let stream_id = raw_params.hdr.stream_id.to_native();
489 let _ = self.params.insert(stream_id, raw_params);
491 self.streams
492 .get(stream_id as usize)
493 .ok_or(Error::InvalidStreamId(stream_id))?;
494 let control_socket_lock = self.control_socket.lock();
495 send_cmd(&control_socket_lock, raw_params)
496 }
497
498 pub fn prepare_stream(&self, stream_id: u32) -> Result<()> {
500 self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_PREPARE)
501 }
502
503 pub fn release_stream(&self, stream_id: u32) -> Result<()> {
505 self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_RELEASE)
506 }
507
508 pub fn start_stream(&self, stream_id: u32) -> Result<()> {
510 self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_START)
511 }
512
513 pub fn stop_stream(&self, stream_id: u32) -> Result<()> {
515 self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_STOP)
516 }
517
518 pub fn inject_audio_data<R, Cb: FnOnce(VolatileSlice) -> R>(
521 &self,
522 stream_id: u32,
523 size: usize,
524 callback: Cb,
525 ) -> Result<(u32, R)> {
526 if self
527 .streams
528 .get(stream_id as usize)
529 .ok_or(Error::InvalidStreamId(stream_id))?
530 .direction
531 != VIRTIO_SND_D_OUTPUT
532 {
533 return Err(Error::WrongDirection(VIRTIO_SND_D_OUTPUT));
534 }
535 self.streams
536 .get(stream_id as usize)
537 .ok_or(Error::InvalidStreamId(stream_id))?;
538 let dst_offset = self.tx.allocate_buffer(size)?;
539 let buffer_slice = self.tx.buffer_at(dst_offset, size)?;
540 let ret = callback(buffer_slice);
541 let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
543 self.tx_subscribers.lock().insert(dst_offset, sender);
544 self.tx.send_buffer(stream_id, dst_offset, size)?;
545 let (_, latency) = await_status(receiver)?;
546 Ok((latency, ret))
547 }
548
549 pub fn request_audio_data<R, Cb: FnOnce(&VolatileSlice) -> R>(
551 &self,
552 stream_id: u32,
553 size: usize,
554 callback: Cb,
555 ) -> Result<(u32, R)> {
556 if self
557 .streams
558 .get(stream_id as usize)
559 .ok_or(Error::InvalidStreamId(stream_id))?
560 .direction
561 != VIRTIO_SND_D_INPUT
562 {
563 return Err(Error::WrongDirection(VIRTIO_SND_D_INPUT));
564 }
565 let src_offset = self.rx.allocate_buffer(size)?;
566 let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
568 self.rx_subscribers.lock().insert(src_offset, sender);
569 self.rx.send_buffer(stream_id, src_offset, size)?;
570 let (recv_size, latency) = await_status(receiver)?;
572 let buffer_slice = self.rx.buffer_at(src_offset, recv_size)?;
573 Ok((latency, callback(&buffer_slice)))
574 }
575
576 pub fn keep_rds(&self) -> Vec<RawDescriptor> {
578 let control_desc = self.control_socket.lock().as_raw_descriptor();
579 let event_desc = self.event_socket.as_raw_descriptor();
580 let event_notifier = self.event_notifier.as_raw_descriptor();
581 let mut ret = vec![control_desc, event_desc, event_notifier];
582 ret.append(&mut self.tx.keep_rds());
583 ret.append(&mut self.rx.keep_rds());
584 ret
585 }
586
587 fn common_stream_op(&self, stream_id: u32, op: u32) -> Result<()> {
588 self.streams
589 .get(stream_id as usize)
590 .ok_or(Error::InvalidStreamId(stream_id))?;
591 let msg = virtio_snd_pcm_hdr {
592 hdr: virtio_snd_hdr { code: op.into() },
593 stream_id: stream_id.into(),
594 };
595 let control_socket_lock = self.control_socket.lock();
596 send_cmd(&control_socket_lock, msg)
597 }
598
599 fn request_and_cache_info(&mut self) -> Result<()> {
600 self.request_and_cache_jacks_info()?;
601 self.request_and_cache_streams_info()?;
602 self.request_and_cache_chmaps_info()?;
603 self.request_and_cache_controls_info()?;
604 Ok(())
605 }
606
607 fn request_info<T: Default + Copy + Clone>(
608 &self,
609 req_code: u32,
610 count: usize,
611 ) -> Result<Vec<T>> {
612 let info_size = std::mem::size_of::<T>();
613 let status_size = std::mem::size_of::<virtio_snd_hdr>();
614 let req = virtio_snd_query_info {
615 hdr: virtio_snd_hdr {
616 code: req_code.into(),
617 },
618 start_id: 0u32.into(),
619 count: (count as u32).into(),
620 size: (std::mem::size_of::<virtio_snd_query_info>() as u32).into(),
621 };
622 let control_socket_lock = self.control_socket.lock();
623 seq_socket_send(&control_socket_lock, req.as_bytes())?;
624 let reply = control_socket_lock
625 .recv_as_vec()
626 .map_err(Error::ServerIOError)?;
627 let mut status: virtio_snd_hdr = Default::default();
628 status
629 .as_mut_bytes()
630 .copy_from_slice(&reply[0..status_size]);
631 if status.code.to_native() != VIRTIO_SND_S_OK {
632 return Err(Error::CommandFailed(status.code.to_native()));
633 }
634 if reply.len() != status_size + count * info_size {
635 return Err(Error::ProtocolError(
636 ProtocolErrorKind::UnexpectedMessageSize(
637 status_size + count * info_size,
638 reply.len(),
639 ),
640 ));
641 }
642 Ok(reply[status_size..]
643 .chunks(info_size)
644 .map(|info_buffer| {
645 unsafe { std::ptr::read_unaligned(info_buffer.as_ptr() as *const T) }
651 })
652 .collect())
653 }
654
655 fn request_and_cache_jacks_info(&mut self) -> Result<()> {
656 let num_jacks = self.config.jacks as usize;
657 if num_jacks == 0 {
658 return Ok(());
659 }
660 self.jacks = self.request_info(VIRTIO_SND_R_JACK_INFO, num_jacks)?;
661 Ok(())
662 }
663
664 fn request_and_cache_streams_info(&mut self) -> Result<()> {
665 let num_streams = self.config.streams as usize;
666 if num_streams == 0 {
667 return Ok(());
668 }
669 self.streams = self.request_info(VIRTIO_SND_R_PCM_INFO, num_streams)?;
670 Ok(())
671 }
672
673 fn request_and_cache_chmaps_info(&mut self) -> Result<()> {
674 let num_chmaps = self.config.chmaps as usize;
675 if num_chmaps == 0 {
676 return Ok(());
677 }
678 self.chmaps = self.request_info(VIRTIO_SND_R_CHMAP_INFO, num_chmaps)?;
679 Ok(())
680 }
681
682 fn request_and_cache_controls_info(&mut self) -> Result<()> {
683 let num_controls = self.config.controls as usize;
684 if num_controls == 0 {
685 return Ok(());
686 }
687 let raw_controls: Vec<virtio_snd_ctl_info> =
688 self.request_info(VIRTIO_SND_R_CTL_INFO, num_controls)?;
689
690 let mut seen = std::collections::HashSet::new();
691
692 self.controls.clear();
693 for (i, info) in raw_controls.into_iter().enumerate() {
694 let name_bytes = info.name; let index = info.index.to_native();
696 let key = (name_bytes, index);
697
698 if seen.contains(&key) {
699 let name = std::str::from_utf8(&info.name)
700 .unwrap_or("invalid utf8")
701 .trim_matches(char::from(0));
702 base::error!(
703 "CTLS: encountered a duplicate vec_idx: {}, index: {}, name: {}",
704 i,
705 index,
706 name
707 );
708 return Err(Error::DuplicatedControlId(index, name.to_owned()));
709 }
710 seen.insert(key);
711 self.controls.push(info);
712 let name = std::str::from_utf8(&info.name)
713 .unwrap_or("invalid utf8")
714 .trim_matches(char::from(0));
715 base::info!(
716 "CTLS: kept vec_idx: {}, index: {}, name: {}",
717 i,
718 index,
719 name
720 );
721 }
722
723 if self.controls.len() != self.config.controls as usize {
724 return Err(Error::ControlsMismatch(
725 self.config.controls,
726 self.controls.len() as u32,
727 ));
728 }
729 Ok(())
730 }
731
732 pub fn snapshot(&self) -> VioSClientSnapshot {
733 VioSClientSnapshot {
734 config: self.config,
735 jacks: self.jacks.clone(),
736 streams: self.streams.clone(),
737 chmaps: self.chmaps.clone(),
738 controls: self.controls.clone(),
739 params: self.params.clone(),
740 }
741 }
742
743 pub fn restore(&mut self, data: VioSClientSnapshot) -> anyhow::Result<()> {
746 anyhow::ensure!(
747 data.config == self.config,
748 "config doesn't match on restore: expected: {:?}, got: {:?}",
749 data.config,
750 self.config
751 );
752 self.jacks = data.jacks;
753 self.streams = data.streams;
754 self.chmaps = data.chmaps;
755 self.controls = data.controls;
756 self.params = data.params;
757 Ok(())
758 }
759
760 pub fn restore_stream(&mut self, stream_id: u32, state: StreamState) -> Result<()> {
761 if let Some(params) = self.params.get(&stream_id).cloned() {
762 self.set_stream_parameters_raw(params)?;
763 }
764 match state {
765 StreamState::Started => {
766 if let Err(e) = self.prepare_stream(stream_id) {
770 error!("failed to prepare stream: {}", e);
771 };
772 self.start_stream(stream_id)
773 }
774 StreamState::Prepared => self.prepare_stream(stream_id),
775 _ => Ok(()),
777 }
778 }
779}
780
781#[derive(Clone, Copy)]
782struct ThreadFlags {
783 reporting_events: bool,
784}
785
786#[derive(EventToken)]
787enum Token {
788 Notification,
789 TxBufferMsg,
790 RxBufferMsg,
791 EventMsg,
792}
793
794fn recv_buffer_status_msg(
795 socket: &UnixSeqpacket,
796 subscribers: &Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
797) -> Result<()> {
798 let mut msg: IoStatusMsg = Default::default();
799 let size = socket
800 .recv(msg.as_mut_bytes())
801 .map_err(Error::ServerIOError)?;
802 if size != std::mem::size_of::<IoStatusMsg>() {
803 return Err(Error::ProtocolError(
804 ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<IoStatusMsg>(), size),
805 ));
806 }
807 let mut status = msg.status.status.into();
808 if status == u32::MAX {
809 status -= 1;
812 }
813 let latency = msg.status.latency_bytes.into();
814 let offset = msg.buffer_offset as usize;
815 let consumed_len = msg.consumed_len as usize;
816 let promise_opt = subscribers.lock().remove(&offset);
817 match promise_opt {
818 None => error!(
819 "Received an unexpected buffer status message: {}. This is a BUG!!",
820 offset
821 ),
822 Some(sender) => {
823 if let Err(e) = sender.send(BufferReleaseMsg {
824 status,
825 latency,
826 consumed_len,
827 }) {
828 error!("Failed to notify waiting thread: {:?}", e);
829 }
830 }
831 }
832 Ok(())
833}
834
835fn recv_event(socket: &UnixSeqpacket) -> Result<virtio_snd_event> {
836 let mut msg: virtio_snd_event = Default::default();
837 let size = socket
838 .recv(msg.as_mut_bytes())
839 .map_err(Error::ServerIOError)?;
840 if size != std::mem::size_of::<virtio_snd_event>() {
841 return Err(Error::ProtocolError(
842 ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<virtio_snd_event>(), size),
843 ));
844 }
845 Ok(msg)
846}
847
848fn run_recv_thread(
849 kill_event: Event,
850 tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
851 rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
852 event_notifier: Event,
853 event_queue: Arc<Mutex<VecDeque<virtio_snd_event>>>,
854 state: Arc<Mutex<ThreadFlags>>,
855 tx_socket: UnixSeqpacket,
856 rx_socket: UnixSeqpacket,
857 event_socket: UnixSeqpacket,
858) -> Result<()> {
859 let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[
860 (&tx_socket, Token::TxBufferMsg),
861 (&rx_socket, Token::RxBufferMsg),
862 (&event_socket, Token::EventMsg),
863 (&kill_event, Token::Notification),
864 ])
865 .map_err(Error::WaitContextCreateError)?;
866 let mut running = true;
867 while running {
868 let events = wait_ctx.wait().map_err(Error::WaitError)?;
869 for evt in events {
870 match evt.token {
871 Token::TxBufferMsg => recv_buffer_status_msg(&tx_socket, &tx_subscribers)?,
872 Token::RxBufferMsg => recv_buffer_status_msg(&rx_socket, &rx_subscribers)?,
873 Token::EventMsg => {
874 let evt = recv_event(&event_socket)?;
875 let state_cpy = *state.lock();
876 if state_cpy.reporting_events {
877 event_queue.lock().push_back(evt);
878 event_notifier.signal().map_err(Error::EventWriteError)?;
879 } }
881 Token::Notification => {
882 if let Err(e) = kill_event.wait() {
885 error!("Failed to consume notification from recv thread: {:?}", e);
886 }
887 running = false;
888 }
889 }
890 }
891 }
892 Ok(())
893}
894
895fn await_status(promise: Receiver<BufferReleaseMsg>) -> Result<(usize, u32)> {
896 let BufferReleaseMsg {
897 status,
898 latency,
899 consumed_len,
900 } = promise.recv().map_err(Error::BufferStatusSenderLost)?;
901 if status == VIRTIO_SND_S_OK {
902 Ok((consumed_len, latency))
903 } else {
904 Err(Error::IOBufferError(status))
905 }
906}
907
908struct IoBufferQueue {
909 socket: UnixSeqpacket,
910 file: File,
911 mmap: MemoryMapping,
912 size: usize,
913 next: Mutex<usize>,
914}
915
916impl IoBufferQueue {
917 fn new(socket: UnixSeqpacket, mut file: File) -> Result<IoBufferQueue> {
918 let size = file.seek(SeekFrom::End(0)).map_err(Error::FileSizeError)? as usize;
919
920 let mmap = MemoryMappingBuilder::new(size)
921 .from_file(&file)
922 .build()
923 .map_err(Error::ServerMmapError)?;
924
925 Ok(IoBufferQueue {
926 socket,
927 file,
928 mmap,
929 size,
930 next: Mutex::new(0),
931 })
932 }
933
934 fn allocate_buffer(&self, size: usize) -> Result<usize> {
935 if size > self.size {
936 return Err(Error::OutOfSpace);
937 }
938 let mut next_lock = self.next.lock();
939 let offset = if size > self.size - *next_lock {
940 0
942 } else {
943 *next_lock
944 };
945 *next_lock = offset + size;
946 Ok(offset)
947 }
948
949 fn buffer_at(&self, offset: usize, len: usize) -> Result<VolatileSlice> {
950 self.mmap
951 .get_slice(offset, len)
952 .map_err(Error::VolatileMemoryError)
953 }
954
955 fn try_clone_socket(&self) -> Result<UnixSeqpacket> {
956 self.socket
957 .try_clone()
958 .map_err(Error::UnixSeqpacketDupError)
959 }
960
961 fn send_buffer(&self, stream_id: u32, offset: usize, size: usize) -> Result<()> {
962 let msg = IoTransferMsg::new(stream_id, offset, size);
963 seq_socket_send(&self.socket, msg.as_bytes())
964 }
965
966 fn keep_rds(&self) -> Vec<RawDescriptor> {
967 vec![
968 self.file.as_raw_descriptor(),
969 self.socket.as_raw_descriptor(),
970 ]
971 }
972}
973
974pub struct VioSStreamParams {
976 pub buffer_bytes: u32,
977 pub period_bytes: u32,
978 pub features: u32,
979 pub channels: u8,
980 pub format: u8,
981 pub rate: u8,
982}
983
984impl From<(u32, VioSStreamParams)> for virtio_snd_pcm_set_params {
985 fn from(val: (u32, VioSStreamParams)) -> Self {
986 virtio_snd_pcm_set_params {
987 hdr: virtio_snd_pcm_hdr {
988 hdr: virtio_snd_hdr {
989 code: VIRTIO_SND_R_PCM_SET_PARAMS.into(),
990 },
991 stream_id: val.0.into(),
992 },
993 buffer_bytes: val.1.buffer_bytes.into(),
994 period_bytes: val.1.period_bytes.into(),
995 features: val.1.features.into(),
996 channels: val.1.channels,
997 format: val.1.format,
998 rate: val.1.rate,
999 padding: 0u8,
1000 }
1001 }
1002}
1003
1004fn send_cmd<T: Immutable + IntoBytes>(control_socket: &UnixSeqpacket, data: T) -> Result<()> {
1005 seq_socket_send(control_socket, data.as_bytes())?;
1006 recv_cmd_status(control_socket)
1007}
1008
1009fn recv_cmd_status(control_socket: &UnixSeqpacket) -> Result<()> {
1010 let mut status: virtio_snd_hdr = Default::default();
1011 control_socket
1012 .recv(status.as_mut_bytes())
1013 .map_err(Error::ServerIOError)?;
1014 if status.code.to_native() == VIRTIO_SND_S_OK {
1015 Ok(())
1016 } else {
1017 Err(Error::CommandFailed(status.code.to_native()))
1018 }
1019}
1020
1021fn seq_socket_send(socket: &UnixSeqpacket, data: &[u8]) -> Result<()> {
1022 loop {
1023 let send_res = socket.send(data);
1024 if let Err(e) = send_res {
1025 match e.kind() {
1026 IOErrorKind::Interrupted => continue,
1028 _ => return Err(Error::ServerIOError(e)),
1029 }
1030 }
1031 break;
1033 }
1034 Ok(())
1035}
1036
1037#[repr(C)]
1038#[derive(
1039 Copy,
1040 Clone,
1041 Default,
1042 FromBytes,
1043 Immutable,
1044 IntoBytes,
1045 KnownLayout,
1046 Serialize,
1047 Deserialize,
1048 PartialEq,
1049 Eq,
1050 Debug,
1051)]
1052struct VioSConfig {
1053 version: u32,
1055 jacks: u32,
1056 streams: u32,
1057 chmaps: u32,
1058
1059 controls: u32,
1061}
1062
1063const VIOS_SIZE_V2: usize = 4 * std::mem::size_of::<u32>();
1064const VIOS_SIZE_V3: usize = std::mem::size_of::<VioSConfig>();
1065
1066struct BufferReleaseMsg {
1067 status: u32,
1068 latency: u32,
1069 consumed_len: usize,
1070}
1071
1072#[repr(C)]
1073#[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
1074struct IoTransferMsg {
1075 io_xfer: virtio_snd_pcm_xfer,
1076 buffer_offset: u32,
1077 buffer_len: u32,
1078}
1079
1080impl IoTransferMsg {
1081 fn new(stream_id: u32, buffer_offset: usize, buffer_len: usize) -> IoTransferMsg {
1082 IoTransferMsg {
1083 io_xfer: virtio_snd_pcm_xfer {
1084 stream_id: stream_id.into(),
1085 },
1086 buffer_offset: buffer_offset as u32,
1087 buffer_len: buffer_len as u32,
1088 }
1089 }
1090}
1091
1092#[repr(C)]
1093#[derive(Copy, Clone, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
1094struct IoStatusMsg {
1095 status: virtio_snd_pcm_status,
1096 buffer_offset: u32,
1097 consumed_len: u32,
1098}