devices/virtio/snd/vios_backend/
worker.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::io::Read;
6use std::sync::mpsc::Sender;
7use std::sync::Arc;
8use std::thread;
9
10use base::error;
11use base::warn;
12use base::Event;
13use base::EventToken;
14use base::WaitContext;
15use data_model::Le32;
16use sync::Mutex;
17use zerocopy::Immutable;
18use zerocopy::IntoBytes;
19
20use super::super::constants::*;
21use super::super::layout::*;
22use super::streams::*;
23use super::Result;
24use super::SoundError;
25use super::*;
26use crate::virtio::DescriptorChain;
27use crate::virtio::Queue;
28
29pub struct Worker {
30    // Lock order: Must never hold more than one queue lock at the same time.
31    pub control_queue: Arc<Mutex<Queue>>,
32    pub event_queue: Option<Queue>,
33    vios_client: Arc<Mutex<VioSClient>>,
34    streams: Vec<StreamProxy>,
35    pub tx_queue: Arc<Mutex<Queue>>,
36    pub rx_queue: Arc<Mutex<Queue>>,
37    io_thread: Option<thread::JoinHandle<Result<()>>>,
38    io_kill: Event,
39    // saved_stream_state holds the previous state of streams. When the sound device is newly
40    // created, this will be empty. It will only contain state if the sound device is put to sleep
41    // OR if we restore a VM.
42    pub saved_stream_state: Vec<StreamSnapshot>,
43}
44
45impl Worker {
46    /// Creates a new virtio-snd worker.
47    pub fn try_new(
48        vios_client: Arc<Mutex<VioSClient>>,
49        control_queue: Arc<Mutex<Queue>>,
50        event_queue: Queue,
51        tx_queue: Arc<Mutex<Queue>>,
52        rx_queue: Arc<Mutex<Queue>>,
53        saved_stream_state: Vec<StreamSnapshot>,
54    ) -> Result<Worker> {
55        let num_streams = vios_client.lock().num_streams();
56        let mut streams: Vec<StreamProxy> = Vec::with_capacity(num_streams as usize);
57        {
58            for stream_id in 0..num_streams {
59                let capture = vios_client
60                    .lock()
61                    .stream_info(stream_id)
62                    .map(|i| i.direction == VIRTIO_SND_D_INPUT)
63                    .unwrap_or(false);
64                let io_queue = if capture { &rx_queue } else { &tx_queue };
65                streams.push(Stream::try_new(
66                    stream_id,
67                    vios_client.clone(),
68                    control_queue.clone(),
69                    io_queue.clone(),
70                    capture,
71                    saved_stream_state.get(stream_id as usize).cloned(),
72                )?);
73            }
74        }
75        let (self_kill_io, kill_io) = Event::new()
76            .and_then(|e| Ok((e.try_clone()?, e)))
77            .map_err(SoundError::CreateEvent)?;
78
79        let senders: Vec<Sender<Box<StreamMsg>>> =
80            streams.iter().map(|sp| sp.msg_sender().clone()).collect();
81        let tx_queue_thread = tx_queue.clone();
82        let rx_queue_thread = rx_queue.clone();
83        let io_thread = thread::Builder::new()
84            .name("v_snd_io".to_string())
85            .spawn(move || {
86                try_set_real_time_priority();
87
88                io_loop(tx_queue_thread, rx_queue_thread, senders, kill_io)
89            })
90            .map_err(SoundError::CreateThread)?;
91        Ok(Worker {
92            control_queue,
93            event_queue: Some(event_queue),
94            vios_client,
95            streams,
96            tx_queue,
97            rx_queue,
98            io_thread: Some(io_thread),
99            io_kill: self_kill_io,
100            saved_stream_state: Vec::new(),
101        })
102    }
103
104    /// Emulates the virtio-snd device. It won't return until something is written to the kill_evt
105    /// event or an unrecoverable error occurs.
106    pub fn control_loop(&mut self, kill_evt: Event) -> Result<()> {
107        let event_notifier = self
108            .vios_client
109            .lock()
110            .get_event_notifier()
111            .map_err(SoundError::ClientEventNotifier)?;
112        #[derive(EventToken)]
113        enum Token {
114            ControlQAvailable,
115            EventQAvailable,
116            EventTriggered,
117            Kill,
118        }
119        let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[
120            (self.control_queue.lock().event(), Token::ControlQAvailable),
121            (
122                self.event_queue.as_ref().expect("queue missing").event(),
123                Token::EventQAvailable,
124            ),
125            (&event_notifier, Token::EventTriggered),
126            (&kill_evt, Token::Kill),
127        ])
128        .map_err(SoundError::WaitCtx)?;
129
130        let mut event_queue = self.event_queue.take().expect("event_queue missing");
131        'wait: loop {
132            let wait_events = wait_ctx.wait().map_err(SoundError::WaitCtx)?;
133
134            for wait_evt in wait_events.iter().filter(|e| e.is_readable) {
135                match wait_evt.token {
136                    Token::ControlQAvailable => {
137                        self.control_queue
138                            .lock()
139                            .event()
140                            .wait()
141                            .map_err(SoundError::QueueEvt)?;
142                        self.process_controlq_buffers()?;
143                    }
144                    Token::EventQAvailable => {
145                        // Just read from the event object to make sure the producer of such events
146                        // never blocks. The buffers will only be used when actual virtio-snd
147                        // events are triggered.
148                        event_queue.event().wait().map_err(SoundError::QueueEvt)?;
149                    }
150                    Token::EventTriggered => {
151                        event_notifier.wait().map_err(SoundError::QueueEvt)?;
152                        self.process_event_triggered(&mut event_queue)?;
153                    }
154                    Token::Kill => {
155                        let _ = kill_evt.wait();
156                        break 'wait;
157                    }
158                }
159            }
160        }
161        self.saved_stream_state = self
162            .streams
163            .drain(..)
164            .map(|stream| stream.stop_thread())
165            .collect();
166        self.event_queue = Some(event_queue);
167        Ok(())
168    }
169
170    fn stop_io_thread(&mut self) {
171        if let Err(e) = self.io_kill.signal() {
172            error!(
173                "virtio-snd: Failed to send Break msg to stream thread: {}",
174                e
175            );
176        }
177        if let Some(th) = self.io_thread.take() {
178            match th.join() {
179                Err(e) => {
180                    error!("virtio-snd: Panic detected on stream thread: {:?}", e);
181                }
182                Ok(r) => {
183                    if let Err(e) = r {
184                        error!("virtio-snd: IO thread exited with and error: {}", e);
185                    }
186                }
187            }
188        }
189    }
190
191    // Pops and handles all available ontrol queue buffers. Logs minor errors, but returns an
192    // Err if it encounters an unrecoverable error.
193    fn process_controlq_buffers(&mut self) -> Result<()> {
194        while let Some(mut avail_desc) = lock_pop_unlock(&self.control_queue) {
195            let reader = &mut avail_desc.reader;
196            let available_bytes = reader.available_bytes();
197            let Ok(hdr) = reader.peek_obj::<virtio_snd_hdr>() else {
198                error!(
199                    "virtio-snd: Message received on control queue is too small: {}",
200                    available_bytes
201                );
202                return reply_control_op_status(
203                    VIRTIO_SND_S_BAD_MSG,
204                    avail_desc,
205                    &self.control_queue,
206                );
207            };
208            let mut read_buf = vec![0u8; available_bytes];
209            reader
210                .read_exact(&mut read_buf)
211                .map_err(SoundError::QueueIO)?;
212            let request_type = hdr.code.to_native();
213            match request_type {
214                VIRTIO_SND_R_JACK_INFO => {
215                    let (code, info_vec) = {
216                        match self.parse_info_query(&read_buf) {
217                            None => (VIRTIO_SND_S_BAD_MSG, Vec::new()),
218                            Some((start_id, count)) => {
219                                let end_id = start_id.saturating_add(count);
220                                if end_id > self.vios_client.lock().num_jacks() {
221                                    error!(
222                                        "virtio-snd: Requested info on invalid jacks ids: {}..{}",
223                                        start_id,
224                                        end_id - 1
225                                    );
226                                    (VIRTIO_SND_S_NOT_SUPP, Vec::new())
227                                } else {
228                                    (
229                                        VIRTIO_SND_S_OK,
230                                        // Safe to unwrap because we just ensured all the ids are
231                                        // valid
232                                        (start_id..end_id)
233                                            .map(|id| {
234                                                self.vios_client.lock().jack_info(id).unwrap()
235                                            })
236                                            .collect(),
237                                    )
238                                }
239                            }
240                        }
241                    };
242                    self.send_info_reply(avail_desc, code, info_vec)?;
243                }
244                VIRTIO_SND_R_JACK_REMAP => {
245                    let code = if read_buf.len() != std::mem::size_of::<virtio_snd_jack_remap>() {
246                        error!(
247                        "virtio-snd: The driver sent the wrong number bytes for a jack_remap struct: {}",
248                        read_buf.len()
249                        );
250                        VIRTIO_SND_S_BAD_MSG
251                    } else {
252                        let mut request: virtio_snd_jack_remap = Default::default();
253                        request.as_mut_bytes().copy_from_slice(&read_buf);
254                        let jack_id = request.hdr.jack_id.to_native();
255                        let association = request.association.to_native();
256                        let sequence = request.sequence.to_native();
257                        if let Err(e) =
258                            self.vios_client
259                                .lock()
260                                .remap_jack(jack_id, association, sequence)
261                        {
262                            error!("virtio-snd: Failed to remap jack: {}", e);
263                            vios_error_to_status_code(e)
264                        } else {
265                            VIRTIO_SND_S_OK
266                        }
267                    };
268                    let writer = &mut avail_desc.writer;
269                    writer
270                        .write_obj(virtio_snd_hdr {
271                            code: Le32::from(code),
272                        })
273                        .map_err(SoundError::QueueIO)?;
274                    {
275                        let mut queue_lock = self.control_queue.lock();
276                        queue_lock.add_used(avail_desc);
277                        queue_lock.trigger_interrupt();
278                    }
279                }
280                VIRTIO_SND_R_CHMAP_INFO => {
281                    let (code, info_vec) = {
282                        match self.parse_info_query(&read_buf) {
283                            None => (VIRTIO_SND_S_BAD_MSG, Vec::new()),
284                            Some((start_id, count)) => {
285                                let end_id = start_id.saturating_add(count);
286                                let num_chmaps = self.vios_client.lock().num_chmaps();
287                                if end_id > num_chmaps {
288                                    error!(
289                                        "virtio-snd: Requested info on invalid chmaps ids: {}..{}",
290                                        start_id,
291                                        end_id - 1
292                                    );
293                                    (VIRTIO_SND_S_NOT_SUPP, Vec::new())
294                                } else {
295                                    (
296                                        VIRTIO_SND_S_OK,
297                                        // Safe to unwrap because we just ensured all the ids are
298                                        // valid
299                                        (start_id..end_id)
300                                            .map(|id| {
301                                                self.vios_client.lock().chmap_info(id).unwrap()
302                                            })
303                                            .collect(),
304                                    )
305                                }
306                            }
307                        }
308                    };
309                    self.send_info_reply(avail_desc, code, info_vec)?;
310                }
311                VIRTIO_SND_R_PCM_INFO => {
312                    let (code, info_vec) = {
313                        match self.parse_info_query(&read_buf) {
314                            None => (VIRTIO_SND_S_BAD_MSG, Vec::new()),
315                            Some((start_id, count)) => {
316                                let end_id = start_id.saturating_add(count);
317                                if end_id > self.vios_client.lock().num_streams() {
318                                    error!(
319                                        "virtio-snd: Requested info on invalid stream ids: {}..{}",
320                                        start_id,
321                                        end_id - 1
322                                    );
323                                    (VIRTIO_SND_S_NOT_SUPP, Vec::new())
324                                } else {
325                                    (
326                                        VIRTIO_SND_S_OK,
327                                        // Safe to unwrap because we just ensured all the ids are
328                                        // valid
329                                        (start_id..end_id)
330                                            .map(|id| {
331                                                self.vios_client.lock().stream_info(id).unwrap()
332                                            })
333                                            .collect(),
334                                    )
335                                }
336                            }
337                        }
338                    };
339                    self.send_info_reply(avail_desc, code, info_vec)?;
340                }
341                VIRTIO_SND_R_PCM_SET_PARAMS => self.process_set_params(avail_desc, &read_buf)?,
342                VIRTIO_SND_R_PCM_PREPARE => {
343                    self.try_parse_pcm_hdr_and_send_msg(&read_buf, StreamMsg::Prepare(avail_desc))?
344                }
345                VIRTIO_SND_R_PCM_RELEASE => {
346                    self.try_parse_pcm_hdr_and_send_msg(&read_buf, StreamMsg::Release(avail_desc))?
347                }
348                VIRTIO_SND_R_PCM_START => {
349                    self.try_parse_pcm_hdr_and_send_msg(&read_buf, StreamMsg::Start(avail_desc))?
350                }
351                VIRTIO_SND_R_PCM_STOP => {
352                    self.try_parse_pcm_hdr_and_send_msg(&read_buf, StreamMsg::Stop(avail_desc))?
353                }
354                _ => {
355                    error!(
356                        "virtio-snd: Unknown control queue mesage code: {}",
357                        request_type
358                    );
359                    reply_control_op_status(
360                        VIRTIO_SND_S_NOT_SUPP,
361                        avail_desc,
362                        &self.control_queue,
363                    )?;
364                }
365            }
366        }
367        Ok(())
368    }
369
370    fn process_event_triggered(&mut self, event_queue: &mut Queue) -> Result<()> {
371        while let Some(evt) = self.vios_client.lock().pop_event() {
372            if let Some(mut desc) = event_queue.pop() {
373                let writer = &mut desc.writer;
374                writer.write_obj(evt).map_err(SoundError::QueueIO)?;
375                event_queue.add_used(desc);
376                event_queue.trigger_interrupt();
377            } else {
378                warn!("virtio-snd: Dropping event because there are no buffers in virtqueue");
379            }
380        }
381        Ok(())
382    }
383
384    fn parse_info_query(&mut self, read_buf: &[u8]) -> Option<(u32, u32)> {
385        if read_buf.len() != std::mem::size_of::<virtio_snd_query_info>() {
386            error!(
387                "virtio-snd: The driver sent the wrong number bytes for a pcm_info struct: {}",
388                read_buf.len()
389            );
390            return None;
391        }
392        let mut query: virtio_snd_query_info = Default::default();
393        query.as_mut_bytes().copy_from_slice(read_buf);
394        let start_id = query.start_id.to_native();
395        let count = query.count.to_native();
396        Some((start_id, count))
397    }
398
399    // Returns Err if it encounters an unrecoverable error, Ok otherwise
400    fn process_set_params(&mut self, desc: DescriptorChain, read_buf: &[u8]) -> Result<()> {
401        if read_buf.len() != std::mem::size_of::<virtio_snd_pcm_set_params>() {
402            error!(
403                "virtio-snd: The driver sent a buffer of the wrong size for a set_params struct: {}",
404                read_buf.len()
405                );
406            return reply_control_op_status(VIRTIO_SND_S_BAD_MSG, desc, &self.control_queue);
407        }
408        let mut params: virtio_snd_pcm_set_params = Default::default();
409        params.as_mut_bytes().copy_from_slice(read_buf);
410        let stream_id = params.hdr.stream_id.to_native();
411        if stream_id < self.vios_client.lock().num_streams() {
412            self.streams[stream_id as usize].send(StreamMsg::SetParams(desc, params))
413        } else {
414            error!(
415                "virtio-snd: Driver requested operation on invalid stream: {}",
416                stream_id
417            );
418            reply_control_op_status(VIRTIO_SND_S_BAD_MSG, desc, &self.control_queue)
419        }
420    }
421
422    // Returns Err if it encounters an unrecoverable error, Ok otherwise
423    fn try_parse_pcm_hdr_and_send_msg(&mut self, read_buf: &[u8], msg: StreamMsg) -> Result<()> {
424        if read_buf.len() != std::mem::size_of::<virtio_snd_pcm_hdr>() {
425            error!(
426                "virtio-snd: The driver sent a buffer too small to contain a header: {}",
427                read_buf.len()
428            );
429            return reply_control_op_status(
430                VIRTIO_SND_S_BAD_MSG,
431                match msg {
432                    StreamMsg::Prepare(d)
433                    | StreamMsg::Start(d)
434                    | StreamMsg::Stop(d)
435                    | StreamMsg::Release(d) => d,
436                    _ => panic!("virtio-snd: Can't handle message. This is a BUG!!"),
437                },
438                &self.control_queue,
439            );
440        }
441        let mut pcm_hdr: virtio_snd_pcm_hdr = Default::default();
442        pcm_hdr.as_mut_bytes().copy_from_slice(read_buf);
443        let stream_id = pcm_hdr.stream_id.to_native();
444        if stream_id < self.vios_client.lock().num_streams() {
445            self.streams[stream_id as usize].send(msg)
446        } else {
447            error!(
448                "virtio-snd: Driver requested operation on invalid stream: {}",
449                stream_id
450            );
451            reply_control_op_status(
452                VIRTIO_SND_S_BAD_MSG,
453                match msg {
454                    StreamMsg::Prepare(d)
455                    | StreamMsg::Start(d)
456                    | StreamMsg::Stop(d)
457                    | StreamMsg::Release(d) => d,
458                    _ => panic!("virtio-snd: Can't handle message. This is a BUG!!"),
459                },
460                &self.control_queue,
461            )
462        }
463    }
464
465    fn send_info_reply<T: Immutable + IntoBytes>(
466        &mut self,
467        mut desc: DescriptorChain,
468        code: u32,
469        info_vec: Vec<T>,
470    ) -> Result<()> {
471        let writer = &mut desc.writer;
472        writer
473            .write_obj(virtio_snd_hdr {
474                code: Le32::from(code),
475            })
476            .map_err(SoundError::QueueIO)?;
477        for info in info_vec {
478            writer.write_obj(info).map_err(SoundError::QueueIO)?;
479        }
480        {
481            let mut queue_lock = self.control_queue.lock();
482            queue_lock.add_used(desc);
483            queue_lock.trigger_interrupt();
484        }
485        Ok(())
486    }
487}
488
489impl Drop for Worker {
490    fn drop(&mut self) {
491        self.stop_io_thread();
492    }
493}
494
495fn io_loop(
496    tx_queue: Arc<Mutex<Queue>>,
497    rx_queue: Arc<Mutex<Queue>>,
498    senders: Vec<Sender<Box<StreamMsg>>>,
499    kill_evt: Event,
500) -> Result<()> {
501    #[derive(EventToken)]
502    enum Token {
503        TxQAvailable,
504        RxQAvailable,
505        Kill,
506    }
507    let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[
508        (tx_queue.lock().event(), Token::TxQAvailable),
509        (rx_queue.lock().event(), Token::RxQAvailable),
510        (&kill_evt, Token::Kill),
511    ])
512    .map_err(SoundError::WaitCtx)?;
513
514    'wait: loop {
515        let wait_events = wait_ctx.wait().map_err(SoundError::WaitCtx)?;
516        for wait_evt in wait_events.iter().filter(|e| e.is_readable) {
517            let queue = match wait_evt.token {
518                Token::TxQAvailable => {
519                    tx_queue
520                        .lock()
521                        .event()
522                        .wait()
523                        .map_err(SoundError::QueueEvt)?;
524                    &tx_queue
525                }
526                Token::RxQAvailable => {
527                    rx_queue
528                        .lock()
529                        .event()
530                        .wait()
531                        .map_err(SoundError::QueueEvt)?;
532                    &rx_queue
533                }
534                Token::Kill => {
535                    let _ = kill_evt.wait();
536                    break 'wait;
537                }
538            };
539            while let Some(mut avail_desc) = lock_pop_unlock(queue) {
540                let reader = &mut avail_desc.reader;
541                let xfer: virtio_snd_pcm_xfer = reader.read_obj().map_err(SoundError::QueueIO)?;
542                let stream_id = xfer.stream_id.to_native();
543                if stream_id as usize >= senders.len() {
544                    error!(
545                        "virtio-snd: Driver sent buffer for invalid stream: {}",
546                        stream_id
547                    );
548                    reply_pcm_buffer_status(VIRTIO_SND_S_IO_ERR, 0, avail_desc, queue)?;
549                } else {
550                    StreamProxy::send_msg(
551                        &senders[stream_id as usize],
552                        StreamMsg::Buffer(avail_desc),
553                    )?;
554                }
555            }
556        }
557    }
558    Ok(())
559}
560
561// If queue.lock().pop() is used directly in the condition of a 'while' loop the lock is held over
562// the entire loop block. Encapsulating it in this fuction guarantees that the lock is dropped
563// immediately after pop() is called, which allows the code to remain somewhat simpler.
564fn lock_pop_unlock(queue: &Arc<Mutex<Queue>>) -> Option<DescriptorChain> {
565    queue.lock().pop()
566}