devices/virtio/vhost_user_backend/
snd.rs1pub mod sys;
6
7use std::borrow::Borrow;
8use std::rc::Rc;
9
10use anyhow::anyhow;
11use anyhow::bail;
12use anyhow::Context;
13use base::error;
14use base::warn;
15use cros_async::sync::RwLock as AsyncRwLock;
16use cros_async::EventAsync;
17use cros_async::Executor;
18use futures::channel::mpsc;
19use futures::FutureExt;
20use hypervisor::ProtectionType;
21use serde::Deserialize;
22use serde::Serialize;
23use snapshot::AnySnapshot;
24pub use sys::run_snd_device;
25pub use sys::Options;
26use vm_memory::GuestMemory;
27use vmm_vhost::message::VhostUserProtocolFeatures;
28use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
29use zerocopy::IntoBytes;
30
31use crate::virtio;
32use crate::virtio::copy_config;
33use crate::virtio::device_constants::snd::virtio_snd_config;
34use crate::virtio::snd::common_backend::async_funcs::handle_ctrl_queue;
35use crate::virtio::snd::common_backend::async_funcs::handle_pcm_queue;
36use crate::virtio::snd::common_backend::async_funcs::send_pcm_response_worker;
37use crate::virtio::snd::common_backend::create_stream_info_builders;
38use crate::virtio::snd::common_backend::hardcoded_snd_data;
39use crate::virtio::snd::common_backend::hardcoded_virtio_snd_config;
40use crate::virtio::snd::common_backend::stream_info::StreamInfo;
41use crate::virtio::snd::common_backend::stream_info::StreamInfoBuilder;
42use crate::virtio::snd::common_backend::stream_info::StreamInfoSnapshot;
43use crate::virtio::snd::common_backend::Error;
44use crate::virtio::snd::common_backend::PcmResponse;
45use crate::virtio::snd::common_backend::SndData;
46use crate::virtio::snd::common_backend::MAX_QUEUE_NUM;
47use crate::virtio::snd::constants::VIRTIO_SND_R_PCM_PREPARE;
48use crate::virtio::snd::constants::VIRTIO_SND_R_PCM_START;
49use crate::virtio::snd::parameters::Parameters;
50use crate::virtio::vhost_user_backend::handler::DeviceRequestHandler;
51use crate::virtio::vhost_user_backend::handler::Error as DeviceError;
52use crate::virtio::vhost_user_backend::handler::VhostUserDevice;
53use crate::virtio::vhost_user_backend::handler::WorkerState;
54use crate::virtio::vhost_user_backend::VhostUserDeviceBuilder;
55use crate::virtio::Queue;
56
57const PCM_RESPONSE_WORKER_IDX_OFFSET: usize = 2;
63struct SndBackend {
64 ex: Executor,
65 cfg: virtio_snd_config,
66 avail_features: u64,
67 workers: [Option<WorkerState<Rc<AsyncRwLock<Queue>>, Result<(), Error>>>; MAX_QUEUE_NUM],
68 response_workers: [Option<WorkerState<Rc<AsyncRwLock<Queue>>, Result<(), Error>>>; 2],
70 snd_data: Rc<SndData>,
71 streams: Rc<AsyncRwLock<Vec<AsyncRwLock<StreamInfo>>>>,
72 tx_send: mpsc::UnboundedSender<PcmResponse>,
73 rx_send: mpsc::UnboundedSender<PcmResponse>,
74 tx_recv: Option<mpsc::UnboundedReceiver<PcmResponse>>,
75 rx_recv: Option<mpsc::UnboundedReceiver<PcmResponse>>,
76 card_index: usize,
78}
79
80#[derive(Serialize, Deserialize)]
81struct SndBackendSnapshot {
82 avail_features: u64,
83 stream_infos: Option<Vec<StreamInfoSnapshot>>,
84 snd_data: SndData,
85}
86
87impl SndBackend {
88 pub fn new(
89 ex: &Executor,
90 params: Parameters,
91 #[cfg(windows)] audio_client_guid: Option<String>,
92 card_index: usize,
93 ) -> anyhow::Result<Self> {
94 let cfg = hardcoded_virtio_snd_config(¶ms);
95 let avail_features = virtio::base_features(ProtectionType::Unprotected)
96 | 1 << VHOST_USER_F_PROTOCOL_FEATURES;
97
98 let snd_data = hardcoded_snd_data(¶ms);
99 let mut keep_rds = Vec::new();
100 let builders = create_stream_info_builders(¶ms, &snd_data, &mut keep_rds, card_index)?;
101
102 if snd_data.pcm_info_len() != builders.len() {
103 error!(
104 "[Card {}] snd: expected {} stream info builders, got {}",
105 card_index,
106 snd_data.pcm_info_len(),
107 builders.len(),
108 )
109 }
110
111 let streams = builders.into_iter();
112
113 #[cfg(windows)]
114 let streams = streams
115 .map(|stream_builder| stream_builder.audio_client_guid(audio_client_guid.clone()));
116
117 let streams = streams
118 .map(StreamInfoBuilder::build)
119 .map(AsyncRwLock::new)
120 .collect();
121 let streams = Rc::new(AsyncRwLock::new(streams));
122
123 let (tx_send, tx_recv) = mpsc::unbounded();
124 let (rx_send, rx_recv) = mpsc::unbounded();
125
126 Ok(SndBackend {
127 ex: ex.clone(),
128 cfg,
129 avail_features,
130 workers: Default::default(),
131 response_workers: Default::default(),
132 snd_data: Rc::new(snd_data),
133 streams,
134 tx_send,
135 rx_send,
136 tx_recv: Some(tx_recv),
137 rx_recv: Some(rx_recv),
138 card_index,
139 })
140 }
141}
142
143impl VhostUserDeviceBuilder for SndBackend {
144 fn build(self: Box<Self>, _ex: &Executor) -> anyhow::Result<Box<dyn vmm_vhost::Backend>> {
145 let handler = DeviceRequestHandler::new(*self);
146 Ok(Box::new(handler))
147 }
148}
149
150impl VhostUserDevice for SndBackend {
151 fn max_queue_num(&self) -> usize {
152 MAX_QUEUE_NUM
153 }
154
155 fn features(&self) -> u64 {
156 self.avail_features
157 }
158
159 fn protocol_features(&self) -> VhostUserProtocolFeatures {
160 VhostUserProtocolFeatures::CONFIG
161 | VhostUserProtocolFeatures::MQ
162 | VhostUserProtocolFeatures::DEVICE_STATE
163 }
164
165 fn read_config(&self, offset: u64, data: &mut [u8]) {
166 copy_config(data, 0, self.cfg.as_bytes(), offset)
167 }
168
169 fn reset(&mut self) {
170 for worker in self.workers.iter_mut().filter_map(Option::take) {
171 let _ = self.ex.run_until(worker.queue_task.cancel());
172 }
173 }
174
175 fn start_queue(
176 &mut self,
177 idx: usize,
178 queue: virtio::Queue,
179 _mem: GuestMemory,
180 ) -> anyhow::Result<()> {
181 if self.workers[idx].is_some() {
182 warn!(
183 "[Card {}] Starting new queue handler without stopping old handler",
184 self.card_index
185 );
186 self.stop_queue(idx)?;
187 }
188
189 let kick_evt = queue
190 .event()
191 .try_clone()
192 .with_context(|| format!("[Card {}] failed to clone queue event", self.card_index))?;
193 let mut kick_evt = EventAsync::new(kick_evt, &self.ex).with_context(|| {
194 format!(
195 "[Card {}] failed to create EventAsync for kick_evt",
196 self.card_index
197 )
198 })?;
199 let queue = Rc::new(AsyncRwLock::new(queue));
200 let card_index = self.card_index;
201 let queue_task = match idx {
202 0 => {
203 let streams = self.streams.clone();
205 let snd_data = self.snd_data.clone();
206 let tx_send = self.tx_send.clone();
207 let rx_send = self.rx_send.clone();
208 let ctrl_queue = queue.clone();
209
210 let ex_clone = self.ex.clone();
211 Some(self.ex.spawn_local(async move {
212 handle_ctrl_queue(
213 &ex_clone,
214 &streams,
215 &snd_data,
216 ctrl_queue,
217 &mut kick_evt,
218 tx_send,
219 rx_send,
220 card_index,
221 None,
222 )
223 .await
224 }))
225 }
226 1 => Some(self.ex.spawn_local(async move { Ok(()) })),
233 2 | 3 => {
234 let (send, recv) = if idx == 2 {
235 (self.tx_send.clone(), self.tx_recv.take())
236 } else {
237 (self.rx_send.clone(), self.rx_recv.take())
238 };
239 let mut recv = recv.ok_or_else(|| {
240 anyhow!("[Card {}] queue restart is not supported", self.card_index)
241 })?;
242 let streams = Rc::clone(&self.streams);
243 let queue_pcm_queue = queue.clone();
244 let queue_task = self.ex.spawn_local(async move {
245 handle_pcm_queue(&streams, send, queue_pcm_queue, &kick_evt, card_index, None)
246 .await
247 });
248
249 let queue_response_queue = queue.clone();
250 let response_queue_task = self.ex.spawn_local(async move {
251 send_pcm_response_worker(queue_response_queue, &mut recv, None).await
252 });
253
254 self.response_workers[idx - PCM_RESPONSE_WORKER_IDX_OFFSET] = Some(WorkerState {
255 queue_task: response_queue_task,
256 queue: queue.clone(),
257 });
258
259 Some(queue_task)
260 }
261 _ => bail!(
262 "[Card {}] attempted to start unknown queue: {}",
263 self.card_index,
264 idx
265 ),
266 };
267
268 if let Some(queue_task) = queue_task {
269 self.workers[idx] = Some(WorkerState { queue_task, queue });
270 }
271 Ok(())
272 }
273
274 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<virtio::Queue> {
275 let worker_queue_rc = self
276 .workers
277 .get_mut(idx)
278 .and_then(Option::take)
279 .map(|worker| {
280 let _ = self.ex.run_until(worker.queue_task.cancel());
282 worker.queue
283 });
284
285 if idx == 2 || idx == 3 {
286 if let Some(worker) = self
287 .response_workers
288 .get_mut(idx - PCM_RESPONSE_WORKER_IDX_OFFSET)
289 .and_then(Option::take)
290 {
291 let _ = self.ex.run_until(worker.queue_task.cancel());
293 }
294 }
295
296 if let Some(queue_rc) = worker_queue_rc {
297 match Rc::try_unwrap(queue_rc) {
298 Ok(queue_mutex) => Ok(queue_mutex.into_inner()),
299 Err(_) => panic!(
300 "[Card {}] failed to recover queue from worker",
301 self.card_index
302 ),
303 }
304 } else {
305 Err(anyhow::Error::new(DeviceError::WorkerNotFound))
306 }
307 }
308
309 fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
310 let stream_info_snaps = if let Some(stream_infos) = &self.streams.lock().now_or_never() {
312 let mut snaps = Vec::new();
313 for stream_info in stream_infos.iter() {
314 snaps.push(
315 stream_info
316 .lock()
317 .now_or_never()
318 .unwrap_or_else(|| {
319 panic!(
320 "[Card {}] failed to lock audio state during snapshot",
321 self.card_index
322 )
323 })
324 .snapshot(),
325 );
326 }
327 Some(snaps)
328 } else {
329 None
330 };
331 let snd_data_ref: &SndData = self.snd_data.borrow();
332 AnySnapshot::to_any(SndBackendSnapshot {
333 avail_features: self.avail_features,
334 stream_infos: stream_info_snaps,
335 snd_data: snd_data_ref.clone(),
336 })
337 .with_context(|| {
338 format!(
339 "[Card {}] Failed to serialize SndBackendSnapshot",
340 self.card_index
341 )
342 })
343 }
344
345 fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
346 let deser: SndBackendSnapshot = AnySnapshot::from_any(data).with_context(|| {
347 format!(
348 "[Card {}] Failed to deserialize SndBackendSnapshot",
349 self.card_index
350 )
351 })?;
352 anyhow::ensure!(
353 deser.avail_features == self.avail_features,
354 "[Card {}] avail features doesn't match on restore: expected: {}, got: {}",
355 self.card_index,
356 deser.avail_features,
357 self.avail_features
358 );
359 let snd_data = self.snd_data.borrow();
360 anyhow::ensure!(
361 &deser.snd_data == snd_data,
362 "[Card {}] snd data doesn't match on restore: expected: {:?}, got: {:?}",
363 self.card_index,
364 deser.snd_data,
365 snd_data,
366 );
367
368 let ex_clone = self.ex.clone();
369 let streams_rc = self.streams.clone();
370 let tx_send_clone = self.tx_send.clone();
371 let rx_send_clone = self.rx_send.clone();
372
373 let card_index = self.card_index;
374 let restore_task =
375 self.ex.spawn_local(async move {
376 if let Some(stream_infos) = &deser.stream_infos {
377 for (stream, stream_info) in
378 streams_rc.lock().await.iter().zip(stream_infos.iter())
379 {
380 stream.lock().await.restore(stream_info);
381 if stream_info.state == VIRTIO_SND_R_PCM_START
382 || stream_info.state == VIRTIO_SND_R_PCM_PREPARE
383 {
384 stream
385 .lock()
386 .await
387 .prepare(&ex_clone, &tx_send_clone, &rx_send_clone)
388 .await
389 .unwrap_or_else(|_| {
390 panic!("[Card {card_index}] failed to prepare PCM")
391 });
392 }
393 if stream_info.state == VIRTIO_SND_R_PCM_START {
394 stream.lock().await.start().await.unwrap_or_else(|_| {
395 panic!("[Card {card_index}] failed to start PCM")
396 });
397 }
398 }
399 }
400 });
401 self.ex
402 .run_until(restore_task)
403 .unwrap_or_else(|_| panic!("[Card {}] failed to restore streams", self.card_index));
404 Ok(())
405 }
406
407 fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
408 Ok(())
410 }
411}