1use std::mem::size_of;
6use std::sync::atomic::AtomicBool;
7use std::sync::atomic::Ordering;
8use std::sync::Arc;
9
10use base::debug;
11use base::error;
12use base::info;
13use bit_field::Error as BitFieldError;
14use remain::sorted;
15use sync::Mutex;
16use thiserror::Error;
17use vm_memory::GuestAddress;
18use vm_memory::GuestMemory;
19use vm_memory::GuestMemoryError;
20
21use super::interrupter::Interrupter;
22use super::transfer_ring_controller::TransferRingController;
23use super::transfer_ring_controller::TransferRingControllerError;
24use super::transfer_ring_controller::TransferRingControllers;
25use super::usb_hub;
26use super::usb_hub::UsbHub;
27use super::xhci_abi::AddressDeviceCommandTrb;
28use super::xhci_abi::ConfigureEndpointCommandTrb;
29use super::xhci_abi::DequeuePtr;
30use super::xhci_abi::DeviceContext;
31use super::xhci_abi::DeviceSlotState;
32use super::xhci_abi::EndpointContext;
33use super::xhci_abi::EndpointState;
34use super::xhci_abi::EvaluateContextCommandTrb;
35use super::xhci_abi::InputControlContext;
36use super::xhci_abi::SlotContext;
37use super::xhci_abi::StreamContextArray;
38use super::xhci_abi::TrbCompletionCode;
39use super::xhci_abi::DEVICE_CONTEXT_ENTRY_SIZE;
40use super::xhci_backend_device::XhciBackendDevice;
41use super::xhci_regs::valid_max_pstreams;
42use super::xhci_regs::valid_slot_id;
43use super::xhci_regs::MAX_PORTS;
44use super::xhci_regs::MAX_SLOTS;
45use crate::register_space::Register;
46use crate::usb::backend::error::Error as BackendProviderError;
47use crate::usb::xhci::ring_buffer_stop_cb::fallible_closure;
48use crate::usb::xhci::ring_buffer_stop_cb::RingBufferStopCallback;
49use crate::utils::EventLoop;
50use crate::utils::FailHandle;
51
52#[sorted]
53#[derive(Error, Debug)]
54pub enum Error {
55 #[error("failed to allocate streams: {0}")]
56 AllocStreams(BackendProviderError),
57 #[error("bad device context: {0}")]
58 BadDeviceContextAddr(GuestAddress),
59 #[error("bad endpoint context: {0}")]
60 BadEndpointContext(GuestAddress),
61 #[error("device slot get a bad endpoint id: {0}")]
62 BadEndpointId(u8),
63 #[error("bad input context address: {0}")]
64 BadInputContextAddr(GuestAddress),
65 #[error("device slot get a bad port id: {0}")]
66 BadPortId(u8),
67 #[error("bad stream context type: {0}")]
68 BadStreamContextType(u8),
69 #[error("callback failed")]
70 CallbackFailed,
71 #[error("failed to create transfer controller: {0}")]
72 CreateTransferController(TransferRingControllerError),
73 #[error("failed to free streams: {0}")]
74 FreeStreams(BackendProviderError),
75 #[error("failed to get endpoint state: {0}")]
76 GetEndpointState(BitFieldError),
77 #[error("failed to get port: {0}")]
78 GetPort(u8),
79 #[error("failed to get slot context state: {0}")]
80 GetSlotContextState(BitFieldError),
81 #[error("failed to get trc: {0}")]
82 GetTrc(u8),
83 #[error("failed to read guest memory: {0}")]
84 ReadGuestMemory(GuestMemoryError),
85 #[error("failed to reset port: {0}")]
86 ResetPort(BackendProviderError),
87 #[error("failed to upgrade weak reference")]
88 WeakReferenceUpgrade,
89 #[error("failed to write guest memory: {0}")]
90 WriteGuestMemory(GuestMemoryError),
91}
92
93type Result<T> = std::result::Result<T, Error>;
94
95pub const TRANSFER_RING_CONTROLLERS_INDEX_END: usize = 31;
103pub const DCI_INDEX_END: u8 = (TRANSFER_RING_CONTROLLERS_INDEX_END + 1) as u8;
105pub const FIRST_TRANSFER_ENDPOINT_DCI: u8 = 2;
107
108fn valid_endpoint_id(endpoint_id: u8) -> bool {
109 endpoint_id < DCI_INDEX_END && endpoint_id > 0
110}
111
112#[derive(Clone)]
113pub struct DeviceSlots {
114 fail_handle: Arc<dyn FailHandle>,
115 hub: Arc<UsbHub>,
116 slots: Vec<Arc<DeviceSlot>>,
117}
118
119impl DeviceSlots {
120 pub fn new(
121 fail_handle: Arc<dyn FailHandle>,
122 dcbaap: Register<u64>,
123 hub: Arc<UsbHub>,
124 interrupter: Arc<Mutex<Interrupter>>,
125 event_loop: Arc<EventLoop>,
126 mem: GuestMemory,
127 ) -> DeviceSlots {
128 let mut slots = Vec::new();
129 for slot_id in 1..=MAX_SLOTS {
130 slots.push(Arc::new(DeviceSlot::new(
131 slot_id,
132 dcbaap.clone(),
133 hub.clone(),
134 interrupter.clone(),
135 event_loop.clone(),
136 mem.clone(),
137 )));
138 }
139 DeviceSlots {
140 fail_handle,
141 hub,
142 slots,
143 }
144 }
145
146 pub fn slot(&self, slot_id: u8) -> Option<Arc<DeviceSlot>> {
148 if valid_slot_id(slot_id) {
149 Some(self.slots[slot_id as usize - 1].clone())
150 } else {
151 error!(
152 "trying to index a wrong slot id {}, max slot = {}",
153 slot_id, MAX_SLOTS
154 );
155 None
156 }
157 }
158
159 pub fn reset_port(&self, port_id: u8) -> Result<()> {
161 if let Some(port) = self.hub.get_port(port_id) {
162 if let Some(backend_device) = port.backend_device().as_mut() {
163 backend_device.lock().reset().map_err(Error::ResetPort)?;
164 }
165 }
166
167 Ok(())
169 }
170
171 pub fn stop_all_and_reset<C: FnMut() + 'static + Send>(&self, mut callback: C) {
173 info!("xhci: stopping all device slots and resetting host hub");
174 let slots = self.slots.clone();
175 let hub = self.hub.clone();
176 let auto_callback = RingBufferStopCallback::new(fallible_closure(
177 self.fail_handle.clone(),
178 move || -> std::result::Result<(), usb_hub::Error> {
179 for slot in &slots {
180 slot.reset();
181 }
182 hub.reset()?;
183 callback();
184 Ok(())
185 },
186 ));
187 self.stop_all(auto_callback);
188 }
189
190 pub fn stop_all(&self, auto_callback: RingBufferStopCallback) {
193 for slot in &self.slots {
194 slot.stop_all_trc(auto_callback.clone());
195 }
196 }
197
198 pub fn disable_slot<
201 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
202 >(
203 &self,
204 slot_id: u8,
205 cb: C,
206 ) -> Result<()> {
207 xhci_trace!("device slot {} is being disabled", slot_id);
208 DeviceSlot::disable(
209 self.fail_handle.clone(),
210 &self.slots[slot_id as usize - 1],
211 cb,
212 )
213 }
214
215 pub fn reset_slot<
217 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
218 >(
219 &self,
220 slot_id: u8,
221 cb: C,
222 ) -> Result<()> {
223 xhci_trace!("device slot {} is resetting", slot_id);
224 DeviceSlot::reset_slot(
225 self.fail_handle.clone(),
226 &self.slots[slot_id as usize - 1],
227 cb,
228 )
229 }
230
231 pub fn stop_endpoint<
232 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
233 >(
234 &self,
235 slot_id: u8,
236 endpoint_id: u8,
237 cb: C,
238 ) -> Result<()> {
239 self.slots[slot_id as usize - 1].stop_endpoint(self.fail_handle.clone(), endpoint_id, cb)
240 }
241
242 pub fn reset_endpoint<
243 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
244 >(
245 &self,
246 slot_id: u8,
247 endpoint_id: u8,
248 cb: C,
249 ) -> Result<()> {
250 self.slots[slot_id as usize - 1].reset_endpoint(self.fail_handle.clone(), endpoint_id, cb)
251 }
252}
253
254struct PortId(Mutex<u8>);
256
257impl PortId {
258 fn new() -> Self {
259 PortId(Mutex::new(0))
260 }
261
262 fn set(&self, value: u8) -> Result<()> {
263 if !(1..=MAX_PORTS).contains(&value) {
264 return Err(Error::BadPortId(value));
265 }
266 *self.0.lock() = value;
267 Ok(())
268 }
269
270 fn reset(&self) {
271 *self.0.lock() = 0;
272 }
273
274 fn get(&self) -> Result<u8> {
275 let val = *self.0.lock();
276 if val == 0 {
277 return Err(Error::BadPortId(val));
278 }
279 Ok(val)
280 }
281}
282
283pub struct DeviceSlot {
284 slot_id: u8,
285 port_id: PortId, dcbaap: Register<u64>,
287 hub: Arc<UsbHub>,
288 interrupter: Arc<Mutex<Interrupter>>,
289 event_loop: Arc<EventLoop>,
290 mem: GuestMemory,
291 enabled: AtomicBool,
292 transfer_ring_controllers: Mutex<Vec<Option<TransferRingControllers>>>,
293}
294
295impl DeviceSlot {
296 pub fn new(
298 slot_id: u8,
299 dcbaap: Register<u64>,
300 hub: Arc<UsbHub>,
301 interrupter: Arc<Mutex<Interrupter>>,
302 event_loop: Arc<EventLoop>,
303 mem: GuestMemory,
304 ) -> Self {
305 let mut transfer_ring_controllers = Vec::new();
306 transfer_ring_controllers.resize_with(TRANSFER_RING_CONTROLLERS_INDEX_END, || None);
307 DeviceSlot {
308 slot_id,
309 port_id: PortId::new(),
310 dcbaap,
311 hub,
312 interrupter,
313 event_loop,
314 mem,
315 enabled: AtomicBool::new(false),
316 transfer_ring_controllers: Mutex::new(transfer_ring_controllers),
317 }
318 }
319
320 fn get_trc(&self, i: usize, stream_id: u16) -> Option<Arc<TransferRingController>> {
321 let trcs = self.transfer_ring_controllers.lock();
322 match &trcs[i] {
323 Some(TransferRingControllers::Endpoint(trc)) => Some(trc.clone()),
324 Some(TransferRingControllers::Stream(trcs)) => {
325 let stream_id = stream_id as usize;
326 if stream_id > 0 && stream_id <= trcs.len() {
327 Some(trcs[stream_id - 1].clone())
328 } else {
329 None
330 }
331 }
332 None => None,
333 }
334 }
335
336 fn get_trcs(&self, i: usize) -> Option<TransferRingControllers> {
337 let trcs = self.transfer_ring_controllers.lock();
338 trcs[i].clone()
339 }
340
341 fn set_trcs(&self, i: usize, trc: Option<TransferRingControllers>) {
342 let mut trcs = self.transfer_ring_controllers.lock();
343 trcs[i] = trc;
344 }
345
346 fn trc_len(&self) -> usize {
347 self.transfer_ring_controllers.lock().len()
348 }
349
350 pub fn ring_doorbell(&self, target: u8, stream_id: u16) -> Result<bool> {
365 if !valid_endpoint_id(target) {
366 error!(
367 "device slot {}: Invalid target written to doorbell register. target: {}",
368 self.slot_id, target
369 );
370 return Ok(false);
371 }
372 xhci_trace!(
373 "device slot {}: ring_doorbell target = {} stream_id = {}",
374 self.slot_id,
375 target,
376 stream_id
377 );
378 let endpoint_index = (target - 1) as usize;
380 let transfer_ring_controller = match self.get_trc(endpoint_index, stream_id) {
381 Some(tr) => tr,
382 None => {
383 error!("Device endpoint is not inited");
384 return Ok(false);
385 }
386 };
387 let mut context = self.get_device_context()?;
388 let endpoint_state = context.endpoint_context[endpoint_index]
389 .get_endpoint_state()
390 .map_err(Error::GetEndpointState)?;
391 if endpoint_state == EndpointState::Running || endpoint_state == EndpointState::Stopped {
392 if endpoint_state == EndpointState::Stopped {
393 context.endpoint_context[endpoint_index].set_endpoint_state(EndpointState::Running);
394 self.set_device_context(context)?;
395 }
396 transfer_ring_controller.start();
398 } else {
399 error!("doorbell rung when endpoint state is {:?}", endpoint_state);
400 }
401 Ok(true)
402 }
403
404 pub fn enable(&self) -> bool {
408 let was_already_enabled = self.enabled.swap(true, Ordering::SeqCst);
409 !was_already_enabled
410 }
411
412 pub fn disable<C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send>(
415 fail_handle: Arc<dyn FailHandle>,
416 slot: &Arc<DeviceSlot>,
417 mut callback: C,
418 ) -> Result<()> {
419 if slot.enabled.swap(false, Ordering::SeqCst) {
420 let slot_weak = Arc::downgrade(slot);
421 let auto_callback =
422 RingBufferStopCallback::new(fallible_closure(fail_handle, move || {
423 let slot = slot_weak.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
426 let mut device_context = slot.get_device_context()?;
427 device_context
428 .slot_context
429 .set_slot_state(DeviceSlotState::DisabledOrEnabled);
430 slot.set_device_context(device_context)?;
431 slot.reset();
432 debug!(
433 "device slot {}: all trc disabled, sending trb",
434 slot.slot_id
435 );
436 callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
437 }));
438 slot.stop_all_trc(auto_callback);
439 Ok(())
440 } else {
441 callback(TrbCompletionCode::SlotNotEnabledError).map_err(|_| Error::CallbackFailed)
442 }
443 }
444
445 pub fn set_address(
447 self: &Arc<Self>,
448 trb: &AddressDeviceCommandTrb,
449 ) -> Result<TrbCompletionCode> {
450 if !self.enabled.load(Ordering::SeqCst) {
451 error!(
452 "trying to set address to a disabled device slot {}",
453 self.slot_id
454 );
455 return Ok(TrbCompletionCode::SlotNotEnabledError);
456 }
457 let device_context = self.get_device_context()?;
458 let state = device_context
459 .slot_context
460 .get_slot_state()
461 .map_err(Error::GetSlotContextState)?;
462 match state {
463 DeviceSlotState::DisabledOrEnabled => {}
464 DeviceSlotState::Default if !trb.get_block_set_address_request() => {}
465 _ => {
466 error!("slot {} has unexpected slot state", self.slot_id);
467 return Ok(TrbCompletionCode::ContextStateError);
468 }
469 }
470
471 let input_context_ptr = GuestAddress(trb.get_input_context_pointer());
474 self.copy_context(input_context_ptr, 0)?;
476 self.copy_context(input_context_ptr, 1)?;
478
479 let mut device_context = self.get_device_context()?;
481 let port_id = device_context.slot_context.get_root_hub_port_number();
482 self.port_id.set(port_id)?;
483 debug!(
484 "port id {} is assigned to slot id {}",
485 port_id, self.slot_id
486 );
487
488 let trc = TransferRingController::new(
490 self.mem.clone(),
491 self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?,
492 self.event_loop.clone(),
493 self.interrupter.clone(),
494 self.slot_id,
495 1,
496 Arc::downgrade(self),
497 None,
498 )
499 .map_err(Error::CreateTransferController)?;
500 self.set_trcs(0, Some(TransferRingControllers::Endpoint(trc)));
501
502 if trb.get_block_set_address_request() {
504 device_context
505 .slot_context
506 .set_slot_state(DeviceSlotState::Default);
507 } else {
508 let port = self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?;
509 match port.backend_device().as_mut() {
510 Some(backend) => {
511 backend.lock().set_address(self.slot_id as u32);
512 }
513 None => {
514 return Ok(TrbCompletionCode::TransactionError);
515 }
516 }
517
518 device_context
519 .slot_context
520 .set_usb_device_address(self.slot_id);
521 device_context
522 .slot_context
523 .set_slot_state(DeviceSlotState::Addressed);
524 }
525
526 self.get_trc(0, 0)
528 .ok_or(Error::GetTrc(0))?
529 .set_dequeue_pointer(
530 device_context.endpoint_context[0]
531 .get_tr_dequeue_pointer()
532 .get_gpa(),
533 );
534
535 self.get_trc(0, 0)
536 .ok_or(Error::GetTrc(0))?
537 .set_consumer_cycle_state(device_context.endpoint_context[0].get_dequeue_cycle_state());
538
539 device_context.endpoint_context[0].set_endpoint_state(EndpointState::Running);
541 self.set_device_context(device_context)?;
542 Ok(TrbCompletionCode::Success)
543 }
544
545 pub fn configure_endpoint(
547 self: &Arc<Self>,
548 trb: &ConfigureEndpointCommandTrb,
549 ) -> Result<TrbCompletionCode> {
550 let input_control_context = if trb.get_deconfigure() {
551 let mut c = InputControlContext::new();
556 c.set_add_context_flags(0);
557 c.set_drop_context_flags(0xfffffffc);
558 c
559 } else {
560 self.mem
561 .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
562 .map_err(Error::ReadGuestMemory)?
563 };
564
565 for device_context_index in 1..DCI_INDEX_END {
566 if input_control_context.drop_context_flag(device_context_index) {
567 self.drop_one_endpoint(device_context_index)?;
568 }
569 if input_control_context.add_context_flag(device_context_index) {
570 self.copy_context(
571 GuestAddress(trb.get_input_context_pointer()),
572 device_context_index,
573 )?;
574 self.add_one_endpoint(device_context_index)?;
575 }
576 }
577
578 if trb.get_deconfigure() {
579 self.set_state(DeviceSlotState::Addressed)?;
580 } else {
581 self.set_state(DeviceSlotState::Configured)?;
582 }
583 Ok(TrbCompletionCode::Success)
584 }
585
586 pub fn evaluate_context(&self, trb: &EvaluateContextCommandTrb) -> Result<TrbCompletionCode> {
589 if !self.enabled.load(Ordering::SeqCst) {
590 return Ok(TrbCompletionCode::SlotNotEnabledError);
591 }
592 let input_control_context: InputControlContext = self
596 .mem
597 .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
598 .map_err(Error::ReadGuestMemory)?;
599
600 let mut device_context = self.get_device_context()?;
601 if input_control_context.add_context_flag(0) {
602 let input_slot_context: SlotContext = self
603 .mem
604 .read_obj_from_addr(GuestAddress(
605 trb.get_input_context_pointer() + DEVICE_CONTEXT_ENTRY_SIZE as u64,
606 ))
607 .map_err(Error::ReadGuestMemory)?;
608 device_context
609 .slot_context
610 .set_interrupter_target(input_slot_context.get_interrupter_target());
611
612 device_context
613 .slot_context
614 .set_max_exit_latency(input_slot_context.get_max_exit_latency());
615 }
616
617 if input_control_context.add_context_flag(1) {
620 let ep0_context: EndpointContext = self
621 .mem
622 .read_obj_from_addr(GuestAddress(
623 trb.get_input_context_pointer() + 2 * DEVICE_CONTEXT_ENTRY_SIZE as u64,
624 ))
625 .map_err(Error::ReadGuestMemory)?;
626 device_context.endpoint_context[0]
627 .set_max_packet_size(ep0_context.get_max_packet_size());
628 }
629 self.set_device_context(device_context)?;
630 Ok(TrbCompletionCode::Success)
631 }
632
633 pub fn reset_slot<
636 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
637 >(
638 fail_handle: Arc<dyn FailHandle>,
639 slot: &Arc<DeviceSlot>,
640 mut callback: C,
641 ) -> Result<()> {
642 let weak_s = Arc::downgrade(slot);
643 let auto_callback =
644 RingBufferStopCallback::new(fallible_closure(fail_handle, move || -> Result<()> {
645 let s = weak_s.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
646 for i in FIRST_TRANSFER_ENDPOINT_DCI..DCI_INDEX_END {
647 s.drop_one_endpoint(i)?;
648 }
649 let mut ctx = s.get_device_context()?;
650 ctx.slot_context.set_slot_state(DeviceSlotState::Default);
651 ctx.slot_context.set_context_entries(1);
652 ctx.slot_context.set_root_hub_port_number(0);
653 s.set_device_context(ctx)?;
654 callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)?;
655 Ok(())
656 }));
657 slot.stop_all_trc(auto_callback);
658 Ok(())
659 }
660
661 pub fn stop_all_trc(&self, auto_callback: RingBufferStopCallback) {
663 for i in 0..self.trc_len() {
664 if let Some(trcs) = self.get_trcs(i) {
665 match trcs {
666 TransferRingControllers::Endpoint(trc) => {
667 trc.stop(auto_callback.clone());
668 }
669 TransferRingControllers::Stream(trcs) => {
670 for trc in trcs {
671 trc.stop(auto_callback.clone());
672 }
673 }
674 }
675 }
676 }
677 }
678
679 pub fn stop_endpoint<
681 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
682 >(
683 &self,
684 fail_handle: Arc<dyn FailHandle>,
685 endpoint_id: u8,
686 mut cb: C,
687 ) -> Result<()> {
688 if !valid_endpoint_id(endpoint_id) {
689 error!("trb indexing wrong endpoint id");
690 return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
691 }
692 let index = endpoint_id - 1;
693 let mut device_context = self.get_device_context()?;
694 let endpoint_context = &mut device_context.endpoint_context[index as usize];
695 match self.get_trcs(index as usize) {
696 Some(TransferRingControllers::Endpoint(trc)) => {
697 let auto_cb = RingBufferStopCallback::new(fallible_closure(
698 fail_handle,
699 move || -> Result<()> {
700 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
701 },
702 ));
703 trc.stop(auto_cb);
704 let dequeue_pointer = trc.get_dequeue_pointer();
705 let dcs = trc.get_consumer_cycle_state();
706 endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
707 endpoint_context.set_dequeue_cycle_state(dcs);
708 }
709 Some(TransferRingControllers::Stream(trcs)) => {
710 let stream_context_array_addr = endpoint_context.get_tr_dequeue_pointer().get_gpa();
711 let mut stream_context_array: StreamContextArray = self
712 .mem
713 .read_obj_from_addr(stream_context_array_addr)
714 .map_err(Error::ReadGuestMemory)?;
715 let auto_cb = RingBufferStopCallback::new(fallible_closure(
716 fail_handle,
717 move || -> Result<()> {
718 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
719 },
720 ));
721 for (i, trc) in trcs.iter().enumerate() {
722 let dequeue_pointer = trc.get_dequeue_pointer();
723 let dcs = trc.get_consumer_cycle_state();
724 trc.stop(auto_cb.clone());
725 stream_context_array.stream_contexts[i + 1]
726 .set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
727 stream_context_array.stream_contexts[i + 1].set_dequeue_cycle_state(dcs);
728 }
729 self.mem
730 .write_obj_at_addr(stream_context_array, stream_context_array_addr)
731 .map_err(Error::WriteGuestMemory)?;
732 }
733 None => {
734 error!("endpoint at index {} is not started", index);
735 cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed)?;
736 }
737 }
738 endpoint_context.set_endpoint_state(EndpointState::Stopped);
739 self.set_device_context(device_context)?;
740 Ok(())
741 }
742
743 pub fn reset_endpoint<
745 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
746 >(
747 &self,
748 fail_handle: Arc<dyn FailHandle>,
749 endpoint_id: u8,
750 mut cb: C,
751 ) -> Result<()> {
752 if !valid_endpoint_id(endpoint_id) {
753 error!("trb indexing wrong endpoint id");
754 return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
755 }
756 let index = endpoint_id - 1;
757 let mut device_context = self.get_device_context()?;
758 let endpoint_context = &mut device_context.endpoint_context[index as usize];
759 if endpoint_context
760 .get_endpoint_state()
761 .map_err(Error::GetEndpointState)?
762 != EndpointState::Halted
763 {
764 error!("endpoint at index {} is not halted", index);
765 return cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed);
766 }
767 match self.get_trcs(index as usize) {
768 Some(TransferRingControllers::Endpoint(trc)) => {
769 let auto_cb = RingBufferStopCallback::new(fallible_closure(
770 fail_handle,
771 move || -> Result<()> {
772 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
773 },
774 ));
775 trc.stop(auto_cb);
776 let dequeue_pointer = trc.get_dequeue_pointer();
777 let dcs = trc.get_consumer_cycle_state();
778 endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
779 endpoint_context.set_dequeue_cycle_state(dcs);
780 }
781 Some(TransferRingControllers::Stream(trcs)) => {
782 let stream_context_array_addr = endpoint_context.get_tr_dequeue_pointer().get_gpa();
783 let mut stream_context_array: StreamContextArray = self
784 .mem
785 .read_obj_from_addr(stream_context_array_addr)
786 .map_err(Error::ReadGuestMemory)?;
787 let auto_cb = RingBufferStopCallback::new(fallible_closure(
788 fail_handle,
789 move || -> Result<()> {
790 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
791 },
792 ));
793 for (i, trc) in trcs.iter().enumerate() {
794 let dequeue_pointer = trc.get_dequeue_pointer();
795 let dcs = trc.get_consumer_cycle_state();
796 trc.stop(auto_cb.clone());
797 stream_context_array.stream_contexts[i + 1]
798 .set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
799 stream_context_array.stream_contexts[i + 1].set_dequeue_cycle_state(dcs);
800 }
801 self.mem
802 .write_obj_at_addr(stream_context_array, stream_context_array_addr)
803 .map_err(Error::WriteGuestMemory)?;
804 }
805 None => {
806 error!("endpoint at index {} is not started", index);
807 cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed)?;
808 }
809 }
810 endpoint_context.set_endpoint_state(EndpointState::Stopped);
811 self.set_device_context(device_context)?;
812 Ok(())
813 }
814
815 pub fn set_tr_dequeue_ptr(
817 &self,
818 endpoint_id: u8,
819 stream_id: u16,
820 ptr: u64,
821 ) -> Result<TrbCompletionCode> {
822 if !valid_endpoint_id(endpoint_id) {
823 error!("trb indexing wrong endpoint id");
824 return Ok(TrbCompletionCode::TrbError);
825 }
826 let index = (endpoint_id - 1) as usize;
827 match self.get_trc(index, stream_id) {
828 Some(trc) => {
829 trc.set_dequeue_pointer(GuestAddress(ptr));
830 let mut ctx = self.get_device_context()?;
831 ctx.endpoint_context[index]
832 .set_tr_dequeue_pointer(DequeuePtr::new(GuestAddress(ptr)));
833 self.set_device_context(ctx)?;
834 Ok(TrbCompletionCode::Success)
835 }
836 None => {
837 error!("set tr dequeue ptr failed due to no trc started");
838 Ok(TrbCompletionCode::ContextStateError)
839 }
840 }
841 }
842
843 fn reset(&self) {
847 for i in 0..self.trc_len() {
848 self.set_trcs(i, None);
849 }
850 debug!("resetting device slot {}!", self.slot_id);
851 self.enabled.store(false, Ordering::SeqCst);
852 self.port_id.reset();
853 }
854
855 fn create_stream_trcs(
856 self: &Arc<Self>,
857 stream_context_array_addr: GuestAddress,
858 max_pstreams: u8,
859 device_context_index: u8,
860 ) -> Result<TransferRingControllers> {
861 let pstreams = 1usize << (max_pstreams + 1);
862 let stream_context_array: StreamContextArray = self
863 .mem
864 .read_obj_from_addr(stream_context_array_addr)
865 .map_err(Error::ReadGuestMemory)?;
866 let mut trcs = Vec::new();
867
868 for i in 1..pstreams {
870 let stream_context = &stream_context_array.stream_contexts[i];
871 let context_type = stream_context.get_stream_context_type();
872 if context_type != 1 {
873 return Err(Error::BadStreamContextType(context_type));
875 }
876 let trc = TransferRingController::new(
877 self.mem.clone(),
878 self.hub
879 .get_port(self.port_id.get()?)
880 .ok_or(Error::GetPort(self.port_id.get()?))?,
881 self.event_loop.clone(),
882 self.interrupter.clone(),
883 self.slot_id,
884 device_context_index,
885 Arc::downgrade(self),
886 Some(i as u16),
887 )
888 .map_err(Error::CreateTransferController)?;
889 trc.set_dequeue_pointer(stream_context.get_tr_dequeue_pointer().get_gpa());
890 trc.set_consumer_cycle_state(stream_context.get_dequeue_cycle_state());
891 trcs.push(trc);
892 }
893 Ok(TransferRingControllers::Stream(trcs))
894 }
895
896 fn add_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
897 xhci_trace!(
898 "adding one endpoint, device context index {}",
899 device_context_index
900 );
901 let mut device_context = self.get_device_context()?;
902 let transfer_ring_index = (device_context_index - 1) as usize;
903 let endpoint_context = &mut device_context.endpoint_context[transfer_ring_index];
904 let max_pstreams = endpoint_context.get_max_primary_streams();
905 let tr_dequeue_pointer = endpoint_context.get_tr_dequeue_pointer().get_gpa();
906 let endpoint_context_addr = self
907 .get_device_context_addr()?
908 .unchecked_add(size_of::<SlotContext>() as u64)
909 .unchecked_add(size_of::<EndpointContext>() as u64 * transfer_ring_index as u64);
910 let trcs = if max_pstreams > 0 {
911 if !valid_max_pstreams(max_pstreams) {
912 return Err(Error::BadEndpointContext(endpoint_context_addr));
913 }
914 let endpoint_type = endpoint_context.get_endpoint_type();
915 if endpoint_type != 2 && endpoint_type != 6 {
916 return Err(Error::BadEndpointId(transfer_ring_index as u8));
918 }
919 if endpoint_context.get_linear_stream_array() != 1 {
920 return Err(Error::BadEndpointContext(endpoint_context_addr));
922 }
923
924 let trcs =
925 self.create_stream_trcs(tr_dequeue_pointer, max_pstreams, device_context_index)?;
926
927 if let Some(port) = self.hub.get_port(self.port_id.get()?) {
928 if let Some(backend_device) = port.backend_device().as_mut() {
929 let mut endpoint_address = device_context_index / 2;
930 if device_context_index % 2 == 1 {
931 endpoint_address |= 1u8 << 7;
932 }
933 let streams = 1 << (max_pstreams + 1);
934 backend_device
936 .lock()
937 .alloc_streams(endpoint_address, streams - 1)
938 .map_err(Error::AllocStreams)?;
939 }
940 }
941 trcs
942 } else {
943 let trc = TransferRingController::new(
944 self.mem.clone(),
945 self.hub
946 .get_port(self.port_id.get()?)
947 .ok_or(Error::GetPort(self.port_id.get()?))?,
948 self.event_loop.clone(),
949 self.interrupter.clone(),
950 self.slot_id,
951 device_context_index,
952 Arc::downgrade(self),
953 None,
954 )
955 .map_err(Error::CreateTransferController)?;
956 trc.set_dequeue_pointer(tr_dequeue_pointer);
957 trc.set_consumer_cycle_state(endpoint_context.get_dequeue_cycle_state());
958 TransferRingControllers::Endpoint(trc)
959 };
960 self.set_trcs(transfer_ring_index, Some(trcs));
961 endpoint_context.set_endpoint_state(EndpointState::Running);
962 self.set_device_context(device_context)
963 }
964
965 fn drop_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
966 let endpoint_index = (device_context_index - 1) as usize;
967 let mut device_context = self.get_device_context()?;
968 let endpoint_context = &mut device_context.endpoint_context[endpoint_index];
969 if endpoint_context.get_max_primary_streams() > 0 {
970 if let Some(port) = self.hub.get_port(self.port_id.get()?) {
971 if let Some(backend_device) = port.backend_device().as_mut() {
972 let mut endpoint_address = device_context_index / 2;
973 if device_context_index % 2 == 1 {
974 endpoint_address |= 1u8 << 7;
975 }
976 backend_device
977 .lock()
978 .free_streams(endpoint_address)
979 .map_err(Error::FreeStreams)?;
980 }
981 }
982 }
983 self.set_trcs(endpoint_index, None);
984 endpoint_context.set_endpoint_state(EndpointState::Disabled);
985 self.set_device_context(device_context)
986 }
987
988 fn get_device_context(&self) -> Result<DeviceContext> {
989 let ctx = self
990 .mem
991 .read_obj_from_addr(self.get_device_context_addr()?)
992 .map_err(Error::ReadGuestMemory)?;
993 Ok(ctx)
994 }
995
996 fn set_device_context(&self, device_context: DeviceContext) -> Result<()> {
997 self.mem
998 .write_obj_at_addr(device_context, self.get_device_context_addr()?)
999 .map_err(Error::WriteGuestMemory)
1000 }
1001
1002 fn copy_context(
1003 &self,
1004 input_context_ptr: GuestAddress,
1005 device_context_index: u8,
1006 ) -> Result<()> {
1007 let ctx: EndpointContext = self
1010 .mem
1011 .read_obj_from_addr(
1012 input_context_ptr
1013 .checked_add(
1014 (device_context_index as u64 + 1) * DEVICE_CONTEXT_ENTRY_SIZE as u64,
1015 )
1016 .ok_or(Error::BadInputContextAddr(input_context_ptr))?,
1017 )
1018 .map_err(Error::ReadGuestMemory)?;
1019 xhci_trace!("copy_context {:?}", ctx);
1020 let device_context_ptr = self.get_device_context_addr()?;
1021 self.mem
1022 .write_obj_at_addr(
1023 ctx,
1024 device_context_ptr
1025 .checked_add(device_context_index as u64 * DEVICE_CONTEXT_ENTRY_SIZE as u64)
1026 .ok_or(Error::BadDeviceContextAddr(device_context_ptr))?,
1027 )
1028 .map_err(Error::WriteGuestMemory)
1029 }
1030
1031 fn get_device_context_addr(&self) -> Result<GuestAddress> {
1032 let addr: u64 = self
1033 .mem
1034 .read_obj_from_addr(GuestAddress(
1035 self.dcbaap.get_value() + size_of::<u64>() as u64 * self.slot_id as u64,
1036 ))
1037 .map_err(Error::ReadGuestMemory)?;
1038 Ok(GuestAddress(addr))
1039 }
1040
1041 fn set_state(&self, state: DeviceSlotState) -> Result<()> {
1042 let mut ctx = self.get_device_context()?;
1043 ctx.slot_context.set_slot_state(state);
1044 self.set_device_context(ctx)
1045 }
1046
1047 pub fn halt_endpoint(&self, endpoint_id: u8) -> Result<()> {
1048 if !valid_endpoint_id(endpoint_id) {
1049 return Err(Error::BadEndpointId(endpoint_id));
1050 }
1051 let index = endpoint_id - 1;
1052 let mut device_context = self.get_device_context()?;
1053 let endpoint_context = &mut device_context.endpoint_context[index as usize];
1054 match self.get_trcs(index as usize) {
1055 Some(trcs) => match trcs {
1056 TransferRingControllers::Endpoint(trc) => {
1057 endpoint_context
1058 .set_tr_dequeue_pointer(DequeuePtr::new(trc.get_dequeue_pointer()));
1059 endpoint_context.set_dequeue_cycle_state(trc.get_consumer_cycle_state());
1060 }
1061 TransferRingControllers::Stream(trcs) => {
1062 let stream_context_array_addr =
1063 endpoint_context.get_tr_dequeue_pointer().get_gpa();
1064 let mut stream_context_array: StreamContextArray = self
1065 .mem
1066 .read_obj_from_addr(stream_context_array_addr)
1067 .map_err(Error::ReadGuestMemory)?;
1068 for (i, trc) in trcs.iter().enumerate() {
1069 stream_context_array.stream_contexts[i + 1]
1070 .set_tr_dequeue_pointer(DequeuePtr::new(trc.get_dequeue_pointer()));
1071 stream_context_array.stream_contexts[i + 1]
1072 .set_dequeue_cycle_state(trc.get_consumer_cycle_state());
1073 }
1074 self.mem
1075 .write_obj_at_addr(stream_context_array, stream_context_array_addr)
1076 .map_err(Error::WriteGuestMemory)?;
1077 }
1078 },
1079 None => {
1080 error!("trc for endpoint {} not found", endpoint_id);
1081 return Err(Error::BadEndpointId(endpoint_id));
1082 }
1083 }
1084 endpoint_context.set_endpoint_state(EndpointState::Halted);
1085 self.set_device_context(device_context)?;
1086 Ok(())
1087 }
1088}
1089
1090#[cfg(test)]
1091mod tests {
1092 use std::thread;
1093
1094 use base::Event;
1095
1096 use super::*;
1097 use crate::usb::xhci::xhci_controller::XhciFailHandle;
1098 use crate::usb::xhci::XhciRegs;
1099
1100 struct TestDeviceSlots {
1101 pub device_slots: DeviceSlots,
1102 event_loop: Arc<EventLoop>,
1103 join_handle: Option<thread::JoinHandle<()>>,
1104 }
1105
1106 impl TestDeviceSlots {
1107 fn cleanup(&mut self) {
1108 if let Some(join_handle) = self.join_handle.take() {
1109 self.event_loop.stop();
1110 join_handle.join().unwrap();
1111 }
1112 }
1113 }
1114
1115 fn setup_test_device_slots() -> TestDeviceSlots {
1116 let test_reg32 = register!(
1117 name: "test",
1118 ty: u32,
1119 offset: 0x0,
1120 reset_value: 0,
1121 guest_writeable_mask: 0x0,
1122 guest_write_1_to_clear_mask: 0,
1123 );
1124 let test_reg64 = register!(
1125 name: "test",
1126 ty: u64,
1127 offset: 0x0,
1128 reset_value: 0,
1129 guest_writeable_mask: 0x0,
1130 guest_write_1_to_clear_mask: 0,
1131 );
1132 let xhci_regs = XhciRegs {
1133 usbcmd: test_reg32.clone(),
1134 usbsts: test_reg32.clone(),
1135 dnctrl: test_reg32.clone(),
1136 crcr: test_reg64.clone(),
1137 dcbaap: test_reg64.clone(),
1138 config: test_reg64.clone(),
1139 portsc: vec![test_reg32.clone(); 16],
1140 doorbells: Vec::new(),
1141 iman: test_reg32.clone(),
1142 imod: test_reg32.clone(),
1143 erstsz: test_reg32.clone(),
1144 erstba: test_reg64.clone(),
1145 erdp: test_reg64.clone(),
1146 };
1147 let fail_handle: Arc<dyn FailHandle> = Arc::new(XhciFailHandle::new(&xhci_regs));
1148 let mem = GuestMemory::new(&[]).unwrap();
1149 let event = Event::new().unwrap();
1150 let interrupter = Arc::new(Mutex::new(Interrupter::new(mem.clone(), event, &xhci_regs)));
1151 let hub = Arc::new(UsbHub::new(&xhci_regs, interrupter.clone()));
1152 let (event_loop, join_handle) =
1153 EventLoop::start("test".to_string(), Some(fail_handle.clone())).unwrap();
1154 let event_loop = Arc::new(event_loop);
1155
1156 let device_slots = DeviceSlots::new(
1157 fail_handle.clone(),
1158 test_reg64.clone(),
1159 hub,
1160 interrupter,
1161 event_loop.clone(),
1162 mem,
1163 );
1164 TestDeviceSlots {
1165 device_slots,
1166 event_loop,
1167 join_handle: Some(join_handle),
1168 }
1169 }
1170
1171 #[test]
1172 fn valid_slot() {
1173 let mut test_device_slots = setup_test_device_slots();
1174 for i in 1..=MAX_SLOTS {
1175 let slot = test_device_slots.device_slots.slot(i);
1176 assert!(slot.is_some());
1177 }
1178
1179 test_device_slots.cleanup();
1180 }
1181
1182 #[test]
1183 fn invalid_slot() {
1184 let mut test_device_slots = setup_test_device_slots();
1185 let slot = test_device_slots.device_slots.slot(0);
1186 assert!(slot.is_none());
1187 let slot = test_device_slots.device_slots.slot(MAX_SLOTS + 1);
1188 assert!(slot.is_none());
1189
1190 test_device_slots.cleanup();
1191 }
1192
1193 #[test]
1194 fn slot_is_disabled_first() {
1195 let mut test_device_slots = setup_test_device_slots();
1196 for i in 1..=MAX_SLOTS {
1197 let _ = test_device_slots
1198 .device_slots
1199 .disable_slot(i, move |completion_code| {
1200 assert_eq!(completion_code, TrbCompletionCode::SlotNotEnabledError);
1201 Ok(())
1202 });
1203 }
1204
1205 test_device_slots.cleanup();
1206 }
1207
1208 #[test]
1209 fn slot_enable_disable_enable() {
1210 let mut test_device_slots = setup_test_device_slots();
1211 for i in 1..=MAX_SLOTS {
1212 assert!(test_device_slots.device_slots.slot(i).unwrap().enable());
1213 let _ = test_device_slots
1214 .device_slots
1215 .disable_slot(i, move |completion_code| {
1216 assert_eq!(completion_code, TrbCompletionCode::Success);
1217 Ok(())
1218 });
1219 assert!(test_device_slots.device_slots.slot(i).unwrap().enable());
1220 }
1221
1222 test_device_slots.cleanup();
1223 }
1224
1225 #[test]
1226 fn slot_enable_disable_disable() {
1227 let mut test_device_slots = setup_test_device_slots();
1228 for i in 1..=MAX_SLOTS {
1229 assert!(test_device_slots.device_slots.slot(i).unwrap().enable());
1230 let _ = test_device_slots
1231 .device_slots
1232 .disable_slot(i, move |completion_code| {
1233 assert_eq!(completion_code, TrbCompletionCode::Success);
1234 Ok(())
1235 });
1236 let _ = test_device_slots
1237 .device_slots
1238 .disable_slot(i, move |completion_code| {
1239 assert_eq!(completion_code, TrbCompletionCode::SlotNotEnabledError);
1240 Ok(())
1241 });
1242 }
1243
1244 test_device_slots.cleanup();
1245 }
1246
1247 #[test]
1248 fn slot_find_disabled() {
1249 let mut test_device_slots = setup_test_device_slots();
1250 for i in 1..=MAX_SLOTS {
1251 assert!(test_device_slots.device_slots.slot(i).unwrap().enable());
1252 }
1253 let free_slot = 5;
1254 let _ = test_device_slots
1255 .device_slots
1256 .disable_slot(free_slot, move |completion_code| {
1257 assert_eq!(completion_code, TrbCompletionCode::Success);
1258 Ok(())
1259 });
1260 let mut found = false;
1261 for i in 1..=MAX_SLOTS {
1262 if test_device_slots.device_slots.slot(i).unwrap().enable() {
1263 assert_eq!(free_slot, i);
1264 found = true;
1265 break;
1266 }
1267 }
1268 assert!(found);
1269
1270 test_device_slots.cleanup();
1271 }
1272}