1use std::fmt;
6use std::fmt::Display;
7use std::sync::Arc;
8use std::sync::MutexGuard;
9
10use anyhow::Context;
11use base::error;
12use base::Error as SysError;
13use base::Event;
14use base::EventType;
15use remain::sorted;
16use sync::Mutex;
17use thiserror::Error;
18use vm_memory::GuestAddress;
19use vm_memory::GuestMemory;
20
21use super::ring_buffer::RingBuffer;
22use super::ring_buffer_stop_cb::RingBufferStopCallback;
23use super::xhci_abi::TransferDescriptor;
24use crate::usb::xhci::xhci_abi::AddressedTrb;
25use crate::utils;
26use crate::utils::EventHandler;
27use crate::utils::EventLoop;
28
29#[sorted]
30#[derive(Error, Debug)]
31pub enum Error {
32 #[error("failed to add event to event loop: {0}")]
33 AddEvent(utils::Error),
34 #[error("failed to create event: {0}")]
35 CreateEvent(SysError),
36}
37
38type Result<T> = std::result::Result<T, Error>;
39
40#[derive(PartialEq, Copy, Clone, Eq)]
41enum RingBufferState {
42 Running,
44 Stopped,
46}
47
48pub trait TransferDescriptorHandler {
51 fn handle_transfer_descriptor(
53 &self,
54 descriptor: TransferDescriptor,
55 trigger_event: Event,
56 ) -> anyhow::Result<()>;
57
58 fn cancel_transfers(&self, _callback: RingBufferStopCallback) {}
64}
65
66pub struct RingBufferController<T: 'static + TransferDescriptorHandler> {
69 name: String,
70 state: Mutex<RingBufferState>,
71 ring_buffer: Mutex<RingBuffer>,
72 handler: Mutex<T>,
73 event_loop: Arc<EventLoop>,
74 event: Event,
75}
76
77impl<T: 'static + TransferDescriptorHandler> Display for RingBufferController<T> {
78 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
79 write!(f, "RingBufferController `{}`", self.name)
80 }
81}
82
83impl<T: Send> RingBufferController<T>
84where
85 T: 'static + TransferDescriptorHandler,
86{
87 pub fn new_with_handler(
89 name: String,
90 mem: GuestMemory,
91 event_loop: Arc<EventLoop>,
92 handler: T,
93 ) -> Result<Arc<RingBufferController<T>>> {
94 let evt = Event::new().map_err(Error::CreateEvent)?;
95 let controller = Arc::new(RingBufferController {
96 name: name.clone(),
97 state: Mutex::new(RingBufferState::Stopped),
98 ring_buffer: Mutex::new(RingBuffer::new(name, mem)),
99 handler: Mutex::new(handler),
100 event_loop: event_loop.clone(),
101 event: evt,
102 });
103 let event_handler: Arc<dyn EventHandler> = controller.clone();
104 event_loop
105 .add_event(
106 &controller.event,
107 EventType::Read,
108 Arc::downgrade(&event_handler),
109 )
110 .map_err(Error::AddEvent)?;
111 Ok(controller)
112 }
113
114 fn lock_ring_buffer(&self) -> MutexGuard<RingBuffer> {
115 self.ring_buffer.lock()
116 }
117
118 pub fn get_stopped_dequeue_state(&self) -> (GuestAddress, bool) {
122 let mut locked = self.lock_ring_buffer();
123 locked.synchronize_with_hardware();
124 (
125 locked.get_dequeue_pointer(),
126 locked.get_consumer_cycle_state(),
127 )
128 }
129
130 pub fn set_dequeue_pointer(&self, ptr: GuestAddress) {
132 xhci_trace!("{}: set_dequeue_pointer({:x})", self.name, ptr.0);
133 self.lock_ring_buffer().set_dequeue_pointer(ptr);
135 }
136
137 pub fn set_consumer_cycle_state(&self, state: bool) {
139 xhci_trace!("{}: set consumer cycle state: {}", self.name, state);
140 self.lock_ring_buffer().set_consumer_cycle_state(state);
142 }
143
144 pub fn start(&self) {
146 xhci_trace!("start {}", self.name);
147 let mut state = self.state.lock();
148 if *state != RingBufferState::Running {
149 *state = RingBufferState::Running;
150 if let Err(e) = self.event.signal() {
151 error!("cannot start event ring: {}", e);
152 }
153 }
154 }
155
156 pub fn stop(&self, callback: RingBufferStopCallback) {
158 xhci_trace!("stop {}", self.name);
159 let mut state = self.state.lock();
160 self.handler.lock().cancel_transfers(callback);
163 *state = RingBufferState::Stopped;
164 }
165
166 pub fn report_completed_trb(&self, trb: &AddressedTrb) {
168 self.lock_ring_buffer().complete(trb);
169 }
170}
171
172impl<T> Drop for RingBufferController<T>
173where
174 T: 'static + TransferDescriptorHandler,
175{
176 fn drop(&mut self) {
177 if let Err(e) = self.event_loop.remove_event_for_descriptor(&self.event) {
179 error!(
180 "cannot remove ring buffer controller from event loop: {}",
181 e
182 );
183 }
184 }
185}
186
187impl<T> EventHandler for RingBufferController<T>
188where
189 T: 'static + TransferDescriptorHandler + Send,
190{
191 fn on_event(&self) -> anyhow::Result<()> {
192 self.event.wait().context("cannot read from event")?;
194 let mut state = self.state.lock();
195
196 match *state {
197 RingBufferState::Stopped => return Ok(()),
198 RingBufferState::Running => {}
199 }
200
201 let transfer_descriptor = self
202 .lock_ring_buffer()
203 .dequeue_transfer_descriptor()
204 .context("cannot dequeue transfer descriptor")?;
205
206 let transfer_descriptor = match transfer_descriptor {
207 Some(t) => t,
208 None => {
209 *state = RingBufferState::Stopped;
210 return Ok(());
211 }
212 };
213
214 let event = self.event.try_clone().context("cannot clone event")?;
215 self.handler
216 .lock()
217 .handle_transfer_descriptor(transfer_descriptor, event)
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use std::mem::size_of;
224 use std::sync::mpsc::channel;
225 use std::sync::mpsc::Sender;
226
227 use base::pagesize;
228
229 use super::super::xhci_abi::LinkTrb;
230 use super::super::xhci_abi::NormalTrb;
231 use super::super::xhci_abi::Trb;
232 use super::super::xhci_abi::TrbType;
233 use super::*;
234
235 struct TestHandler {
236 sender: Sender<i32>,
237 }
238
239 impl TransferDescriptorHandler for TestHandler {
240 fn handle_transfer_descriptor(
241 &self,
242 descriptor: TransferDescriptor,
243 trigger_event: Event,
244 ) -> anyhow::Result<()> {
245 for atrb in &descriptor {
246 assert_eq!(atrb.trb.get_trb_type().unwrap(), TrbType::Normal);
247 self.sender.send(atrb.trb.get_parameter() as i32).unwrap();
248 }
249 trigger_event.signal().unwrap();
250 Ok(())
251 }
252 }
253
254 struct TestLazyHandler {
255 sender: Sender<u64>,
256 processing: Mutex<Vec<u64>>,
257 }
258
259 impl TransferDescriptorHandler for TestLazyHandler {
260 fn handle_transfer_descriptor(
261 &self,
262 descriptor: TransferDescriptor,
263 trigger_event: Event,
264 ) -> anyhow::Result<()> {
265 let mut locked = self.processing.lock();
266 for a in locked.iter() {
267 self.sender.send(*a).unwrap();
268 }
269 trigger_event.signal().unwrap();
270 *locked = descriptor.iter().map(|atrb| atrb.gpa).collect();
271 Ok(())
272 }
273
274 fn cancel_transfers(&self, _callback: RingBufferStopCallback) {}
275 }
276
277 fn setup_mem() -> GuestMemory {
278 let trb_size = size_of::<Trb>() as u64;
279 let gm = GuestMemory::new(&[(GuestAddress(0), pagesize() as u64)]).unwrap();
280
281 let mut trb = NormalTrb::new();
287 trb.set_trb_type(TrbType::Normal);
288 trb.set_data_buffer_pointer(1);
289 trb.set_chain(true);
290 gm.write_obj_at_addr(trb, GuestAddress(0x100)).unwrap();
291
292 trb.set_data_buffer_pointer(2);
293 gm.write_obj_at_addr(trb, GuestAddress(0x100 + trb_size))
294 .unwrap();
295
296 let mut ltrb = LinkTrb::new();
297 ltrb.set_trb_type(TrbType::Link);
298 ltrb.set_ring_segment_pointer(0x200);
299 gm.write_obj_at_addr(ltrb, GuestAddress(0x100 + 2 * trb_size))
300 .unwrap();
301
302 trb.set_data_buffer_pointer(3);
303 gm.write_obj_at_addr(trb, GuestAddress(0x200)).unwrap();
304
305 trb.set_data_buffer_pointer(4);
307 trb.set_chain(false);
308 gm.write_obj_at_addr(trb, GuestAddress(0x200 + 1 * trb_size))
309 .unwrap();
310
311 ltrb.set_ring_segment_pointer(0x300);
312 gm.write_obj_at_addr(ltrb, GuestAddress(0x200 + 2 * trb_size))
313 .unwrap();
314
315 trb.set_data_buffer_pointer(5);
316 trb.set_chain(true);
317 gm.write_obj_at_addr(trb, GuestAddress(0x300)).unwrap();
318
319 trb.set_data_buffer_pointer(6);
321 trb.set_chain(false);
322 gm.write_obj_at_addr(trb, GuestAddress(0x300 + 1 * trb_size))
323 .unwrap();
324
325 ltrb.set_ring_segment_pointer(0x100);
326 ltrb.set_toggle_cycle(true);
327 gm.write_obj_at_addr(ltrb, GuestAddress(0x300 + 2 * trb_size))
328 .unwrap();
329 gm
330 }
331
332 #[test]
333 fn test_ring_buffer_controller() {
334 let (tx, rx) = channel();
335 let mem = setup_mem();
336 let (l, j) = EventLoop::start("test".to_string(), None).unwrap();
337 let l = Arc::new(l);
338 let controller = RingBufferController::new_with_handler(
339 "".to_string(),
340 mem,
341 l.clone(),
342 TestHandler { sender: tx },
343 )
344 .unwrap();
345 controller.set_dequeue_pointer(GuestAddress(0x100));
346 controller.set_consumer_cycle_state(false);
347 controller.start();
348 assert_eq!(rx.recv().unwrap(), 1);
349 assert_eq!(rx.recv().unwrap(), 2);
350 assert_eq!(rx.recv().unwrap(), 3);
351 assert_eq!(rx.recv().unwrap(), 4);
352 assert_eq!(rx.recv().unwrap(), 5);
353 assert_eq!(rx.recv().unwrap(), 6);
354 l.stop();
355 j.join().unwrap();
356 }
357
358 #[test]
359 fn synchronize_dequeue_pointer() {
360 let (tx, rx) = channel();
361 let mem = setup_mem();
362 let (l, j) = EventLoop::start("test".to_string(), None).unwrap();
363 let l = Arc::new(l);
364 let controller = RingBufferController::new_with_handler(
365 "".to_string(),
366 mem,
367 l.clone(),
368 TestHandler { sender: tx },
369 )
370 .unwrap();
371 controller.set_dequeue_pointer(GuestAddress(0x100));
372 controller.set_consumer_cycle_state(false);
373 controller.start();
374 for i in 1..=6 {
375 assert_eq!(rx.recv().unwrap(), i);
376 }
377
378 let mut trb = Trb::new();
379 trb.set_cycle(false);
380 let atrb = AddressedTrb { trb, gpa: 0x210 };
381 let null_callback = RingBufferStopCallback::new(move || {});
382 controller.stop(null_callback);
383 controller.report_completed_trb(&atrb);
384 let (dq, cycle) = controller.get_stopped_dequeue_state();
385 assert_eq!(dq.offset(), 0x220);
386 assert_eq!(cycle, false); l.stop();
388 j.join().unwrap();
389 }
390
391 #[test]
392 fn synchronize_dequeue_pointer_across_link_trb() {
393 let (tx, rx) = channel();
394 let mem = setup_mem();
395 let (l, j) = EventLoop::start("test".to_string(), None).unwrap();
396 let l = Arc::new(l);
397 let controller = RingBufferController::new_with_handler(
398 "".to_string(),
399 mem,
400 l.clone(),
401 TestHandler { sender: tx },
402 )
403 .unwrap();
404 controller.set_dequeue_pointer(GuestAddress(0x100));
405 controller.set_consumer_cycle_state(false);
406 controller.start();
407
408 for i in 1..=6 {
412 assert_eq!(rx.recv().unwrap(), i);
413 }
414
415 let mut trb = Trb::new();
419 trb.set_cycle(false);
420 let atrb = AddressedTrb { trb, gpa: 0x300 };
421 let null_callback = RingBufferStopCallback::new(move || {});
422 controller.stop(null_callback);
423 controller.report_completed_trb(&atrb);
424 let (dq, cycle) = controller.get_stopped_dequeue_state();
425 assert_eq!(dq.offset(), 0x310);
426 assert_eq!(cycle, false);
427 l.stop();
428 j.join().unwrap();
429 }
430
431 #[test]
432 fn synchronize_dequeue_pointer_for_lazy_handler() {
433 let (tx, rx) = channel();
434 let mem = setup_mem();
435 let (l, j) = EventLoop::start("test".to_string(), None).unwrap();
436 let l = Arc::new(l);
437 let controller = RingBufferController::new_with_handler(
438 "".to_string(),
439 mem,
440 l.clone(),
441 TestLazyHandler {
442 sender: tx,
443 processing: Mutex::new(Vec::new()),
444 },
445 )
446 .unwrap();
447 controller.set_dequeue_pointer(GuestAddress(0x100));
448 controller.set_consumer_cycle_state(false);
449 controller.start();
450 assert_eq!(rx.recv().unwrap(), 0x100);
451 assert_eq!(rx.recv().unwrap(), 0x110);
452 assert_eq!(rx.recv().unwrap(), 0x200);
453 assert_eq!(rx.recv().unwrap(), 0x210);
454 assert!(rx.try_recv().is_err());
455
456 let null_callback = RingBufferStopCallback::new(move || {});
457 controller.stop(null_callback);
458 let (dq, cycle) = controller.get_stopped_dequeue_state();
461 assert_eq!(dq.offset(), 0x100);
462 assert_eq!(cycle, false);
463 l.stop();
464 j.join().unwrap();
465 }
466}