devices/usb/xhci/
ring_buffer_controller.rs1use std::fmt;
6use std::fmt::Display;
7use std::sync::Arc;
8use std::sync::MutexGuard;
9
10use anyhow::Context;
11use base::debug;
12use base::error;
13use base::info;
14use base::Error as SysError;
15use base::Event;
16use base::EventType;
17use remain::sorted;
18use sync::Mutex;
19use thiserror::Error;
20use vm_memory::GuestAddress;
21use vm_memory::GuestMemory;
22
23use super::ring_buffer::RingBuffer;
24use super::ring_buffer_stop_cb::RingBufferStopCallback;
25use super::xhci_abi::TransferDescriptor;
26use crate::utils;
27use crate::utils::EventHandler;
28use crate::utils::EventLoop;
29
30#[sorted]
31#[derive(Error, Debug)]
32pub enum Error {
33 #[error("failed to add event to event loop: {0}")]
34 AddEvent(utils::Error),
35 #[error("failed to create event: {0}")]
36 CreateEvent(SysError),
37}
38
39type Result<T> = std::result::Result<T, Error>;
40
41#[derive(PartialEq, Copy, Clone, Eq)]
42enum RingBufferState {
43 Running,
45 Stopping,
48 Stopped,
50}
51
52pub trait TransferDescriptorHandler {
55 fn handle_transfer_descriptor(
57 &self,
58 descriptor: TransferDescriptor,
59 complete_event: Event,
60 ) -> anyhow::Result<()>;
61
62 fn stop(&self) -> bool {
73 true
74 }
75}
76
77pub struct RingBufferController<T: 'static + TransferDescriptorHandler> {
80 name: String,
81 state: Mutex<RingBufferState>,
82 stop_callback: Mutex<Vec<RingBufferStopCallback>>,
83 ring_buffer: Mutex<RingBuffer>,
84 handler: Mutex<T>,
85 event_loop: Arc<EventLoop>,
86 event: Event,
87}
88
89impl<T: 'static + TransferDescriptorHandler> Display for RingBufferController<T> {
90 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
91 write!(f, "RingBufferController `{}`", self.name)
92 }
93}
94
95impl<T: Send> RingBufferController<T>
96where
97 T: 'static + TransferDescriptorHandler,
98{
99 pub fn new_with_handler(
101 name: String,
102 mem: GuestMemory,
103 event_loop: Arc<EventLoop>,
104 handler: T,
105 ) -> Result<Arc<RingBufferController<T>>> {
106 let evt = Event::new().map_err(Error::CreateEvent)?;
107 let controller = Arc::new(RingBufferController {
108 name: name.clone(),
109 state: Mutex::new(RingBufferState::Stopped),
110 stop_callback: Mutex::new(Vec::new()),
111 ring_buffer: Mutex::new(RingBuffer::new(name, mem)),
112 handler: Mutex::new(handler),
113 event_loop: event_loop.clone(),
114 event: evt,
115 });
116 let event_handler: Arc<dyn EventHandler> = controller.clone();
117 event_loop
118 .add_event(
119 &controller.event,
120 EventType::Read,
121 Arc::downgrade(&event_handler),
122 )
123 .map_err(Error::AddEvent)?;
124 Ok(controller)
125 }
126
127 fn lock_ring_buffer(&self) -> MutexGuard<RingBuffer> {
128 self.ring_buffer.lock()
129 }
130
131 pub fn get_dequeue_pointer(&self) -> GuestAddress {
133 self.lock_ring_buffer().get_dequeue_pointer()
134 }
135
136 pub fn set_dequeue_pointer(&self, ptr: GuestAddress) {
138 xhci_trace!("{}: set_dequeue_pointer({:x})", self.name, ptr.0);
139 self.lock_ring_buffer().set_dequeue_pointer(ptr);
141 }
142
143 pub fn get_consumer_cycle_state(&self) -> bool {
145 self.lock_ring_buffer().get_consumer_cycle_state()
146 }
147
148 pub fn set_consumer_cycle_state(&self, state: bool) {
150 xhci_trace!("{}: set consumer cycle state: {}", self.name, state);
151 self.lock_ring_buffer().set_consumer_cycle_state(state);
153 }
154
155 pub fn start(&self) {
157 xhci_trace!("start {}", self.name);
158 let mut state = self.state.lock();
159 if *state != RingBufferState::Running {
160 *state = RingBufferState::Running;
161 if let Err(e) = self.event.signal() {
162 error!("cannot start event ring: {}", e);
163 }
164 }
165 }
166
167 pub fn stop(&self, callback: RingBufferStopCallback) {
169 xhci_trace!("stop {}", self.name);
170 let mut state = self.state.lock();
171 if *state == RingBufferState::Stopped {
172 info!("xhci: {} is already stopped", self.name);
173 return;
174 }
175 if self.handler.lock().stop() {
176 *state = RingBufferState::Stopping;
177 self.stop_callback.lock().push(callback);
178 } else {
179 *state = RingBufferState::Stopped;
180 }
181 }
182}
183
184impl<T> Drop for RingBufferController<T>
185where
186 T: 'static + TransferDescriptorHandler,
187{
188 fn drop(&mut self) {
189 if let Err(e) = self.event_loop.remove_event_for_descriptor(&self.event) {
191 error!(
192 "cannot remove ring buffer controller from event loop: {}",
193 e
194 );
195 }
196 }
197}
198
199impl<T> EventHandler for RingBufferController<T>
200where
201 T: 'static + TransferDescriptorHandler + Send,
202{
203 fn on_event(&self) -> anyhow::Result<()> {
204 self.event.wait().context("cannot read from event")?;
206 let mut state = self.state.lock();
207
208 match *state {
209 RingBufferState::Stopped => return Ok(()),
210 RingBufferState::Stopping => {
211 debug!("xhci: {}: stopping ring buffer controller", self.name);
212 *state = RingBufferState::Stopped;
213 self.stop_callback.lock().clear();
214 return Ok(());
215 }
216 RingBufferState::Running => {}
217 }
218
219 let transfer_descriptor = self
220 .lock_ring_buffer()
221 .dequeue_transfer_descriptor()
222 .context("cannot dequeue transfer descriptor")?;
223
224 let transfer_descriptor = match transfer_descriptor {
225 Some(t) => t,
226 None => {
227 *state = RingBufferState::Stopped;
228 self.stop_callback.lock().clear();
229 return Ok(());
230 }
231 };
232
233 let event = self.event.try_clone().context("cannot clone event")?;
234 self.handler
235 .lock()
236 .handle_transfer_descriptor(transfer_descriptor, event)
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use std::mem::size_of;
243 use std::sync::mpsc::channel;
244 use std::sync::mpsc::Sender;
245
246 use base::pagesize;
247
248 use super::super::xhci_abi::LinkTrb;
249 use super::super::xhci_abi::NormalTrb;
250 use super::super::xhci_abi::Trb;
251 use super::super::xhci_abi::TrbType;
252 use super::*;
253
254 struct TestHandler {
255 sender: Sender<i32>,
256 }
257
258 impl TransferDescriptorHandler for TestHandler {
259 fn handle_transfer_descriptor(
260 &self,
261 descriptor: TransferDescriptor,
262 complete_event: Event,
263 ) -> anyhow::Result<()> {
264 for atrb in descriptor {
265 assert_eq!(atrb.trb.get_trb_type().unwrap(), TrbType::Normal);
266 self.sender.send(atrb.trb.get_parameter() as i32).unwrap();
267 }
268 complete_event.signal().unwrap();
269 Ok(())
270 }
271 }
272
273 fn setup_mem() -> GuestMemory {
274 let trb_size = size_of::<Trb>() as u64;
275 let gm = GuestMemory::new(&[(GuestAddress(0), pagesize() as u64)]).unwrap();
276
277 let mut trb = NormalTrb::new();
283 trb.set_trb_type(TrbType::Normal);
284 trb.set_data_buffer_pointer(1);
285 trb.set_chain(true);
286 gm.write_obj_at_addr(trb, GuestAddress(0x100)).unwrap();
287
288 trb.set_data_buffer_pointer(2);
289 gm.write_obj_at_addr(trb, GuestAddress(0x100 + trb_size))
290 .unwrap();
291
292 let mut ltrb = LinkTrb::new();
293 ltrb.set_trb_type(TrbType::Link);
294 ltrb.set_ring_segment_pointer(0x200);
295 gm.write_obj_at_addr(ltrb, GuestAddress(0x100 + 2 * trb_size))
296 .unwrap();
297
298 trb.set_data_buffer_pointer(3);
299 gm.write_obj_at_addr(trb, GuestAddress(0x200)).unwrap();
300
301 trb.set_data_buffer_pointer(4);
303 trb.set_chain(false);
304 gm.write_obj_at_addr(trb, GuestAddress(0x200 + 1 * trb_size))
305 .unwrap();
306
307 ltrb.set_ring_segment_pointer(0x300);
308 gm.write_obj_at_addr(ltrb, GuestAddress(0x200 + 2 * trb_size))
309 .unwrap();
310
311 trb.set_data_buffer_pointer(5);
312 trb.set_chain(true);
313 gm.write_obj_at_addr(trb, GuestAddress(0x300)).unwrap();
314
315 trb.set_data_buffer_pointer(6);
317 trb.set_chain(false);
318 gm.write_obj_at_addr(trb, GuestAddress(0x300 + 1 * trb_size))
319 .unwrap();
320
321 ltrb.set_ring_segment_pointer(0x100);
322 gm.write_obj_at_addr(ltrb, GuestAddress(0x300 + 2 * trb_size))
323 .unwrap();
324 gm
325 }
326
327 #[test]
328 fn test_ring_buffer_controller() {
329 let (tx, rx) = channel();
330 let mem = setup_mem();
331 let (l, j) = EventLoop::start("test".to_string(), None).unwrap();
332 let l = Arc::new(l);
333 let controller = RingBufferController::new_with_handler(
334 "".to_string(),
335 mem,
336 l.clone(),
337 TestHandler { sender: tx },
338 )
339 .unwrap();
340 controller.set_dequeue_pointer(GuestAddress(0x100));
341 controller.set_consumer_cycle_state(false);
342 controller.start();
343 assert_eq!(rx.recv().unwrap(), 1);
344 assert_eq!(rx.recv().unwrap(), 2);
345 assert_eq!(rx.recv().unwrap(), 3);
346 assert_eq!(rx.recv().unwrap(), 4);
347 assert_eq!(rx.recv().unwrap(), 5);
348 assert_eq!(rx.recv().unwrap(), 6);
349 l.stop();
350 j.join().unwrap();
351 }
352}