1use std::mem::size_of;
6use std::sync::atomic::AtomicBool;
7use std::sync::atomic::Ordering;
8use std::sync::Arc;
9
10use base::debug;
11use base::error;
12use base::info;
13use bit_field::Error as BitFieldError;
14use remain::sorted;
15use sync::Mutex;
16use thiserror::Error;
17use vm_memory::GuestAddress;
18use vm_memory::GuestMemory;
19use vm_memory::GuestMemoryError;
20
21use super::interrupter::Error as InterrupterError;
22use super::interrupter::Interrupter;
23use super::transfer_ring_controller::TransferRingController;
24use super::transfer_ring_controller::TransferRingControllerError;
25use super::transfer_ring_controller::TransferRingControllers;
26use super::usb_hub;
27use super::usb_hub::UsbHub;
28use super::xhci_abi::AddressDeviceCommandTrb;
29use super::xhci_abi::ConfigureEndpointCommandTrb;
30use super::xhci_abi::DequeuePtr;
31use super::xhci_abi::DeviceContext;
32use super::xhci_abi::DeviceSlotState;
33use super::xhci_abi::EndpointContext;
34use super::xhci_abi::EndpointState;
35use super::xhci_abi::EvaluateContextCommandTrb;
36use super::xhci_abi::InputControlContext;
37use super::xhci_abi::SlotContext;
38use super::xhci_abi::StreamContextArray;
39use super::xhci_abi::TrbCompletionCode;
40use super::xhci_abi::DEVICE_CONTEXT_ENTRY_SIZE;
41use super::xhci_backend_device::XhciBackendDevice;
42use super::xhci_regs::valid_max_pstreams;
43use super::xhci_regs::valid_slot_id;
44use super::xhci_regs::MAX_PORTS;
45use super::xhci_regs::MAX_SLOTS;
46use crate::register_space::Register;
47use crate::usb::backend::error::Error as BackendProviderError;
48use crate::usb::xhci::ring_buffer_stop_cb::fallible_closure;
49use crate::usb::xhci::ring_buffer_stop_cb::RingBufferStopCallback;
50use crate::utils::EventLoop;
51use crate::utils::FailHandle;
52
53#[sorted]
54#[derive(Error, Debug)]
55pub enum Error {
56 #[error("failed to allocate streams: {0}")]
57 AllocStreams(BackendProviderError),
58 #[error("bad device context: {0}")]
59 BadDeviceContextAddr(GuestAddress),
60 #[error("bad endpoint context: {0}")]
61 BadEndpointContext(GuestAddress),
62 #[error("device slot get a bad endpoint id: {0}")]
63 BadEndpointId(u8),
64 #[error("bad input context address: {0}")]
65 BadInputContextAddr(GuestAddress),
66 #[error("device slot get a bad port id: {0}")]
67 BadPortId(u8),
68 #[error("bad stream context type: {0}")]
69 BadStreamContextType(u8),
70 #[error("callback failed")]
71 CallbackFailed,
72 #[error("failed to create transfer controller: {0}")]
73 CreateTransferController(TransferRingControllerError),
74 #[error("failed to send Force Stopped Event: {0}")]
75 ForceStoppedEvent(InterrupterError),
76 #[error("failed to free streams: {0}")]
77 FreeStreams(BackendProviderError),
78 #[error("failed to get endpoint state: {0}")]
79 GetEndpointState(BitFieldError),
80 #[error("failed to get port: {0}")]
81 GetPort(u8),
82 #[error("failed to get slot context state: {0}")]
83 GetSlotContextState(BitFieldError),
84 #[error("failed to get trc: {0}")]
85 GetTrc(u8),
86 #[error("failed to read guest memory: {0}")]
87 ReadGuestMemory(GuestMemoryError),
88 #[error("failed to reset port: {0}")]
89 ResetPort(BackendProviderError),
90 #[error("failed to upgrade weak reference")]
91 WeakReferenceUpgrade,
92 #[error("failed to write guest memory: {0}")]
93 WriteGuestMemory(GuestMemoryError),
94}
95
96type Result<T> = std::result::Result<T, Error>;
97
98pub const TRANSFER_RING_CONTROLLERS_INDEX_END: usize = 31;
106pub const DCI_INDEX_END: u8 = (TRANSFER_RING_CONTROLLERS_INDEX_END + 1) as u8;
108pub const FIRST_TRANSFER_ENDPOINT_DCI: u8 = 2;
110
111fn valid_endpoint_id(endpoint_id: u8) -> bool {
112 endpoint_id < DCI_INDEX_END && endpoint_id > 0
113}
114
115#[derive(Clone)]
116pub struct DeviceSlots {
117 fail_handle: Arc<dyn FailHandle>,
118 hub: Arc<UsbHub>,
119 slots: Vec<Arc<DeviceSlot>>,
120}
121
122impl DeviceSlots {
123 pub fn new(
124 fail_handle: Arc<dyn FailHandle>,
125 dcbaap: Register<u64>,
126 hub: Arc<UsbHub>,
127 interrupter: Arc<Mutex<Interrupter>>,
128 event_loop: Arc<EventLoop>,
129 mem: GuestMemory,
130 ) -> DeviceSlots {
131 let mut slots = Vec::new();
132 for slot_id in 1..=MAX_SLOTS {
133 slots.push(Arc::new(DeviceSlot::new(
134 slot_id,
135 dcbaap.clone(),
136 hub.clone(),
137 interrupter.clone(),
138 event_loop.clone(),
139 mem.clone(),
140 )));
141 }
142 DeviceSlots {
143 fail_handle,
144 hub,
145 slots,
146 }
147 }
148
149 pub fn slot(&self, slot_id: u8) -> Option<Arc<DeviceSlot>> {
151 if valid_slot_id(slot_id) {
152 Some(self.slots[slot_id as usize - 1].clone())
153 } else {
154 error!(
155 "trying to index a wrong slot id {}, max slot = {}",
156 slot_id, MAX_SLOTS
157 );
158 None
159 }
160 }
161
162 pub fn reset_port(&self, port_id: u8) -> Result<()> {
164 if let Some(port) = self.hub.get_port(port_id) {
165 if let Some(backend_device) = port.backend_device().as_mut() {
166 backend_device.lock().reset().map_err(Error::ResetPort)?;
167 }
168 }
169
170 Ok(())
172 }
173
174 pub fn stop_all_and_reset<C: FnMut() + 'static + Send>(&self, mut callback: C) {
176 info!("xhci: stopping all device slots and resetting host hub");
177 let slots = self.slots.clone();
178 let hub = self.hub.clone();
179 let auto_callback = RingBufferStopCallback::new(fallible_closure(
180 self.fail_handle.clone(),
181 move || -> std::result::Result<(), usb_hub::Error> {
182 for slot in &slots {
183 slot.reset();
184 }
185 hub.reset()?;
186 callback();
187 Ok(())
188 },
189 ));
190 self.stop_all(auto_callback);
191 }
192
193 pub fn stop_all(&self, auto_callback: RingBufferStopCallback) {
196 for slot in &self.slots {
197 slot.stop_all_trc(auto_callback.clone());
198 }
199 }
200
201 pub fn disable_slot<
204 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
205 >(
206 &self,
207 slot_id: u8,
208 cb: C,
209 ) -> Result<()> {
210 xhci_trace!("device slot {} is being disabled", slot_id);
211 DeviceSlot::disable(
212 self.fail_handle.clone(),
213 &self.slots[slot_id as usize - 1],
214 cb,
215 )
216 }
217
218 pub fn reset_slot<
220 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
221 >(
222 &self,
223 slot_id: u8,
224 cb: C,
225 ) -> Result<()> {
226 xhci_trace!("device slot {} is resetting", slot_id);
227 DeviceSlot::reset_slot(
228 self.fail_handle.clone(),
229 &self.slots[slot_id as usize - 1],
230 cb,
231 )
232 }
233
234 pub fn stop_endpoint<
235 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
236 >(
237 &self,
238 slot_id: u8,
239 endpoint_id: u8,
240 cb: C,
241 ) -> Result<()> {
242 self.slots[slot_id as usize - 1].stop_endpoint(self.fail_handle.clone(), endpoint_id, cb)
243 }
244
245 pub fn reset_endpoint<
246 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
247 >(
248 &self,
249 slot_id: u8,
250 endpoint_id: u8,
251 cb: C,
252 ) -> Result<()> {
253 self.slots[slot_id as usize - 1].reset_endpoint(self.fail_handle.clone(), endpoint_id, cb)
254 }
255}
256
257struct PortId(Mutex<u8>);
259
260impl PortId {
261 fn new() -> Self {
262 PortId(Mutex::new(0))
263 }
264
265 fn set(&self, value: u8) -> Result<()> {
266 if !(1..=MAX_PORTS).contains(&value) {
267 return Err(Error::BadPortId(value));
268 }
269 *self.0.lock() = value;
270 Ok(())
271 }
272
273 fn reset(&self) {
274 *self.0.lock() = 0;
275 }
276
277 fn get(&self) -> Result<u8> {
278 let val = *self.0.lock();
279 if val == 0 {
280 return Err(Error::BadPortId(val));
281 }
282 Ok(val)
283 }
284}
285
286pub struct DeviceSlot {
287 slot_id: u8,
288 port_id: PortId, dcbaap: Register<u64>,
290 hub: Arc<UsbHub>,
291 interrupter: Arc<Mutex<Interrupter>>,
292 event_loop: Arc<EventLoop>,
293 mem: GuestMemory,
294 enabled: AtomicBool,
295 transfer_ring_controllers: Mutex<Vec<Option<TransferRingControllers>>>,
296}
297
298impl DeviceSlot {
299 pub fn new(
301 slot_id: u8,
302 dcbaap: Register<u64>,
303 hub: Arc<UsbHub>,
304 interrupter: Arc<Mutex<Interrupter>>,
305 event_loop: Arc<EventLoop>,
306 mem: GuestMemory,
307 ) -> Self {
308 let mut transfer_ring_controllers = Vec::new();
309 transfer_ring_controllers.resize_with(TRANSFER_RING_CONTROLLERS_INDEX_END, || None);
310 DeviceSlot {
311 slot_id,
312 port_id: PortId::new(),
313 dcbaap,
314 hub,
315 interrupter,
316 event_loop,
317 mem,
318 enabled: AtomicBool::new(false),
319 transfer_ring_controllers: Mutex::new(transfer_ring_controllers),
320 }
321 }
322
323 fn get_trc(&self, i: usize, stream_id: u16) -> Option<Arc<TransferRingController>> {
324 let trcs = self.transfer_ring_controllers.lock();
325 match &trcs[i] {
326 Some(TransferRingControllers::Endpoint(trc)) => Some(trc.clone()),
327 Some(TransferRingControllers::Stream(trcs)) => {
328 let stream_id = stream_id as usize;
329 if stream_id > 0 && stream_id <= trcs.len() {
330 Some(trcs[stream_id - 1].clone())
331 } else {
332 None
333 }
334 }
335 None => None,
336 }
337 }
338
339 fn get_trcs(&self, i: usize) -> Option<TransferRingControllers> {
340 let trcs = self.transfer_ring_controllers.lock();
341 trcs[i].clone()
342 }
343
344 fn set_trcs(&self, i: usize, trc: Option<TransferRingControllers>) {
345 let mut trcs = self.transfer_ring_controllers.lock();
346 trcs[i] = trc;
347 }
348
349 fn trc_len(&self) -> usize {
350 self.transfer_ring_controllers.lock().len()
351 }
352
353 pub fn ring_doorbell(&self, target: u8, stream_id: u16) -> Result<bool> {
368 if !valid_endpoint_id(target) {
369 error!(
370 "device slot {}: Invalid target written to doorbell register. target: {}",
371 self.slot_id, target
372 );
373 return Ok(false);
374 }
375 xhci_trace!(
376 "device slot {}: ring_doorbell target = {} stream_id = {}",
377 self.slot_id,
378 target,
379 stream_id
380 );
381 let endpoint_index = (target - 1) as usize;
383 let transfer_ring_controller = match self.get_trc(endpoint_index, stream_id) {
384 Some(tr) => tr,
385 None => {
386 error!("Device endpoint is not inited");
387 return Ok(false);
388 }
389 };
390 let mut context = self.get_device_context()?;
391 let endpoint_state = context.endpoint_context[endpoint_index]
392 .get_endpoint_state()
393 .map_err(Error::GetEndpointState)?;
394 if endpoint_state == EndpointState::Running || endpoint_state == EndpointState::Stopped {
395 if endpoint_state == EndpointState::Stopped {
396 context.endpoint_context[endpoint_index].set_endpoint_state(EndpointState::Running);
397 self.set_device_context(context)?;
398 }
399 transfer_ring_controller.start();
401 } else {
402 error!("doorbell rung when endpoint state is {:?}", endpoint_state);
403 }
404 Ok(true)
405 }
406
407 pub fn enable(&self) -> bool {
409 let was_already_enabled = self.enabled.swap(true, Ordering::SeqCst);
410 if was_already_enabled {
411 error!("device slot is already enabled");
412 }
413 !was_already_enabled
414 }
415
416 pub fn disable<C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send>(
419 fail_handle: Arc<dyn FailHandle>,
420 slot: &Arc<DeviceSlot>,
421 mut callback: C,
422 ) -> Result<()> {
423 if slot.enabled.load(Ordering::SeqCst) {
424 let slot_weak = Arc::downgrade(slot);
425 let auto_callback =
426 RingBufferStopCallback::new(fallible_closure(fail_handle, move || {
427 let slot = slot_weak.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
430 let mut device_context = slot.get_device_context()?;
431 device_context
432 .slot_context
433 .set_slot_state(DeviceSlotState::DisabledOrEnabled);
434 slot.set_device_context(device_context)?;
435 slot.reset();
436 debug!(
437 "device slot {}: all trc disabled, sending trb",
438 slot.slot_id
439 );
440 callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
441 }));
442 slot.stop_all_trc(auto_callback);
443 Ok(())
444 } else {
445 callback(TrbCompletionCode::SlotNotEnabledError).map_err(|_| Error::CallbackFailed)
446 }
447 }
448
449 pub fn set_address(
451 self: &Arc<Self>,
452 trb: &AddressDeviceCommandTrb,
453 ) -> Result<TrbCompletionCode> {
454 if !self.enabled.load(Ordering::SeqCst) {
455 error!(
456 "trying to set address to a disabled device slot {}",
457 self.slot_id
458 );
459 return Ok(TrbCompletionCode::SlotNotEnabledError);
460 }
461 let device_context = self.get_device_context()?;
462 let state = device_context
463 .slot_context
464 .get_slot_state()
465 .map_err(Error::GetSlotContextState)?;
466 match state {
467 DeviceSlotState::DisabledOrEnabled => {}
468 DeviceSlotState::Default if !trb.get_block_set_address_request() => {}
469 _ => {
470 error!("slot {} has unexpected slot state", self.slot_id);
471 return Ok(TrbCompletionCode::ContextStateError);
472 }
473 }
474
475 let input_context_ptr = GuestAddress(trb.get_input_context_pointer());
478 self.copy_context(input_context_ptr, 0)?;
480 self.copy_context(input_context_ptr, 1)?;
482
483 let mut device_context = self.get_device_context()?;
485 let port_id = device_context.slot_context.get_root_hub_port_number();
486 self.port_id.set(port_id)?;
487 debug!(
488 "port id {} is assigned to slot id {}",
489 port_id, self.slot_id
490 );
491
492 let trc = TransferRingController::new(
494 self.mem.clone(),
495 self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?,
496 self.event_loop.clone(),
497 self.interrupter.clone(),
498 self.slot_id,
499 1,
500 Arc::downgrade(self),
501 None,
502 )
503 .map_err(Error::CreateTransferController)?;
504 self.set_trcs(0, Some(TransferRingControllers::Endpoint(trc)));
505
506 if trb.get_block_set_address_request() {
508 device_context
509 .slot_context
510 .set_slot_state(DeviceSlotState::Default);
511 } else {
512 let port = self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?;
513 match port.backend_device().as_mut() {
514 Some(backend) => {
515 backend.lock().set_address(self.slot_id as u32);
516 }
517 None => {
518 return Ok(TrbCompletionCode::TransactionError);
519 }
520 }
521
522 device_context
523 .slot_context
524 .set_usb_device_address(self.slot_id);
525 device_context
526 .slot_context
527 .set_slot_state(DeviceSlotState::Addressed);
528 }
529
530 self.get_trc(0, 0)
532 .ok_or(Error::GetTrc(0))?
533 .set_dequeue_pointer(
534 device_context.endpoint_context[0]
535 .get_tr_dequeue_pointer()
536 .get_gpa(),
537 );
538
539 self.get_trc(0, 0)
540 .ok_or(Error::GetTrc(0))?
541 .set_consumer_cycle_state(device_context.endpoint_context[0].get_dequeue_cycle_state());
542
543 device_context.endpoint_context[0].set_endpoint_state(EndpointState::Running);
545 self.set_device_context(device_context)?;
546 Ok(TrbCompletionCode::Success)
547 }
548
549 pub fn configure_endpoint(
551 self: &Arc<Self>,
552 trb: &ConfigureEndpointCommandTrb,
553 ) -> Result<TrbCompletionCode> {
554 let input_control_context = if trb.get_deconfigure() {
555 let mut c = InputControlContext::new();
560 c.set_add_context_flags(0);
561 c.set_drop_context_flags(0xfffffffc);
562 c
563 } else {
564 self.mem
565 .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
566 .map_err(Error::ReadGuestMemory)?
567 };
568
569 for device_context_index in 1..DCI_INDEX_END {
570 if input_control_context.drop_context_flag(device_context_index) {
571 self.drop_one_endpoint(device_context_index)?;
572 }
573 if input_control_context.add_context_flag(device_context_index) {
574 self.copy_context(
575 GuestAddress(trb.get_input_context_pointer()),
576 device_context_index,
577 )?;
578 self.add_one_endpoint(device_context_index)?;
579 }
580 }
581
582 if trb.get_deconfigure() {
583 self.set_state(DeviceSlotState::Addressed)?;
584 } else {
585 self.set_state(DeviceSlotState::Configured)?;
586 }
587 Ok(TrbCompletionCode::Success)
588 }
589
590 pub fn evaluate_context(&self, trb: &EvaluateContextCommandTrb) -> Result<TrbCompletionCode> {
593 if !self.enabled.load(Ordering::SeqCst) {
594 return Ok(TrbCompletionCode::SlotNotEnabledError);
595 }
596 let input_control_context: InputControlContext = self
600 .mem
601 .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
602 .map_err(Error::ReadGuestMemory)?;
603
604 let mut device_context = self.get_device_context()?;
605 if input_control_context.add_context_flag(0) {
606 let input_slot_context: SlotContext = self
607 .mem
608 .read_obj_from_addr(GuestAddress(
609 trb.get_input_context_pointer() + DEVICE_CONTEXT_ENTRY_SIZE as u64,
610 ))
611 .map_err(Error::ReadGuestMemory)?;
612 device_context
613 .slot_context
614 .set_interrupter_target(input_slot_context.get_interrupter_target());
615
616 device_context
617 .slot_context
618 .set_max_exit_latency(input_slot_context.get_max_exit_latency());
619 }
620
621 if input_control_context.add_context_flag(1) {
624 let ep0_context: EndpointContext = self
625 .mem
626 .read_obj_from_addr(GuestAddress(
627 trb.get_input_context_pointer() + 2 * DEVICE_CONTEXT_ENTRY_SIZE as u64,
628 ))
629 .map_err(Error::ReadGuestMemory)?;
630 device_context.endpoint_context[0]
631 .set_max_packet_size(ep0_context.get_max_packet_size());
632 }
633 self.set_device_context(device_context)?;
634 Ok(TrbCompletionCode::Success)
635 }
636
637 pub fn reset_slot<
640 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
641 >(
642 fail_handle: Arc<dyn FailHandle>,
643 slot: &Arc<DeviceSlot>,
644 mut callback: C,
645 ) -> Result<()> {
646 let weak_s = Arc::downgrade(slot);
647 let auto_callback =
648 RingBufferStopCallback::new(fallible_closure(fail_handle, move || -> Result<()> {
649 let s = weak_s.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
650 for i in FIRST_TRANSFER_ENDPOINT_DCI..DCI_INDEX_END {
651 s.drop_one_endpoint(i)?;
652 }
653 let mut ctx = s.get_device_context()?;
654 ctx.slot_context.set_slot_state(DeviceSlotState::Default);
655 ctx.slot_context.set_context_entries(1);
656 ctx.slot_context.set_root_hub_port_number(0);
657 s.set_device_context(ctx)?;
658 callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)?;
659 Ok(())
660 }));
661 slot.stop_all_trc(auto_callback);
662 Ok(())
663 }
664
665 pub fn stop_all_trc(&self, auto_callback: RingBufferStopCallback) {
667 for i in 0..self.trc_len() {
668 if let Some(trcs) = self.get_trcs(i) {
669 match trcs {
670 TransferRingControllers::Endpoint(trc) => {
671 trc.stop(auto_callback.clone());
672 }
673 TransferRingControllers::Stream(trcs) => {
674 for trc in trcs {
675 trc.stop(auto_callback.clone());
676 }
677 }
678 }
679 }
680 }
681 }
682
683 fn force_stoppped_event(&self, slot_id: u8, endpoint_id: u8, dequeue_ptr: u64) -> Result<()> {
684 self.interrupter
685 .lock()
686 .send_transfer_event_trb(
687 TrbCompletionCode::StoppedLengthInvalid,
688 dequeue_ptr,
689 0,
690 false,
691 slot_id,
692 endpoint_id,
693 )
694 .map_err(Error::ForceStoppedEvent)
695 }
696
697 pub fn stop_endpoint<
699 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
700 >(
701 &self,
702 fail_handle: Arc<dyn FailHandle>,
703 endpoint_id: u8,
704 mut cb: C,
705 ) -> Result<()> {
706 if !valid_endpoint_id(endpoint_id) {
707 error!("trb indexing wrong endpoint id");
708 return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
709 }
710 let index = endpoint_id - 1;
711 let mut device_context = self.get_device_context()?;
712 let endpoint_context = &mut device_context.endpoint_context[index as usize];
713 match self.get_trcs(index as usize) {
714 Some(TransferRingControllers::Endpoint(trc)) => {
715 let auto_cb = RingBufferStopCallback::new(fallible_closure(
716 fail_handle,
717 move || -> Result<()> {
718 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
719 },
720 ));
721 trc.stop(auto_cb);
722 let dequeue_pointer = trc.get_dequeue_pointer();
723 let dcs = trc.get_consumer_cycle_state();
724 endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
725 endpoint_context.set_dequeue_cycle_state(dcs);
726 self.force_stoppped_event(self.slot_id, endpoint_id, dequeue_pointer.offset())?;
727 }
728 Some(TransferRingControllers::Stream(trcs)) => {
729 let stream_context_array_addr = endpoint_context.get_tr_dequeue_pointer().get_gpa();
730 let mut stream_context_array: StreamContextArray = self
731 .mem
732 .read_obj_from_addr(stream_context_array_addr)
733 .map_err(Error::ReadGuestMemory)?;
734 let auto_cb = RingBufferStopCallback::new(fallible_closure(
735 fail_handle,
736 move || -> Result<()> {
737 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
738 },
739 ));
740 for (i, trc) in trcs.iter().enumerate() {
741 let dequeue_pointer = trc.get_dequeue_pointer();
742 let dcs = trc.get_consumer_cycle_state();
743 trc.stop(auto_cb.clone());
744 stream_context_array.stream_contexts[i + 1]
745 .set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
746 stream_context_array.stream_contexts[i + 1].set_dequeue_cycle_state(dcs);
747 }
748 self.mem
749 .write_obj_at_addr(stream_context_array, stream_context_array_addr)
750 .map_err(Error::WriteGuestMemory)?;
751 }
752 None => {
753 error!("endpoint at index {} is not started", index);
754 cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed)?;
755 }
756 }
757 endpoint_context.set_endpoint_state(EndpointState::Stopped);
758 self.set_device_context(device_context)?;
759 Ok(())
760 }
761
762 pub fn reset_endpoint<
764 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
765 >(
766 &self,
767 fail_handle: Arc<dyn FailHandle>,
768 endpoint_id: u8,
769 mut cb: C,
770 ) -> Result<()> {
771 if !valid_endpoint_id(endpoint_id) {
772 error!("trb indexing wrong endpoint id");
773 return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
774 }
775 let index = endpoint_id - 1;
776 let mut device_context = self.get_device_context()?;
777 let endpoint_context = &mut device_context.endpoint_context[index as usize];
778 if endpoint_context
779 .get_endpoint_state()
780 .map_err(Error::GetEndpointState)?
781 != EndpointState::Halted
782 {
783 error!("endpoint at index {} is not halted", index);
784 return cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed);
785 }
786 match self.get_trcs(index as usize) {
787 Some(TransferRingControllers::Endpoint(trc)) => {
788 let auto_cb = RingBufferStopCallback::new(fallible_closure(
789 fail_handle,
790 move || -> Result<()> {
791 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
792 },
793 ));
794 trc.stop(auto_cb);
795 let dequeue_pointer = trc.get_dequeue_pointer();
796 let dcs = trc.get_consumer_cycle_state();
797 endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
798 endpoint_context.set_dequeue_cycle_state(dcs);
799 }
800 Some(TransferRingControllers::Stream(trcs)) => {
801 let stream_context_array_addr = endpoint_context.get_tr_dequeue_pointer().get_gpa();
802 let mut stream_context_array: StreamContextArray = self
803 .mem
804 .read_obj_from_addr(stream_context_array_addr)
805 .map_err(Error::ReadGuestMemory)?;
806 let auto_cb = RingBufferStopCallback::new(fallible_closure(
807 fail_handle,
808 move || -> Result<()> {
809 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
810 },
811 ));
812 for (i, trc) in trcs.iter().enumerate() {
813 let dequeue_pointer = trc.get_dequeue_pointer();
814 let dcs = trc.get_consumer_cycle_state();
815 trc.stop(auto_cb.clone());
816 stream_context_array.stream_contexts[i + 1]
817 .set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
818 stream_context_array.stream_contexts[i + 1].set_dequeue_cycle_state(dcs);
819 }
820 self.mem
821 .write_obj_at_addr(stream_context_array, stream_context_array_addr)
822 .map_err(Error::WriteGuestMemory)?;
823 }
824 None => {
825 error!("endpoint at index {} is not started", index);
826 cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed)?;
827 }
828 }
829 endpoint_context.set_endpoint_state(EndpointState::Stopped);
830 self.set_device_context(device_context)?;
831 Ok(())
832 }
833
834 pub fn set_tr_dequeue_ptr(
836 &self,
837 endpoint_id: u8,
838 stream_id: u16,
839 ptr: u64,
840 ) -> Result<TrbCompletionCode> {
841 if !valid_endpoint_id(endpoint_id) {
842 error!("trb indexing wrong endpoint id");
843 return Ok(TrbCompletionCode::TrbError);
844 }
845 let index = (endpoint_id - 1) as usize;
846 match self.get_trc(index, stream_id) {
847 Some(trc) => {
848 trc.set_dequeue_pointer(GuestAddress(ptr));
849 let mut ctx = self.get_device_context()?;
850 ctx.endpoint_context[index]
851 .set_tr_dequeue_pointer(DequeuePtr::new(GuestAddress(ptr)));
852 self.set_device_context(ctx)?;
853 Ok(TrbCompletionCode::Success)
854 }
855 None => {
856 error!("set tr dequeue ptr failed due to no trc started");
857 Ok(TrbCompletionCode::ContextStateError)
858 }
859 }
860 }
861
862 fn reset(&self) {
866 for i in 0..self.trc_len() {
867 self.set_trcs(i, None);
868 }
869 debug!("resetting device slot {}!", self.slot_id);
870 self.enabled.store(false, Ordering::SeqCst);
871 self.port_id.reset();
872 }
873
874 fn create_stream_trcs(
875 self: &Arc<Self>,
876 stream_context_array_addr: GuestAddress,
877 max_pstreams: u8,
878 device_context_index: u8,
879 ) -> Result<TransferRingControllers> {
880 let pstreams = 1usize << (max_pstreams + 1);
881 let stream_context_array: StreamContextArray = self
882 .mem
883 .read_obj_from_addr(stream_context_array_addr)
884 .map_err(Error::ReadGuestMemory)?;
885 let mut trcs = Vec::new();
886
887 for i in 1..pstreams {
889 let stream_context = &stream_context_array.stream_contexts[i];
890 let context_type = stream_context.get_stream_context_type();
891 if context_type != 1 {
892 return Err(Error::BadStreamContextType(context_type));
894 }
895 let trc = TransferRingController::new(
896 self.mem.clone(),
897 self.hub
898 .get_port(self.port_id.get()?)
899 .ok_or(Error::GetPort(self.port_id.get()?))?,
900 self.event_loop.clone(),
901 self.interrupter.clone(),
902 self.slot_id,
903 device_context_index,
904 Arc::downgrade(self),
905 Some(i as u16),
906 )
907 .map_err(Error::CreateTransferController)?;
908 trc.set_dequeue_pointer(stream_context.get_tr_dequeue_pointer().get_gpa());
909 trc.set_consumer_cycle_state(stream_context.get_dequeue_cycle_state());
910 trcs.push(trc);
911 }
912 Ok(TransferRingControllers::Stream(trcs))
913 }
914
915 fn add_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
916 xhci_trace!(
917 "adding one endpoint, device context index {}",
918 device_context_index
919 );
920 let mut device_context = self.get_device_context()?;
921 let transfer_ring_index = (device_context_index - 1) as usize;
922 let endpoint_context = &mut device_context.endpoint_context[transfer_ring_index];
923 let max_pstreams = endpoint_context.get_max_primary_streams();
924 let tr_dequeue_pointer = endpoint_context.get_tr_dequeue_pointer().get_gpa();
925 let endpoint_context_addr = self
926 .get_device_context_addr()?
927 .unchecked_add(size_of::<SlotContext>() as u64)
928 .unchecked_add(size_of::<EndpointContext>() as u64 * transfer_ring_index as u64);
929 let trcs = if max_pstreams > 0 {
930 if !valid_max_pstreams(max_pstreams) {
931 return Err(Error::BadEndpointContext(endpoint_context_addr));
932 }
933 let endpoint_type = endpoint_context.get_endpoint_type();
934 if endpoint_type != 2 && endpoint_type != 6 {
935 return Err(Error::BadEndpointId(transfer_ring_index as u8));
937 }
938 if endpoint_context.get_linear_stream_array() != 1 {
939 return Err(Error::BadEndpointContext(endpoint_context_addr));
941 }
942
943 let trcs =
944 self.create_stream_trcs(tr_dequeue_pointer, max_pstreams, device_context_index)?;
945
946 if let Some(port) = self.hub.get_port(self.port_id.get()?) {
947 if let Some(backend_device) = port.backend_device().as_mut() {
948 let mut endpoint_address = device_context_index / 2;
949 if device_context_index % 2 == 1 {
950 endpoint_address |= 1u8 << 7;
951 }
952 let streams = 1 << (max_pstreams + 1);
953 backend_device
955 .lock()
956 .alloc_streams(endpoint_address, streams - 1)
957 .map_err(Error::AllocStreams)?;
958 }
959 }
960 trcs
961 } else {
962 let trc = TransferRingController::new(
963 self.mem.clone(),
964 self.hub
965 .get_port(self.port_id.get()?)
966 .ok_or(Error::GetPort(self.port_id.get()?))?,
967 self.event_loop.clone(),
968 self.interrupter.clone(),
969 self.slot_id,
970 device_context_index,
971 Arc::downgrade(self),
972 None,
973 )
974 .map_err(Error::CreateTransferController)?;
975 trc.set_dequeue_pointer(tr_dequeue_pointer);
976 trc.set_consumer_cycle_state(endpoint_context.get_dequeue_cycle_state());
977 TransferRingControllers::Endpoint(trc)
978 };
979 self.set_trcs(transfer_ring_index, Some(trcs));
980 endpoint_context.set_endpoint_state(EndpointState::Running);
981 self.set_device_context(device_context)
982 }
983
984 fn drop_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
985 let endpoint_index = (device_context_index - 1) as usize;
986 let mut device_context = self.get_device_context()?;
987 let endpoint_context = &mut device_context.endpoint_context[endpoint_index];
988 if endpoint_context.get_max_primary_streams() > 0 {
989 if let Some(port) = self.hub.get_port(self.port_id.get()?) {
990 if let Some(backend_device) = port.backend_device().as_mut() {
991 let mut endpoint_address = device_context_index / 2;
992 if device_context_index % 2 == 1 {
993 endpoint_address |= 1u8 << 7;
994 }
995 backend_device
996 .lock()
997 .free_streams(endpoint_address)
998 .map_err(Error::FreeStreams)?;
999 }
1000 }
1001 }
1002 self.set_trcs(endpoint_index, None);
1003 endpoint_context.set_endpoint_state(EndpointState::Disabled);
1004 self.set_device_context(device_context)
1005 }
1006
1007 fn get_device_context(&self) -> Result<DeviceContext> {
1008 let ctx = self
1009 .mem
1010 .read_obj_from_addr(self.get_device_context_addr()?)
1011 .map_err(Error::ReadGuestMemory)?;
1012 Ok(ctx)
1013 }
1014
1015 fn set_device_context(&self, device_context: DeviceContext) -> Result<()> {
1016 self.mem
1017 .write_obj_at_addr(device_context, self.get_device_context_addr()?)
1018 .map_err(Error::WriteGuestMemory)
1019 }
1020
1021 fn copy_context(
1022 &self,
1023 input_context_ptr: GuestAddress,
1024 device_context_index: u8,
1025 ) -> Result<()> {
1026 let ctx: EndpointContext = self
1029 .mem
1030 .read_obj_from_addr(
1031 input_context_ptr
1032 .checked_add(
1033 (device_context_index as u64 + 1) * DEVICE_CONTEXT_ENTRY_SIZE as u64,
1034 )
1035 .ok_or(Error::BadInputContextAddr(input_context_ptr))?,
1036 )
1037 .map_err(Error::ReadGuestMemory)?;
1038 xhci_trace!("copy_context {:?}", ctx);
1039 let device_context_ptr = self.get_device_context_addr()?;
1040 self.mem
1041 .write_obj_at_addr(
1042 ctx,
1043 device_context_ptr
1044 .checked_add(device_context_index as u64 * DEVICE_CONTEXT_ENTRY_SIZE as u64)
1045 .ok_or(Error::BadDeviceContextAddr(device_context_ptr))?,
1046 )
1047 .map_err(Error::WriteGuestMemory)
1048 }
1049
1050 fn get_device_context_addr(&self) -> Result<GuestAddress> {
1051 let addr: u64 = self
1052 .mem
1053 .read_obj_from_addr(GuestAddress(
1054 self.dcbaap.get_value() + size_of::<u64>() as u64 * self.slot_id as u64,
1055 ))
1056 .map_err(Error::ReadGuestMemory)?;
1057 Ok(GuestAddress(addr))
1058 }
1059
1060 fn set_state(&self, state: DeviceSlotState) -> Result<()> {
1061 let mut ctx = self.get_device_context()?;
1062 ctx.slot_context.set_slot_state(state);
1063 self.set_device_context(ctx)
1064 }
1065
1066 pub fn halt_endpoint(&self, endpoint_id: u8) -> Result<()> {
1067 if !valid_endpoint_id(endpoint_id) {
1068 return Err(Error::BadEndpointId(endpoint_id));
1069 }
1070 let index = endpoint_id - 1;
1071 let mut device_context = self.get_device_context()?;
1072 let endpoint_context = &mut device_context.endpoint_context[index as usize];
1073 match self.get_trcs(index as usize) {
1074 Some(trcs) => match trcs {
1075 TransferRingControllers::Endpoint(trc) => {
1076 endpoint_context
1077 .set_tr_dequeue_pointer(DequeuePtr::new(trc.get_dequeue_pointer()));
1078 endpoint_context.set_dequeue_cycle_state(trc.get_consumer_cycle_state());
1079 }
1080 TransferRingControllers::Stream(trcs) => {
1081 let stream_context_array_addr =
1082 endpoint_context.get_tr_dequeue_pointer().get_gpa();
1083 let mut stream_context_array: StreamContextArray = self
1084 .mem
1085 .read_obj_from_addr(stream_context_array_addr)
1086 .map_err(Error::ReadGuestMemory)?;
1087 for (i, trc) in trcs.iter().enumerate() {
1088 stream_context_array.stream_contexts[i + 1]
1089 .set_tr_dequeue_pointer(DequeuePtr::new(trc.get_dequeue_pointer()));
1090 stream_context_array.stream_contexts[i + 1]
1091 .set_dequeue_cycle_state(trc.get_consumer_cycle_state());
1092 }
1093 self.mem
1094 .write_obj_at_addr(stream_context_array, stream_context_array_addr)
1095 .map_err(Error::WriteGuestMemory)?;
1096 }
1097 },
1098 None => {
1099 error!("trc for endpoint {} not found", endpoint_id);
1100 return Err(Error::BadEndpointId(endpoint_id));
1101 }
1102 }
1103 endpoint_context.set_endpoint_state(EndpointState::Halted);
1104 self.set_device_context(device_context)?;
1105 Ok(())
1106 }
1107
1108 pub fn get_max_esit_payload(&self, endpoint_id: u8) -> Result<u32> {
1109 let index = endpoint_id - 1;
1110 let device_context = self.get_device_context()?;
1111 let endpoint_context = device_context.endpoint_context[index as usize];
1112
1113 let lo = endpoint_context.get_max_esit_payload_lo() as u32;
1114 let hi = endpoint_context.get_max_esit_payload_hi() as u32;
1115 Ok(hi << 16 | lo)
1116 }
1117}