1use std::mem::size_of;
6use std::sync::atomic::AtomicBool;
7use std::sync::atomic::Ordering;
8use std::sync::Arc;
9
10use base::debug;
11use base::error;
12use base::info;
13use bit_field::Error as BitFieldError;
14use remain::sorted;
15use sync::Mutex;
16use thiserror::Error;
17use vm_memory::GuestAddress;
18use vm_memory::GuestMemory;
19use vm_memory::GuestMemoryError;
20
21use super::interrupter::Interrupter;
22use super::transfer_ring_controller::TransferRingController;
23use super::transfer_ring_controller::TransferRingControllerError;
24use super::transfer_ring_controller::TransferRingControllers;
25use super::usb_hub;
26use super::usb_hub::UsbHub;
27use super::xhci_abi::AddressDeviceCommandTrb;
28use super::xhci_abi::AddressedTrb;
29use super::xhci_abi::ConfigureEndpointCommandTrb;
30use super::xhci_abi::DequeuePtr;
31use super::xhci_abi::DeviceContext;
32use super::xhci_abi::DeviceSlotState;
33use super::xhci_abi::EndpointContext;
34use super::xhci_abi::EndpointState;
35use super::xhci_abi::EvaluateContextCommandTrb;
36use super::xhci_abi::InputControlContext;
37use super::xhci_abi::SlotContext;
38use super::xhci_abi::StreamContextArray;
39use super::xhci_abi::TrbCompletionCode;
40use super::xhci_abi::DEVICE_CONTEXT_ENTRY_SIZE;
41use super::xhci_backend_device::XhciBackendDevice;
42use super::xhci_regs::valid_max_pstreams;
43use super::xhci_regs::valid_slot_id;
44use super::xhci_regs::MAX_PORTS;
45use super::xhci_regs::MAX_SLOTS;
46use crate::register_space::Register;
47use crate::usb::backend::device::BackendDevice;
48use crate::usb::backend::error::Error as BackendProviderError;
49use crate::usb::xhci::ring_buffer_stop_cb::fallible_closure;
50use crate::usb::xhci::ring_buffer_stop_cb::RingBufferStopCallback;
51use crate::utils::EventLoop;
52use crate::utils::FailHandle;
53
54#[sorted]
55#[derive(Error, Debug)]
56pub enum Error {
57 #[error("failed to allocate streams: {0}")]
58 AllocStreams(BackendProviderError),
59 #[error("bad device context: {0}")]
60 BadDeviceContextAddr(GuestAddress),
61 #[error("bad endpoint context: {0}")]
62 BadEndpointContext(GuestAddress),
63 #[error("device slot get a bad endpoint id: {0}")]
64 BadEndpointId(u8),
65 #[error("bad input context address: {0}")]
66 BadInputContextAddr(GuestAddress),
67 #[error("device slot get a bad port id: {0}")]
68 BadPortId(u8),
69 #[error("bad stream context type: {0}")]
70 BadStreamContextType(u8),
71 #[error("callback failed")]
72 CallbackFailed,
73 #[error("failed to create transfer controller: {0}")]
74 CreateTransferController(TransferRingControllerError),
75 #[error("failed to free streams: {0}")]
76 FreeStreams(BackendProviderError),
77 #[error("failed to get endpoint state: {0}")]
78 GetEndpointState(BitFieldError),
79 #[error("failed to get port: {0}")]
80 GetPort(u8),
81 #[error("failed to get slot context state: {0}")]
82 GetSlotContextState(BitFieldError),
83 #[error("failed to get trc: {0}")]
84 GetTrc(u8),
85 #[error("failed to read guest memory: {0}")]
86 ReadGuestMemory(GuestMemoryError),
87 #[error("failed to reset port: {0}")]
88 ResetPort(BackendProviderError),
89 #[error("failed to upgrade weak reference")]
90 WeakReferenceUpgrade,
91 #[error("failed to write guest memory: {0}")]
92 WriteGuestMemory(GuestMemoryError),
93}
94
95type Result<T> = std::result::Result<T, Error>;
96
97pub const TRANSFER_RING_CONTROLLERS_INDEX_END: usize = 31;
105pub const DCI_INDEX_END: u8 = (TRANSFER_RING_CONTROLLERS_INDEX_END + 1) as u8;
107pub const FIRST_TRANSFER_ENDPOINT_DCI: u8 = 2;
109
110fn valid_endpoint_id(endpoint_id: u8) -> bool {
111 endpoint_id < DCI_INDEX_END && endpoint_id > 0
112}
113
114#[derive(Clone)]
115pub struct DeviceSlots {
116 fail_handle: Arc<dyn FailHandle>,
117 hub: Arc<UsbHub>,
118 slots: Vec<Arc<DeviceSlot>>,
119}
120
121impl DeviceSlots {
122 pub fn new(
123 fail_handle: Arc<dyn FailHandle>,
124 dcbaap: Register<u64>,
125 hub: Arc<UsbHub>,
126 interrupter: Arc<Mutex<Interrupter>>,
127 event_loop: Arc<EventLoop>,
128 mem: GuestMemory,
129 ) -> DeviceSlots {
130 let mut slots = Vec::new();
131 for slot_id in 1..=MAX_SLOTS {
132 slots.push(Arc::new(DeviceSlot::new(
133 slot_id,
134 dcbaap.clone(),
135 hub.clone(),
136 interrupter.clone(),
137 event_loop.clone(),
138 mem.clone(),
139 )));
140 }
141 DeviceSlots {
142 fail_handle,
143 hub,
144 slots,
145 }
146 }
147
148 pub fn slot(&self, slot_id: u8) -> Option<Arc<DeviceSlot>> {
150 if valid_slot_id(slot_id) {
151 Some(self.slots[slot_id as usize - 1].clone())
152 } else {
153 error!(
154 "trying to index a wrong slot id {}, max slot = {}",
155 slot_id, MAX_SLOTS
156 );
157 None
158 }
159 }
160
161 pub fn reset_port(&self, port_id: u8) -> Result<()> {
163 if let Some(port) = self.hub.get_port(port_id) {
164 if let Some(backend_device) = port.backend_device().as_mut() {
165 backend_device.lock().reset().map_err(Error::ResetPort)?;
166 }
167 }
168
169 Ok(())
171 }
172
173 pub fn stop_all_and_reset<C: FnMut() + 'static + Send>(&self, mut callback: C) {
175 info!("xhci: stopping all device slots and resetting host hub");
176 let slots = self.slots.clone();
177 let hub = self.hub.clone();
178 let auto_callback = RingBufferStopCallback::new(fallible_closure(
179 self.fail_handle.clone(),
180 move || -> std::result::Result<(), usb_hub::Error> {
181 for slot in &slots {
182 slot.reset();
183 }
184 hub.reset()?;
185 callback();
186 Ok(())
187 },
188 ));
189 self.stop_all(auto_callback);
190 }
191
192 pub fn stop_all(&self, auto_callback: RingBufferStopCallback) {
195 for slot in &self.slots {
196 slot.stop_all_trc(auto_callback.clone());
197 }
198 }
199
200 pub fn disable_slot<
203 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
204 >(
205 &self,
206 slot_id: u8,
207 cb: C,
208 ) -> Result<()> {
209 xhci_trace!("device slot {} is being disabled", slot_id);
210 DeviceSlot::disable(
211 self.fail_handle.clone(),
212 &self.slots[slot_id as usize - 1],
213 cb,
214 )
215 }
216
217 pub fn reset_slot<
219 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
220 >(
221 &self,
222 slot_id: u8,
223 cb: C,
224 ) -> Result<()> {
225 xhci_trace!("device slot {} is resetting", slot_id);
226 DeviceSlot::reset_slot(
227 self.fail_handle.clone(),
228 &self.slots[slot_id as usize - 1],
229 cb,
230 )
231 }
232
233 pub fn stop_endpoint<
234 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
235 >(
236 &self,
237 slot_id: u8,
238 endpoint_id: u8,
239 cb: C,
240 ) -> Result<()> {
241 self.slots[slot_id as usize - 1].clone().stop_endpoint(
242 self.fail_handle.clone(),
243 endpoint_id,
244 cb,
245 )
246 }
247
248 pub fn reset_endpoint<
249 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
250 >(
251 &self,
252 slot_id: u8,
253 endpoint_id: u8,
254 cb: C,
255 ) -> Result<()> {
256 self.slots[slot_id as usize - 1].clone().reset_endpoint(
257 self.fail_handle.clone(),
258 endpoint_id,
259 cb,
260 )
261 }
262}
263
264struct PortId(Mutex<u8>);
266
267impl PortId {
268 fn new() -> Self {
269 PortId(Mutex::new(0))
270 }
271
272 fn set(&self, value: u8) -> Result<()> {
273 if !(1..=MAX_PORTS).contains(&value) {
274 return Err(Error::BadPortId(value));
275 }
276 *self.0.lock() = value;
277 Ok(())
278 }
279
280 fn reset(&self) {
281 *self.0.lock() = 0;
282 }
283
284 fn get(&self) -> Result<u8> {
285 let val = *self.0.lock();
286 if val == 0 {
287 return Err(Error::BadPortId(val));
288 }
289 Ok(val)
290 }
291}
292
293pub struct DeviceSlot {
294 slot_id: u8,
295 port_id: PortId, dcbaap: Register<u64>,
297 hub: Arc<UsbHub>,
298 interrupter: Arc<Mutex<Interrupter>>,
299 event_loop: Arc<EventLoop>,
300 mem: GuestMemory,
301 enabled: AtomicBool,
302 transfer_ring_controllers: Mutex<Vec<Option<TransferRingControllers>>>,
303}
304
305impl DeviceSlot {
306 pub fn new(
308 slot_id: u8,
309 dcbaap: Register<u64>,
310 hub: Arc<UsbHub>,
311 interrupter: Arc<Mutex<Interrupter>>,
312 event_loop: Arc<EventLoop>,
313 mem: GuestMemory,
314 ) -> Self {
315 let mut transfer_ring_controllers = Vec::new();
316 transfer_ring_controllers.resize_with(TRANSFER_RING_CONTROLLERS_INDEX_END, || None);
317 DeviceSlot {
318 slot_id,
319 port_id: PortId::new(),
320 dcbaap,
321 hub,
322 interrupter,
323 event_loop,
324 mem,
325 enabled: AtomicBool::new(false),
326 transfer_ring_controllers: Mutex::new(transfer_ring_controllers),
327 }
328 }
329
330 fn get_trc(&self, i: usize, stream_id: u16) -> Option<Arc<TransferRingController>> {
331 let trcs = self.transfer_ring_controllers.lock();
332 match &trcs[i] {
333 Some(TransferRingControllers::Endpoint(trc)) => Some(trc.clone()),
334 Some(TransferRingControllers::Stream(trcs)) => {
335 let stream_id = stream_id as usize;
336 if stream_id > 0 && stream_id <= trcs.len() {
337 Some(trcs[stream_id - 1].clone())
338 } else {
339 None
340 }
341 }
342 None => None,
343 }
344 }
345
346 fn get_trcs(&self, i: usize) -> Option<TransferRingControllers> {
347 let trcs = self.transfer_ring_controllers.lock();
348 trcs[i].clone()
349 }
350
351 fn set_trcs(&self, i: usize, trc: Option<TransferRingControllers>) {
352 let mut trcs = self.transfer_ring_controllers.lock();
353 trcs[i] = trc;
354 }
355
356 fn trc_len(&self) -> usize {
357 self.transfer_ring_controllers.lock().len()
358 }
359
360 pub fn ring_doorbell(&self, target: u8, stream_id: u16) -> Result<bool> {
375 if !valid_endpoint_id(target) {
376 error!(
377 "device slot {}: Invalid target written to doorbell register. target: {}",
378 self.slot_id, target
379 );
380 return Ok(false);
381 }
382 xhci_trace!(
383 "device slot {}: ring_doorbell target = {} stream_id = {}",
384 self.slot_id,
385 target,
386 stream_id
387 );
388 let endpoint_index = (target - 1) as usize;
390 let transfer_ring_controller = match self.get_trc(endpoint_index, stream_id) {
391 Some(tr) => tr,
392 None => {
393 error!("Device endpoint is not inited");
394 return Ok(false);
395 }
396 };
397 let mut context = self.get_device_context()?;
398 let endpoint_state = context.endpoint_context[endpoint_index]
399 .get_endpoint_state()
400 .map_err(Error::GetEndpointState)?;
401 if endpoint_state == EndpointState::Running || endpoint_state == EndpointState::Stopped {
402 if endpoint_state == EndpointState::Stopped {
403 context.endpoint_context[endpoint_index].set_endpoint_state(EndpointState::Running);
404 self.set_device_context(context)?;
405 }
406 transfer_ring_controller.start();
408 } else {
409 error!("doorbell rung when endpoint state is {:?}", endpoint_state);
410 }
411 Ok(true)
412 }
413
414 pub fn enable(&self) -> bool {
418 let was_already_enabled = self.enabled.swap(true, Ordering::SeqCst);
419 !was_already_enabled
420 }
421
422 pub fn disable<C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send>(
425 fail_handle: Arc<dyn FailHandle>,
426 slot: &Arc<DeviceSlot>,
427 mut callback: C,
428 ) -> Result<()> {
429 if slot.enabled.swap(false, Ordering::SeqCst) {
430 let slot_weak = Arc::downgrade(slot);
431 let auto_callback =
432 RingBufferStopCallback::new(fallible_closure(fail_handle, move || {
433 let slot = slot_weak.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
436 let mut device_context = slot.get_device_context()?;
437 device_context
438 .slot_context
439 .set_slot_state(DeviceSlotState::DisabledOrEnabled);
440 slot.set_device_context(device_context)?;
441 slot.reset();
442 debug!(
443 "device slot {}: all trc disabled, sending trb",
444 slot.slot_id
445 );
446 callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
447 }));
448 slot.stop_all_trc(auto_callback);
449 Ok(())
450 } else {
451 callback(TrbCompletionCode::SlotNotEnabledError).map_err(|_| Error::CallbackFailed)
452 }
453 }
454
455 pub fn set_address(
457 self: &Arc<Self>,
458 trb: &AddressDeviceCommandTrb,
459 ) -> Result<TrbCompletionCode> {
460 if !self.enabled.load(Ordering::SeqCst) {
461 error!(
462 "trying to set address to a disabled device slot {}",
463 self.slot_id
464 );
465 return Ok(TrbCompletionCode::SlotNotEnabledError);
466 }
467 let device_context = self.get_device_context()?;
468 let state = device_context
469 .slot_context
470 .get_slot_state()
471 .map_err(Error::GetSlotContextState)?;
472 match state {
473 DeviceSlotState::DisabledOrEnabled => {}
474 DeviceSlotState::Default if !trb.get_block_set_address_request() => {}
475 _ => {
476 error!("slot {} has unexpected slot state", self.slot_id);
477 return Ok(TrbCompletionCode::ContextStateError);
478 }
479 }
480
481 let input_context_ptr = GuestAddress(trb.get_input_context_pointer());
484 self.copy_context(input_context_ptr, 0)?;
486 self.copy_context(input_context_ptr, 1)?;
488
489 let mut device_context = self.get_device_context()?;
491 let port_id = device_context.slot_context.get_root_hub_port_number();
492 self.port_id.set(port_id)?;
493 debug!(
494 "port id {} is assigned to slot id {}",
495 port_id, self.slot_id
496 );
497
498 let trc = TransferRingController::new(
500 self.mem.clone(),
501 self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?,
502 self.event_loop.clone(),
503 self.interrupter.clone(),
504 self.slot_id,
505 1,
506 Arc::downgrade(self),
507 None,
508 )
509 .map_err(Error::CreateTransferController)?;
510 self.set_trcs(0, Some(TransferRingControllers::Endpoint(trc)));
511
512 if trb.get_block_set_address_request() {
514 device_context
515 .slot_context
516 .set_slot_state(DeviceSlotState::Default);
517 } else {
518 let port = self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?;
519 match port.backend_device().as_mut() {
520 Some(backend) => {
521 backend.lock().set_address(self.slot_id as u32);
522 }
523 None => {
524 return Ok(TrbCompletionCode::TransactionError);
525 }
526 }
527
528 device_context
529 .slot_context
530 .set_usb_device_address(self.slot_id);
531 device_context
532 .slot_context
533 .set_slot_state(DeviceSlotState::Addressed);
534 }
535
536 self.get_trc(0, 0)
538 .ok_or(Error::GetTrc(0))?
539 .set_dequeue_pointer(
540 device_context.endpoint_context[0]
541 .get_tr_dequeue_pointer()
542 .get_gpa(),
543 );
544
545 self.get_trc(0, 0)
546 .ok_or(Error::GetTrc(0))?
547 .set_consumer_cycle_state(device_context.endpoint_context[0].get_dequeue_cycle_state());
548
549 device_context.endpoint_context[0].set_endpoint_state(EndpointState::Running);
551 self.set_device_context(device_context)?;
552 Ok(TrbCompletionCode::Success)
553 }
554
555 pub fn configure_endpoint(
557 self: &Arc<Self>,
558 trb: &ConfigureEndpointCommandTrb,
559 ) -> Result<TrbCompletionCode> {
560 let input_control_context = if trb.get_deconfigure() {
561 let mut c = InputControlContext::new();
566 c.set_add_context_flags(0);
567 c.set_drop_context_flags(0xfffffffc);
568 c
569 } else {
570 self.mem
571 .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
572 .map_err(Error::ReadGuestMemory)?
573 };
574
575 for device_context_index in 1..DCI_INDEX_END {
576 if input_control_context.drop_context_flag(device_context_index) {
577 self.drop_one_endpoint(device_context_index)?;
578 }
579 if input_control_context.add_context_flag(device_context_index) {
580 self.copy_context(
581 GuestAddress(trb.get_input_context_pointer()),
582 device_context_index,
583 )?;
584 self.add_one_endpoint(device_context_index)?;
585 }
586 }
587
588 if trb.get_deconfigure() {
589 self.set_state(DeviceSlotState::Addressed)?;
590 } else {
591 self.set_state(DeviceSlotState::Configured)?;
592 }
593 Ok(TrbCompletionCode::Success)
594 }
595
596 pub fn evaluate_context(&self, trb: &EvaluateContextCommandTrb) -> Result<TrbCompletionCode> {
599 if !self.enabled.load(Ordering::SeqCst) {
600 return Ok(TrbCompletionCode::SlotNotEnabledError);
601 }
602 let input_control_context: InputControlContext = self
606 .mem
607 .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
608 .map_err(Error::ReadGuestMemory)?;
609
610 let mut device_context = self.get_device_context()?;
611 if input_control_context.add_context_flag(0) {
612 let input_slot_context: SlotContext = self
613 .mem
614 .read_obj_from_addr(GuestAddress(
615 trb.get_input_context_pointer() + DEVICE_CONTEXT_ENTRY_SIZE as u64,
616 ))
617 .map_err(Error::ReadGuestMemory)?;
618 device_context
619 .slot_context
620 .set_interrupter_target(input_slot_context.get_interrupter_target());
621
622 device_context
623 .slot_context
624 .set_max_exit_latency(input_slot_context.get_max_exit_latency());
625 }
626
627 if input_control_context.add_context_flag(1) {
630 let ep0_context: EndpointContext = self
631 .mem
632 .read_obj_from_addr(GuestAddress(
633 trb.get_input_context_pointer() + 2 * DEVICE_CONTEXT_ENTRY_SIZE as u64,
634 ))
635 .map_err(Error::ReadGuestMemory)?;
636 device_context.endpoint_context[0]
637 .set_max_packet_size(ep0_context.get_max_packet_size());
638 }
639 self.set_device_context(device_context)?;
640 Ok(TrbCompletionCode::Success)
641 }
642
643 pub fn reset_slot<
646 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
647 >(
648 fail_handle: Arc<dyn FailHandle>,
649 slot: &Arc<DeviceSlot>,
650 mut callback: C,
651 ) -> Result<()> {
652 let weak_s = Arc::downgrade(slot);
653 let auto_callback =
654 RingBufferStopCallback::new(fallible_closure(fail_handle, move || -> Result<()> {
655 let s = weak_s.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
656 for i in FIRST_TRANSFER_ENDPOINT_DCI..DCI_INDEX_END {
657 s.drop_one_endpoint(i)?;
658 }
659 let mut ctx = s.get_device_context()?;
660 ctx.slot_context.set_slot_state(DeviceSlotState::Default);
661 ctx.slot_context.set_context_entries(1);
662 ctx.slot_context.set_root_hub_port_number(0);
663 s.set_device_context(ctx)?;
664 callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)?;
665 Ok(())
666 }));
667 slot.stop_all_trc(auto_callback);
668 Ok(())
669 }
670
671 pub fn stop_all_trc(&self, auto_callback: RingBufferStopCallback) {
673 for i in 0..self.trc_len() {
674 if let Some(trcs) = self.get_trcs(i) {
675 self.stop_trcs_helper(trcs, auto_callback.clone());
676 }
677 }
678 }
679
680 fn stop_trcs_helper(&self, trcs: TransferRingControllers, stop_cb: RingBufferStopCallback) {
681 match trcs {
682 TransferRingControllers::Endpoint(trc) => {
683 trc.stop(stop_cb);
684 }
685 TransferRingControllers::Stream(trcs) => {
686 for trc in trcs {
687 trc.stop(stop_cb.clone());
688 }
689 }
690 }
691 }
692
693 pub fn stop_endpoint<
695 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
696 >(
697 self: Arc<Self>,
698 fail_handle: Arc<dyn FailHandle>,
699 endpoint_id: u8,
700 mut cb: C,
701 ) -> Result<()> {
702 if !valid_endpoint_id(endpoint_id) {
703 error!("trb indexing wrong endpoint id");
704 return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
705 }
706 let index = endpoint_id - 1;
707 let trcs = match self.get_trcs(index as usize) {
708 Some(trcs) => trcs,
709 None => {
710 error!("endpoint at index {} is not started", index);
711 return cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed);
712 }
713 };
714
715 let slot_weak = Arc::downgrade(&self);
716 let trcs_for_cb = trcs.clone();
717 let stop_cb =
723 RingBufferStopCallback::new(fallible_closure(fail_handle, move || -> Result<()> {
724 let slot = slot_weak.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
725 slot.update_contexts(endpoint_id, &trcs_for_cb, EndpointState::Stopped)?;
726 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
727 }));
728 self.stop_trcs_helper(trcs, stop_cb);
729 Ok(())
730 }
731
732 pub fn reset_endpoint<
734 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
735 >(
736 self: Arc<Self>,
737 fail_handle: Arc<dyn FailHandle>,
738 endpoint_id: u8,
739 mut cb: C,
740 ) -> Result<()> {
741 if !valid_endpoint_id(endpoint_id) {
742 error!("trb indexing wrong endpoint id");
743 return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
744 }
745 let index = endpoint_id - 1;
746 let trcs = match self.get_trcs(index as usize) {
747 Some(trcs) => trcs,
748 None => {
749 error!("endpoint at index {} is not started", index);
750 return cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed);
751 }
752 };
753
754 let mut device_context = self.get_device_context()?;
755 let endpoint_context = &mut device_context.endpoint_context[index as usize];
756 if endpoint_context
757 .get_endpoint_state()
758 .map_err(Error::GetEndpointState)?
759 != EndpointState::Halted
760 {
761 error!("endpoint at index {} is not halted", index);
762 return cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed);
763 }
764
765 let slot_weak = Arc::downgrade(&self);
768 let trcs_for_cb = trcs.clone();
769 let stop_cb =
770 RingBufferStopCallback::new(fallible_closure(fail_handle, move || -> Result<()> {
771 let slot = slot_weak.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
772 slot.update_contexts(endpoint_id, &trcs_for_cb, EndpointState::Stopped)?;
773 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
774 }));
775 self.stop_trcs_helper(trcs, stop_cb);
776 Ok(())
777 }
778
779 pub fn set_tr_dequeue_ptr(
781 &self,
782 endpoint_id: u8,
783 stream_id: u16,
784 ptr: u64,
785 ) -> Result<TrbCompletionCode> {
786 if !valid_endpoint_id(endpoint_id) {
787 error!("trb indexing wrong endpoint id");
788 return Ok(TrbCompletionCode::TrbError);
789 }
790 let index = (endpoint_id - 1) as usize;
791 if index == 0 {
792 if let Some(port) = self.hub.get_port(self.port_id.get()?) {
793 if let Some(backend_device) = port.backend_device().as_mut() {
794 backend_device.lock().reset_control_transfer_state();
795 }
796 }
797 }
798 match self.get_trc(index, stream_id) {
799 Some(trc) => {
800 trc.set_dequeue_pointer(GuestAddress(ptr));
801 let mut ctx = self.get_device_context()?;
802 ctx.endpoint_context[index]
803 .set_tr_dequeue_pointer(DequeuePtr::new(GuestAddress(ptr)));
804 self.set_device_context(ctx)?;
805 Ok(TrbCompletionCode::Success)
806 }
807 None => {
808 error!("set tr dequeue ptr failed due to no trc started");
809 Ok(TrbCompletionCode::ContextStateError)
810 }
811 }
812 }
813
814 fn reset(&self) {
818 for i in 0..self.trc_len() {
819 self.set_trcs(i, None);
820 }
821 debug!("resetting device slot {}!", self.slot_id);
822 self.enabled.store(false, Ordering::SeqCst);
823 self.port_id.reset();
824 }
825
826 fn create_stream_trcs(
827 self: &Arc<Self>,
828 stream_context_array_addr: GuestAddress,
829 max_pstreams: u8,
830 device_context_index: u8,
831 ) -> Result<TransferRingControllers> {
832 let pstreams = 1usize << (max_pstreams + 1);
833 let stream_context_array: StreamContextArray = self
834 .mem
835 .read_obj_from_addr(stream_context_array_addr)
836 .map_err(Error::ReadGuestMemory)?;
837 let mut trcs = Vec::new();
838
839 for i in 1..pstreams {
841 let stream_context = &stream_context_array.stream_contexts[i];
842 let context_type = stream_context.get_stream_context_type();
843 if context_type != 1 {
844 return Err(Error::BadStreamContextType(context_type));
846 }
847 let trc = TransferRingController::new(
848 self.mem.clone(),
849 self.hub
850 .get_port(self.port_id.get()?)
851 .ok_or(Error::GetPort(self.port_id.get()?))?,
852 self.event_loop.clone(),
853 self.interrupter.clone(),
854 self.slot_id,
855 device_context_index,
856 Arc::downgrade(self),
857 Some(i as u16),
858 )
859 .map_err(Error::CreateTransferController)?;
860 trc.set_dequeue_pointer(stream_context.get_tr_dequeue_pointer().get_gpa());
861 trc.set_consumer_cycle_state(stream_context.get_dequeue_cycle_state());
862 trcs.push(trc);
863 }
864 Ok(TransferRingControllers::Stream(trcs))
865 }
866
867 fn add_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
868 xhci_trace!(
869 "adding one endpoint, device context index {}",
870 device_context_index
871 );
872 let mut device_context = self.get_device_context()?;
873 let transfer_ring_index = (device_context_index - 1) as usize;
874 let endpoint_context = &mut device_context.endpoint_context[transfer_ring_index];
875 let max_pstreams = endpoint_context.get_max_primary_streams();
876 let tr_dequeue_pointer = endpoint_context.get_tr_dequeue_pointer().get_gpa();
877 let endpoint_context_addr = self
878 .get_device_context_addr()?
879 .unchecked_add(size_of::<SlotContext>() as u64)
880 .unchecked_add(size_of::<EndpointContext>() as u64 * transfer_ring_index as u64);
881 let trcs = if max_pstreams > 0 {
882 if !valid_max_pstreams(max_pstreams) {
883 return Err(Error::BadEndpointContext(endpoint_context_addr));
884 }
885 let endpoint_type = endpoint_context.get_endpoint_type();
886 if endpoint_type != 2 && endpoint_type != 6 {
887 return Err(Error::BadEndpointId(transfer_ring_index as u8));
889 }
890 if endpoint_context.get_linear_stream_array() != 1 {
891 return Err(Error::BadEndpointContext(endpoint_context_addr));
893 }
894
895 let trcs =
896 self.create_stream_trcs(tr_dequeue_pointer, max_pstreams, device_context_index)?;
897
898 if let Some(port) = self.hub.get_port(self.port_id.get()?) {
899 if let Some(backend_device) = port.backend_device().as_mut() {
900 let mut endpoint_address = device_context_index / 2;
901 if device_context_index % 2 == 1 {
902 endpoint_address |= 1u8 << 7;
903 }
904 let streams = 1 << (max_pstreams + 1);
905 backend_device
907 .lock()
908 .alloc_streams(endpoint_address, streams - 1)
909 .map_err(Error::AllocStreams)?;
910 }
911 }
912 trcs
913 } else {
914 let trc = TransferRingController::new(
915 self.mem.clone(),
916 self.hub
917 .get_port(self.port_id.get()?)
918 .ok_or(Error::GetPort(self.port_id.get()?))?,
919 self.event_loop.clone(),
920 self.interrupter.clone(),
921 self.slot_id,
922 device_context_index,
923 Arc::downgrade(self),
924 None,
925 )
926 .map_err(Error::CreateTransferController)?;
927 trc.set_dequeue_pointer(tr_dequeue_pointer);
928 trc.set_consumer_cycle_state(endpoint_context.get_dequeue_cycle_state());
929 TransferRingControllers::Endpoint(trc)
930 };
931 self.set_trcs(transfer_ring_index, Some(trcs));
932 endpoint_context.set_endpoint_state(EndpointState::Running);
933 self.set_device_context(device_context)
934 }
935
936 fn drop_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
937 let endpoint_index = (device_context_index - 1) as usize;
938 let mut device_context = self.get_device_context()?;
939 let endpoint_context = &mut device_context.endpoint_context[endpoint_index];
940 if endpoint_context.get_max_primary_streams() > 0 {
941 if let Some(port) = self.hub.get_port(self.port_id.get()?) {
942 if let Some(backend_device) = port.backend_device().as_mut() {
943 let mut endpoint_address = device_context_index / 2;
944 if device_context_index % 2 == 1 {
945 endpoint_address |= 1u8 << 7;
946 }
947 backend_device
948 .lock()
949 .free_streams(endpoint_address)
950 .map_err(Error::FreeStreams)?;
951 }
952 }
953 }
954 self.set_trcs(endpoint_index, None);
955 endpoint_context.set_endpoint_state(EndpointState::Disabled);
956 self.set_device_context(device_context)
957 }
958
959 fn get_device_context(&self) -> Result<DeviceContext> {
960 let ctx = self
961 .mem
962 .read_obj_from_addr(self.get_device_context_addr()?)
963 .map_err(Error::ReadGuestMemory)?;
964 Ok(ctx)
965 }
966
967 fn set_device_context(&self, device_context: DeviceContext) -> Result<()> {
968 self.mem
969 .write_obj_at_addr(device_context, self.get_device_context_addr()?)
970 .map_err(Error::WriteGuestMemory)
971 }
972
973 fn copy_context(
974 &self,
975 input_context_ptr: GuestAddress,
976 device_context_index: u8,
977 ) -> Result<()> {
978 let ctx: EndpointContext = self
981 .mem
982 .read_obj_from_addr(
983 input_context_ptr
984 .checked_add(
985 (device_context_index as u64 + 1) * DEVICE_CONTEXT_ENTRY_SIZE as u64,
986 )
987 .ok_or(Error::BadInputContextAddr(input_context_ptr))?,
988 )
989 .map_err(Error::ReadGuestMemory)?;
990 xhci_trace!("copy_context {:?}", ctx);
991 let device_context_ptr = self.get_device_context_addr()?;
992 self.mem
993 .write_obj_at_addr(
994 ctx,
995 device_context_ptr
996 .checked_add(device_context_index as u64 * DEVICE_CONTEXT_ENTRY_SIZE as u64)
997 .ok_or(Error::BadDeviceContextAddr(device_context_ptr))?,
998 )
999 .map_err(Error::WriteGuestMemory)
1000 }
1001
1002 fn get_device_context_addr(&self) -> Result<GuestAddress> {
1003 let addr: u64 = self
1004 .mem
1005 .read_obj_from_addr(GuestAddress(
1006 self.dcbaap.get_value() + size_of::<u64>() as u64 * self.slot_id as u64,
1007 ))
1008 .map_err(Error::ReadGuestMemory)?;
1009 Ok(GuestAddress(addr))
1010 }
1011
1012 fn set_state(&self, state: DeviceSlotState) -> Result<()> {
1013 let mut ctx = self.get_device_context()?;
1014 ctx.slot_context.set_slot_state(state);
1015 self.set_device_context(ctx)
1016 }
1017
1018 pub fn halt_endpoint(self: Arc<Self>, endpoint_id: u8) -> Result<()> {
1025 if !valid_endpoint_id(endpoint_id) {
1026 return Err(Error::BadEndpointId(endpoint_id));
1027 }
1028 let index = endpoint_id - 1;
1029 let trcs = match self.get_trcs(index as usize) {
1030 Some(trcs) => trcs,
1031 None => {
1032 error!("trc for endpoint {} not found", endpoint_id);
1033 return Err(Error::BadEndpointId(endpoint_id));
1034 }
1035 };
1036
1037 self.update_contexts(endpoint_id, &trcs, EndpointState::Halted)
1038 }
1039
1040 fn update_contexts(
1041 &self,
1042 endpoint_id: u8,
1043 trcs: &TransferRingControllers,
1044 new_state: EndpointState,
1045 ) -> Result<()> {
1046 let index = (endpoint_id - 1) as usize;
1047 let mut device_context = self.get_device_context()?;
1048 let endpoint_context = &mut device_context.endpoint_context[index];
1049 match &trcs {
1050 TransferRingControllers::Endpoint(trc) => {
1051 let (dequeue_pointer, dcs) = trc.get_stopped_dequeue_state();
1052 endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
1053 endpoint_context.set_dequeue_cycle_state(dcs);
1054 }
1055 TransferRingControllers::Stream(trcs) => {
1056 let addr = endpoint_context.get_tr_dequeue_pointer().get_gpa();
1057 let mut array: StreamContextArray = self
1058 .mem
1059 .read_obj_from_addr(addr)
1060 .map_err(Error::ReadGuestMemory)?;
1061 for (i, trc) in trcs.iter().enumerate() {
1062 let (dequeue_pointer, dcs) = trc.get_stopped_dequeue_state();
1063 array.stream_contexts[i + 1]
1064 .set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
1065 array.stream_contexts[i + 1].set_dequeue_cycle_state(dcs);
1066 }
1067 self.mem
1068 .write_obj_at_addr(array, addr)
1069 .map_err(Error::WriteGuestMemory)?;
1070 }
1071 }
1072 endpoint_context.set_endpoint_state(new_state);
1073 self.set_device_context(device_context)
1074 }
1075
1076 pub fn report_trb_completion(
1077 &self,
1078 endpoint_id: u8,
1079 stream_id: Option<u16>,
1080 trb: &AddressedTrb,
1081 ) {
1082 if !valid_endpoint_id(endpoint_id) {
1086 error!(
1087 "Ignoring TRB completion for invalid endpoint ID {}",
1088 endpoint_id
1089 );
1090 return;
1091 }
1092
1093 let index = endpoint_id - 1;
1094 let stream_id = stream_id.unwrap_or(0);
1095 match self.get_trc(index as usize, stream_id) {
1096 Some(trc) => {
1097 trc.report_completed_trb(trb);
1098 }
1099 None => {
1100 error!(
1101 "No transfer ring controller for endpoint {} stream {}",
1102 endpoint_id, stream_id
1103 );
1104 }
1105 }
1106 }
1107
1108 pub fn get_max_esit_payload(&self, endpoint_id: u8) -> Result<u32> {
1109 if !valid_endpoint_id(endpoint_id) {
1110 return Err(Error::BadEndpointId(endpoint_id));
1111 }
1112
1113 let index = endpoint_id - 1;
1114 let device_context = self.get_device_context()?;
1115 let endpoint_context = device_context.endpoint_context[index as usize];
1116
1117 let lo = endpoint_context.get_max_esit_payload_lo() as u32;
1118 let hi = endpoint_context.get_max_esit_payload_hi() as u32;
1119 Ok(hi << 16 | lo)
1120 }
1121}
1122
1123#[cfg(test)]
1124mod tests {
1125 use std::thread;
1126
1127 use base::Event;
1128
1129 use super::*;
1130 use crate::usb::xhci::xhci_controller::XhciFailHandle;
1131 use crate::usb::xhci::XhciRegs;
1132
1133 struct TestDeviceSlots {
1134 pub device_slots: DeviceSlots,
1135 event_loop: Arc<EventLoop>,
1136 join_handle: Option<thread::JoinHandle<()>>,
1137 }
1138
1139 impl TestDeviceSlots {
1140 fn cleanup(&mut self) {
1141 if let Some(join_handle) = self.join_handle.take() {
1142 self.event_loop.stop();
1143 join_handle.join().unwrap();
1144 }
1145 }
1146 }
1147
1148 fn setup_test_device_slots() -> TestDeviceSlots {
1149 let test_reg32 = register!(
1150 name: "test",
1151 ty: u32,
1152 offset: 0x0,
1153 reset_value: 0,
1154 guest_writeable_mask: 0x0,
1155 guest_write_1_to_clear_mask: 0,
1156 );
1157 let test_reg64 = register!(
1158 name: "test",
1159 ty: u64,
1160 offset: 0x0,
1161 reset_value: 0,
1162 guest_writeable_mask: 0x0,
1163 guest_write_1_to_clear_mask: 0,
1164 );
1165 let xhci_regs = XhciRegs {
1166 usbcmd: test_reg32.clone(),
1167 usbsts: test_reg32.clone(),
1168 dnctrl: test_reg32.clone(),
1169 crcr: test_reg64.clone(),
1170 dcbaap: test_reg64.clone(),
1171 config: test_reg64.clone(),
1172 portsc: vec![test_reg32.clone(); 16],
1173 doorbells: Vec::new(),
1174 iman: test_reg32.clone(),
1175 imod: test_reg32.clone(),
1176 erstsz: test_reg32.clone(),
1177 erstba: test_reg64.clone(),
1178 erdp: test_reg64.clone(),
1179 };
1180 let fail_handle: Arc<dyn FailHandle> = Arc::new(XhciFailHandle::new(&xhci_regs));
1181 let mem = GuestMemory::new(&[]).unwrap();
1182 let event = Event::new().unwrap();
1183 let interrupter = Arc::new(Mutex::new(Interrupter::new(mem.clone(), event, &xhci_regs)));
1184 let hub = Arc::new(UsbHub::new(&xhci_regs, interrupter.clone()));
1185 let (event_loop, join_handle) =
1186 EventLoop::start("test".to_string(), Some(fail_handle.clone())).unwrap();
1187 let event_loop = Arc::new(event_loop);
1188
1189 let device_slots = DeviceSlots::new(
1190 fail_handle.clone(),
1191 test_reg64.clone(),
1192 hub,
1193 interrupter,
1194 event_loop.clone(),
1195 mem,
1196 );
1197 TestDeviceSlots {
1198 device_slots,
1199 event_loop,
1200 join_handle: Some(join_handle),
1201 }
1202 }
1203
1204 #[test]
1205 fn valid_slot() {
1206 let mut test_device_slots = setup_test_device_slots();
1207 for i in 1..=MAX_SLOTS {
1208 let slot = test_device_slots.device_slots.slot(i);
1209 assert!(slot.is_some());
1210 }
1211
1212 test_device_slots.cleanup();
1213 }
1214
1215 #[test]
1216 fn invalid_slot() {
1217 let mut test_device_slots = setup_test_device_slots();
1218 let slot = test_device_slots.device_slots.slot(0);
1219 assert!(slot.is_none());
1220 let slot = test_device_slots.device_slots.slot(MAX_SLOTS + 1);
1221 assert!(slot.is_none());
1222
1223 test_device_slots.cleanup();
1224 }
1225
1226 #[test]
1227 fn slot_is_disabled_first() {
1228 let mut test_device_slots = setup_test_device_slots();
1229 for i in 1..=MAX_SLOTS {
1230 let _ = test_device_slots
1231 .device_slots
1232 .disable_slot(i, move |completion_code| {
1233 assert_eq!(completion_code, TrbCompletionCode::SlotNotEnabledError);
1234 Ok(())
1235 });
1236 }
1237
1238 test_device_slots.cleanup();
1239 }
1240
1241 #[test]
1242 fn slot_enable_disable_enable() {
1243 let mut test_device_slots = setup_test_device_slots();
1244 for i in 1..=MAX_SLOTS {
1245 assert!(test_device_slots.device_slots.slot(i).unwrap().enable());
1246 let _ = test_device_slots
1247 .device_slots
1248 .disable_slot(i, move |completion_code| {
1249 assert_eq!(completion_code, TrbCompletionCode::Success);
1250 Ok(())
1251 });
1252 assert!(test_device_slots.device_slots.slot(i).unwrap().enable());
1253 }
1254
1255 test_device_slots.cleanup();
1256 }
1257
1258 #[test]
1259 fn slot_enable_disable_disable() {
1260 let mut test_device_slots = setup_test_device_slots();
1261 for i in 1..=MAX_SLOTS {
1262 assert!(test_device_slots.device_slots.slot(i).unwrap().enable());
1263 let _ = test_device_slots
1264 .device_slots
1265 .disable_slot(i, move |completion_code| {
1266 assert_eq!(completion_code, TrbCompletionCode::Success);
1267 Ok(())
1268 });
1269 let _ = test_device_slots
1270 .device_slots
1271 .disable_slot(i, move |completion_code| {
1272 assert_eq!(completion_code, TrbCompletionCode::SlotNotEnabledError);
1273 Ok(())
1274 });
1275 }
1276
1277 test_device_slots.cleanup();
1278 }
1279
1280 #[test]
1281 fn slot_find_disabled() {
1282 let mut test_device_slots = setup_test_device_slots();
1283 for i in 1..=MAX_SLOTS {
1284 assert!(test_device_slots.device_slots.slot(i).unwrap().enable());
1285 }
1286 let free_slot = 5;
1287 let _ = test_device_slots
1288 .device_slots
1289 .disable_slot(free_slot, move |completion_code| {
1290 assert_eq!(completion_code, TrbCompletionCode::Success);
1291 Ok(())
1292 });
1293 let mut found = false;
1294 for i in 1..=MAX_SLOTS {
1295 if test_device_slots.device_slots.slot(i).unwrap().enable() {
1296 assert_eq!(free_slot, i);
1297 found = true;
1298 break;
1299 }
1300 }
1301 assert!(found);
1302
1303 test_device_slots.cleanup();
1304 }
1305}