devices/usb/xhci/
device_slot.rs

1// Copyright 2019 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::mem::size_of;
6use std::sync::atomic::AtomicBool;
7use std::sync::atomic::Ordering;
8use std::sync::Arc;
9
10use base::debug;
11use base::error;
12use base::info;
13use bit_field::Error as BitFieldError;
14use remain::sorted;
15use sync::Mutex;
16use thiserror::Error;
17use vm_memory::GuestAddress;
18use vm_memory::GuestMemory;
19use vm_memory::GuestMemoryError;
20
21use super::interrupter::Error as InterrupterError;
22use super::interrupter::Interrupter;
23use super::transfer_ring_controller::TransferRingController;
24use super::transfer_ring_controller::TransferRingControllerError;
25use super::transfer_ring_controller::TransferRingControllers;
26use super::usb_hub;
27use super::usb_hub::UsbHub;
28use super::xhci_abi::AddressDeviceCommandTrb;
29use super::xhci_abi::ConfigureEndpointCommandTrb;
30use super::xhci_abi::DequeuePtr;
31use super::xhci_abi::DeviceContext;
32use super::xhci_abi::DeviceSlotState;
33use super::xhci_abi::EndpointContext;
34use super::xhci_abi::EndpointState;
35use super::xhci_abi::EvaluateContextCommandTrb;
36use super::xhci_abi::InputControlContext;
37use super::xhci_abi::SlotContext;
38use super::xhci_abi::StreamContextArray;
39use super::xhci_abi::TrbCompletionCode;
40use super::xhci_abi::DEVICE_CONTEXT_ENTRY_SIZE;
41use super::xhci_backend_device::XhciBackendDevice;
42use super::xhci_regs::valid_max_pstreams;
43use super::xhci_regs::valid_slot_id;
44use super::xhci_regs::MAX_PORTS;
45use super::xhci_regs::MAX_SLOTS;
46use crate::register_space::Register;
47use crate::usb::backend::error::Error as BackendProviderError;
48use crate::usb::xhci::ring_buffer_stop_cb::fallible_closure;
49use crate::usb::xhci::ring_buffer_stop_cb::RingBufferStopCallback;
50use crate::utils::EventLoop;
51use crate::utils::FailHandle;
52
53#[sorted]
54#[derive(Error, Debug)]
55pub enum Error {
56    #[error("failed to allocate streams: {0}")]
57    AllocStreams(BackendProviderError),
58    #[error("bad device context: {0}")]
59    BadDeviceContextAddr(GuestAddress),
60    #[error("bad endpoint context: {0}")]
61    BadEndpointContext(GuestAddress),
62    #[error("device slot get a bad endpoint id: {0}")]
63    BadEndpointId(u8),
64    #[error("bad input context address: {0}")]
65    BadInputContextAddr(GuestAddress),
66    #[error("device slot get a bad port id: {0}")]
67    BadPortId(u8),
68    #[error("bad stream context type: {0}")]
69    BadStreamContextType(u8),
70    #[error("callback failed")]
71    CallbackFailed,
72    #[error("failed to create transfer controller: {0}")]
73    CreateTransferController(TransferRingControllerError),
74    #[error("failed to send Force Stopped Event: {0}")]
75    ForceStoppedEvent(InterrupterError),
76    #[error("failed to free streams: {0}")]
77    FreeStreams(BackendProviderError),
78    #[error("failed to get endpoint state: {0}")]
79    GetEndpointState(BitFieldError),
80    #[error("failed to get port: {0}")]
81    GetPort(u8),
82    #[error("failed to get slot context state: {0}")]
83    GetSlotContextState(BitFieldError),
84    #[error("failed to get trc: {0}")]
85    GetTrc(u8),
86    #[error("failed to read guest memory: {0}")]
87    ReadGuestMemory(GuestMemoryError),
88    #[error("failed to reset port: {0}")]
89    ResetPort(BackendProviderError),
90    #[error("failed to upgrade weak reference")]
91    WeakReferenceUpgrade,
92    #[error("failed to write guest memory: {0}")]
93    WriteGuestMemory(GuestMemoryError),
94}
95
96type Result<T> = std::result::Result<T, Error>;
97
98/// See spec 4.5.1 for dci.
99/// index 0: Control endpoint. Device Context Index: 1.
100/// index 1: Endpoint 1 out. Device Context Index: 2
101/// index 2: Endpoint 1 in. Device Context Index: 3.
102/// index 3: Endpoint 2 out. Device Context Index: 4
103/// ...
104/// index 30: Endpoint 15 in. Device Context Index: 31
105pub const TRANSFER_RING_CONTROLLERS_INDEX_END: usize = 31;
106/// End of device context index.
107pub const DCI_INDEX_END: u8 = (TRANSFER_RING_CONTROLLERS_INDEX_END + 1) as u8;
108/// Device context index of first transfer endpoint.
109pub const FIRST_TRANSFER_ENDPOINT_DCI: u8 = 2;
110
111fn valid_endpoint_id(endpoint_id: u8) -> bool {
112    endpoint_id < DCI_INDEX_END && endpoint_id > 0
113}
114
115#[derive(Clone)]
116pub struct DeviceSlots {
117    fail_handle: Arc<dyn FailHandle>,
118    hub: Arc<UsbHub>,
119    slots: Vec<Arc<DeviceSlot>>,
120}
121
122impl DeviceSlots {
123    pub fn new(
124        fail_handle: Arc<dyn FailHandle>,
125        dcbaap: Register<u64>,
126        hub: Arc<UsbHub>,
127        interrupter: Arc<Mutex<Interrupter>>,
128        event_loop: Arc<EventLoop>,
129        mem: GuestMemory,
130    ) -> DeviceSlots {
131        let mut slots = Vec::new();
132        for slot_id in 1..=MAX_SLOTS {
133            slots.push(Arc::new(DeviceSlot::new(
134                slot_id,
135                dcbaap.clone(),
136                hub.clone(),
137                interrupter.clone(),
138                event_loop.clone(),
139                mem.clone(),
140            )));
141        }
142        DeviceSlots {
143            fail_handle,
144            hub,
145            slots,
146        }
147    }
148
149    /// Note that slot id starts from 1. Slot index start from 0.
150    pub fn slot(&self, slot_id: u8) -> Option<Arc<DeviceSlot>> {
151        if valid_slot_id(slot_id) {
152            Some(self.slots[slot_id as usize - 1].clone())
153        } else {
154            error!(
155                "trying to index a wrong slot id {}, max slot = {}",
156                slot_id, MAX_SLOTS
157            );
158            None
159        }
160    }
161
162    /// Reset the device connected to a specific port.
163    pub fn reset_port(&self, port_id: u8) -> Result<()> {
164        if let Some(port) = self.hub.get_port(port_id) {
165            if let Some(backend_device) = port.backend_device().as_mut() {
166                backend_device.lock().reset().map_err(Error::ResetPort)?;
167            }
168        }
169
170        // No device on port, so nothing to reset.
171        Ok(())
172    }
173
174    /// Stop all device slots and reset them.
175    pub fn stop_all_and_reset<C: FnMut() + 'static + Send>(&self, mut callback: C) {
176        info!("xhci: stopping all device slots and resetting host hub");
177        let slots = self.slots.clone();
178        let hub = self.hub.clone();
179        let auto_callback = RingBufferStopCallback::new(fallible_closure(
180            self.fail_handle.clone(),
181            move || -> std::result::Result<(), usb_hub::Error> {
182                for slot in &slots {
183                    slot.reset();
184                }
185                hub.reset()?;
186                callback();
187                Ok(())
188            },
189        ));
190        self.stop_all(auto_callback);
191    }
192
193    /// Stop all devices. The auto callback will be executed when all trc is stopped. It could
194    /// happen asynchronously, if there are any pending transfers.
195    pub fn stop_all(&self, auto_callback: RingBufferStopCallback) {
196        for slot in &self.slots {
197            slot.stop_all_trc(auto_callback.clone());
198        }
199    }
200
201    /// Disable a slot. This might happen asynchronously, if there is any pending transfers. The
202    /// callback will be invoked when slot is actually disabled.
203    pub fn disable_slot<
204        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
205    >(
206        &self,
207        slot_id: u8,
208        cb: C,
209    ) -> Result<()> {
210        xhci_trace!("device slot {} is being disabled", slot_id);
211        DeviceSlot::disable(
212            self.fail_handle.clone(),
213            &self.slots[slot_id as usize - 1],
214            cb,
215        )
216    }
217
218    /// Reset a slot. This is a shortcut call for DeviceSlot::reset_slot.
219    pub fn reset_slot<
220        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
221    >(
222        &self,
223        slot_id: u8,
224        cb: C,
225    ) -> Result<()> {
226        xhci_trace!("device slot {} is resetting", slot_id);
227        DeviceSlot::reset_slot(
228            self.fail_handle.clone(),
229            &self.slots[slot_id as usize - 1],
230            cb,
231        )
232    }
233
234    pub fn stop_endpoint<
235        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
236    >(
237        &self,
238        slot_id: u8,
239        endpoint_id: u8,
240        cb: C,
241    ) -> Result<()> {
242        self.slots[slot_id as usize - 1].stop_endpoint(self.fail_handle.clone(), endpoint_id, cb)
243    }
244
245    pub fn reset_endpoint<
246        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
247    >(
248        &self,
249        slot_id: u8,
250        endpoint_id: u8,
251        cb: C,
252    ) -> Result<()> {
253        self.slots[slot_id as usize - 1].reset_endpoint(self.fail_handle.clone(), endpoint_id, cb)
254    }
255}
256
257// Usb port id. Valid ids starts from 1, to MAX_PORTS.
258struct PortId(Mutex<u8>);
259
260impl PortId {
261    fn new() -> Self {
262        PortId(Mutex::new(0))
263    }
264
265    fn set(&self, value: u8) -> Result<()> {
266        if !(1..=MAX_PORTS).contains(&value) {
267            return Err(Error::BadPortId(value));
268        }
269        *self.0.lock() = value;
270        Ok(())
271    }
272
273    fn reset(&self) {
274        *self.0.lock() = 0;
275    }
276
277    fn get(&self) -> Result<u8> {
278        let val = *self.0.lock();
279        if val == 0 {
280            return Err(Error::BadPortId(val));
281        }
282        Ok(val)
283    }
284}
285
286pub struct DeviceSlot {
287    slot_id: u8,
288    port_id: PortId, // Valid port id starts from 1, to MAX_PORTS.
289    dcbaap: Register<u64>,
290    hub: Arc<UsbHub>,
291    interrupter: Arc<Mutex<Interrupter>>,
292    event_loop: Arc<EventLoop>,
293    mem: GuestMemory,
294    enabled: AtomicBool,
295    transfer_ring_controllers: Mutex<Vec<Option<TransferRingControllers>>>,
296}
297
298impl DeviceSlot {
299    /// Create a new device slot.
300    pub fn new(
301        slot_id: u8,
302        dcbaap: Register<u64>,
303        hub: Arc<UsbHub>,
304        interrupter: Arc<Mutex<Interrupter>>,
305        event_loop: Arc<EventLoop>,
306        mem: GuestMemory,
307    ) -> Self {
308        let mut transfer_ring_controllers = Vec::new();
309        transfer_ring_controllers.resize_with(TRANSFER_RING_CONTROLLERS_INDEX_END, || None);
310        DeviceSlot {
311            slot_id,
312            port_id: PortId::new(),
313            dcbaap,
314            hub,
315            interrupter,
316            event_loop,
317            mem,
318            enabled: AtomicBool::new(false),
319            transfer_ring_controllers: Mutex::new(transfer_ring_controllers),
320        }
321    }
322
323    fn get_trc(&self, i: usize, stream_id: u16) -> Option<Arc<TransferRingController>> {
324        let trcs = self.transfer_ring_controllers.lock();
325        match &trcs[i] {
326            Some(TransferRingControllers::Endpoint(trc)) => Some(trc.clone()),
327            Some(TransferRingControllers::Stream(trcs)) => {
328                let stream_id = stream_id as usize;
329                if stream_id > 0 && stream_id <= trcs.len() {
330                    Some(trcs[stream_id - 1].clone())
331                } else {
332                    None
333                }
334            }
335            None => None,
336        }
337    }
338
339    fn get_trcs(&self, i: usize) -> Option<TransferRingControllers> {
340        let trcs = self.transfer_ring_controllers.lock();
341        trcs[i].clone()
342    }
343
344    fn set_trcs(&self, i: usize, trc: Option<TransferRingControllers>) {
345        let mut trcs = self.transfer_ring_controllers.lock();
346        trcs[i] = trc;
347    }
348
349    fn trc_len(&self) -> usize {
350        self.transfer_ring_controllers.lock().len()
351    }
352
353    /// The arguments are identical to the fields in each doorbell register. The
354    /// target value:
355    /// 1: Reserved
356    /// 2: Control endpoint
357    /// 3: Endpoint 1 out
358    /// 4: Endpoint 1 in
359    /// 5: Endpoint 2 out
360    /// ...
361    /// 32: Endpoint 15 in
362    ///
363    /// Steam ID will be useful when host controller support streams.
364    /// The stream ID must be zero for endpoints that do not have streams
365    /// configured.
366    /// This function will return false if it fails to trigger transfer ring start.
367    pub fn ring_doorbell(&self, target: u8, stream_id: u16) -> Result<bool> {
368        if !valid_endpoint_id(target) {
369            error!(
370                "device slot {}: Invalid target written to doorbell register. target: {}",
371                self.slot_id, target
372            );
373            return Ok(false);
374        }
375        xhci_trace!(
376            "device slot {}: ring_doorbell target = {} stream_id = {}",
377            self.slot_id,
378            target,
379            stream_id
380        );
381        // See DCI in spec.
382        let endpoint_index = (target - 1) as usize;
383        let transfer_ring_controller = match self.get_trc(endpoint_index, stream_id) {
384            Some(tr) => tr,
385            None => {
386                error!("Device endpoint is not inited");
387                return Ok(false);
388            }
389        };
390        let mut context = self.get_device_context()?;
391        let endpoint_state = context.endpoint_context[endpoint_index]
392            .get_endpoint_state()
393            .map_err(Error::GetEndpointState)?;
394        if endpoint_state == EndpointState::Running || endpoint_state == EndpointState::Stopped {
395            if endpoint_state == EndpointState::Stopped {
396                context.endpoint_context[endpoint_index].set_endpoint_state(EndpointState::Running);
397                self.set_device_context(context)?;
398            }
399            // endpoint is started, start transfer ring
400            transfer_ring_controller.start();
401        } else {
402            error!("doorbell rung when endpoint state is {:?}", endpoint_state);
403        }
404        Ok(true)
405    }
406
407    /// Enable the slot. This function returns false if it's already enabled.
408    pub fn enable(&self) -> bool {
409        let was_already_enabled = self.enabled.swap(true, Ordering::SeqCst);
410        if was_already_enabled {
411            error!("device slot is already enabled");
412        }
413        !was_already_enabled
414    }
415
416    /// Disable this device slot. If the slot is not enabled, callback will be invoked immediately
417    /// with error. Otherwise, callback will be invoked when all trc is stopped.
418    pub fn disable<C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send>(
419        fail_handle: Arc<dyn FailHandle>,
420        slot: &Arc<DeviceSlot>,
421        mut callback: C,
422    ) -> Result<()> {
423        if slot.enabled.load(Ordering::SeqCst) {
424            let slot_weak = Arc::downgrade(slot);
425            let auto_callback =
426                RingBufferStopCallback::new(fallible_closure(fail_handle, move || {
427                    // Slot should still be alive when the callback is invoked. If it's not, there
428                    // must be a bug somewhere.
429                    let slot = slot_weak.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
430                    let mut device_context = slot.get_device_context()?;
431                    device_context
432                        .slot_context
433                        .set_slot_state(DeviceSlotState::DisabledOrEnabled);
434                    slot.set_device_context(device_context)?;
435                    slot.reset();
436                    debug!(
437                        "device slot {}: all trc disabled, sending trb",
438                        slot.slot_id
439                    );
440                    callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
441                }));
442            slot.stop_all_trc(auto_callback);
443            Ok(())
444        } else {
445            callback(TrbCompletionCode::SlotNotEnabledError).map_err(|_| Error::CallbackFailed)
446        }
447    }
448
449    // Assigns the device address and initializes slot and endpoint 0 context.
450    pub fn set_address(
451        self: &Arc<Self>,
452        trb: &AddressDeviceCommandTrb,
453    ) -> Result<TrbCompletionCode> {
454        if !self.enabled.load(Ordering::SeqCst) {
455            error!(
456                "trying to set address to a disabled device slot {}",
457                self.slot_id
458            );
459            return Ok(TrbCompletionCode::SlotNotEnabledError);
460        }
461        let device_context = self.get_device_context()?;
462        let state = device_context
463            .slot_context
464            .get_slot_state()
465            .map_err(Error::GetSlotContextState)?;
466        match state {
467            DeviceSlotState::DisabledOrEnabled => {}
468            DeviceSlotState::Default if !trb.get_block_set_address_request() => {}
469            _ => {
470                error!("slot {} has unexpected slot state", self.slot_id);
471                return Ok(TrbCompletionCode::ContextStateError);
472            }
473        }
474
475        // Copy all fields of the slot context and endpoint 0 context from the input context
476        // to the output context.
477        let input_context_ptr = GuestAddress(trb.get_input_context_pointer());
478        // Copy slot context.
479        self.copy_context(input_context_ptr, 0)?;
480        // Copy control endpoint context.
481        self.copy_context(input_context_ptr, 1)?;
482
483        // Read back device context.
484        let mut device_context = self.get_device_context()?;
485        let port_id = device_context.slot_context.get_root_hub_port_number();
486        self.port_id.set(port_id)?;
487        debug!(
488            "port id {} is assigned to slot id {}",
489            port_id, self.slot_id
490        );
491
492        // Initialize the control endpoint. Endpoint id = 1.
493        let trc = TransferRingController::new(
494            self.mem.clone(),
495            self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?,
496            self.event_loop.clone(),
497            self.interrupter.clone(),
498            self.slot_id,
499            1,
500            Arc::downgrade(self),
501            None,
502        )
503        .map_err(Error::CreateTransferController)?;
504        self.set_trcs(0, Some(TransferRingControllers::Endpoint(trc)));
505
506        // Assign slot ID as device address if block_set_address_request is not set.
507        if trb.get_block_set_address_request() {
508            device_context
509                .slot_context
510                .set_slot_state(DeviceSlotState::Default);
511        } else {
512            let port = self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?;
513            match port.backend_device().as_mut() {
514                Some(backend) => {
515                    backend.lock().set_address(self.slot_id as u32);
516                }
517                None => {
518                    return Ok(TrbCompletionCode::TransactionError);
519                }
520            }
521
522            device_context
523                .slot_context
524                .set_usb_device_address(self.slot_id);
525            device_context
526                .slot_context
527                .set_slot_state(DeviceSlotState::Addressed);
528        }
529
530        // TODO(jkwang) trc should always exists. Fix this.
531        self.get_trc(0, 0)
532            .ok_or(Error::GetTrc(0))?
533            .set_dequeue_pointer(
534                device_context.endpoint_context[0]
535                    .get_tr_dequeue_pointer()
536                    .get_gpa(),
537            );
538
539        self.get_trc(0, 0)
540            .ok_or(Error::GetTrc(0))?
541            .set_consumer_cycle_state(device_context.endpoint_context[0].get_dequeue_cycle_state());
542
543        // Setting endpoint 0 to running
544        device_context.endpoint_context[0].set_endpoint_state(EndpointState::Running);
545        self.set_device_context(device_context)?;
546        Ok(TrbCompletionCode::Success)
547    }
548
549    // Adds or drops multiple endpoints in the device slot.
550    pub fn configure_endpoint(
551        self: &Arc<Self>,
552        trb: &ConfigureEndpointCommandTrb,
553    ) -> Result<TrbCompletionCode> {
554        let input_control_context = if trb.get_deconfigure() {
555            // From section 4.6.6 of the xHCI spec:
556            // Setting the deconfigure (DC) flag to '1' in the Configure Endpoint Command
557            // TRB is equivalent to setting Input Context Drop Context flags 2-31 to '1'
558            // and Add Context 2-31 flags to '0'.
559            let mut c = InputControlContext::new();
560            c.set_add_context_flags(0);
561            c.set_drop_context_flags(0xfffffffc);
562            c
563        } else {
564            self.mem
565                .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
566                .map_err(Error::ReadGuestMemory)?
567        };
568
569        for device_context_index in 1..DCI_INDEX_END {
570            if input_control_context.drop_context_flag(device_context_index) {
571                self.drop_one_endpoint(device_context_index)?;
572            }
573            if input_control_context.add_context_flag(device_context_index) {
574                self.copy_context(
575                    GuestAddress(trb.get_input_context_pointer()),
576                    device_context_index,
577                )?;
578                self.add_one_endpoint(device_context_index)?;
579            }
580        }
581
582        if trb.get_deconfigure() {
583            self.set_state(DeviceSlotState::Addressed)?;
584        } else {
585            self.set_state(DeviceSlotState::Configured)?;
586        }
587        Ok(TrbCompletionCode::Success)
588    }
589
590    // Evaluates the device context by reading new values for certain fields of
591    // the slot context and/or control endpoint context.
592    pub fn evaluate_context(&self, trb: &EvaluateContextCommandTrb) -> Result<TrbCompletionCode> {
593        if !self.enabled.load(Ordering::SeqCst) {
594            return Ok(TrbCompletionCode::SlotNotEnabledError);
595        }
596        // TODO(jkwang) verify this
597        // The spec has multiple contradictions about validating context parameters in sections
598        // 4.6.7, 6.2.3.3. To keep things as simple as possible we do no further validation here.
599        let input_control_context: InputControlContext = self
600            .mem
601            .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
602            .map_err(Error::ReadGuestMemory)?;
603
604        let mut device_context = self.get_device_context()?;
605        if input_control_context.add_context_flag(0) {
606            let input_slot_context: SlotContext = self
607                .mem
608                .read_obj_from_addr(GuestAddress(
609                    trb.get_input_context_pointer() + DEVICE_CONTEXT_ENTRY_SIZE as u64,
610                ))
611                .map_err(Error::ReadGuestMemory)?;
612            device_context
613                .slot_context
614                .set_interrupter_target(input_slot_context.get_interrupter_target());
615
616            device_context
617                .slot_context
618                .set_max_exit_latency(input_slot_context.get_max_exit_latency());
619        }
620
621        // From 6.2.3.3: "Endpoint Contexts 2 throught 31 shall not be evaluated by the Evaluate
622        // Context Command".
623        if input_control_context.add_context_flag(1) {
624            let ep0_context: EndpointContext = self
625                .mem
626                .read_obj_from_addr(GuestAddress(
627                    trb.get_input_context_pointer() + 2 * DEVICE_CONTEXT_ENTRY_SIZE as u64,
628                ))
629                .map_err(Error::ReadGuestMemory)?;
630            device_context.endpoint_context[0]
631                .set_max_packet_size(ep0_context.get_max_packet_size());
632        }
633        self.set_device_context(device_context)?;
634        Ok(TrbCompletionCode::Success)
635    }
636
637    /// Reset the device slot to default state and deconfigures all but the
638    /// control endpoint.
639    pub fn reset_slot<
640        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
641    >(
642        fail_handle: Arc<dyn FailHandle>,
643        slot: &Arc<DeviceSlot>,
644        mut callback: C,
645    ) -> Result<()> {
646        let weak_s = Arc::downgrade(slot);
647        let auto_callback =
648            RingBufferStopCallback::new(fallible_closure(fail_handle, move || -> Result<()> {
649                let s = weak_s.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
650                for i in FIRST_TRANSFER_ENDPOINT_DCI..DCI_INDEX_END {
651                    s.drop_one_endpoint(i)?;
652                }
653                let mut ctx = s.get_device_context()?;
654                ctx.slot_context.set_slot_state(DeviceSlotState::Default);
655                ctx.slot_context.set_context_entries(1);
656                ctx.slot_context.set_root_hub_port_number(0);
657                s.set_device_context(ctx)?;
658                callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)?;
659                Ok(())
660            }));
661        slot.stop_all_trc(auto_callback);
662        Ok(())
663    }
664
665    /// Stop all transfer ring controllers.
666    pub fn stop_all_trc(&self, auto_callback: RingBufferStopCallback) {
667        for i in 0..self.trc_len() {
668            if let Some(trcs) = self.get_trcs(i) {
669                match trcs {
670                    TransferRingControllers::Endpoint(trc) => {
671                        trc.stop(auto_callback.clone());
672                    }
673                    TransferRingControllers::Stream(trcs) => {
674                        for trc in trcs {
675                            trc.stop(auto_callback.clone());
676                        }
677                    }
678                }
679            }
680        }
681    }
682
683    fn force_stoppped_event(&self, slot_id: u8, endpoint_id: u8, dequeue_ptr: u64) -> Result<()> {
684        self.interrupter
685            .lock()
686            .send_transfer_event_trb(
687                TrbCompletionCode::StoppedLengthInvalid,
688                dequeue_ptr,
689                0,
690                false,
691                slot_id,
692                endpoint_id,
693            )
694            .map_err(Error::ForceStoppedEvent)
695    }
696
697    /// Stop an endpoint.
698    pub fn stop_endpoint<
699        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
700    >(
701        &self,
702        fail_handle: Arc<dyn FailHandle>,
703        endpoint_id: u8,
704        mut cb: C,
705    ) -> Result<()> {
706        if !valid_endpoint_id(endpoint_id) {
707            error!("trb indexing wrong endpoint id");
708            return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
709        }
710        let index = endpoint_id - 1;
711        let mut device_context = self.get_device_context()?;
712        let endpoint_context = &mut device_context.endpoint_context[index as usize];
713        match self.get_trcs(index as usize) {
714            Some(TransferRingControllers::Endpoint(trc)) => {
715                let auto_cb = RingBufferStopCallback::new(fallible_closure(
716                    fail_handle,
717                    move || -> Result<()> {
718                        cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
719                    },
720                ));
721                trc.stop(auto_cb);
722                let dequeue_pointer = trc.get_dequeue_pointer();
723                let dcs = trc.get_consumer_cycle_state();
724                endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
725                endpoint_context.set_dequeue_cycle_state(dcs);
726                self.force_stoppped_event(self.slot_id, endpoint_id, dequeue_pointer.offset())?;
727            }
728            Some(TransferRingControllers::Stream(trcs)) => {
729                let stream_context_array_addr = endpoint_context.get_tr_dequeue_pointer().get_gpa();
730                let mut stream_context_array: StreamContextArray = self
731                    .mem
732                    .read_obj_from_addr(stream_context_array_addr)
733                    .map_err(Error::ReadGuestMemory)?;
734                let auto_cb = RingBufferStopCallback::new(fallible_closure(
735                    fail_handle,
736                    move || -> Result<()> {
737                        cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
738                    },
739                ));
740                for (i, trc) in trcs.iter().enumerate() {
741                    let dequeue_pointer = trc.get_dequeue_pointer();
742                    let dcs = trc.get_consumer_cycle_state();
743                    trc.stop(auto_cb.clone());
744                    stream_context_array.stream_contexts[i + 1]
745                        .set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
746                    stream_context_array.stream_contexts[i + 1].set_dequeue_cycle_state(dcs);
747                }
748                self.mem
749                    .write_obj_at_addr(stream_context_array, stream_context_array_addr)
750                    .map_err(Error::WriteGuestMemory)?;
751            }
752            None => {
753                error!("endpoint at index {} is not started", index);
754                cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed)?;
755            }
756        }
757        endpoint_context.set_endpoint_state(EndpointState::Stopped);
758        self.set_device_context(device_context)?;
759        Ok(())
760    }
761
762    /// Reset an endpoint.
763    pub fn reset_endpoint<
764        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
765    >(
766        &self,
767        fail_handle: Arc<dyn FailHandle>,
768        endpoint_id: u8,
769        mut cb: C,
770    ) -> Result<()> {
771        if !valid_endpoint_id(endpoint_id) {
772            error!("trb indexing wrong endpoint id");
773            return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
774        }
775        let index = endpoint_id - 1;
776        let mut device_context = self.get_device_context()?;
777        let endpoint_context = &mut device_context.endpoint_context[index as usize];
778        if endpoint_context
779            .get_endpoint_state()
780            .map_err(Error::GetEndpointState)?
781            != EndpointState::Halted
782        {
783            error!("endpoint at index {} is not halted", index);
784            return cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed);
785        }
786        match self.get_trcs(index as usize) {
787            Some(TransferRingControllers::Endpoint(trc)) => {
788                let auto_cb = RingBufferStopCallback::new(fallible_closure(
789                    fail_handle,
790                    move || -> Result<()> {
791                        cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
792                    },
793                ));
794                trc.stop(auto_cb);
795                let dequeue_pointer = trc.get_dequeue_pointer();
796                let dcs = trc.get_consumer_cycle_state();
797                endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
798                endpoint_context.set_dequeue_cycle_state(dcs);
799            }
800            Some(TransferRingControllers::Stream(trcs)) => {
801                let stream_context_array_addr = endpoint_context.get_tr_dequeue_pointer().get_gpa();
802                let mut stream_context_array: StreamContextArray = self
803                    .mem
804                    .read_obj_from_addr(stream_context_array_addr)
805                    .map_err(Error::ReadGuestMemory)?;
806                let auto_cb = RingBufferStopCallback::new(fallible_closure(
807                    fail_handle,
808                    move || -> Result<()> {
809                        cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
810                    },
811                ));
812                for (i, trc) in trcs.iter().enumerate() {
813                    let dequeue_pointer = trc.get_dequeue_pointer();
814                    let dcs = trc.get_consumer_cycle_state();
815                    trc.stop(auto_cb.clone());
816                    stream_context_array.stream_contexts[i + 1]
817                        .set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
818                    stream_context_array.stream_contexts[i + 1].set_dequeue_cycle_state(dcs);
819                }
820                self.mem
821                    .write_obj_at_addr(stream_context_array, stream_context_array_addr)
822                    .map_err(Error::WriteGuestMemory)?;
823            }
824            None => {
825                error!("endpoint at index {} is not started", index);
826                cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed)?;
827            }
828        }
829        endpoint_context.set_endpoint_state(EndpointState::Stopped);
830        self.set_device_context(device_context)?;
831        Ok(())
832    }
833
834    /// Set transfer ring dequeue pointer.
835    pub fn set_tr_dequeue_ptr(
836        &self,
837        endpoint_id: u8,
838        stream_id: u16,
839        ptr: u64,
840    ) -> Result<TrbCompletionCode> {
841        if !valid_endpoint_id(endpoint_id) {
842            error!("trb indexing wrong endpoint id");
843            return Ok(TrbCompletionCode::TrbError);
844        }
845        let index = (endpoint_id - 1) as usize;
846        match self.get_trc(index, stream_id) {
847            Some(trc) => {
848                trc.set_dequeue_pointer(GuestAddress(ptr));
849                let mut ctx = self.get_device_context()?;
850                ctx.endpoint_context[index]
851                    .set_tr_dequeue_pointer(DequeuePtr::new(GuestAddress(ptr)));
852                self.set_device_context(ctx)?;
853                Ok(TrbCompletionCode::Success)
854            }
855            None => {
856                error!("set tr dequeue ptr failed due to no trc started");
857                Ok(TrbCompletionCode::ContextStateError)
858            }
859        }
860    }
861
862    // Reset and reset_slot are different.
863    // Reset_slot handles command ring `reset slot` command. It will reset the slot state.
864    // Reset handles xhci reset. It will destroy everything.
865    fn reset(&self) {
866        for i in 0..self.trc_len() {
867            self.set_trcs(i, None);
868        }
869        debug!("resetting device slot {}!", self.slot_id);
870        self.enabled.store(false, Ordering::SeqCst);
871        self.port_id.reset();
872    }
873
874    fn create_stream_trcs(
875        self: &Arc<Self>,
876        stream_context_array_addr: GuestAddress,
877        max_pstreams: u8,
878        device_context_index: u8,
879    ) -> Result<TransferRingControllers> {
880        let pstreams = 1usize << (max_pstreams + 1);
881        let stream_context_array: StreamContextArray = self
882            .mem
883            .read_obj_from_addr(stream_context_array_addr)
884            .map_err(Error::ReadGuestMemory)?;
885        let mut trcs = Vec::new();
886
887        // Stream ID 0 is reserved (xHCI spec Section 4.12.2)
888        for i in 1..pstreams {
889            let stream_context = &stream_context_array.stream_contexts[i];
890            let context_type = stream_context.get_stream_context_type();
891            if context_type != 1 {
892                // We only support Linear Stream Context Array for now
893                return Err(Error::BadStreamContextType(context_type));
894            }
895            let trc = TransferRingController::new(
896                self.mem.clone(),
897                self.hub
898                    .get_port(self.port_id.get()?)
899                    .ok_or(Error::GetPort(self.port_id.get()?))?,
900                self.event_loop.clone(),
901                self.interrupter.clone(),
902                self.slot_id,
903                device_context_index,
904                Arc::downgrade(self),
905                Some(i as u16),
906            )
907            .map_err(Error::CreateTransferController)?;
908            trc.set_dequeue_pointer(stream_context.get_tr_dequeue_pointer().get_gpa());
909            trc.set_consumer_cycle_state(stream_context.get_dequeue_cycle_state());
910            trcs.push(trc);
911        }
912        Ok(TransferRingControllers::Stream(trcs))
913    }
914
915    fn add_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
916        xhci_trace!(
917            "adding one endpoint, device context index {}",
918            device_context_index
919        );
920        let mut device_context = self.get_device_context()?;
921        let transfer_ring_index = (device_context_index - 1) as usize;
922        let endpoint_context = &mut device_context.endpoint_context[transfer_ring_index];
923        let max_pstreams = endpoint_context.get_max_primary_streams();
924        let tr_dequeue_pointer = endpoint_context.get_tr_dequeue_pointer().get_gpa();
925        let endpoint_context_addr = self
926            .get_device_context_addr()?
927            .unchecked_add(size_of::<SlotContext>() as u64)
928            .unchecked_add(size_of::<EndpointContext>() as u64 * transfer_ring_index as u64);
929        let trcs = if max_pstreams > 0 {
930            if !valid_max_pstreams(max_pstreams) {
931                return Err(Error::BadEndpointContext(endpoint_context_addr));
932            }
933            let endpoint_type = endpoint_context.get_endpoint_type();
934            if endpoint_type != 2 && endpoint_type != 6 {
935                // Stream is only supported on a bulk endpoint
936                return Err(Error::BadEndpointId(transfer_ring_index as u8));
937            }
938            if endpoint_context.get_linear_stream_array() != 1 {
939                // We only support Linear Stream Context Array for now
940                return Err(Error::BadEndpointContext(endpoint_context_addr));
941            }
942
943            let trcs =
944                self.create_stream_trcs(tr_dequeue_pointer, max_pstreams, device_context_index)?;
945
946            if let Some(port) = self.hub.get_port(self.port_id.get()?) {
947                if let Some(backend_device) = port.backend_device().as_mut() {
948                    let mut endpoint_address = device_context_index / 2;
949                    if device_context_index % 2 == 1 {
950                        endpoint_address |= 1u8 << 7;
951                    }
952                    let streams = 1 << (max_pstreams + 1);
953                    // Subtracting 1 is to ignore Stream ID 0
954                    backend_device
955                        .lock()
956                        .alloc_streams(endpoint_address, streams - 1)
957                        .map_err(Error::AllocStreams)?;
958                }
959            }
960            trcs
961        } else {
962            let trc = TransferRingController::new(
963                self.mem.clone(),
964                self.hub
965                    .get_port(self.port_id.get()?)
966                    .ok_or(Error::GetPort(self.port_id.get()?))?,
967                self.event_loop.clone(),
968                self.interrupter.clone(),
969                self.slot_id,
970                device_context_index,
971                Arc::downgrade(self),
972                None,
973            )
974            .map_err(Error::CreateTransferController)?;
975            trc.set_dequeue_pointer(tr_dequeue_pointer);
976            trc.set_consumer_cycle_state(endpoint_context.get_dequeue_cycle_state());
977            TransferRingControllers::Endpoint(trc)
978        };
979        self.set_trcs(transfer_ring_index, Some(trcs));
980        endpoint_context.set_endpoint_state(EndpointState::Running);
981        self.set_device_context(device_context)
982    }
983
984    fn drop_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
985        let endpoint_index = (device_context_index - 1) as usize;
986        let mut device_context = self.get_device_context()?;
987        let endpoint_context = &mut device_context.endpoint_context[endpoint_index];
988        if endpoint_context.get_max_primary_streams() > 0 {
989            if let Some(port) = self.hub.get_port(self.port_id.get()?) {
990                if let Some(backend_device) = port.backend_device().as_mut() {
991                    let mut endpoint_address = device_context_index / 2;
992                    if device_context_index % 2 == 1 {
993                        endpoint_address |= 1u8 << 7;
994                    }
995                    backend_device
996                        .lock()
997                        .free_streams(endpoint_address)
998                        .map_err(Error::FreeStreams)?;
999                }
1000            }
1001        }
1002        self.set_trcs(endpoint_index, None);
1003        endpoint_context.set_endpoint_state(EndpointState::Disabled);
1004        self.set_device_context(device_context)
1005    }
1006
1007    fn get_device_context(&self) -> Result<DeviceContext> {
1008        let ctx = self
1009            .mem
1010            .read_obj_from_addr(self.get_device_context_addr()?)
1011            .map_err(Error::ReadGuestMemory)?;
1012        Ok(ctx)
1013    }
1014
1015    fn set_device_context(&self, device_context: DeviceContext) -> Result<()> {
1016        self.mem
1017            .write_obj_at_addr(device_context, self.get_device_context_addr()?)
1018            .map_err(Error::WriteGuestMemory)
1019    }
1020
1021    fn copy_context(
1022        &self,
1023        input_context_ptr: GuestAddress,
1024        device_context_index: u8,
1025    ) -> Result<()> {
1026        // Note that it could be slot context or device context. They have the same size. Won't
1027        // make a difference here.
1028        let ctx: EndpointContext = self
1029            .mem
1030            .read_obj_from_addr(
1031                input_context_ptr
1032                    .checked_add(
1033                        (device_context_index as u64 + 1) * DEVICE_CONTEXT_ENTRY_SIZE as u64,
1034                    )
1035                    .ok_or(Error::BadInputContextAddr(input_context_ptr))?,
1036            )
1037            .map_err(Error::ReadGuestMemory)?;
1038        xhci_trace!("copy_context {:?}", ctx);
1039        let device_context_ptr = self.get_device_context_addr()?;
1040        self.mem
1041            .write_obj_at_addr(
1042                ctx,
1043                device_context_ptr
1044                    .checked_add(device_context_index as u64 * DEVICE_CONTEXT_ENTRY_SIZE as u64)
1045                    .ok_or(Error::BadDeviceContextAddr(device_context_ptr))?,
1046            )
1047            .map_err(Error::WriteGuestMemory)
1048    }
1049
1050    fn get_device_context_addr(&self) -> Result<GuestAddress> {
1051        let addr: u64 = self
1052            .mem
1053            .read_obj_from_addr(GuestAddress(
1054                self.dcbaap.get_value() + size_of::<u64>() as u64 * self.slot_id as u64,
1055            ))
1056            .map_err(Error::ReadGuestMemory)?;
1057        Ok(GuestAddress(addr))
1058    }
1059
1060    fn set_state(&self, state: DeviceSlotState) -> Result<()> {
1061        let mut ctx = self.get_device_context()?;
1062        ctx.slot_context.set_slot_state(state);
1063        self.set_device_context(ctx)
1064    }
1065
1066    pub fn halt_endpoint(&self, endpoint_id: u8) -> Result<()> {
1067        if !valid_endpoint_id(endpoint_id) {
1068            return Err(Error::BadEndpointId(endpoint_id));
1069        }
1070        let index = endpoint_id - 1;
1071        let mut device_context = self.get_device_context()?;
1072        let endpoint_context = &mut device_context.endpoint_context[index as usize];
1073        match self.get_trcs(index as usize) {
1074            Some(trcs) => match trcs {
1075                TransferRingControllers::Endpoint(trc) => {
1076                    endpoint_context
1077                        .set_tr_dequeue_pointer(DequeuePtr::new(trc.get_dequeue_pointer()));
1078                    endpoint_context.set_dequeue_cycle_state(trc.get_consumer_cycle_state());
1079                }
1080                TransferRingControllers::Stream(trcs) => {
1081                    let stream_context_array_addr =
1082                        endpoint_context.get_tr_dequeue_pointer().get_gpa();
1083                    let mut stream_context_array: StreamContextArray = self
1084                        .mem
1085                        .read_obj_from_addr(stream_context_array_addr)
1086                        .map_err(Error::ReadGuestMemory)?;
1087                    for (i, trc) in trcs.iter().enumerate() {
1088                        stream_context_array.stream_contexts[i + 1]
1089                            .set_tr_dequeue_pointer(DequeuePtr::new(trc.get_dequeue_pointer()));
1090                        stream_context_array.stream_contexts[i + 1]
1091                            .set_dequeue_cycle_state(trc.get_consumer_cycle_state());
1092                    }
1093                    self.mem
1094                        .write_obj_at_addr(stream_context_array, stream_context_array_addr)
1095                        .map_err(Error::WriteGuestMemory)?;
1096                }
1097            },
1098            None => {
1099                error!("trc for endpoint {} not found", endpoint_id);
1100                return Err(Error::BadEndpointId(endpoint_id));
1101            }
1102        }
1103        endpoint_context.set_endpoint_state(EndpointState::Halted);
1104        self.set_device_context(device_context)?;
1105        Ok(())
1106    }
1107
1108    pub fn get_max_esit_payload(&self, endpoint_id: u8) -> Result<u32> {
1109        let index = endpoint_id - 1;
1110        let device_context = self.get_device_context()?;
1111        let endpoint_context = device_context.endpoint_context[index as usize];
1112
1113        let lo = endpoint_context.get_max_esit_payload_lo() as u32;
1114        let hi = endpoint_context.get_max_esit_payload_hi() as u32;
1115        Ok(hi << 16 | lo)
1116    }
1117}