devices/virtio/vhost_user_backend/
snd.rs1pub mod sys;
6
7use std::borrow::Borrow;
8use std::rc::Rc;
9
10use anyhow::anyhow;
11use anyhow::bail;
12use anyhow::Context;
13use base::error;
14use base::warn;
15use cros_async::sync::RwLock as AsyncRwLock;
16use cros_async::EventAsync;
17use cros_async::Executor;
18use futures::channel::mpsc;
19use futures::FutureExt;
20use hypervisor::ProtectionType;
21use serde::Deserialize;
22use serde::Serialize;
23use snapshot::AnySnapshot;
24pub use sys::run_snd_device;
25pub use sys::Options;
26use vm_memory::GuestMemory;
27use vmm_vhost::message::VhostUserProtocolFeatures;
28use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
29use zerocopy::IntoBytes;
30
31use crate::virtio;
32use crate::virtio::copy_config;
33use crate::virtio::device_constants::snd::virtio_snd_config;
34use crate::virtio::snd::common_backend::async_funcs::handle_ctrl_queue;
35use crate::virtio::snd::common_backend::async_funcs::handle_pcm_queue;
36use crate::virtio::snd::common_backend::async_funcs::send_pcm_response_worker;
37use crate::virtio::snd::common_backend::create_stream_info_builders;
38use crate::virtio::snd::common_backend::hardcoded_snd_data;
39use crate::virtio::snd::common_backend::hardcoded_virtio_snd_config;
40use crate::virtio::snd::common_backend::stream_info::StreamInfo;
41use crate::virtio::snd::common_backend::stream_info::StreamInfoBuilder;
42use crate::virtio::snd::common_backend::stream_info::StreamInfoSnapshot;
43use crate::virtio::snd::common_backend::Error;
44use crate::virtio::snd::common_backend::PcmResponse;
45use crate::virtio::snd::common_backend::SndData;
46use crate::virtio::snd::common_backend::MAX_QUEUE_NUM;
47use crate::virtio::snd::constants::VIRTIO_SND_R_PCM_PREPARE;
48use crate::virtio::snd::constants::VIRTIO_SND_R_PCM_START;
49use crate::virtio::snd::parameters::Parameters;
50use crate::virtio::vhost_user_backend::handler::DeviceRequestHandler;
51use crate::virtio::vhost_user_backend::handler::Error as DeviceError;
52use crate::virtio::vhost_user_backend::handler::VhostUserDevice;
53use crate::virtio::vhost_user_backend::handler::WorkerState;
54use crate::virtio::vhost_user_backend::VhostUserDeviceBuilder;
55use crate::virtio::Queue;
56
57const PCM_RESPONSE_WORKER_IDX_OFFSET: usize = 2;
63struct SndBackend {
64 ex: Executor,
65 cfg: virtio_snd_config,
66 avail_features: u64,
67 workers: [Option<WorkerState<Rc<AsyncRwLock<Queue>>, Result<(), Error>>>; MAX_QUEUE_NUM],
68 response_workers: [Option<WorkerState<Rc<AsyncRwLock<Queue>>, Result<(), Error>>>; 2],
70 snd_data: Rc<SndData>,
71 streams: Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
72 tx_send: mpsc::UnboundedSender<PcmResponse>,
73 rx_send: mpsc::UnboundedSender<PcmResponse>,
74 tx_recv: Option<mpsc::UnboundedReceiver<PcmResponse>>,
75 rx_recv: Option<mpsc::UnboundedReceiver<PcmResponse>>,
76 card_index: usize,
78 unmap_guest_memory_on_fork: bool,
79}
80
81#[derive(Serialize, Deserialize)]
82struct SndBackendSnapshot {
83 avail_features: u64,
84 stream_infos: Option<Vec<StreamInfoSnapshot>>,
85 snd_data: SndData,
86}
87
88impl SndBackend {
89 pub fn new(
90 ex: &Executor,
91 params: Parameters,
92 #[cfg(windows)] audio_client_guid: Option<String>,
93 card_index: usize,
94 ) -> anyhow::Result<Self> {
95 let cfg = hardcoded_virtio_snd_config(¶ms);
96 let avail_features = virtio::base_features(ProtectionType::Unprotected)
97 | 1 << VHOST_USER_F_PROTOCOL_FEATURES;
98
99 let snd_data = hardcoded_snd_data(¶ms);
100 let mut keep_rds = Vec::new();
101 let builders = create_stream_info_builders(¶ms, &snd_data, &mut keep_rds, card_index)?;
102
103 if snd_data.pcm_info_len() != builders.len() {
104 error!(
105 "[Card {}] snd: expected {} stream info builders, got {}",
106 card_index,
107 snd_data.pcm_info_len(),
108 builders.len(),
109 )
110 }
111
112 let streams = builders.into_iter();
113
114 #[cfg(windows)]
115 let streams = streams
116 .map(|stream_builder| stream_builder.audio_client_guid(audio_client_guid.clone()));
117
118 let streams = streams
119 .map(StreamInfoBuilder::build)
120 .map(AsyncRwLock::new)
121 .collect();
122 let streams = Rc::new(AsyncRwLock::new(streams));
123
124 let (tx_send, tx_recv) = mpsc::unbounded();
125 let (rx_send, rx_recv) = mpsc::unbounded();
126
127 #[cfg(any(target_os = "android", target_os = "linux"))]
128 let unmap_guest_memory_on_fork = params.unmap_guest_memory_on_fork;
129 #[cfg(not(any(target_os = "android", target_os = "linux")))]
130 let unmap_guest_memory_on_fork = false;
131
132 Ok(SndBackend {
133 ex: ex.clone(),
134 cfg,
135 avail_features,
136 workers: Default::default(),
137 response_workers: Default::default(),
138 snd_data: Rc::new(snd_data),
139 streams,
140 tx_send,
141 rx_send,
142 tx_recv: Some(tx_recv),
143 rx_recv: Some(rx_recv),
144 card_index,
145 unmap_guest_memory_on_fork,
146 })
147 }
148}
149
150impl VhostUserDeviceBuilder for SndBackend {
151 fn build(self: Box<Self>, _ex: &Executor) -> anyhow::Result<Box<dyn vmm_vhost::Backend>> {
152 let handler = DeviceRequestHandler::new(*self);
153 Ok(Box::new(handler))
154 }
155}
156
157impl VhostUserDevice for SndBackend {
158 fn max_queue_num(&self) -> usize {
159 MAX_QUEUE_NUM
160 }
161
162 fn features(&self) -> u64 {
163 self.avail_features
164 }
165
166 fn protocol_features(&self) -> VhostUserProtocolFeatures {
167 VhostUserProtocolFeatures::CONFIG
168 | VhostUserProtocolFeatures::MQ
169 | VhostUserProtocolFeatures::DEVICE_STATE
170 }
171
172 fn unmap_guest_memory_on_fork(&self) -> bool {
173 self.unmap_guest_memory_on_fork
174 }
175
176 fn read_config(&self, offset: u64, data: &mut [u8]) {
177 copy_config(data, 0, self.cfg.as_bytes(), offset)
178 }
179
180 fn reset(&mut self) {
181 for worker in self.workers.iter_mut().filter_map(Option::take) {
182 let _ = self.ex.run_until(worker.queue_task.cancel());
183 }
184 }
185
186 fn start_queue(
187 &mut self,
188 idx: usize,
189 queue: virtio::Queue,
190 _mem: GuestMemory,
191 ) -> anyhow::Result<()> {
192 if self.workers[idx].is_some() {
193 warn!(
194 "[Card {}] Starting new queue handler without stopping old handler",
195 self.card_index
196 );
197 self.stop_queue(idx)?;
198 }
199
200 let kick_evt = queue
201 .event()
202 .try_clone()
203 .with_context(|| format!("[Card {}] failed to clone queue event", self.card_index))?;
204 let mut kick_evt = EventAsync::new(kick_evt, &self.ex).with_context(|| {
205 format!(
206 "[Card {}] failed to create EventAsync for kick_evt",
207 self.card_index
208 )
209 })?;
210 let queue = Rc::new(AsyncRwLock::new(queue));
211 let card_index = self.card_index;
212 let queue_task = match idx {
213 0 => {
214 let streams = self.streams.clone();
216 let snd_data = self.snd_data.clone();
217 let tx_send = self.tx_send.clone();
218 let rx_send = self.rx_send.clone();
219 let ctrl_queue = queue.clone();
220
221 let ex_clone = self.ex.clone();
222 Some(self.ex.spawn_local(async move {
223 handle_ctrl_queue(
224 &ex_clone,
225 &streams,
226 &snd_data,
227 ctrl_queue,
228 &mut kick_evt,
229 tx_send,
230 rx_send,
231 card_index,
232 None,
233 )
234 .await
235 }))
236 }
237 1 => Some(self.ex.spawn_local(async move { Ok(()) })),
244 2 | 3 => {
245 let (send, recv) = if idx == 2 {
246 (self.tx_send.clone(), self.tx_recv.take())
247 } else {
248 (self.rx_send.clone(), self.rx_recv.take())
249 };
250 let mut recv = recv.ok_or_else(|| {
251 anyhow!("[Card {}] queue restart is not supported", self.card_index)
252 })?;
253 let streams = Rc::clone(&self.streams);
254 let queue_pcm_queue = queue.clone();
255 let queue_task = self.ex.spawn_local(async move {
256 handle_pcm_queue(&streams, send, queue_pcm_queue, &kick_evt, card_index, None)
257 .await
258 });
259
260 let queue_response_queue = queue.clone();
261 let response_queue_task = self.ex.spawn_local(async move {
262 send_pcm_response_worker(queue_response_queue, &mut recv, None).await
263 });
264
265 self.response_workers[idx - PCM_RESPONSE_WORKER_IDX_OFFSET] = Some(WorkerState {
266 queue_task: response_queue_task,
267 queue: queue.clone(),
268 });
269
270 Some(queue_task)
271 }
272 _ => bail!(
273 "[Card {}] attempted to start unknown queue: {}",
274 self.card_index,
275 idx
276 ),
277 };
278
279 if let Some(queue_task) = queue_task {
280 self.workers[idx] = Some(WorkerState { queue_task, queue });
281 }
282 Ok(())
283 }
284
285 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<virtio::Queue> {
286 let worker_queue_rc = self
287 .workers
288 .get_mut(idx)
289 .and_then(Option::take)
290 .map(|worker| {
291 let _ = self.ex.run_until(worker.queue_task.cancel());
293 worker.queue
294 });
295
296 if idx == 2 || idx == 3 {
297 if let Some(worker) = self
298 .response_workers
299 .get_mut(idx - PCM_RESPONSE_WORKER_IDX_OFFSET)
300 .and_then(Option::take)
301 {
302 let _ = self.ex.run_until(worker.queue_task.cancel());
304 }
305 }
306
307 if let Some(queue_rc) = worker_queue_rc {
308 match Rc::try_unwrap(queue_rc) {
309 Ok(queue_mutex) => Ok(queue_mutex.into_inner()),
310 Err(_) => panic!(
311 "[Card {}] failed to recover queue from worker",
312 self.card_index
313 ),
314 }
315 } else {
316 Err(anyhow::Error::new(DeviceError::WorkerNotFound))
317 }
318 }
319
320 fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
321 let stream_info_snaps = if let Some(stream_infos) = &self.streams.lock().now_or_never() {
323 let mut snaps = Vec::new();
324 for stream_info in stream_infos.iter() {
325 snaps.push(
326 stream_info
327 .lock()
328 .now_or_never()
329 .unwrap_or_else(|| {
330 panic!(
331 "[Card {}] failed to lock audio state during snapshot",
332 self.card_index
333 )
334 })
335 .snapshot(),
336 );
337 }
338 Some(snaps)
339 } else {
340 None
341 };
342 let snd_data_ref: &SndData = self.snd_data.borrow();
343 AnySnapshot::to_any(SndBackendSnapshot {
344 avail_features: self.avail_features,
345 stream_infos: stream_info_snaps,
346 snd_data: snd_data_ref.clone(),
347 })
348 .with_context(|| {
349 format!(
350 "[Card {}] Failed to serialize SndBackendSnapshot",
351 self.card_index
352 )
353 })
354 }
355
356 fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
357 let deser: SndBackendSnapshot = AnySnapshot::from_any(data).with_context(|| {
358 format!(
359 "[Card {}] Failed to deserialize SndBackendSnapshot",
360 self.card_index
361 )
362 })?;
363 anyhow::ensure!(
364 deser.avail_features == self.avail_features,
365 "[Card {}] avail features doesn't match on restore: expected: {}, got: {}",
366 self.card_index,
367 deser.avail_features,
368 self.avail_features
369 );
370 let snd_data = self.snd_data.borrow();
371 anyhow::ensure!(
372 &deser.snd_data == snd_data,
373 "[Card {}] snd data doesn't match on restore: expected: {:?}, got: {:?}",
374 self.card_index,
375 deser.snd_data,
376 snd_data,
377 );
378
379 let ex_clone = self.ex.clone();
380 let streams_rc = self.streams.clone();
381 let tx_send_clone = self.tx_send.clone();
382 let rx_send_clone = self.rx_send.clone();
383
384 let card_index = self.card_index;
385 let restore_task =
386 self.ex.spawn_local(async move {
387 if let Some(stream_infos) = &deser.stream_infos {
388 for (stream, stream_info) in
389 streams_rc.lock().await.iter().zip(stream_infos.iter())
390 {
391 stream.lock().await.restore(stream_info);
392 if stream_info.state == VIRTIO_SND_R_PCM_START
393 || stream_info.state == VIRTIO_SND_R_PCM_PREPARE
394 {
395 stream
396 .lock()
397 .await
398 .prepare(&ex_clone, &tx_send_clone, &rx_send_clone)
399 .await
400 .unwrap_or_else(|_| {
401 panic!("[Card {card_index}] failed to prepare PCM")
402 });
403 }
404 if stream_info.state == VIRTIO_SND_R_PCM_START {
405 stream.lock().await.start().await.unwrap_or_else(|_| {
406 panic!("[Card {card_index}] failed to start PCM")
407 });
408 }
409 }
410 }
411 });
412 self.ex
413 .run_until(restore_task)
414 .unwrap_or_else(|_| panic!("[Card {}] failed to restore streams", self.card_index));
415 Ok(())
416 }
417
418 fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
419 Ok(())
421 }
422}