devices/usb/xhci/
device_slot.rs

1// Copyright 2019 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::mem::size_of;
6use std::sync::atomic::AtomicBool;
7use std::sync::atomic::Ordering;
8use std::sync::Arc;
9
10use base::debug;
11use base::error;
12use base::info;
13use bit_field::Error as BitFieldError;
14use remain::sorted;
15use sync::Mutex;
16use thiserror::Error;
17use vm_memory::GuestAddress;
18use vm_memory::GuestMemory;
19use vm_memory::GuestMemoryError;
20
21use super::interrupter::Interrupter;
22use super::transfer_ring_controller::TransferRingController;
23use super::transfer_ring_controller::TransferRingControllerError;
24use super::transfer_ring_controller::TransferRingControllers;
25use super::usb_hub;
26use super::usb_hub::UsbHub;
27use super::xhci_abi::AddressDeviceCommandTrb;
28use super::xhci_abi::ConfigureEndpointCommandTrb;
29use super::xhci_abi::DequeuePtr;
30use super::xhci_abi::DeviceContext;
31use super::xhci_abi::DeviceSlotState;
32use super::xhci_abi::EndpointContext;
33use super::xhci_abi::EndpointState;
34use super::xhci_abi::EvaluateContextCommandTrb;
35use super::xhci_abi::InputControlContext;
36use super::xhci_abi::SlotContext;
37use super::xhci_abi::StreamContextArray;
38use super::xhci_abi::TrbCompletionCode;
39use super::xhci_abi::DEVICE_CONTEXT_ENTRY_SIZE;
40use super::xhci_backend_device::XhciBackendDevice;
41use super::xhci_regs::valid_max_pstreams;
42use super::xhci_regs::valid_slot_id;
43use super::xhci_regs::MAX_PORTS;
44use super::xhci_regs::MAX_SLOTS;
45use crate::register_space::Register;
46use crate::usb::backend::error::Error as BackendProviderError;
47use crate::usb::xhci::ring_buffer_stop_cb::fallible_closure;
48use crate::usb::xhci::ring_buffer_stop_cb::RingBufferStopCallback;
49use crate::utils::EventLoop;
50use crate::utils::FailHandle;
51
52#[sorted]
53#[derive(Error, Debug)]
54pub enum Error {
55    #[error("failed to allocate streams: {0}")]
56    AllocStreams(BackendProviderError),
57    #[error("bad device context: {0}")]
58    BadDeviceContextAddr(GuestAddress),
59    #[error("bad endpoint context: {0}")]
60    BadEndpointContext(GuestAddress),
61    #[error("device slot get a bad endpoint id: {0}")]
62    BadEndpointId(u8),
63    #[error("bad input context address: {0}")]
64    BadInputContextAddr(GuestAddress),
65    #[error("device slot get a bad port id: {0}")]
66    BadPortId(u8),
67    #[error("bad stream context type: {0}")]
68    BadStreamContextType(u8),
69    #[error("callback failed")]
70    CallbackFailed,
71    #[error("failed to create transfer controller: {0}")]
72    CreateTransferController(TransferRingControllerError),
73    #[error("failed to free streams: {0}")]
74    FreeStreams(BackendProviderError),
75    #[error("failed to get endpoint state: {0}")]
76    GetEndpointState(BitFieldError),
77    #[error("failed to get port: {0}")]
78    GetPort(u8),
79    #[error("failed to get slot context state: {0}")]
80    GetSlotContextState(BitFieldError),
81    #[error("failed to get trc: {0}")]
82    GetTrc(u8),
83    #[error("failed to read guest memory: {0}")]
84    ReadGuestMemory(GuestMemoryError),
85    #[error("failed to reset port: {0}")]
86    ResetPort(BackendProviderError),
87    #[error("failed to upgrade weak reference")]
88    WeakReferenceUpgrade,
89    #[error("failed to write guest memory: {0}")]
90    WriteGuestMemory(GuestMemoryError),
91}
92
93type Result<T> = std::result::Result<T, Error>;
94
95/// See spec 4.5.1 for dci.
96/// index 0: Control endpoint. Device Context Index: 1.
97/// index 1: Endpoint 1 out. Device Context Index: 2
98/// index 2: Endpoint 1 in. Device Context Index: 3.
99/// index 3: Endpoint 2 out. Device Context Index: 4
100/// ...
101/// index 30: Endpoint 15 in. Device Context Index: 31
102pub const TRANSFER_RING_CONTROLLERS_INDEX_END: usize = 31;
103/// End of device context index.
104pub const DCI_INDEX_END: u8 = (TRANSFER_RING_CONTROLLERS_INDEX_END + 1) as u8;
105/// Device context index of first transfer endpoint.
106pub const FIRST_TRANSFER_ENDPOINT_DCI: u8 = 2;
107
108fn valid_endpoint_id(endpoint_id: u8) -> bool {
109    endpoint_id < DCI_INDEX_END && endpoint_id > 0
110}
111
112#[derive(Clone)]
113pub struct DeviceSlots {
114    fail_handle: Arc<dyn FailHandle>,
115    hub: Arc<UsbHub>,
116    slots: Vec<Arc<DeviceSlot>>,
117}
118
119impl DeviceSlots {
120    pub fn new(
121        fail_handle: Arc<dyn FailHandle>,
122        dcbaap: Register<u64>,
123        hub: Arc<UsbHub>,
124        interrupter: Arc<Mutex<Interrupter>>,
125        event_loop: Arc<EventLoop>,
126        mem: GuestMemory,
127    ) -> DeviceSlots {
128        let mut slots = Vec::new();
129        for slot_id in 1..=MAX_SLOTS {
130            slots.push(Arc::new(DeviceSlot::new(
131                slot_id,
132                dcbaap.clone(),
133                hub.clone(),
134                interrupter.clone(),
135                event_loop.clone(),
136                mem.clone(),
137            )));
138        }
139        DeviceSlots {
140            fail_handle,
141            hub,
142            slots,
143        }
144    }
145
146    /// Note that slot id starts from 1. Slot index start from 0.
147    pub fn slot(&self, slot_id: u8) -> Option<Arc<DeviceSlot>> {
148        if valid_slot_id(slot_id) {
149            Some(self.slots[slot_id as usize - 1].clone())
150        } else {
151            error!(
152                "trying to index a wrong slot id {}, max slot = {}",
153                slot_id, MAX_SLOTS
154            );
155            None
156        }
157    }
158
159    /// Reset the device connected to a specific port.
160    pub fn reset_port(&self, port_id: u8) -> Result<()> {
161        if let Some(port) = self.hub.get_port(port_id) {
162            if let Some(backend_device) = port.backend_device().as_mut() {
163                backend_device.lock().reset().map_err(Error::ResetPort)?;
164            }
165        }
166
167        // No device on port, so nothing to reset.
168        Ok(())
169    }
170
171    /// Stop all device slots and reset them.
172    pub fn stop_all_and_reset<C: FnMut() + 'static + Send>(&self, mut callback: C) {
173        info!("xhci: stopping all device slots and resetting host hub");
174        let slots = self.slots.clone();
175        let hub = self.hub.clone();
176        let auto_callback = RingBufferStopCallback::new(fallible_closure(
177            self.fail_handle.clone(),
178            move || -> std::result::Result<(), usb_hub::Error> {
179                for slot in &slots {
180                    slot.reset();
181                }
182                hub.reset()?;
183                callback();
184                Ok(())
185            },
186        ));
187        self.stop_all(auto_callback);
188    }
189
190    /// Stop all devices. The auto callback will be executed when all trc is stopped. It could
191    /// happen asynchronously, if there are any pending transfers.
192    pub fn stop_all(&self, auto_callback: RingBufferStopCallback) {
193        for slot in &self.slots {
194            slot.stop_all_trc(auto_callback.clone());
195        }
196    }
197
198    /// Disable a slot. This might happen asynchronously, if there is any pending transfers. The
199    /// callback will be invoked when slot is actually disabled.
200    pub fn disable_slot<
201        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
202    >(
203        &self,
204        slot_id: u8,
205        cb: C,
206    ) -> Result<()> {
207        xhci_trace!("device slot {} is being disabled", slot_id);
208        DeviceSlot::disable(
209            self.fail_handle.clone(),
210            &self.slots[slot_id as usize - 1],
211            cb,
212        )
213    }
214
215    /// Reset a slot. This is a shortcut call for DeviceSlot::reset_slot.
216    pub fn reset_slot<
217        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
218    >(
219        &self,
220        slot_id: u8,
221        cb: C,
222    ) -> Result<()> {
223        xhci_trace!("device slot {} is resetting", slot_id);
224        DeviceSlot::reset_slot(
225            self.fail_handle.clone(),
226            &self.slots[slot_id as usize - 1],
227            cb,
228        )
229    }
230
231    pub fn stop_endpoint<
232        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
233    >(
234        &self,
235        slot_id: u8,
236        endpoint_id: u8,
237        cb: C,
238    ) -> Result<()> {
239        self.slots[slot_id as usize - 1].stop_endpoint(self.fail_handle.clone(), endpoint_id, cb)
240    }
241
242    pub fn reset_endpoint<
243        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
244    >(
245        &self,
246        slot_id: u8,
247        endpoint_id: u8,
248        cb: C,
249    ) -> Result<()> {
250        self.slots[slot_id as usize - 1].reset_endpoint(self.fail_handle.clone(), endpoint_id, cb)
251    }
252}
253
254// Usb port id. Valid ids starts from 1, to MAX_PORTS.
255struct PortId(Mutex<u8>);
256
257impl PortId {
258    fn new() -> Self {
259        PortId(Mutex::new(0))
260    }
261
262    fn set(&self, value: u8) -> Result<()> {
263        if !(1..=MAX_PORTS).contains(&value) {
264            return Err(Error::BadPortId(value));
265        }
266        *self.0.lock() = value;
267        Ok(())
268    }
269
270    fn reset(&self) {
271        *self.0.lock() = 0;
272    }
273
274    fn get(&self) -> Result<u8> {
275        let val = *self.0.lock();
276        if val == 0 {
277            return Err(Error::BadPortId(val));
278        }
279        Ok(val)
280    }
281}
282
283pub struct DeviceSlot {
284    slot_id: u8,
285    port_id: PortId, // Valid port id starts from 1, to MAX_PORTS.
286    dcbaap: Register<u64>,
287    hub: Arc<UsbHub>,
288    interrupter: Arc<Mutex<Interrupter>>,
289    event_loop: Arc<EventLoop>,
290    mem: GuestMemory,
291    enabled: AtomicBool,
292    transfer_ring_controllers: Mutex<Vec<Option<TransferRingControllers>>>,
293}
294
295impl DeviceSlot {
296    /// Create a new device slot.
297    pub fn new(
298        slot_id: u8,
299        dcbaap: Register<u64>,
300        hub: Arc<UsbHub>,
301        interrupter: Arc<Mutex<Interrupter>>,
302        event_loop: Arc<EventLoop>,
303        mem: GuestMemory,
304    ) -> Self {
305        let mut transfer_ring_controllers = Vec::new();
306        transfer_ring_controllers.resize_with(TRANSFER_RING_CONTROLLERS_INDEX_END, || None);
307        DeviceSlot {
308            slot_id,
309            port_id: PortId::new(),
310            dcbaap,
311            hub,
312            interrupter,
313            event_loop,
314            mem,
315            enabled: AtomicBool::new(false),
316            transfer_ring_controllers: Mutex::new(transfer_ring_controllers),
317        }
318    }
319
320    fn get_trc(&self, i: usize, stream_id: u16) -> Option<Arc<TransferRingController>> {
321        let trcs = self.transfer_ring_controllers.lock();
322        match &trcs[i] {
323            Some(TransferRingControllers::Endpoint(trc)) => Some(trc.clone()),
324            Some(TransferRingControllers::Stream(trcs)) => {
325                let stream_id = stream_id as usize;
326                if stream_id > 0 && stream_id <= trcs.len() {
327                    Some(trcs[stream_id - 1].clone())
328                } else {
329                    None
330                }
331            }
332            None => None,
333        }
334    }
335
336    fn get_trcs(&self, i: usize) -> Option<TransferRingControllers> {
337        let trcs = self.transfer_ring_controllers.lock();
338        trcs[i].clone()
339    }
340
341    fn set_trcs(&self, i: usize, trc: Option<TransferRingControllers>) {
342        let mut trcs = self.transfer_ring_controllers.lock();
343        trcs[i] = trc;
344    }
345
346    fn trc_len(&self) -> usize {
347        self.transfer_ring_controllers.lock().len()
348    }
349
350    /// The arguments are identical to the fields in each doorbell register. The
351    /// target value:
352    /// 1: Reserved
353    /// 2: Control endpoint
354    /// 3: Endpoint 1 out
355    /// 4: Endpoint 1 in
356    /// 5: Endpoint 2 out
357    /// ...
358    /// 32: Endpoint 15 in
359    ///
360    /// Steam ID will be useful when host controller support streams.
361    /// The stream ID must be zero for endpoints that do not have streams
362    /// configured.
363    /// This function will return false if it fails to trigger transfer ring start.
364    pub fn ring_doorbell(&self, target: u8, stream_id: u16) -> Result<bool> {
365        if !valid_endpoint_id(target) {
366            error!(
367                "device slot {}: Invalid target written to doorbell register. target: {}",
368                self.slot_id, target
369            );
370            return Ok(false);
371        }
372        xhci_trace!(
373            "device slot {}: ring_doorbell target = {} stream_id = {}",
374            self.slot_id,
375            target,
376            stream_id
377        );
378        // See DCI in spec.
379        let endpoint_index = (target - 1) as usize;
380        let transfer_ring_controller = match self.get_trc(endpoint_index, stream_id) {
381            Some(tr) => tr,
382            None => {
383                error!("Device endpoint is not inited");
384                return Ok(false);
385            }
386        };
387        let mut context = self.get_device_context()?;
388        let endpoint_state = context.endpoint_context[endpoint_index]
389            .get_endpoint_state()
390            .map_err(Error::GetEndpointState)?;
391        if endpoint_state == EndpointState::Running || endpoint_state == EndpointState::Stopped {
392            if endpoint_state == EndpointState::Stopped {
393                context.endpoint_context[endpoint_index].set_endpoint_state(EndpointState::Running);
394                self.set_device_context(context)?;
395            }
396            // endpoint is started, start transfer ring
397            transfer_ring_controller.start();
398        } else {
399            error!("doorbell rung when endpoint state is {:?}", endpoint_state);
400        }
401        Ok(true)
402    }
403
404    /// Enable the slot. This function returns false if it's already enabled.
405    pub fn enable(&self) -> bool {
406        let was_already_enabled = self.enabled.swap(true, Ordering::SeqCst);
407        if was_already_enabled {
408            error!("device slot is already enabled");
409        }
410        !was_already_enabled
411    }
412
413    /// Disable this device slot. If the slot is not enabled, callback will be invoked immediately
414    /// with error. Otherwise, callback will be invoked when all trc is stopped.
415    pub fn disable<C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send>(
416        fail_handle: Arc<dyn FailHandle>,
417        slot: &Arc<DeviceSlot>,
418        mut callback: C,
419    ) -> Result<()> {
420        if slot.enabled.load(Ordering::SeqCst) {
421            let slot_weak = Arc::downgrade(slot);
422            let auto_callback =
423                RingBufferStopCallback::new(fallible_closure(fail_handle, move || {
424                    // Slot should still be alive when the callback is invoked. If it's not, there
425                    // must be a bug somewhere.
426                    let slot = slot_weak.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
427                    let mut device_context = slot.get_device_context()?;
428                    device_context
429                        .slot_context
430                        .set_slot_state(DeviceSlotState::DisabledOrEnabled);
431                    slot.set_device_context(device_context)?;
432                    slot.reset();
433                    debug!(
434                        "device slot {}: all trc disabled, sending trb",
435                        slot.slot_id
436                    );
437                    callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
438                }));
439            slot.stop_all_trc(auto_callback);
440            Ok(())
441        } else {
442            callback(TrbCompletionCode::SlotNotEnabledError).map_err(|_| Error::CallbackFailed)
443        }
444    }
445
446    // Assigns the device address and initializes slot and endpoint 0 context.
447    pub fn set_address(
448        self: &Arc<Self>,
449        trb: &AddressDeviceCommandTrb,
450    ) -> Result<TrbCompletionCode> {
451        if !self.enabled.load(Ordering::SeqCst) {
452            error!(
453                "trying to set address to a disabled device slot {}",
454                self.slot_id
455            );
456            return Ok(TrbCompletionCode::SlotNotEnabledError);
457        }
458        let device_context = self.get_device_context()?;
459        let state = device_context
460            .slot_context
461            .get_slot_state()
462            .map_err(Error::GetSlotContextState)?;
463        match state {
464            DeviceSlotState::DisabledOrEnabled => {}
465            DeviceSlotState::Default if !trb.get_block_set_address_request() => {}
466            _ => {
467                error!("slot {} has unexpected slot state", self.slot_id);
468                return Ok(TrbCompletionCode::ContextStateError);
469            }
470        }
471
472        // Copy all fields of the slot context and endpoint 0 context from the input context
473        // to the output context.
474        let input_context_ptr = GuestAddress(trb.get_input_context_pointer());
475        // Copy slot context.
476        self.copy_context(input_context_ptr, 0)?;
477        // Copy control endpoint context.
478        self.copy_context(input_context_ptr, 1)?;
479
480        // Read back device context.
481        let mut device_context = self.get_device_context()?;
482        let port_id = device_context.slot_context.get_root_hub_port_number();
483        self.port_id.set(port_id)?;
484        debug!(
485            "port id {} is assigned to slot id {}",
486            port_id, self.slot_id
487        );
488
489        // Initialize the control endpoint. Endpoint id = 1.
490        let trc = TransferRingController::new(
491            self.mem.clone(),
492            self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?,
493            self.event_loop.clone(),
494            self.interrupter.clone(),
495            self.slot_id,
496            1,
497            Arc::downgrade(self),
498            None,
499        )
500        .map_err(Error::CreateTransferController)?;
501        self.set_trcs(0, Some(TransferRingControllers::Endpoint(trc)));
502
503        // Assign slot ID as device address if block_set_address_request is not set.
504        if trb.get_block_set_address_request() {
505            device_context
506                .slot_context
507                .set_slot_state(DeviceSlotState::Default);
508        } else {
509            let port = self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?;
510            match port.backend_device().as_mut() {
511                Some(backend) => {
512                    backend.lock().set_address(self.slot_id as u32);
513                }
514                None => {
515                    return Ok(TrbCompletionCode::TransactionError);
516                }
517            }
518
519            device_context
520                .slot_context
521                .set_usb_device_address(self.slot_id);
522            device_context
523                .slot_context
524                .set_slot_state(DeviceSlotState::Addressed);
525        }
526
527        // TODO(jkwang) trc should always exists. Fix this.
528        self.get_trc(0, 0)
529            .ok_or(Error::GetTrc(0))?
530            .set_dequeue_pointer(
531                device_context.endpoint_context[0]
532                    .get_tr_dequeue_pointer()
533                    .get_gpa(),
534            );
535
536        self.get_trc(0, 0)
537            .ok_or(Error::GetTrc(0))?
538            .set_consumer_cycle_state(device_context.endpoint_context[0].get_dequeue_cycle_state());
539
540        // Setting endpoint 0 to running
541        device_context.endpoint_context[0].set_endpoint_state(EndpointState::Running);
542        self.set_device_context(device_context)?;
543        Ok(TrbCompletionCode::Success)
544    }
545
546    // Adds or drops multiple endpoints in the device slot.
547    pub fn configure_endpoint(
548        self: &Arc<Self>,
549        trb: &ConfigureEndpointCommandTrb,
550    ) -> Result<TrbCompletionCode> {
551        let input_control_context = if trb.get_deconfigure() {
552            // From section 4.6.6 of the xHCI spec:
553            // Setting the deconfigure (DC) flag to '1' in the Configure Endpoint Command
554            // TRB is equivalent to setting Input Context Drop Context flags 2-31 to '1'
555            // and Add Context 2-31 flags to '0'.
556            let mut c = InputControlContext::new();
557            c.set_add_context_flags(0);
558            c.set_drop_context_flags(0xfffffffc);
559            c
560        } else {
561            self.mem
562                .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
563                .map_err(Error::ReadGuestMemory)?
564        };
565
566        for device_context_index in 1..DCI_INDEX_END {
567            if input_control_context.drop_context_flag(device_context_index) {
568                self.drop_one_endpoint(device_context_index)?;
569            }
570            if input_control_context.add_context_flag(device_context_index) {
571                self.copy_context(
572                    GuestAddress(trb.get_input_context_pointer()),
573                    device_context_index,
574                )?;
575                self.add_one_endpoint(device_context_index)?;
576            }
577        }
578
579        if trb.get_deconfigure() {
580            self.set_state(DeviceSlotState::Addressed)?;
581        } else {
582            self.set_state(DeviceSlotState::Configured)?;
583        }
584        Ok(TrbCompletionCode::Success)
585    }
586
587    // Evaluates the device context by reading new values for certain fields of
588    // the slot context and/or control endpoint context.
589    pub fn evaluate_context(&self, trb: &EvaluateContextCommandTrb) -> Result<TrbCompletionCode> {
590        if !self.enabled.load(Ordering::SeqCst) {
591            return Ok(TrbCompletionCode::SlotNotEnabledError);
592        }
593        // TODO(jkwang) verify this
594        // The spec has multiple contradictions about validating context parameters in sections
595        // 4.6.7, 6.2.3.3. To keep things as simple as possible we do no further validation here.
596        let input_control_context: InputControlContext = self
597            .mem
598            .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
599            .map_err(Error::ReadGuestMemory)?;
600
601        let mut device_context = self.get_device_context()?;
602        if input_control_context.add_context_flag(0) {
603            let input_slot_context: SlotContext = self
604                .mem
605                .read_obj_from_addr(GuestAddress(
606                    trb.get_input_context_pointer() + DEVICE_CONTEXT_ENTRY_SIZE as u64,
607                ))
608                .map_err(Error::ReadGuestMemory)?;
609            device_context
610                .slot_context
611                .set_interrupter_target(input_slot_context.get_interrupter_target());
612
613            device_context
614                .slot_context
615                .set_max_exit_latency(input_slot_context.get_max_exit_latency());
616        }
617
618        // From 6.2.3.3: "Endpoint Contexts 2 throught 31 shall not be evaluated by the Evaluate
619        // Context Command".
620        if input_control_context.add_context_flag(1) {
621            let ep0_context: EndpointContext = self
622                .mem
623                .read_obj_from_addr(GuestAddress(
624                    trb.get_input_context_pointer() + 2 * DEVICE_CONTEXT_ENTRY_SIZE as u64,
625                ))
626                .map_err(Error::ReadGuestMemory)?;
627            device_context.endpoint_context[0]
628                .set_max_packet_size(ep0_context.get_max_packet_size());
629        }
630        self.set_device_context(device_context)?;
631        Ok(TrbCompletionCode::Success)
632    }
633
634    /// Reset the device slot to default state and deconfigures all but the
635    /// control endpoint.
636    pub fn reset_slot<
637        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
638    >(
639        fail_handle: Arc<dyn FailHandle>,
640        slot: &Arc<DeviceSlot>,
641        mut callback: C,
642    ) -> Result<()> {
643        let weak_s = Arc::downgrade(slot);
644        let auto_callback =
645            RingBufferStopCallback::new(fallible_closure(fail_handle, move || -> Result<()> {
646                let s = weak_s.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
647                for i in FIRST_TRANSFER_ENDPOINT_DCI..DCI_INDEX_END {
648                    s.drop_one_endpoint(i)?;
649                }
650                let mut ctx = s.get_device_context()?;
651                ctx.slot_context.set_slot_state(DeviceSlotState::Default);
652                ctx.slot_context.set_context_entries(1);
653                ctx.slot_context.set_root_hub_port_number(0);
654                s.set_device_context(ctx)?;
655                callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)?;
656                Ok(())
657            }));
658        slot.stop_all_trc(auto_callback);
659        Ok(())
660    }
661
662    /// Stop all transfer ring controllers.
663    pub fn stop_all_trc(&self, auto_callback: RingBufferStopCallback) {
664        for i in 0..self.trc_len() {
665            if let Some(trcs) = self.get_trcs(i) {
666                match trcs {
667                    TransferRingControllers::Endpoint(trc) => {
668                        trc.stop(auto_callback.clone());
669                    }
670                    TransferRingControllers::Stream(trcs) => {
671                        for trc in trcs {
672                            trc.stop(auto_callback.clone());
673                        }
674                    }
675                }
676            }
677        }
678    }
679
680    /// Stop an endpoint.
681    pub fn stop_endpoint<
682        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
683    >(
684        &self,
685        fail_handle: Arc<dyn FailHandle>,
686        endpoint_id: u8,
687        mut cb: C,
688    ) -> Result<()> {
689        if !valid_endpoint_id(endpoint_id) {
690            error!("trb indexing wrong endpoint id");
691            return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
692        }
693        let index = endpoint_id - 1;
694        let mut device_context = self.get_device_context()?;
695        let endpoint_context = &mut device_context.endpoint_context[index as usize];
696        match self.get_trcs(index as usize) {
697            Some(TransferRingControllers::Endpoint(trc)) => {
698                let auto_cb = RingBufferStopCallback::new(fallible_closure(
699                    fail_handle,
700                    move || -> Result<()> {
701                        cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
702                    },
703                ));
704                trc.stop(auto_cb);
705                let dequeue_pointer = trc.get_dequeue_pointer();
706                let dcs = trc.get_consumer_cycle_state();
707                endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
708                endpoint_context.set_dequeue_cycle_state(dcs);
709            }
710            Some(TransferRingControllers::Stream(trcs)) => {
711                let stream_context_array_addr = endpoint_context.get_tr_dequeue_pointer().get_gpa();
712                let mut stream_context_array: StreamContextArray = self
713                    .mem
714                    .read_obj_from_addr(stream_context_array_addr)
715                    .map_err(Error::ReadGuestMemory)?;
716                let auto_cb = RingBufferStopCallback::new(fallible_closure(
717                    fail_handle,
718                    move || -> Result<()> {
719                        cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
720                    },
721                ));
722                for (i, trc) in trcs.iter().enumerate() {
723                    let dequeue_pointer = trc.get_dequeue_pointer();
724                    let dcs = trc.get_consumer_cycle_state();
725                    trc.stop(auto_cb.clone());
726                    stream_context_array.stream_contexts[i + 1]
727                        .set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
728                    stream_context_array.stream_contexts[i + 1].set_dequeue_cycle_state(dcs);
729                }
730                self.mem
731                    .write_obj_at_addr(stream_context_array, stream_context_array_addr)
732                    .map_err(Error::WriteGuestMemory)?;
733            }
734            None => {
735                error!("endpoint at index {} is not started", index);
736                cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed)?;
737            }
738        }
739        endpoint_context.set_endpoint_state(EndpointState::Stopped);
740        self.set_device_context(device_context)?;
741        Ok(())
742    }
743
744    /// Reset an endpoint.
745    pub fn reset_endpoint<
746        C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
747    >(
748        &self,
749        fail_handle: Arc<dyn FailHandle>,
750        endpoint_id: u8,
751        mut cb: C,
752    ) -> Result<()> {
753        if !valid_endpoint_id(endpoint_id) {
754            error!("trb indexing wrong endpoint id");
755            return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
756        }
757        let index = endpoint_id - 1;
758        let mut device_context = self.get_device_context()?;
759        let endpoint_context = &mut device_context.endpoint_context[index as usize];
760        if endpoint_context
761            .get_endpoint_state()
762            .map_err(Error::GetEndpointState)?
763            != EndpointState::Halted
764        {
765            error!("endpoint at index {} is not halted", index);
766            return cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed);
767        }
768        match self.get_trcs(index as usize) {
769            Some(TransferRingControllers::Endpoint(trc)) => {
770                let auto_cb = RingBufferStopCallback::new(fallible_closure(
771                    fail_handle,
772                    move || -> Result<()> {
773                        cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
774                    },
775                ));
776                trc.stop(auto_cb);
777                let dequeue_pointer = trc.get_dequeue_pointer();
778                let dcs = trc.get_consumer_cycle_state();
779                endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
780                endpoint_context.set_dequeue_cycle_state(dcs);
781            }
782            Some(TransferRingControllers::Stream(trcs)) => {
783                let stream_context_array_addr = endpoint_context.get_tr_dequeue_pointer().get_gpa();
784                let mut stream_context_array: StreamContextArray = self
785                    .mem
786                    .read_obj_from_addr(stream_context_array_addr)
787                    .map_err(Error::ReadGuestMemory)?;
788                let auto_cb = RingBufferStopCallback::new(fallible_closure(
789                    fail_handle,
790                    move || -> Result<()> {
791                        cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
792                    },
793                ));
794                for (i, trc) in trcs.iter().enumerate() {
795                    let dequeue_pointer = trc.get_dequeue_pointer();
796                    let dcs = trc.get_consumer_cycle_state();
797                    trc.stop(auto_cb.clone());
798                    stream_context_array.stream_contexts[i + 1]
799                        .set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
800                    stream_context_array.stream_contexts[i + 1].set_dequeue_cycle_state(dcs);
801                }
802                self.mem
803                    .write_obj_at_addr(stream_context_array, stream_context_array_addr)
804                    .map_err(Error::WriteGuestMemory)?;
805            }
806            None => {
807                error!("endpoint at index {} is not started", index);
808                cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed)?;
809            }
810        }
811        endpoint_context.set_endpoint_state(EndpointState::Stopped);
812        self.set_device_context(device_context)?;
813        Ok(())
814    }
815
816    /// Set transfer ring dequeue pointer.
817    pub fn set_tr_dequeue_ptr(
818        &self,
819        endpoint_id: u8,
820        stream_id: u16,
821        ptr: u64,
822    ) -> Result<TrbCompletionCode> {
823        if !valid_endpoint_id(endpoint_id) {
824            error!("trb indexing wrong endpoint id");
825            return Ok(TrbCompletionCode::TrbError);
826        }
827        let index = (endpoint_id - 1) as usize;
828        match self.get_trc(index, stream_id) {
829            Some(trc) => {
830                trc.set_dequeue_pointer(GuestAddress(ptr));
831                let mut ctx = self.get_device_context()?;
832                ctx.endpoint_context[index]
833                    .set_tr_dequeue_pointer(DequeuePtr::new(GuestAddress(ptr)));
834                self.set_device_context(ctx)?;
835                Ok(TrbCompletionCode::Success)
836            }
837            None => {
838                error!("set tr dequeue ptr failed due to no trc started");
839                Ok(TrbCompletionCode::ContextStateError)
840            }
841        }
842    }
843
844    // Reset and reset_slot are different.
845    // Reset_slot handles command ring `reset slot` command. It will reset the slot state.
846    // Reset handles xhci reset. It will destroy everything.
847    fn reset(&self) {
848        for i in 0..self.trc_len() {
849            self.set_trcs(i, None);
850        }
851        debug!("resetting device slot {}!", self.slot_id);
852        self.enabled.store(false, Ordering::SeqCst);
853        self.port_id.reset();
854    }
855
856    fn create_stream_trcs(
857        self: &Arc<Self>,
858        stream_context_array_addr: GuestAddress,
859        max_pstreams: u8,
860        device_context_index: u8,
861    ) -> Result<TransferRingControllers> {
862        let pstreams = 1usize << (max_pstreams + 1);
863        let stream_context_array: StreamContextArray = self
864            .mem
865            .read_obj_from_addr(stream_context_array_addr)
866            .map_err(Error::ReadGuestMemory)?;
867        let mut trcs = Vec::new();
868
869        // Stream ID 0 is reserved (xHCI spec Section 4.12.2)
870        for i in 1..pstreams {
871            let stream_context = &stream_context_array.stream_contexts[i];
872            let context_type = stream_context.get_stream_context_type();
873            if context_type != 1 {
874                // We only support Linear Stream Context Array for now
875                return Err(Error::BadStreamContextType(context_type));
876            }
877            let trc = TransferRingController::new(
878                self.mem.clone(),
879                self.hub
880                    .get_port(self.port_id.get()?)
881                    .ok_or(Error::GetPort(self.port_id.get()?))?,
882                self.event_loop.clone(),
883                self.interrupter.clone(),
884                self.slot_id,
885                device_context_index,
886                Arc::downgrade(self),
887                Some(i as u16),
888            )
889            .map_err(Error::CreateTransferController)?;
890            trc.set_dequeue_pointer(stream_context.get_tr_dequeue_pointer().get_gpa());
891            trc.set_consumer_cycle_state(stream_context.get_dequeue_cycle_state());
892            trcs.push(trc);
893        }
894        Ok(TransferRingControllers::Stream(trcs))
895    }
896
897    fn add_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
898        xhci_trace!(
899            "adding one endpoint, device context index {}",
900            device_context_index
901        );
902        let mut device_context = self.get_device_context()?;
903        let transfer_ring_index = (device_context_index - 1) as usize;
904        let endpoint_context = &mut device_context.endpoint_context[transfer_ring_index];
905        let max_pstreams = endpoint_context.get_max_primary_streams();
906        let tr_dequeue_pointer = endpoint_context.get_tr_dequeue_pointer().get_gpa();
907        let endpoint_context_addr = self
908            .get_device_context_addr()?
909            .unchecked_add(size_of::<SlotContext>() as u64)
910            .unchecked_add(size_of::<EndpointContext>() as u64 * transfer_ring_index as u64);
911        let trcs = if max_pstreams > 0 {
912            if !valid_max_pstreams(max_pstreams) {
913                return Err(Error::BadEndpointContext(endpoint_context_addr));
914            }
915            let endpoint_type = endpoint_context.get_endpoint_type();
916            if endpoint_type != 2 && endpoint_type != 6 {
917                // Stream is only supported on a bulk endpoint
918                return Err(Error::BadEndpointId(transfer_ring_index as u8));
919            }
920            if endpoint_context.get_linear_stream_array() != 1 {
921                // We only support Linear Stream Context Array for now
922                return Err(Error::BadEndpointContext(endpoint_context_addr));
923            }
924
925            let trcs =
926                self.create_stream_trcs(tr_dequeue_pointer, max_pstreams, device_context_index)?;
927
928            if let Some(port) = self.hub.get_port(self.port_id.get()?) {
929                if let Some(backend_device) = port.backend_device().as_mut() {
930                    let mut endpoint_address = device_context_index / 2;
931                    if device_context_index % 2 == 1 {
932                        endpoint_address |= 1u8 << 7;
933                    }
934                    let streams = 1 << (max_pstreams + 1);
935                    // Subtracting 1 is to ignore Stream ID 0
936                    backend_device
937                        .lock()
938                        .alloc_streams(endpoint_address, streams - 1)
939                        .map_err(Error::AllocStreams)?;
940                }
941            }
942            trcs
943        } else {
944            let trc = TransferRingController::new(
945                self.mem.clone(),
946                self.hub
947                    .get_port(self.port_id.get()?)
948                    .ok_or(Error::GetPort(self.port_id.get()?))?,
949                self.event_loop.clone(),
950                self.interrupter.clone(),
951                self.slot_id,
952                device_context_index,
953                Arc::downgrade(self),
954                None,
955            )
956            .map_err(Error::CreateTransferController)?;
957            trc.set_dequeue_pointer(tr_dequeue_pointer);
958            trc.set_consumer_cycle_state(endpoint_context.get_dequeue_cycle_state());
959            TransferRingControllers::Endpoint(trc)
960        };
961        self.set_trcs(transfer_ring_index, Some(trcs));
962        endpoint_context.set_endpoint_state(EndpointState::Running);
963        self.set_device_context(device_context)
964    }
965
966    fn drop_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
967        let endpoint_index = (device_context_index - 1) as usize;
968        let mut device_context = self.get_device_context()?;
969        let endpoint_context = &mut device_context.endpoint_context[endpoint_index];
970        if endpoint_context.get_max_primary_streams() > 0 {
971            if let Some(port) = self.hub.get_port(self.port_id.get()?) {
972                if let Some(backend_device) = port.backend_device().as_mut() {
973                    let mut endpoint_address = device_context_index / 2;
974                    if device_context_index % 2 == 1 {
975                        endpoint_address |= 1u8 << 7;
976                    }
977                    backend_device
978                        .lock()
979                        .free_streams(endpoint_address)
980                        .map_err(Error::FreeStreams)?;
981                }
982            }
983        }
984        self.set_trcs(endpoint_index, None);
985        endpoint_context.set_endpoint_state(EndpointState::Disabled);
986        self.set_device_context(device_context)
987    }
988
989    fn get_device_context(&self) -> Result<DeviceContext> {
990        let ctx = self
991            .mem
992            .read_obj_from_addr(self.get_device_context_addr()?)
993            .map_err(Error::ReadGuestMemory)?;
994        Ok(ctx)
995    }
996
997    fn set_device_context(&self, device_context: DeviceContext) -> Result<()> {
998        self.mem
999            .write_obj_at_addr(device_context, self.get_device_context_addr()?)
1000            .map_err(Error::WriteGuestMemory)
1001    }
1002
1003    fn copy_context(
1004        &self,
1005        input_context_ptr: GuestAddress,
1006        device_context_index: u8,
1007    ) -> Result<()> {
1008        // Note that it could be slot context or device context. They have the same size. Won't
1009        // make a difference here.
1010        let ctx: EndpointContext = self
1011            .mem
1012            .read_obj_from_addr(
1013                input_context_ptr
1014                    .checked_add(
1015                        (device_context_index as u64 + 1) * DEVICE_CONTEXT_ENTRY_SIZE as u64,
1016                    )
1017                    .ok_or(Error::BadInputContextAddr(input_context_ptr))?,
1018            )
1019            .map_err(Error::ReadGuestMemory)?;
1020        xhci_trace!("copy_context {:?}", ctx);
1021        let device_context_ptr = self.get_device_context_addr()?;
1022        self.mem
1023            .write_obj_at_addr(
1024                ctx,
1025                device_context_ptr
1026                    .checked_add(device_context_index as u64 * DEVICE_CONTEXT_ENTRY_SIZE as u64)
1027                    .ok_or(Error::BadDeviceContextAddr(device_context_ptr))?,
1028            )
1029            .map_err(Error::WriteGuestMemory)
1030    }
1031
1032    fn get_device_context_addr(&self) -> Result<GuestAddress> {
1033        let addr: u64 = self
1034            .mem
1035            .read_obj_from_addr(GuestAddress(
1036                self.dcbaap.get_value() + size_of::<u64>() as u64 * self.slot_id as u64,
1037            ))
1038            .map_err(Error::ReadGuestMemory)?;
1039        Ok(GuestAddress(addr))
1040    }
1041
1042    fn set_state(&self, state: DeviceSlotState) -> Result<()> {
1043        let mut ctx = self.get_device_context()?;
1044        ctx.slot_context.set_slot_state(state);
1045        self.set_device_context(ctx)
1046    }
1047
1048    pub fn halt_endpoint(&self, endpoint_id: u8) -> Result<()> {
1049        if !valid_endpoint_id(endpoint_id) {
1050            return Err(Error::BadEndpointId(endpoint_id));
1051        }
1052        let index = endpoint_id - 1;
1053        let mut device_context = self.get_device_context()?;
1054        let endpoint_context = &mut device_context.endpoint_context[index as usize];
1055        match self.get_trcs(index as usize) {
1056            Some(trcs) => match trcs {
1057                TransferRingControllers::Endpoint(trc) => {
1058                    endpoint_context
1059                        .set_tr_dequeue_pointer(DequeuePtr::new(trc.get_dequeue_pointer()));
1060                    endpoint_context.set_dequeue_cycle_state(trc.get_consumer_cycle_state());
1061                }
1062                TransferRingControllers::Stream(trcs) => {
1063                    let stream_context_array_addr =
1064                        endpoint_context.get_tr_dequeue_pointer().get_gpa();
1065                    let mut stream_context_array: StreamContextArray = self
1066                        .mem
1067                        .read_obj_from_addr(stream_context_array_addr)
1068                        .map_err(Error::ReadGuestMemory)?;
1069                    for (i, trc) in trcs.iter().enumerate() {
1070                        stream_context_array.stream_contexts[i + 1]
1071                            .set_tr_dequeue_pointer(DequeuePtr::new(trc.get_dequeue_pointer()));
1072                        stream_context_array.stream_contexts[i + 1]
1073                            .set_dequeue_cycle_state(trc.get_consumer_cycle_state());
1074                    }
1075                    self.mem
1076                        .write_obj_at_addr(stream_context_array, stream_context_array_addr)
1077                        .map_err(Error::WriteGuestMemory)?;
1078                }
1079            },
1080            None => {
1081                error!("trc for endpoint {} not found", endpoint_id);
1082                return Err(Error::BadEndpointId(endpoint_id));
1083            }
1084        }
1085        endpoint_context.set_endpoint_state(EndpointState::Halted);
1086        self.set_device_context(device_context)?;
1087        Ok(())
1088    }
1089}