1use std::mem::size_of;
6use std::sync::atomic::AtomicBool;
7use std::sync::atomic::Ordering;
8use std::sync::Arc;
9
10use base::debug;
11use base::error;
12use base::info;
13use bit_field::Error as BitFieldError;
14use remain::sorted;
15use sync::Mutex;
16use thiserror::Error;
17use vm_memory::GuestAddress;
18use vm_memory::GuestMemory;
19use vm_memory::GuestMemoryError;
20
21use super::interrupter::Interrupter;
22use super::transfer_ring_controller::TransferRingController;
23use super::transfer_ring_controller::TransferRingControllerError;
24use super::transfer_ring_controller::TransferRingControllers;
25use super::usb_hub;
26use super::usb_hub::UsbHub;
27use super::xhci_abi::AddressDeviceCommandTrb;
28use super::xhci_abi::ConfigureEndpointCommandTrb;
29use super::xhci_abi::DequeuePtr;
30use super::xhci_abi::DeviceContext;
31use super::xhci_abi::DeviceSlotState;
32use super::xhci_abi::EndpointContext;
33use super::xhci_abi::EndpointState;
34use super::xhci_abi::EvaluateContextCommandTrb;
35use super::xhci_abi::InputControlContext;
36use super::xhci_abi::SlotContext;
37use super::xhci_abi::StreamContextArray;
38use super::xhci_abi::TrbCompletionCode;
39use super::xhci_abi::DEVICE_CONTEXT_ENTRY_SIZE;
40use super::xhci_backend_device::XhciBackendDevice;
41use super::xhci_regs::valid_max_pstreams;
42use super::xhci_regs::valid_slot_id;
43use super::xhci_regs::MAX_PORTS;
44use super::xhci_regs::MAX_SLOTS;
45use crate::register_space::Register;
46use crate::usb::backend::error::Error as BackendProviderError;
47use crate::usb::xhci::ring_buffer_stop_cb::fallible_closure;
48use crate::usb::xhci::ring_buffer_stop_cb::RingBufferStopCallback;
49use crate::utils::EventLoop;
50use crate::utils::FailHandle;
51
52#[sorted]
53#[derive(Error, Debug)]
54pub enum Error {
55 #[error("failed to allocate streams: {0}")]
56 AllocStreams(BackendProviderError),
57 #[error("bad device context: {0}")]
58 BadDeviceContextAddr(GuestAddress),
59 #[error("bad endpoint context: {0}")]
60 BadEndpointContext(GuestAddress),
61 #[error("device slot get a bad endpoint id: {0}")]
62 BadEndpointId(u8),
63 #[error("bad input context address: {0}")]
64 BadInputContextAddr(GuestAddress),
65 #[error("device slot get a bad port id: {0}")]
66 BadPortId(u8),
67 #[error("bad stream context type: {0}")]
68 BadStreamContextType(u8),
69 #[error("callback failed")]
70 CallbackFailed,
71 #[error("failed to create transfer controller: {0}")]
72 CreateTransferController(TransferRingControllerError),
73 #[error("failed to free streams: {0}")]
74 FreeStreams(BackendProviderError),
75 #[error("failed to get endpoint state: {0}")]
76 GetEndpointState(BitFieldError),
77 #[error("failed to get port: {0}")]
78 GetPort(u8),
79 #[error("failed to get slot context state: {0}")]
80 GetSlotContextState(BitFieldError),
81 #[error("failed to get trc: {0}")]
82 GetTrc(u8),
83 #[error("failed to read guest memory: {0}")]
84 ReadGuestMemory(GuestMemoryError),
85 #[error("failed to reset port: {0}")]
86 ResetPort(BackendProviderError),
87 #[error("failed to upgrade weak reference")]
88 WeakReferenceUpgrade,
89 #[error("failed to write guest memory: {0}")]
90 WriteGuestMemory(GuestMemoryError),
91}
92
93type Result<T> = std::result::Result<T, Error>;
94
95pub const TRANSFER_RING_CONTROLLERS_INDEX_END: usize = 31;
103pub const DCI_INDEX_END: u8 = (TRANSFER_RING_CONTROLLERS_INDEX_END + 1) as u8;
105pub const FIRST_TRANSFER_ENDPOINT_DCI: u8 = 2;
107
108fn valid_endpoint_id(endpoint_id: u8) -> bool {
109 endpoint_id < DCI_INDEX_END && endpoint_id > 0
110}
111
112#[derive(Clone)]
113pub struct DeviceSlots {
114 fail_handle: Arc<dyn FailHandle>,
115 hub: Arc<UsbHub>,
116 slots: Vec<Arc<DeviceSlot>>,
117}
118
119impl DeviceSlots {
120 pub fn new(
121 fail_handle: Arc<dyn FailHandle>,
122 dcbaap: Register<u64>,
123 hub: Arc<UsbHub>,
124 interrupter: Arc<Mutex<Interrupter>>,
125 event_loop: Arc<EventLoop>,
126 mem: GuestMemory,
127 ) -> DeviceSlots {
128 let mut slots = Vec::new();
129 for slot_id in 1..=MAX_SLOTS {
130 slots.push(Arc::new(DeviceSlot::new(
131 slot_id,
132 dcbaap.clone(),
133 hub.clone(),
134 interrupter.clone(),
135 event_loop.clone(),
136 mem.clone(),
137 )));
138 }
139 DeviceSlots {
140 fail_handle,
141 hub,
142 slots,
143 }
144 }
145
146 pub fn slot(&self, slot_id: u8) -> Option<Arc<DeviceSlot>> {
148 if valid_slot_id(slot_id) {
149 Some(self.slots[slot_id as usize - 1].clone())
150 } else {
151 error!(
152 "trying to index a wrong slot id {}, max slot = {}",
153 slot_id, MAX_SLOTS
154 );
155 None
156 }
157 }
158
159 pub fn reset_port(&self, port_id: u8) -> Result<()> {
161 if let Some(port) = self.hub.get_port(port_id) {
162 if let Some(backend_device) = port.backend_device().as_mut() {
163 backend_device.lock().reset().map_err(Error::ResetPort)?;
164 }
165 }
166
167 Ok(())
169 }
170
171 pub fn stop_all_and_reset<C: FnMut() + 'static + Send>(&self, mut callback: C) {
173 info!("xhci: stopping all device slots and resetting host hub");
174 let slots = self.slots.clone();
175 let hub = self.hub.clone();
176 let auto_callback = RingBufferStopCallback::new(fallible_closure(
177 self.fail_handle.clone(),
178 move || -> std::result::Result<(), usb_hub::Error> {
179 for slot in &slots {
180 slot.reset();
181 }
182 hub.reset()?;
183 callback();
184 Ok(())
185 },
186 ));
187 self.stop_all(auto_callback);
188 }
189
190 pub fn stop_all(&self, auto_callback: RingBufferStopCallback) {
193 for slot in &self.slots {
194 slot.stop_all_trc(auto_callback.clone());
195 }
196 }
197
198 pub fn disable_slot<
201 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
202 >(
203 &self,
204 slot_id: u8,
205 cb: C,
206 ) -> Result<()> {
207 xhci_trace!("device slot {} is being disabled", slot_id);
208 DeviceSlot::disable(
209 self.fail_handle.clone(),
210 &self.slots[slot_id as usize - 1],
211 cb,
212 )
213 }
214
215 pub fn reset_slot<
217 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
218 >(
219 &self,
220 slot_id: u8,
221 cb: C,
222 ) -> Result<()> {
223 xhci_trace!("device slot {} is resetting", slot_id);
224 DeviceSlot::reset_slot(
225 self.fail_handle.clone(),
226 &self.slots[slot_id as usize - 1],
227 cb,
228 )
229 }
230
231 pub fn stop_endpoint<
232 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
233 >(
234 &self,
235 slot_id: u8,
236 endpoint_id: u8,
237 cb: C,
238 ) -> Result<()> {
239 self.slots[slot_id as usize - 1].stop_endpoint(self.fail_handle.clone(), endpoint_id, cb)
240 }
241
242 pub fn reset_endpoint<
243 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
244 >(
245 &self,
246 slot_id: u8,
247 endpoint_id: u8,
248 cb: C,
249 ) -> Result<()> {
250 self.slots[slot_id as usize - 1].reset_endpoint(self.fail_handle.clone(), endpoint_id, cb)
251 }
252}
253
254struct PortId(Mutex<u8>);
256
257impl PortId {
258 fn new() -> Self {
259 PortId(Mutex::new(0))
260 }
261
262 fn set(&self, value: u8) -> Result<()> {
263 if !(1..=MAX_PORTS).contains(&value) {
264 return Err(Error::BadPortId(value));
265 }
266 *self.0.lock() = value;
267 Ok(())
268 }
269
270 fn reset(&self) {
271 *self.0.lock() = 0;
272 }
273
274 fn get(&self) -> Result<u8> {
275 let val = *self.0.lock();
276 if val == 0 {
277 return Err(Error::BadPortId(val));
278 }
279 Ok(val)
280 }
281}
282
283pub struct DeviceSlot {
284 slot_id: u8,
285 port_id: PortId, dcbaap: Register<u64>,
287 hub: Arc<UsbHub>,
288 interrupter: Arc<Mutex<Interrupter>>,
289 event_loop: Arc<EventLoop>,
290 mem: GuestMemory,
291 enabled: AtomicBool,
292 transfer_ring_controllers: Mutex<Vec<Option<TransferRingControllers>>>,
293}
294
295impl DeviceSlot {
296 pub fn new(
298 slot_id: u8,
299 dcbaap: Register<u64>,
300 hub: Arc<UsbHub>,
301 interrupter: Arc<Mutex<Interrupter>>,
302 event_loop: Arc<EventLoop>,
303 mem: GuestMemory,
304 ) -> Self {
305 let mut transfer_ring_controllers = Vec::new();
306 transfer_ring_controllers.resize_with(TRANSFER_RING_CONTROLLERS_INDEX_END, || None);
307 DeviceSlot {
308 slot_id,
309 port_id: PortId::new(),
310 dcbaap,
311 hub,
312 interrupter,
313 event_loop,
314 mem,
315 enabled: AtomicBool::new(false),
316 transfer_ring_controllers: Mutex::new(transfer_ring_controllers),
317 }
318 }
319
320 fn get_trc(&self, i: usize, stream_id: u16) -> Option<Arc<TransferRingController>> {
321 let trcs = self.transfer_ring_controllers.lock();
322 match &trcs[i] {
323 Some(TransferRingControllers::Endpoint(trc)) => Some(trc.clone()),
324 Some(TransferRingControllers::Stream(trcs)) => {
325 let stream_id = stream_id as usize;
326 if stream_id > 0 && stream_id <= trcs.len() {
327 Some(trcs[stream_id - 1].clone())
328 } else {
329 None
330 }
331 }
332 None => None,
333 }
334 }
335
336 fn get_trcs(&self, i: usize) -> Option<TransferRingControllers> {
337 let trcs = self.transfer_ring_controllers.lock();
338 trcs[i].clone()
339 }
340
341 fn set_trcs(&self, i: usize, trc: Option<TransferRingControllers>) {
342 let mut trcs = self.transfer_ring_controllers.lock();
343 trcs[i] = trc;
344 }
345
346 fn trc_len(&self) -> usize {
347 self.transfer_ring_controllers.lock().len()
348 }
349
350 pub fn ring_doorbell(&self, target: u8, stream_id: u16) -> Result<bool> {
365 if !valid_endpoint_id(target) {
366 error!(
367 "device slot {}: Invalid target written to doorbell register. target: {}",
368 self.slot_id, target
369 );
370 return Ok(false);
371 }
372 xhci_trace!(
373 "device slot {}: ring_doorbell target = {} stream_id = {}",
374 self.slot_id,
375 target,
376 stream_id
377 );
378 let endpoint_index = (target - 1) as usize;
380 let transfer_ring_controller = match self.get_trc(endpoint_index, stream_id) {
381 Some(tr) => tr,
382 None => {
383 error!("Device endpoint is not inited");
384 return Ok(false);
385 }
386 };
387 let mut context = self.get_device_context()?;
388 let endpoint_state = context.endpoint_context[endpoint_index]
389 .get_endpoint_state()
390 .map_err(Error::GetEndpointState)?;
391 if endpoint_state == EndpointState::Running || endpoint_state == EndpointState::Stopped {
392 if endpoint_state == EndpointState::Stopped {
393 context.endpoint_context[endpoint_index].set_endpoint_state(EndpointState::Running);
394 self.set_device_context(context)?;
395 }
396 transfer_ring_controller.start();
398 } else {
399 error!("doorbell rung when endpoint state is {:?}", endpoint_state);
400 }
401 Ok(true)
402 }
403
404 pub fn enable(&self) -> bool {
406 let was_already_enabled = self.enabled.swap(true, Ordering::SeqCst);
407 if was_already_enabled {
408 error!("device slot is already enabled");
409 }
410 !was_already_enabled
411 }
412
413 pub fn disable<C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send>(
416 fail_handle: Arc<dyn FailHandle>,
417 slot: &Arc<DeviceSlot>,
418 mut callback: C,
419 ) -> Result<()> {
420 if slot.enabled.load(Ordering::SeqCst) {
421 let slot_weak = Arc::downgrade(slot);
422 let auto_callback =
423 RingBufferStopCallback::new(fallible_closure(fail_handle, move || {
424 let slot = slot_weak.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
427 let mut device_context = slot.get_device_context()?;
428 device_context
429 .slot_context
430 .set_slot_state(DeviceSlotState::DisabledOrEnabled);
431 slot.set_device_context(device_context)?;
432 slot.reset();
433 debug!(
434 "device slot {}: all trc disabled, sending trb",
435 slot.slot_id
436 );
437 callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
438 }));
439 slot.stop_all_trc(auto_callback);
440 Ok(())
441 } else {
442 callback(TrbCompletionCode::SlotNotEnabledError).map_err(|_| Error::CallbackFailed)
443 }
444 }
445
446 pub fn set_address(
448 self: &Arc<Self>,
449 trb: &AddressDeviceCommandTrb,
450 ) -> Result<TrbCompletionCode> {
451 if !self.enabled.load(Ordering::SeqCst) {
452 error!(
453 "trying to set address to a disabled device slot {}",
454 self.slot_id
455 );
456 return Ok(TrbCompletionCode::SlotNotEnabledError);
457 }
458 let device_context = self.get_device_context()?;
459 let state = device_context
460 .slot_context
461 .get_slot_state()
462 .map_err(Error::GetSlotContextState)?;
463 match state {
464 DeviceSlotState::DisabledOrEnabled => {}
465 DeviceSlotState::Default if !trb.get_block_set_address_request() => {}
466 _ => {
467 error!("slot {} has unexpected slot state", self.slot_id);
468 return Ok(TrbCompletionCode::ContextStateError);
469 }
470 }
471
472 let input_context_ptr = GuestAddress(trb.get_input_context_pointer());
475 self.copy_context(input_context_ptr, 0)?;
477 self.copy_context(input_context_ptr, 1)?;
479
480 let mut device_context = self.get_device_context()?;
482 let port_id = device_context.slot_context.get_root_hub_port_number();
483 self.port_id.set(port_id)?;
484 debug!(
485 "port id {} is assigned to slot id {}",
486 port_id, self.slot_id
487 );
488
489 let trc = TransferRingController::new(
491 self.mem.clone(),
492 self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?,
493 self.event_loop.clone(),
494 self.interrupter.clone(),
495 self.slot_id,
496 1,
497 Arc::downgrade(self),
498 None,
499 )
500 .map_err(Error::CreateTransferController)?;
501 self.set_trcs(0, Some(TransferRingControllers::Endpoint(trc)));
502
503 if trb.get_block_set_address_request() {
505 device_context
506 .slot_context
507 .set_slot_state(DeviceSlotState::Default);
508 } else {
509 let port = self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?;
510 match port.backend_device().as_mut() {
511 Some(backend) => {
512 backend.lock().set_address(self.slot_id as u32);
513 }
514 None => {
515 return Ok(TrbCompletionCode::TransactionError);
516 }
517 }
518
519 device_context
520 .slot_context
521 .set_usb_device_address(self.slot_id);
522 device_context
523 .slot_context
524 .set_slot_state(DeviceSlotState::Addressed);
525 }
526
527 self.get_trc(0, 0)
529 .ok_or(Error::GetTrc(0))?
530 .set_dequeue_pointer(
531 device_context.endpoint_context[0]
532 .get_tr_dequeue_pointer()
533 .get_gpa(),
534 );
535
536 self.get_trc(0, 0)
537 .ok_or(Error::GetTrc(0))?
538 .set_consumer_cycle_state(device_context.endpoint_context[0].get_dequeue_cycle_state());
539
540 device_context.endpoint_context[0].set_endpoint_state(EndpointState::Running);
542 self.set_device_context(device_context)?;
543 Ok(TrbCompletionCode::Success)
544 }
545
546 pub fn configure_endpoint(
548 self: &Arc<Self>,
549 trb: &ConfigureEndpointCommandTrb,
550 ) -> Result<TrbCompletionCode> {
551 let input_control_context = if trb.get_deconfigure() {
552 let mut c = InputControlContext::new();
557 c.set_add_context_flags(0);
558 c.set_drop_context_flags(0xfffffffc);
559 c
560 } else {
561 self.mem
562 .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
563 .map_err(Error::ReadGuestMemory)?
564 };
565
566 for device_context_index in 1..DCI_INDEX_END {
567 if input_control_context.drop_context_flag(device_context_index) {
568 self.drop_one_endpoint(device_context_index)?;
569 }
570 if input_control_context.add_context_flag(device_context_index) {
571 self.copy_context(
572 GuestAddress(trb.get_input_context_pointer()),
573 device_context_index,
574 )?;
575 self.add_one_endpoint(device_context_index)?;
576 }
577 }
578
579 if trb.get_deconfigure() {
580 self.set_state(DeviceSlotState::Addressed)?;
581 } else {
582 self.set_state(DeviceSlotState::Configured)?;
583 }
584 Ok(TrbCompletionCode::Success)
585 }
586
587 pub fn evaluate_context(&self, trb: &EvaluateContextCommandTrb) -> Result<TrbCompletionCode> {
590 if !self.enabled.load(Ordering::SeqCst) {
591 return Ok(TrbCompletionCode::SlotNotEnabledError);
592 }
593 let input_control_context: InputControlContext = self
597 .mem
598 .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
599 .map_err(Error::ReadGuestMemory)?;
600
601 let mut device_context = self.get_device_context()?;
602 if input_control_context.add_context_flag(0) {
603 let input_slot_context: SlotContext = self
604 .mem
605 .read_obj_from_addr(GuestAddress(
606 trb.get_input_context_pointer() + DEVICE_CONTEXT_ENTRY_SIZE as u64,
607 ))
608 .map_err(Error::ReadGuestMemory)?;
609 device_context
610 .slot_context
611 .set_interrupter_target(input_slot_context.get_interrupter_target());
612
613 device_context
614 .slot_context
615 .set_max_exit_latency(input_slot_context.get_max_exit_latency());
616 }
617
618 if input_control_context.add_context_flag(1) {
621 let ep0_context: EndpointContext = self
622 .mem
623 .read_obj_from_addr(GuestAddress(
624 trb.get_input_context_pointer() + 2 * DEVICE_CONTEXT_ENTRY_SIZE as u64,
625 ))
626 .map_err(Error::ReadGuestMemory)?;
627 device_context.endpoint_context[0]
628 .set_max_packet_size(ep0_context.get_max_packet_size());
629 }
630 self.set_device_context(device_context)?;
631 Ok(TrbCompletionCode::Success)
632 }
633
634 pub fn reset_slot<
637 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
638 >(
639 fail_handle: Arc<dyn FailHandle>,
640 slot: &Arc<DeviceSlot>,
641 mut callback: C,
642 ) -> Result<()> {
643 let weak_s = Arc::downgrade(slot);
644 let auto_callback =
645 RingBufferStopCallback::new(fallible_closure(fail_handle, move || -> Result<()> {
646 let s = weak_s.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
647 for i in FIRST_TRANSFER_ENDPOINT_DCI..DCI_INDEX_END {
648 s.drop_one_endpoint(i)?;
649 }
650 let mut ctx = s.get_device_context()?;
651 ctx.slot_context.set_slot_state(DeviceSlotState::Default);
652 ctx.slot_context.set_context_entries(1);
653 ctx.slot_context.set_root_hub_port_number(0);
654 s.set_device_context(ctx)?;
655 callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)?;
656 Ok(())
657 }));
658 slot.stop_all_trc(auto_callback);
659 Ok(())
660 }
661
662 pub fn stop_all_trc(&self, auto_callback: RingBufferStopCallback) {
664 for i in 0..self.trc_len() {
665 if let Some(trcs) = self.get_trcs(i) {
666 match trcs {
667 TransferRingControllers::Endpoint(trc) => {
668 trc.stop(auto_callback.clone());
669 }
670 TransferRingControllers::Stream(trcs) => {
671 for trc in trcs {
672 trc.stop(auto_callback.clone());
673 }
674 }
675 }
676 }
677 }
678 }
679
680 pub fn stop_endpoint<
682 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
683 >(
684 &self,
685 fail_handle: Arc<dyn FailHandle>,
686 endpoint_id: u8,
687 mut cb: C,
688 ) -> Result<()> {
689 if !valid_endpoint_id(endpoint_id) {
690 error!("trb indexing wrong endpoint id");
691 return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
692 }
693 let index = endpoint_id - 1;
694 let mut device_context = self.get_device_context()?;
695 let endpoint_context = &mut device_context.endpoint_context[index as usize];
696 match self.get_trcs(index as usize) {
697 Some(TransferRingControllers::Endpoint(trc)) => {
698 let auto_cb = RingBufferStopCallback::new(fallible_closure(
699 fail_handle,
700 move || -> Result<()> {
701 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
702 },
703 ));
704 trc.stop(auto_cb);
705 let dequeue_pointer = trc.get_dequeue_pointer();
706 let dcs = trc.get_consumer_cycle_state();
707 endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
708 endpoint_context.set_dequeue_cycle_state(dcs);
709 }
710 Some(TransferRingControllers::Stream(trcs)) => {
711 let stream_context_array_addr = endpoint_context.get_tr_dequeue_pointer().get_gpa();
712 let mut stream_context_array: StreamContextArray = self
713 .mem
714 .read_obj_from_addr(stream_context_array_addr)
715 .map_err(Error::ReadGuestMemory)?;
716 let auto_cb = RingBufferStopCallback::new(fallible_closure(
717 fail_handle,
718 move || -> Result<()> {
719 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
720 },
721 ));
722 for (i, trc) in trcs.iter().enumerate() {
723 let dequeue_pointer = trc.get_dequeue_pointer();
724 let dcs = trc.get_consumer_cycle_state();
725 trc.stop(auto_cb.clone());
726 stream_context_array.stream_contexts[i + 1]
727 .set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
728 stream_context_array.stream_contexts[i + 1].set_dequeue_cycle_state(dcs);
729 }
730 self.mem
731 .write_obj_at_addr(stream_context_array, stream_context_array_addr)
732 .map_err(Error::WriteGuestMemory)?;
733 }
734 None => {
735 error!("endpoint at index {} is not started", index);
736 cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed)?;
737 }
738 }
739 endpoint_context.set_endpoint_state(EndpointState::Stopped);
740 self.set_device_context(device_context)?;
741 Ok(())
742 }
743
744 pub fn reset_endpoint<
746 C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
747 >(
748 &self,
749 fail_handle: Arc<dyn FailHandle>,
750 endpoint_id: u8,
751 mut cb: C,
752 ) -> Result<()> {
753 if !valid_endpoint_id(endpoint_id) {
754 error!("trb indexing wrong endpoint id");
755 return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
756 }
757 let index = endpoint_id - 1;
758 let mut device_context = self.get_device_context()?;
759 let endpoint_context = &mut device_context.endpoint_context[index as usize];
760 if endpoint_context
761 .get_endpoint_state()
762 .map_err(Error::GetEndpointState)?
763 != EndpointState::Halted
764 {
765 error!("endpoint at index {} is not halted", index);
766 return cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed);
767 }
768 match self.get_trcs(index as usize) {
769 Some(TransferRingControllers::Endpoint(trc)) => {
770 let auto_cb = RingBufferStopCallback::new(fallible_closure(
771 fail_handle,
772 move || -> Result<()> {
773 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
774 },
775 ));
776 trc.stop(auto_cb);
777 let dequeue_pointer = trc.get_dequeue_pointer();
778 let dcs = trc.get_consumer_cycle_state();
779 endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
780 endpoint_context.set_dequeue_cycle_state(dcs);
781 }
782 Some(TransferRingControllers::Stream(trcs)) => {
783 let stream_context_array_addr = endpoint_context.get_tr_dequeue_pointer().get_gpa();
784 let mut stream_context_array: StreamContextArray = self
785 .mem
786 .read_obj_from_addr(stream_context_array_addr)
787 .map_err(Error::ReadGuestMemory)?;
788 let auto_cb = RingBufferStopCallback::new(fallible_closure(
789 fail_handle,
790 move || -> Result<()> {
791 cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
792 },
793 ));
794 for (i, trc) in trcs.iter().enumerate() {
795 let dequeue_pointer = trc.get_dequeue_pointer();
796 let dcs = trc.get_consumer_cycle_state();
797 trc.stop(auto_cb.clone());
798 stream_context_array.stream_contexts[i + 1]
799 .set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
800 stream_context_array.stream_contexts[i + 1].set_dequeue_cycle_state(dcs);
801 }
802 self.mem
803 .write_obj_at_addr(stream_context_array, stream_context_array_addr)
804 .map_err(Error::WriteGuestMemory)?;
805 }
806 None => {
807 error!("endpoint at index {} is not started", index);
808 cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed)?;
809 }
810 }
811 endpoint_context.set_endpoint_state(EndpointState::Stopped);
812 self.set_device_context(device_context)?;
813 Ok(())
814 }
815
816 pub fn set_tr_dequeue_ptr(
818 &self,
819 endpoint_id: u8,
820 stream_id: u16,
821 ptr: u64,
822 ) -> Result<TrbCompletionCode> {
823 if !valid_endpoint_id(endpoint_id) {
824 error!("trb indexing wrong endpoint id");
825 return Ok(TrbCompletionCode::TrbError);
826 }
827 let index = (endpoint_id - 1) as usize;
828 match self.get_trc(index, stream_id) {
829 Some(trc) => {
830 trc.set_dequeue_pointer(GuestAddress(ptr));
831 let mut ctx = self.get_device_context()?;
832 ctx.endpoint_context[index]
833 .set_tr_dequeue_pointer(DequeuePtr::new(GuestAddress(ptr)));
834 self.set_device_context(ctx)?;
835 Ok(TrbCompletionCode::Success)
836 }
837 None => {
838 error!("set tr dequeue ptr failed due to no trc started");
839 Ok(TrbCompletionCode::ContextStateError)
840 }
841 }
842 }
843
844 fn reset(&self) {
848 for i in 0..self.trc_len() {
849 self.set_trcs(i, None);
850 }
851 debug!("resetting device slot {}!", self.slot_id);
852 self.enabled.store(false, Ordering::SeqCst);
853 self.port_id.reset();
854 }
855
856 fn create_stream_trcs(
857 self: &Arc<Self>,
858 stream_context_array_addr: GuestAddress,
859 max_pstreams: u8,
860 device_context_index: u8,
861 ) -> Result<TransferRingControllers> {
862 let pstreams = 1usize << (max_pstreams + 1);
863 let stream_context_array: StreamContextArray = self
864 .mem
865 .read_obj_from_addr(stream_context_array_addr)
866 .map_err(Error::ReadGuestMemory)?;
867 let mut trcs = Vec::new();
868
869 for i in 1..pstreams {
871 let stream_context = &stream_context_array.stream_contexts[i];
872 let context_type = stream_context.get_stream_context_type();
873 if context_type != 1 {
874 return Err(Error::BadStreamContextType(context_type));
876 }
877 let trc = TransferRingController::new(
878 self.mem.clone(),
879 self.hub
880 .get_port(self.port_id.get()?)
881 .ok_or(Error::GetPort(self.port_id.get()?))?,
882 self.event_loop.clone(),
883 self.interrupter.clone(),
884 self.slot_id,
885 device_context_index,
886 Arc::downgrade(self),
887 Some(i as u16),
888 )
889 .map_err(Error::CreateTransferController)?;
890 trc.set_dequeue_pointer(stream_context.get_tr_dequeue_pointer().get_gpa());
891 trc.set_consumer_cycle_state(stream_context.get_dequeue_cycle_state());
892 trcs.push(trc);
893 }
894 Ok(TransferRingControllers::Stream(trcs))
895 }
896
897 fn add_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
898 xhci_trace!(
899 "adding one endpoint, device context index {}",
900 device_context_index
901 );
902 let mut device_context = self.get_device_context()?;
903 let transfer_ring_index = (device_context_index - 1) as usize;
904 let endpoint_context = &mut device_context.endpoint_context[transfer_ring_index];
905 let max_pstreams = endpoint_context.get_max_primary_streams();
906 let tr_dequeue_pointer = endpoint_context.get_tr_dequeue_pointer().get_gpa();
907 let endpoint_context_addr = self
908 .get_device_context_addr()?
909 .unchecked_add(size_of::<SlotContext>() as u64)
910 .unchecked_add(size_of::<EndpointContext>() as u64 * transfer_ring_index as u64);
911 let trcs = if max_pstreams > 0 {
912 if !valid_max_pstreams(max_pstreams) {
913 return Err(Error::BadEndpointContext(endpoint_context_addr));
914 }
915 let endpoint_type = endpoint_context.get_endpoint_type();
916 if endpoint_type != 2 && endpoint_type != 6 {
917 return Err(Error::BadEndpointId(transfer_ring_index as u8));
919 }
920 if endpoint_context.get_linear_stream_array() != 1 {
921 return Err(Error::BadEndpointContext(endpoint_context_addr));
923 }
924
925 let trcs =
926 self.create_stream_trcs(tr_dequeue_pointer, max_pstreams, device_context_index)?;
927
928 if let Some(port) = self.hub.get_port(self.port_id.get()?) {
929 if let Some(backend_device) = port.backend_device().as_mut() {
930 let mut endpoint_address = device_context_index / 2;
931 if device_context_index % 2 == 1 {
932 endpoint_address |= 1u8 << 7;
933 }
934 let streams = 1 << (max_pstreams + 1);
935 backend_device
937 .lock()
938 .alloc_streams(endpoint_address, streams - 1)
939 .map_err(Error::AllocStreams)?;
940 }
941 }
942 trcs
943 } else {
944 let trc = TransferRingController::new(
945 self.mem.clone(),
946 self.hub
947 .get_port(self.port_id.get()?)
948 .ok_or(Error::GetPort(self.port_id.get()?))?,
949 self.event_loop.clone(),
950 self.interrupter.clone(),
951 self.slot_id,
952 device_context_index,
953 Arc::downgrade(self),
954 None,
955 )
956 .map_err(Error::CreateTransferController)?;
957 trc.set_dequeue_pointer(tr_dequeue_pointer);
958 trc.set_consumer_cycle_state(endpoint_context.get_dequeue_cycle_state());
959 TransferRingControllers::Endpoint(trc)
960 };
961 self.set_trcs(transfer_ring_index, Some(trcs));
962 endpoint_context.set_endpoint_state(EndpointState::Running);
963 self.set_device_context(device_context)
964 }
965
966 fn drop_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
967 let endpoint_index = (device_context_index - 1) as usize;
968 let mut device_context = self.get_device_context()?;
969 let endpoint_context = &mut device_context.endpoint_context[endpoint_index];
970 if endpoint_context.get_max_primary_streams() > 0 {
971 if let Some(port) = self.hub.get_port(self.port_id.get()?) {
972 if let Some(backend_device) = port.backend_device().as_mut() {
973 let mut endpoint_address = device_context_index / 2;
974 if device_context_index % 2 == 1 {
975 endpoint_address |= 1u8 << 7;
976 }
977 backend_device
978 .lock()
979 .free_streams(endpoint_address)
980 .map_err(Error::FreeStreams)?;
981 }
982 }
983 }
984 self.set_trcs(endpoint_index, None);
985 endpoint_context.set_endpoint_state(EndpointState::Disabled);
986 self.set_device_context(device_context)
987 }
988
989 fn get_device_context(&self) -> Result<DeviceContext> {
990 let ctx = self
991 .mem
992 .read_obj_from_addr(self.get_device_context_addr()?)
993 .map_err(Error::ReadGuestMemory)?;
994 Ok(ctx)
995 }
996
997 fn set_device_context(&self, device_context: DeviceContext) -> Result<()> {
998 self.mem
999 .write_obj_at_addr(device_context, self.get_device_context_addr()?)
1000 .map_err(Error::WriteGuestMemory)
1001 }
1002
1003 fn copy_context(
1004 &self,
1005 input_context_ptr: GuestAddress,
1006 device_context_index: u8,
1007 ) -> Result<()> {
1008 let ctx: EndpointContext = self
1011 .mem
1012 .read_obj_from_addr(
1013 input_context_ptr
1014 .checked_add(
1015 (device_context_index as u64 + 1) * DEVICE_CONTEXT_ENTRY_SIZE as u64,
1016 )
1017 .ok_or(Error::BadInputContextAddr(input_context_ptr))?,
1018 )
1019 .map_err(Error::ReadGuestMemory)?;
1020 xhci_trace!("copy_context {:?}", ctx);
1021 let device_context_ptr = self.get_device_context_addr()?;
1022 self.mem
1023 .write_obj_at_addr(
1024 ctx,
1025 device_context_ptr
1026 .checked_add(device_context_index as u64 * DEVICE_CONTEXT_ENTRY_SIZE as u64)
1027 .ok_or(Error::BadDeviceContextAddr(device_context_ptr))?,
1028 )
1029 .map_err(Error::WriteGuestMemory)
1030 }
1031
1032 fn get_device_context_addr(&self) -> Result<GuestAddress> {
1033 let addr: u64 = self
1034 .mem
1035 .read_obj_from_addr(GuestAddress(
1036 self.dcbaap.get_value() + size_of::<u64>() as u64 * self.slot_id as u64,
1037 ))
1038 .map_err(Error::ReadGuestMemory)?;
1039 Ok(GuestAddress(addr))
1040 }
1041
1042 fn set_state(&self, state: DeviceSlotState) -> Result<()> {
1043 let mut ctx = self.get_device_context()?;
1044 ctx.slot_context.set_slot_state(state);
1045 self.set_device_context(ctx)
1046 }
1047
1048 pub fn halt_endpoint(&self, endpoint_id: u8) -> Result<()> {
1049 if !valid_endpoint_id(endpoint_id) {
1050 return Err(Error::BadEndpointId(endpoint_id));
1051 }
1052 let index = endpoint_id - 1;
1053 let mut device_context = self.get_device_context()?;
1054 let endpoint_context = &mut device_context.endpoint_context[index as usize];
1055 match self.get_trcs(index as usize) {
1056 Some(trcs) => match trcs {
1057 TransferRingControllers::Endpoint(trc) => {
1058 endpoint_context
1059 .set_tr_dequeue_pointer(DequeuePtr::new(trc.get_dequeue_pointer()));
1060 endpoint_context.set_dequeue_cycle_state(trc.get_consumer_cycle_state());
1061 }
1062 TransferRingControllers::Stream(trcs) => {
1063 let stream_context_array_addr =
1064 endpoint_context.get_tr_dequeue_pointer().get_gpa();
1065 let mut stream_context_array: StreamContextArray = self
1066 .mem
1067 .read_obj_from_addr(stream_context_array_addr)
1068 .map_err(Error::ReadGuestMemory)?;
1069 for (i, trc) in trcs.iter().enumerate() {
1070 stream_context_array.stream_contexts[i + 1]
1071 .set_tr_dequeue_pointer(DequeuePtr::new(trc.get_dequeue_pointer()));
1072 stream_context_array.stream_contexts[i + 1]
1073 .set_dequeue_cycle_state(trc.get_consumer_cycle_state());
1074 }
1075 self.mem
1076 .write_obj_at_addr(stream_context_array, stream_context_array_addr)
1077 .map_err(Error::WriteGuestMemory)?;
1078 }
1079 },
1080 None => {
1081 error!("trc for endpoint {} not found", endpoint_id);
1082 return Err(Error::BadEndpointId(endpoint_id));
1083 }
1084 }
1085 endpoint_context.set_endpoint_state(EndpointState::Halted);
1086 self.set_device_context(device_context)?;
1087 Ok(())
1088 }
1089}