devices/virtio/console/
worker.rs

1// Copyright 2024 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Virtio console device worker thread.
6
7use std::collections::BTreeMap;
8use std::collections::VecDeque;
9use std::sync::mpsc;
10use std::sync::Arc;
11
12use anyhow::anyhow;
13use anyhow::Context;
14use base::error;
15use base::Event;
16use base::EventToken;
17use base::WaitContext;
18use base::WorkerThread;
19use sync::Mutex;
20
21use crate::virtio::console::control::process_control_receive_queue;
22use crate::virtio::console::control::process_control_transmit_queue;
23use crate::virtio::console::control::ControlMsgBytes;
24use crate::virtio::console::input::process_receive_queue;
25use crate::virtio::console::output::process_transmit_queue;
26use crate::virtio::console::port::ConsolePort;
27use crate::virtio::console::port::ConsolePortInfo;
28use crate::virtio::Queue;
29
30const PORT0_RECEIVEQ_IDX: usize = 0;
31const PORT0_TRANSMITQ_IDX: usize = 1;
32const CONTROL_RECEIVEQ_IDX: usize = 2;
33const CONTROL_TRANSMITQ_IDX: usize = 3;
34const PORT1_RECEIVEQ_IDX: usize = 4;
35const PORT1_TRANSMITQ_IDX: usize = 5;
36
37pub struct WorkerPort {
38    info: Option<ConsolePortInfo>,
39
40    in_avail_evt: Event,
41    input_buffer: Arc<Mutex<VecDeque<u8>>>,
42    output: Box<dyn std::io::Write + Send>,
43}
44
45impl WorkerPort {
46    pub fn from_console_port(port: &mut ConsolePort) -> WorkerPort {
47        let in_avail_evt = port.clone_in_avail_evt().unwrap();
48        let input_buffer = port.clone_input_buffer();
49        let output = port
50            .take_output()
51            .unwrap_or_else(|| Box::new(std::io::sink()));
52        let info = port.port_info().cloned();
53        WorkerPort {
54            info,
55            in_avail_evt,
56            input_buffer,
57            output,
58        }
59    }
60
61    /// Restore the state retrieved from `ConsolePort` by `WorkerPort::from_console_port()`.
62    pub fn into_console_port(self, console_port: &mut ConsolePort) {
63        console_port.restore_output(self.output);
64    }
65
66    pub fn is_console(&self) -> bool {
67        self.info
68            .as_ref()
69            .map(|info| info.console)
70            .unwrap_or_default()
71    }
72
73    pub fn name(&self) -> Option<&str> {
74        self.info.as_ref().and_then(ConsolePortInfo::name)
75    }
76}
77
78#[derive(EventToken)]
79enum Token {
80    ReceiveQueueAvailable(u32),
81    TransmitQueueAvailable(u32),
82    InputAvailable(u32),
83    ControlReceiveQueueAvailable,
84    ControlTransmitQueueAvailable,
85    WorkerRequest,
86    Kill,
87}
88
89pub enum WorkerRequest {
90    StartQueue {
91        idx: usize,
92        queue: Queue,
93        response_sender: mpsc::SyncSender<anyhow::Result<()>>,
94    },
95    StopQueue {
96        idx: usize,
97        response_sender: mpsc::SyncSender<Option<Queue>>,
98    },
99}
100
101pub struct Worker {
102    wait_ctx: WaitContext<Token>,
103
104    // Currently running queues.
105    queues: BTreeMap<usize, Queue>,
106
107    // Console ports indexed by port ID. At least port 0 will exist, and other ports may be
108    // available if `VIRTIO_CONSOLE_F_MULTIPORT` is enabled.
109    ports: Vec<WorkerPort>,
110
111    // Device-to-driver messages to be received by the driver via the control receiveq.
112    pending_receive_control_msgs: VecDeque<ControlMsgBytes>,
113
114    worker_receiver: mpsc::Receiver<WorkerRequest>,
115    worker_event: Event,
116}
117
118impl Worker {
119    pub fn new(
120        ports: Vec<WorkerPort>,
121        worker_receiver: mpsc::Receiver<WorkerRequest>,
122        worker_event: Event,
123    ) -> anyhow::Result<Self> {
124        let wait_ctx = WaitContext::new().context("WaitContext::new() failed")?;
125
126        wait_ctx.add(&worker_event, Token::WorkerRequest)?;
127
128        for (index, port) in ports.iter().enumerate() {
129            let port_id = index as u32;
130            wait_ctx.add(&port.in_avail_evt, Token::InputAvailable(port_id))?;
131        }
132
133        Ok(Worker {
134            wait_ctx,
135            queues: BTreeMap::new(),
136            ports,
137            pending_receive_control_msgs: VecDeque::new(),
138            worker_receiver,
139            worker_event,
140        })
141    }
142
143    pub fn run(&mut self, kill_evt: &Event) -> anyhow::Result<()> {
144        self.wait_ctx.add(kill_evt, Token::Kill)?;
145        let res = self.run_loop();
146        self.wait_ctx.delete(kill_evt)?;
147        res
148    }
149
150    fn run_loop(&mut self) -> anyhow::Result<()> {
151        let mut running = true;
152        while running {
153            let events = self.wait_ctx.wait()?;
154
155            for event in events.iter().filter(|e| e.is_readable) {
156                match event.token {
157                    Token::TransmitQueueAvailable(port_id) => {
158                        if let (Some(port), Some(transmitq)) = (
159                            self.ports.get_mut(port_id as usize),
160                            transmitq_idx(port_id).and_then(|idx| self.queues.get_mut(&idx)),
161                        ) {
162                            transmitq
163                                .event()
164                                .wait()
165                                .context("failed reading transmit queue Event")?;
166                            process_transmit_queue(transmitq, &mut port.output);
167                        }
168                    }
169                    Token::ReceiveQueueAvailable(port_id) | Token::InputAvailable(port_id) => {
170                        let port = self.ports.get_mut(port_id as usize);
171                        let receiveq =
172                            receiveq_idx(port_id).and_then(|idx| self.queues.get_mut(&idx));
173
174                        let event = if matches!(event.token, Token::ReceiveQueueAvailable(..)) {
175                            receiveq.as_ref().map(|q| q.event())
176                        } else {
177                            port.as_ref().map(|p| &p.in_avail_evt)
178                        };
179                        if let Some(event) = event {
180                            event.wait().context("failed to clear receive event")?;
181                        }
182
183                        if let (Some(port), Some(receiveq)) = (port, receiveq) {
184                            let mut input_buffer = port.input_buffer.lock();
185                            process_receive_queue(&mut input_buffer, receiveq);
186                        }
187                    }
188                    Token::ControlReceiveQueueAvailable => {
189                        if let Some(ctrl_receiveq) = self.queues.get_mut(&CONTROL_RECEIVEQ_IDX) {
190                            ctrl_receiveq
191                                .event()
192                                .wait()
193                                .context("failed waiting on control event")?;
194                            process_control_receive_queue(
195                                ctrl_receiveq,
196                                &mut self.pending_receive_control_msgs,
197                            );
198                        }
199                    }
200                    Token::ControlTransmitQueueAvailable => {
201                        if let Some(ctrl_transmitq) = self.queues.get_mut(&CONTROL_TRANSMITQ_IDX) {
202                            ctrl_transmitq
203                                .event()
204                                .wait()
205                                .context("failed waiting on control event")?;
206                            process_control_transmit_queue(
207                                ctrl_transmitq,
208                                &self.ports,
209                                &mut self.pending_receive_control_msgs,
210                            );
211                        }
212
213                        // Attempt to send any new replies if there is space in the receiveq.
214                        if let Some(ctrl_receiveq) = self.queues.get_mut(&CONTROL_RECEIVEQ_IDX) {
215                            process_control_receive_queue(
216                                ctrl_receiveq,
217                                &mut self.pending_receive_control_msgs,
218                            )
219                        }
220                    }
221                    Token::WorkerRequest => {
222                        self.worker_event.wait()?;
223                        self.process_worker_requests();
224                    }
225                    Token::Kill => running = false,
226                }
227            }
228        }
229        Ok(())
230    }
231
232    fn process_worker_requests(&mut self) {
233        while let Ok(request) = self.worker_receiver.try_recv() {
234            match request {
235                WorkerRequest::StartQueue {
236                    idx,
237                    queue,
238                    response_sender,
239                } => {
240                    let res = self.start_queue(idx, queue);
241                    let _ = response_sender.send(res);
242                }
243                WorkerRequest::StopQueue {
244                    idx,
245                    response_sender,
246                } => {
247                    let res = self.stop_queue(idx);
248                    let _ = response_sender.send(res);
249                }
250            }
251        }
252    }
253
254    fn start_queue(&mut self, idx: usize, queue: Queue) -> anyhow::Result<()> {
255        if let Some(port_id) = receiveq_port_id(idx) {
256            self.wait_ctx
257                .add(queue.event(), Token::ReceiveQueueAvailable(port_id))?;
258        } else if let Some(port_id) = transmitq_port_id(idx) {
259            self.wait_ctx
260                .add(queue.event(), Token::TransmitQueueAvailable(port_id))?;
261        } else if idx == CONTROL_RECEIVEQ_IDX {
262            self.wait_ctx
263                .add(queue.event(), Token::ControlReceiveQueueAvailable)?;
264        } else if idx == CONTROL_TRANSMITQ_IDX {
265            self.wait_ctx
266                .add(queue.event(), Token::ControlTransmitQueueAvailable)?;
267        } else {
268            return Err(anyhow!("unhandled queue idx {idx}"));
269        }
270
271        let prev = self.queues.insert(idx, queue);
272        assert!(prev.is_none());
273        Ok(())
274    }
275
276    fn stop_queue(&mut self, idx: usize) -> Option<Queue> {
277        if let Some(queue) = self.queues.remove(&idx) {
278            let _ = self.wait_ctx.delete(queue.event());
279            Some(queue)
280        } else {
281            None
282        }
283    }
284}
285
286pub struct WorkerHandle {
287    worker_thread: WorkerThread<Vec<WorkerPort>>,
288    worker_sender: mpsc::Sender<WorkerRequest>,
289    worker_event: Event,
290}
291
292impl WorkerHandle {
293    pub fn new(ports: Vec<WorkerPort>) -> anyhow::Result<Self> {
294        let worker_event = Event::new().context("Event::new")?;
295        let worker_event_clone = worker_event.try_clone().context("Event::try_clone")?;
296        let (worker_sender, worker_receiver) = mpsc::channel();
297        let worker_thread = WorkerThread::start("v_console", move |kill_evt| {
298            let mut worker = Worker::new(ports, worker_receiver, worker_event_clone)
299                .expect("console Worker::new() failed");
300            if let Err(e) = worker.run(&kill_evt) {
301                error!("console worker failed: {:#}", e);
302            }
303            worker.ports
304        });
305        Ok(WorkerHandle {
306            worker_thread,
307            worker_sender,
308            worker_event,
309        })
310    }
311
312    pub fn start_queue(&mut self, idx: usize, queue: Queue) -> anyhow::Result<()> {
313        let (response_sender, response_receiver) = mpsc::sync_channel(0);
314        self.worker_sender
315            .send(WorkerRequest::StartQueue {
316                idx,
317                queue,
318                response_sender,
319            })
320            .context("mpsc::Sender::send")?;
321        self.worker_event.signal().context("Event::signal")?;
322        response_receiver.recv().context("mpsc::Receiver::recv")?
323    }
324
325    pub fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Option<Queue>> {
326        let (response_sender, response_receiver) = mpsc::sync_channel(0);
327        self.worker_sender
328            .send(WorkerRequest::StopQueue {
329                idx,
330                response_sender,
331            })
332            .context("mpsc::Sender::send")?;
333        self.worker_event.signal().context("Event::signal")?;
334        response_receiver.recv().context("mpsc::Receiver::recv")
335    }
336
337    pub fn stop(self) -> Vec<WorkerPort> {
338        self.worker_thread.stop()
339    }
340}
341
342fn receiveq_idx(port_id: u32) -> Option<usize> {
343    if port_id == 0 {
344        Some(PORT0_RECEIVEQ_IDX)
345    } else {
346        PORT1_RECEIVEQ_IDX.checked_add((port_id - 1).checked_mul(2)?.try_into().ok()?)
347    }
348}
349
350fn transmitq_idx(port_id: u32) -> Option<usize> {
351    if port_id == 0 {
352        Some(PORT0_TRANSMITQ_IDX)
353    } else {
354        PORT1_TRANSMITQ_IDX.checked_add((port_id - 1).checked_mul(2)?.try_into().ok()?)
355    }
356}
357
358fn receiveq_port_id(queue_idx: usize) -> Option<u32> {
359    if queue_idx == PORT0_RECEIVEQ_IDX {
360        Some(0)
361    } else if queue_idx >= PORT1_RECEIVEQ_IDX && (queue_idx & 1) == 0 {
362        ((queue_idx - PORT1_RECEIVEQ_IDX) / 2)
363            .checked_add(1)?
364            .try_into()
365            .ok()
366    } else {
367        None
368    }
369}
370
371fn transmitq_port_id(queue_idx: usize) -> Option<u32> {
372    if queue_idx == PORT0_TRANSMITQ_IDX {
373        Some(0)
374    } else if queue_idx >= PORT1_TRANSMITQ_IDX && (queue_idx & 1) == 1 {
375        ((queue_idx - PORT1_TRANSMITQ_IDX) / 2)
376            .checked_add(1)?
377            .try_into()
378            .ok()
379    } else {
380        None
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    #[test]
389    fn test_receiveq_idx() {
390        assert_eq!(receiveq_idx(0), Some(0));
391        assert_eq!(receiveq_idx(1), Some(4));
392        assert_eq!(receiveq_idx(2), Some(6));
393        assert_eq!(receiveq_idx(3), Some(8));
394    }
395
396    #[test]
397    fn test_transmitq_idx() {
398        assert_eq!(transmitq_idx(0), Some(1));
399        assert_eq!(transmitq_idx(1), Some(5));
400        assert_eq!(transmitq_idx(2), Some(7));
401        assert_eq!(transmitq_idx(3), Some(9));
402    }
403
404    #[test]
405    fn test_receiveq_port_id() {
406        assert_eq!(receiveq_port_id(0), Some(0));
407        assert_eq!(receiveq_port_id(1), None); // port0 transmitq
408        assert_eq!(receiveq_port_id(2), None); // ctrl receiveq
409        assert_eq!(receiveq_port_id(3), None); // ctrl transmitq
410        assert_eq!(receiveq_port_id(4), Some(1));
411        assert_eq!(receiveq_port_id(5), None);
412        assert_eq!(receiveq_port_id(6), Some(2));
413        assert_eq!(receiveq_port_id(7), None);
414        assert_eq!(receiveq_port_id(8), Some(3));
415        assert_eq!(receiveq_port_id(9), None);
416    }
417
418    #[test]
419    fn test_transmitq_port_id() {
420        assert_eq!(transmitq_port_id(0), None); // port0 receiveq
421        assert_eq!(transmitq_port_id(1), Some(0));
422        assert_eq!(transmitq_port_id(2), None); // ctrl receiveq
423        assert_eq!(transmitq_port_id(3), None); // ctrl transmitq
424        assert_eq!(transmitq_port_id(4), None); // port1 receiveq
425        assert_eq!(transmitq_port_id(5), Some(1));
426        assert_eq!(transmitq_port_id(6), None);
427        assert_eq!(transmitq_port_id(7), Some(2));
428        assert_eq!(transmitq_port_id(8), None);
429        assert_eq!(transmitq_port_id(9), Some(3));
430    }
431}