devices/virtio/
iommu.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5pub mod ipc_memory_mapper;
6pub mod memory_mapper;
7pub mod protocol;
8pub(crate) mod sys;
9
10use std::cell::RefCell;
11use std::collections::btree_map::Entry;
12use std::collections::BTreeMap;
13use std::io;
14use std::io::Write;
15use std::mem::size_of;
16use std::ops::RangeInclusive;
17use std::rc::Rc;
18use std::result;
19use std::sync::Arc;
20
21#[cfg(target_arch = "x86_64")]
22use acpi_tables::sdt::SDT;
23use anyhow::anyhow;
24use anyhow::Context;
25use base::debug;
26use base::error;
27use base::pagesize;
28use base::AsRawDescriptor;
29use base::Error as SysError;
30use base::Event;
31use base::MappedRegion;
32use base::MemoryMapping;
33use base::Protection;
34use base::RawDescriptor;
35use base::Result as SysResult;
36use base::Tube;
37use base::TubeError;
38use base::WorkerThread;
39use cros_async::AsyncError;
40use cros_async::AsyncTube;
41use cros_async::EventAsync;
42use cros_async::Executor;
43use data_model::Le64;
44use futures::select;
45use futures::FutureExt;
46use remain::sorted;
47use sync::Mutex;
48use thiserror::Error;
49use vm_control::VmMemoryRegionId;
50use vm_memory::GuestAddress;
51use vm_memory::GuestMemory;
52use vm_memory::GuestMemoryError;
53#[cfg(target_arch = "x86_64")]
54use zerocopy::FromBytes;
55#[cfg(target_arch = "x86_64")]
56use zerocopy::Immutable;
57use zerocopy::IntoBytes;
58#[cfg(target_arch = "x86_64")]
59use zerocopy::KnownLayout;
60
61#[cfg(target_arch = "x86_64")]
62use crate::pci::PciAddress;
63use crate::virtio::async_utils;
64use crate::virtio::copy_config;
65use crate::virtio::iommu::memory_mapper::*;
66use crate::virtio::iommu::protocol::*;
67use crate::virtio::DescriptorChain;
68use crate::virtio::DeviceType;
69use crate::virtio::Interrupt;
70use crate::virtio::Queue;
71use crate::virtio::Reader;
72use crate::virtio::VirtioDevice;
73#[cfg(target_arch = "x86_64")]
74use crate::virtio::Writer;
75
76const QUEUE_SIZE: u16 = 256;
77const NUM_QUEUES: usize = 2;
78const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
79
80// Size of struct virtio_iommu_probe_property
81#[cfg(target_arch = "x86_64")]
82const IOMMU_PROBE_SIZE: usize = size_of::<virtio_iommu_probe_resv_mem>();
83
84#[cfg(target_arch = "x86_64")]
85const VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE: u8 = 1;
86#[cfg(target_arch = "x86_64")]
87const VIRTIO_IOMMU_VIOT_NODE_VIRTIO_IOMMU_PCI: u8 = 3;
88
89#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
90#[repr(C, packed)]
91#[cfg(target_arch = "x86_64")]
92struct VirtioIommuViotHeader {
93    node_count: u16,
94    node_offset: u16,
95    reserved: [u8; 8],
96}
97
98#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
99#[repr(C, packed)]
100#[cfg(target_arch = "x86_64")]
101struct VirtioIommuViotVirtioPciNode {
102    type_: u8,
103    reserved: [u8; 1],
104    length: u16,
105    segment: u16,
106    bdf: u16,
107    reserved2: [u8; 8],
108}
109
110#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
111#[repr(C, packed)]
112#[cfg(target_arch = "x86_64")]
113struct VirtioIommuViotPciRangeNode {
114    type_: u8,
115    reserved: [u8; 1],
116    length: u16,
117    endpoint_start: u32,
118    segment_start: u16,
119    segment_end: u16,
120    bdf_start: u16,
121    bdf_end: u16,
122    output_node: u16,
123    reserved2: [u8; 2],
124    reserved3: [u8; 4],
125}
126
127type Result<T> = result::Result<T, IommuError>;
128
129#[sorted]
130#[derive(Error, Debug)]
131pub enum IommuError {
132    #[error("async executor error: {0}")]
133    AsyncExec(AsyncError),
134    #[error("failed to create wait context: {0}")]
135    CreateWaitContext(SysError),
136    #[error("failed getting host address: {0}")]
137    GetHostAddress(GuestMemoryError),
138    #[error("failed to read from guest address: {0}")]
139    GuestMemoryRead(io::Error),
140    #[error("failed to write to guest address: {0}")]
141    GuestMemoryWrite(io::Error),
142    #[error("memory mapper failed: {0}")]
143    MemoryMapper(anyhow::Error),
144    #[error("Failed to read descriptor asynchronously: {0}")]
145    ReadAsyncDesc(AsyncError),
146    #[error("failed to read from virtio queue Event: {0}")]
147    ReadQueueEvent(SysError),
148    #[error("tube error: {0}")]
149    Tube(TubeError),
150    #[error("unexpected descriptor error")]
151    UnexpectedDescriptor,
152    #[error("failed to receive virtio-iommu control request: {0}")]
153    VirtioIOMMUReqError(TubeError),
154    #[error("failed to send virtio-iommu control response: {0}")]
155    VirtioIOMMUResponseError(TubeError),
156    #[error("failed to wait for events: {0}")]
157    WaitError(SysError),
158    #[error("write buffer length too small")]
159    WriteBufferTooSmall,
160}
161
162// key: domain ID
163// value: reference counter and MemoryMapperTrait
164type DomainMap = BTreeMap<u32, (u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>)>;
165
166struct DmabufRegionEntry {
167    mmap: MemoryMapping,
168    region_id: VmMemoryRegionId,
169    size: u64,
170}
171
172// Shared state for the virtio-iommu device.
173struct State {
174    mem: GuestMemory,
175    page_mask: u64,
176    // Hot-pluggable PCI endpoints ranges
177    // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address)
178    #[cfg_attr(windows, allow(dead_code))]
179    hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
180    // All PCI endpoints that attach to certain IOMMU domain
181    // key: endpoint PCI address
182    // value: attached domain ID
183    endpoint_map: BTreeMap<u32, u32>,
184    // All attached domains
185    domain_map: DomainMap,
186    // Contains all pass-through endpoints that attach to this IOMMU device
187    // key: endpoint PCI address
188    // value: reference counter and MemoryMapperTrait
189    endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
190    // Contains dmabuf regions
191    // key: guest physical address
192    dmabuf_mem: BTreeMap<u64, DmabufRegionEntry>,
193}
194
195impl State {
196    // Detach the given endpoint if possible, and return whether or not the endpoint
197    // was actually detached. If a successfully detached endpoint has exported
198    // memory, returns an event that will be signaled once all exported memory is released.
199    //
200    // The device MUST ensure that after being detached from a domain, the endpoint
201    // cannot access any mapping from that domain.
202    //
203    // Currently, we only support detaching an endpoint if it is the only endpoint attached
204    // to its domain.
205    fn detach_endpoint(
206        endpoint_map: &mut BTreeMap<u32, u32>,
207        domain_map: &mut DomainMap,
208        endpoint: u32,
209    ) -> (bool, Option<EventAsync>) {
210        let mut evt = None;
211        // The endpoint has attached to an IOMMU domain
212        if let Some(attached_domain) = endpoint_map.get(&endpoint) {
213            // Remove the entry or update the domain reference count
214            if let Entry::Occupied(o) = domain_map.entry(*attached_domain) {
215                let (refs, mapper) = o.get();
216                if !mapper.lock().supports_detach() {
217                    return (false, None);
218                }
219
220                match refs {
221                    0 => unreachable!(),
222                    1 => {
223                        evt = mapper.lock().reset_domain();
224                        o.remove();
225                    }
226                    _ => return (false, None),
227                }
228            }
229        }
230
231        endpoint_map.remove(&endpoint);
232        (true, evt)
233    }
234
235    // Processes an attach request. This may require detaching the endpoint from
236    // its current endpoint before attaching it to a new endpoint. If that happens
237    // while the endpoint has exported memory, this function returns an event that
238    // will be signaled once all exported memory is released.
239    //
240    // Notes: if a VFIO group contains multiple devices, it could violate the follow
241    // requirement from the virtio IOMMU spec: If the VIRTIO_IOMMU_F_BYPASS feature
242    // is negotiated, all accesses from unattached endpoints are allowed and translated
243    // by the IOMMU using the identity function. If the feature is not negotiated, any
244    // memory access from an unattached endpoint fails.
245    //
246    // This happens after the virtio-iommu device receives a VIRTIO_IOMMU_T_ATTACH
247    // request for the first endpoint in a VFIO group, any not yet attached endpoints
248    // in the VFIO group will be able to access the domain.
249    //
250    // This violation is benign for current virtualization use cases. Since device
251    // topology in the guest matches topology in the host, the guest doesn't expect
252    // the device in the same VFIO group are isolated from each other in the first place.
253    fn process_attach_request(
254        &mut self,
255        reader: &mut Reader,
256        tail: &mut virtio_iommu_req_tail,
257    ) -> Result<(usize, Option<EventAsync>)> {
258        let req: virtio_iommu_req_attach =
259            reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
260        let mut fault_resolved_event = None;
261
262        // If the reserved field of an ATTACH request is not zero,
263        // the device MUST reject the request and set status to
264        // VIRTIO_IOMMU_S_INVAL.
265        if req.reserved.iter().any(|&x| x != 0) {
266            tail.status = VIRTIO_IOMMU_S_INVAL;
267            return Ok((0, None));
268        }
269
270        let domain: u32 = req.domain.into();
271        let endpoint: u32 = req.endpoint.into();
272
273        if let Some(mapper) = self.endpoints.get(&endpoint) {
274            // The same mapper can't be used for two domains at the same time,
275            // since that would result in conflicts/permission leaks between
276            // the two domains.
277            let mapper_id = {
278                let m = mapper.lock();
279                ((**m).type_id(), m.id())
280            };
281            for (other_endpoint, other_mapper) in self.endpoints.iter() {
282                if *other_endpoint == endpoint {
283                    continue;
284                }
285                let other_id = {
286                    let m = other_mapper.lock();
287                    ((**m).type_id(), m.id())
288                };
289                if mapper_id == other_id {
290                    if self
291                        .endpoint_map
292                        .get(other_endpoint)
293                        .is_some_and(|d| d != &domain)
294                    {
295                        tail.status = VIRTIO_IOMMU_S_UNSUPP;
296                        return Ok((0, None));
297                    }
298                }
299            }
300
301            // If the endpoint identified by `endpoint` is already attached
302            // to another domain, then the device SHOULD first detach it
303            // from that domain and attach it to the one identified by domain.
304            if self.endpoint_map.contains_key(&endpoint) {
305                // In that case the device SHOULD behave as if the driver issued
306                // a DETACH request with this endpoint, followed by the ATTACH
307                // request. If the device cannot do so, it MUST reject the request
308                // and set status to VIRTIO_IOMMU_S_UNSUPP.
309                let (detached, evt) =
310                    Self::detach_endpoint(&mut self.endpoint_map, &mut self.domain_map, endpoint);
311                if !detached {
312                    tail.status = VIRTIO_IOMMU_S_UNSUPP;
313                    return Ok((0, None));
314                }
315                fault_resolved_event = evt;
316            }
317
318            let new_ref = match self.domain_map.get(&domain) {
319                None => 1,
320                Some(val) => val.0 + 1,
321            };
322
323            self.endpoint_map.insert(endpoint, domain);
324            self.domain_map.insert(domain, (new_ref, mapper.clone()));
325        } else {
326            // If the endpoint identified by endpoint doesn’t exist,
327            // the device MUST reject the request and set status to
328            // VIRTIO_IOMMU_S_NOENT.
329            tail.status = VIRTIO_IOMMU_S_NOENT;
330        }
331
332        Ok((0, fault_resolved_event))
333    }
334
335    fn process_detach_request(
336        &mut self,
337        reader: &mut Reader,
338        tail: &mut virtio_iommu_req_tail,
339    ) -> Result<(usize, Option<EventAsync>)> {
340        let req: virtio_iommu_req_detach =
341            reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
342
343        // If the endpoint identified by |req.endpoint| doesn’t exist,
344        // the device MUST reject the request and set status to
345        // VIRTIO_IOMMU_S_NOENT.
346        let endpoint: u32 = req.endpoint.into();
347        if !self.endpoints.contains_key(&endpoint) {
348            tail.status = VIRTIO_IOMMU_S_NOENT;
349            return Ok((0, None));
350        }
351
352        let (detached, evt) =
353            Self::detach_endpoint(&mut self.endpoint_map, &mut self.domain_map, endpoint);
354        if !detached {
355            tail.status = VIRTIO_IOMMU_S_UNSUPP;
356        }
357        Ok((0, evt))
358    }
359
360    fn process_dma_map_request(
361        &mut self,
362        reader: &mut Reader,
363        tail: &mut virtio_iommu_req_tail,
364    ) -> Result<usize> {
365        let req: virtio_iommu_req_map = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
366
367        let phys_start = u64::from(req.phys_start);
368        let virt_start = u64::from(req.virt_start);
369        let virt_end = u64::from(req.virt_end);
370
371        // enforce driver requirement: virt_end MUST be strictly greater than virt_start.
372        if virt_start >= virt_end {
373            tail.status = VIRTIO_IOMMU_S_INVAL;
374            return Ok(0);
375        }
376
377        // If virt_start, phys_start or (virt_end + 1) is not aligned
378        // on the page granularity, the device SHOULD reject the
379        // request and set status to VIRTIO_IOMMU_S_RANGE
380        if self.page_mask & phys_start != 0
381            || self.page_mask & virt_start != 0
382            || self.page_mask & (virt_end + 1) != 0
383        {
384            tail.status = VIRTIO_IOMMU_S_RANGE;
385            return Ok(0);
386        }
387
388        // If the device doesn’t recognize a flags bit, it MUST reject
389        // the request and set status to VIRTIO_IOMMU_S_INVAL.
390        if u32::from(req.flags) & !VIRTIO_IOMMU_MAP_F_MASK != 0 {
391            tail.status = VIRTIO_IOMMU_S_INVAL;
392            return Ok(0);
393        }
394
395        let domain: u32 = req.domain.into();
396        if !self.domain_map.contains_key(&domain) {
397            // If domain does not exist, the device SHOULD reject
398            // the request and set status to VIRTIO_IOMMU_S_NOENT.
399            tail.status = VIRTIO_IOMMU_S_NOENT;
400            return Ok(0);
401        }
402
403        // The device MUST NOT allow writes to a range mapped
404        // without the VIRTIO_IOMMU_MAP_F_WRITE flag.
405        let write_en = u32::from(req.flags) & VIRTIO_IOMMU_MAP_F_WRITE != 0;
406
407        if let Some(mapper) = self.domain_map.get(&domain) {
408            let gpa = phys_start;
409            let iova = virt_start;
410            let Some(size) = u64::checked_add(virt_end - virt_start, 1) else {
411                // implementation doesn't support unlikely request for size == U64::MAX+1
412                tail.status = VIRTIO_IOMMU_S_DEVERR;
413                return Ok(0);
414            };
415
416            let dmabuf_map =
417                self.dmabuf_mem
418                    .range(..=gpa)
419                    .next_back()
420                    .and_then(|(base_gpa, region)| {
421                        if gpa + size <= base_gpa + region.size {
422                            let offset = gpa - base_gpa;
423                            Some(region.mmap.as_ptr() as u64 + offset)
424                        } else {
425                            None
426                        }
427                    });
428
429            let prot = match write_en {
430                true => Protection::read_write(),
431                false => Protection::read(),
432            };
433
434            let vfio_map_result = match dmabuf_map {
435                // SAFETY:
436                // Safe because [dmabuf_map, dmabuf_map + size) refers to an external mmap'ed
437                // region.
438                Some(dmabuf_map) => unsafe {
439                    mapper.1.lock().vfio_dma_map(iova, dmabuf_map, size, prot)
440                },
441                None => mapper.1.lock().add_map(MappingInfo {
442                    iova,
443                    gpa: GuestAddress(gpa),
444                    size,
445                    prot,
446                }),
447            };
448
449            match vfio_map_result {
450                Ok(AddMapResult::Ok) => (),
451                Ok(AddMapResult::OverlapFailure) => {
452                    // If a mapping already exists in the requested range,
453                    // the device SHOULD reject the request and set status
454                    // to VIRTIO_IOMMU_S_INVAL.
455                    tail.status = VIRTIO_IOMMU_S_INVAL;
456                }
457                Err(e) => return Err(IommuError::MemoryMapper(e)),
458            }
459        }
460
461        Ok(0)
462    }
463
464    fn process_dma_unmap_request(
465        &mut self,
466        reader: &mut Reader,
467        tail: &mut virtio_iommu_req_tail,
468    ) -> Result<(usize, Option<EventAsync>)> {
469        let req: virtio_iommu_req_unmap = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
470
471        let domain: u32 = req.domain.into();
472        let fault_resolved_event = if let Some(mapper) = self.domain_map.get(&domain) {
473            let size = u64::from(req.virt_end) - u64::from(req.virt_start) + 1;
474            let res = mapper
475                .1
476                .lock()
477                .remove_map(u64::from(req.virt_start), size)
478                .map_err(IommuError::MemoryMapper)?;
479            match res {
480                RemoveMapResult::Success(evt) => evt,
481                RemoveMapResult::OverlapFailure => {
482                    // If a mapping affected by the range is not covered in its entirety by the
483                    // range (the UNMAP request would split the mapping), then the device SHOULD
484                    // set the request `status` to VIRTIO_IOMMU_S_RANGE, and SHOULD NOT remove
485                    // any mapping.
486                    tail.status = VIRTIO_IOMMU_S_RANGE;
487                    None
488                }
489            }
490        } else {
491            // If domain does not exist, the device SHOULD set the
492            // request status to VIRTIO_IOMMU_S_NOENT
493            tail.status = VIRTIO_IOMMU_S_NOENT;
494            None
495        };
496
497        Ok((0, fault_resolved_event))
498    }
499
500    #[cfg(target_arch = "x86_64")]
501    fn process_probe_request(
502        &mut self,
503        reader: &mut Reader,
504        writer: &mut Writer,
505        tail: &mut virtio_iommu_req_tail,
506    ) -> Result<usize> {
507        let req: virtio_iommu_req_probe = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
508        let endpoint: u32 = req.endpoint.into();
509
510        // If the endpoint identified by endpoint doesn’t exist,
511        // then the device SHOULD reject the request and set status
512        // to VIRTIO_IOMMU_S_NOENT.
513        if !self.endpoints.contains_key(&endpoint) {
514            tail.status = VIRTIO_IOMMU_S_NOENT;
515        }
516
517        let properties_size = writer.available_bytes() - size_of::<virtio_iommu_req_tail>();
518
519        // It's OK if properties_size is larger than probe_size
520        // We are good even if properties_size is 0
521        if properties_size < IOMMU_PROBE_SIZE {
522            // If the properties list is smaller than probe_size, the device
523            // SHOULD NOT write any property. It SHOULD reject the request
524            // and set status to VIRTIO_IOMMU_S_INVAL.
525            tail.status = VIRTIO_IOMMU_S_INVAL;
526        } else if tail.status == VIRTIO_IOMMU_S_OK {
527            const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1;
528            const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1;
529            const PROBE_PROPERTY_SIZE: u16 = 4;
530            const X86_MSI_IOVA_START: u64 = 0xfee0_0000;
531            const X86_MSI_IOVA_END: u64 = 0xfeef_ffff;
532
533            let properties = virtio_iommu_probe_resv_mem {
534                head: virtio_iommu_probe_property {
535                    type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM.into(),
536                    length: (IOMMU_PROBE_SIZE as u16 - PROBE_PROPERTY_SIZE).into(),
537                },
538                subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
539                start: X86_MSI_IOVA_START.into(),
540                end: X86_MSI_IOVA_END.into(),
541                ..Default::default()
542            };
543            writer
544                .write_all(properties.as_bytes())
545                .map_err(IommuError::GuestMemoryWrite)?;
546        }
547
548        // If the device doesn’t fill all probe_size bytes with properties,
549        // it SHOULD fill the remaining bytes of properties with zeroes.
550        let remaining_bytes = writer.available_bytes() - size_of::<virtio_iommu_req_tail>();
551
552        if remaining_bytes > 0 {
553            let buffer: Vec<u8> = vec![0; remaining_bytes];
554            writer
555                .write_all(buffer.as_slice())
556                .map_err(IommuError::GuestMemoryWrite)?;
557        }
558
559        Ok(properties_size)
560    }
561
562    fn execute_request(
563        &mut self,
564        avail_desc: &mut DescriptorChain,
565    ) -> Result<(usize, Option<EventAsync>)> {
566        let reader = &mut avail_desc.reader;
567        let writer = &mut avail_desc.writer;
568
569        // at least we need space to write VirtioIommuReqTail
570        if writer.available_bytes() < size_of::<virtio_iommu_req_tail>() {
571            return Err(IommuError::WriteBufferTooSmall);
572        }
573
574        let req_head: virtio_iommu_req_head =
575            reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
576
577        let mut tail = virtio_iommu_req_tail {
578            status: VIRTIO_IOMMU_S_OK,
579            ..Default::default()
580        };
581
582        let (reply_len, fault_resolved_event) = match req_head.type_ {
583            VIRTIO_IOMMU_T_ATTACH => self.process_attach_request(reader, &mut tail)?,
584            VIRTIO_IOMMU_T_DETACH => self.process_detach_request(reader, &mut tail)?,
585            VIRTIO_IOMMU_T_MAP => (self.process_dma_map_request(reader, &mut tail)?, None),
586            VIRTIO_IOMMU_T_UNMAP => self.process_dma_unmap_request(reader, &mut tail)?,
587            #[cfg(target_arch = "x86_64")]
588            VIRTIO_IOMMU_T_PROBE => (self.process_probe_request(reader, writer, &mut tail)?, None),
589            _ => return Err(IommuError::UnexpectedDescriptor),
590        };
591
592        writer
593            .write_all(tail.as_bytes())
594            .map_err(IommuError::GuestMemoryWrite)?;
595        Ok((
596            reply_len + size_of::<virtio_iommu_req_tail>(),
597            fault_resolved_event,
598        ))
599    }
600}
601
602async fn request_queue(
603    state: &Rc<RefCell<State>>,
604    mut queue: Queue,
605    mut queue_event: EventAsync,
606) -> Result<()> {
607    loop {
608        let mut avail_desc = queue
609            .next_async(&mut queue_event)
610            .await
611            .map_err(IommuError::ReadAsyncDesc)?;
612
613        let (len, fault_resolved_event) = match state.borrow_mut().execute_request(&mut avail_desc)
614        {
615            Ok(res) => res,
616            Err(e) => {
617                error!("execute_request failed: {}", e);
618
619                // If a request type is not recognized, the device SHOULD NOT write
620                // the buffer and SHOULD set the used length to zero
621                (0, None)
622            }
623        };
624
625        if let Some(fault_resolved_event) = fault_resolved_event {
626            debug!("waiting for iommu fault resolution");
627            fault_resolved_event
628                .next_val()
629                .await
630                .expect("failed waiting for fault");
631            debug!("iommu fault resolved");
632        }
633
634        queue.add_used_with_bytes_written(avail_desc, len as u32);
635        queue.trigger_interrupt();
636    }
637}
638
639fn run(
640    state: State,
641    iommu_device_tube: Tube,
642    mut queues: BTreeMap<usize, Queue>,
643    kill_evt: Event,
644    translate_response_senders: Option<BTreeMap<u32, Tube>>,
645    translate_request_rx: Option<Tube>,
646) -> Result<()> {
647    let state = Rc::new(RefCell::new(state));
648    let ex = Executor::new().expect("Failed to create an executor");
649
650    let req_queue = queues.remove(&0).unwrap();
651    let req_evt = req_queue
652        .event()
653        .try_clone()
654        .expect("Failed to clone queue event");
655    let req_evt = EventAsync::new(req_evt, &ex).expect("Failed to create async event for queue");
656
657    let f_kill = async_utils::await_and_exit(&ex, kill_evt);
658
659    let request_tube = translate_request_rx
660        .map(|t| AsyncTube::new(&ex, t).expect("Failed to create async tube for rx"));
661    let response_tubes = translate_response_senders.map(|m| {
662        m.into_iter()
663            .map(|x| {
664                (
665                    x.0,
666                    AsyncTube::new(&ex, x.1).expect("Failed to create async tube"),
667                )
668            })
669            .collect()
670    });
671
672    let f_handle_translate_request =
673        sys::handle_translate_request(&ex, &state, request_tube, response_tubes);
674    let f_request = request_queue(&state, req_queue, req_evt);
675
676    let command_tube = AsyncTube::new(&ex, iommu_device_tube).unwrap();
677    // Future to handle command messages from host, such as passing vfio containers.
678    let f_cmd = sys::handle_command_tube(&state, command_tube);
679
680    let done = async {
681        select! {
682            res = f_request.fuse() => res.context("error in handling request queue"),
683            res = f_kill.fuse() => res.context("error in await_and_exit"),
684            res = f_handle_translate_request.fuse() => {
685                res.context("error in handle_translate_request")
686            }
687            res = f_cmd.fuse() => res.context("error in handling host request"),
688        }
689    };
690    match ex.run_until(done) {
691        Ok(Ok(())) => {}
692        Ok(Err(e)) => error!("Error in worker: {:#}", e),
693        Err(e) => return Err(IommuError::AsyncExec(e)),
694    }
695
696    Ok(())
697}
698
699/// Virtio device for IOMMU memory management.
700pub struct Iommu {
701    worker_thread: Option<WorkerThread<()>>,
702    config: virtio_iommu_config,
703    avail_features: u64,
704    // Attached endpoints
705    // key: endpoint PCI address
706    // value: reference counter and MemoryMapperTrait
707    endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
708    // Hot-pluggable PCI endpoints ranges
709    // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address)
710    hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
711    translate_response_senders: Option<BTreeMap<u32, Tube>>,
712    translate_request_rx: Option<Tube>,
713    iommu_device_tube: Option<Tube>,
714}
715
716impl Iommu {
717    /// Create a new virtio IOMMU device.
718    pub fn new(
719        base_features: u64,
720        endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
721        iova_max_addr: u64,
722        hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
723        translate_response_senders: Option<BTreeMap<u32, Tube>>,
724        translate_request_rx: Option<Tube>,
725        iommu_device_tube: Option<Tube>,
726    ) -> SysResult<Iommu> {
727        let mut page_size_mask = !((pagesize() as u64) - 1);
728        for (_, container) in endpoints.iter() {
729            page_size_mask &= container
730                .lock()
731                .get_mask()
732                .map_err(|_e| SysError::new(libc::EIO))?;
733        }
734
735        if page_size_mask == 0 {
736            return Err(SysError::new(libc::EIO));
737        }
738
739        let input_range = virtio_iommu_range_64 {
740            start: Le64::from(0),
741            end: iova_max_addr.into(),
742        };
743
744        let config = virtio_iommu_config {
745            page_size_mask: page_size_mask.into(),
746            input_range,
747            #[cfg(target_arch = "x86_64")]
748            probe_size: (IOMMU_PROBE_SIZE as u32).into(),
749            ..Default::default()
750        };
751
752        let mut avail_features: u64 = base_features;
753        avail_features |= 1 << VIRTIO_IOMMU_F_MAP_UNMAP
754            | 1 << VIRTIO_IOMMU_F_INPUT_RANGE
755            | 1 << VIRTIO_IOMMU_F_MMIO;
756
757        if cfg!(target_arch = "x86_64") {
758            avail_features |= 1 << VIRTIO_IOMMU_F_PROBE;
759        }
760
761        Ok(Iommu {
762            worker_thread: None,
763            config,
764            avail_features,
765            endpoints,
766            hp_endpoints_ranges,
767            translate_response_senders,
768            translate_request_rx,
769            iommu_device_tube,
770        })
771    }
772}
773
774impl VirtioDevice for Iommu {
775    fn keep_rds(&self) -> Vec<RawDescriptor> {
776        let mut rds = Vec::new();
777
778        for (_, mapper) in self.endpoints.iter() {
779            rds.append(&mut mapper.lock().as_raw_descriptors());
780        }
781        if let Some(senders) = &self.translate_response_senders {
782            for (_, tube) in senders.iter() {
783                rds.push(tube.as_raw_descriptor());
784            }
785        }
786        if let Some(rx) = &self.translate_request_rx {
787            rds.push(rx.as_raw_descriptor());
788        }
789
790        if let Some(iommu_device_tube) = &self.iommu_device_tube {
791            rds.push(iommu_device_tube.as_raw_descriptor());
792        }
793
794        rds
795    }
796
797    fn device_type(&self) -> DeviceType {
798        DeviceType::Iommu
799    }
800
801    fn queue_max_sizes(&self) -> &[u16] {
802        QUEUE_SIZES
803    }
804
805    fn features(&self) -> u64 {
806        self.avail_features
807    }
808
809    fn read_config(&self, offset: u64, data: &mut [u8]) {
810        let mut config: Vec<u8> = Vec::new();
811        config.extend_from_slice(self.config.as_bytes());
812        copy_config(data, 0, config.as_slice(), offset);
813    }
814
815    fn activate(
816        &mut self,
817        mem: GuestMemory,
818        _interrupt: Interrupt,
819        queues: BTreeMap<usize, Queue>,
820    ) -> anyhow::Result<()> {
821        if queues.len() != QUEUE_SIZES.len() {
822            return Err(anyhow!(
823                "expected {} queues, got {}",
824                QUEUE_SIZES.len(),
825                queues.len()
826            ));
827        }
828
829        // The least significant bit of page_size_masks defines the page
830        // granularity of IOMMU mappings
831        let page_mask = (1u64 << u64::from(self.config.page_size_mask).trailing_zeros()) - 1;
832        let eps = self.endpoints.clone();
833        let hp_endpoints_ranges = self.hp_endpoints_ranges.to_owned();
834
835        let translate_response_senders = self.translate_response_senders.take();
836        let translate_request_rx = self.translate_request_rx.take();
837
838        let iommu_device_tube = self
839            .iommu_device_tube
840            .take()
841            .context("failed to start virtio-iommu worker: No control tube")?;
842
843        self.worker_thread = Some(WorkerThread::start("v_iommu", move |kill_evt| {
844            let state = State {
845                mem,
846                page_mask,
847                hp_endpoints_ranges,
848                endpoint_map: BTreeMap::new(),
849                domain_map: BTreeMap::new(),
850                endpoints: eps,
851                dmabuf_mem: BTreeMap::new(),
852            };
853            let result = run(
854                state,
855                iommu_device_tube,
856                queues,
857                kill_evt,
858                translate_response_senders,
859                translate_request_rx,
860            );
861            if let Err(e) = result {
862                error!("virtio-iommu worker thread exited with error: {}", e);
863            }
864        }));
865        Ok(())
866    }
867
868    #[cfg(target_arch = "x86_64")]
869    fn generate_acpi(
870        &mut self,
871        pci_address: PciAddress,
872        sdts: &mut Vec<SDT>,
873    ) -> anyhow::Result<()> {
874        const OEM_REVISION: u32 = 1;
875        const VIOT_REVISION: u8 = 0;
876
877        // there should only be one VIOT table
878        if sdts.iter().any(|sdt| sdt.is_signature(b"VIOT")) {
879            return Err(anyhow!("duplicate VIOT table"));
880        }
881
882        let mut viot = SDT::new(
883            *b"VIOT",
884            acpi_tables::HEADER_LEN,
885            VIOT_REVISION,
886            *b"CROSVM",
887            *b"CROSVMDT",
888            OEM_REVISION,
889        );
890        viot.append(VirtioIommuViotHeader {
891            // # of PCI range nodes + 1 virtio-pci node
892            node_count: (self.endpoints.len() + self.hp_endpoints_ranges.len() + 1) as u16,
893            node_offset: (viot.len() + std::mem::size_of::<VirtioIommuViotHeader>()) as u16,
894            ..Default::default()
895        });
896
897        let bdf = pci_address.to_u32() as u16;
898        let iommu_offset = viot.len();
899
900        viot.append(VirtioIommuViotVirtioPciNode {
901            type_: VIRTIO_IOMMU_VIOT_NODE_VIRTIO_IOMMU_PCI,
902            length: size_of::<VirtioIommuViotVirtioPciNode>() as u16,
903            bdf,
904            ..Default::default()
905        });
906
907        for (endpoint, _) in self.endpoints.iter() {
908            viot.append(VirtioIommuViotPciRangeNode {
909                type_: VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE,
910                length: size_of::<VirtioIommuViotPciRangeNode>() as u16,
911                endpoint_start: *endpoint,
912                bdf_start: *endpoint as u16,
913                bdf_end: *endpoint as u16,
914                output_node: iommu_offset as u16,
915                ..Default::default()
916            });
917        }
918
919        for endpoints_range in self.hp_endpoints_ranges.iter() {
920            let (endpoint_start, endpoint_end) = endpoints_range.clone().into_inner();
921            viot.append(VirtioIommuViotPciRangeNode {
922                type_: VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE,
923                length: size_of::<VirtioIommuViotPciRangeNode>() as u16,
924                endpoint_start,
925                bdf_start: endpoint_start as u16,
926                bdf_end: endpoint_end as u16,
927                output_node: iommu_offset as u16,
928                ..Default::default()
929            });
930        }
931
932        sdts.push(viot);
933        Ok(())
934    }
935}