devices/virtio/iommu/sys/
linux.rs

1// Copyright 2022 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5pub mod vfio_wrapper;
6
7use std::cell::RefCell;
8use std::collections::BTreeMap;
9use std::fs::File;
10use std::rc::Rc;
11use std::sync::Arc;
12
13use base::error;
14use base::MemoryMappingBuilder;
15use base::TubeError;
16use cros_async::AsyncTube;
17use cros_async::Executor;
18use sync::Mutex;
19use vm_control::VirtioIOMMURequest;
20use vm_control::VirtioIOMMUResponse;
21use vm_control::VirtioIOMMUVfioCommand;
22use vm_control::VirtioIOMMUVfioResult;
23use vm_control::VmMemoryRegionId;
24
25use self::vfio_wrapper::VfioWrapper;
26use crate::virtio::iommu::ipc_memory_mapper::IommuRequest;
27use crate::virtio::iommu::ipc_memory_mapper::IommuResponse;
28use crate::virtio::iommu::DmabufRegionEntry;
29use crate::virtio::iommu::Result;
30use crate::virtio::iommu::State;
31use crate::virtio::IommuError;
32use crate::VfioContainer;
33
34impl State {
35    pub(in crate::virtio::iommu) fn handle_add_vfio_device(
36        &mut self,
37        endpoint_addr: u32,
38        wrapper: VfioWrapper,
39    ) -> VirtioIOMMUVfioResult {
40        let exists = |endpoint_addr: u32| -> bool {
41            for endpoints_range in self.hp_endpoints_ranges.iter() {
42                if endpoints_range.contains(&endpoint_addr) {
43                    return true;
44                }
45            }
46            false
47        };
48
49        if !exists(endpoint_addr) {
50            return VirtioIOMMUVfioResult::NotInPCIRanges;
51        }
52
53        self.endpoints
54            .insert(endpoint_addr, Arc::new(Mutex::new(Box::new(wrapper))));
55        VirtioIOMMUVfioResult::Ok
56    }
57
58    pub(in crate::virtio::iommu) fn handle_del_vfio_device(
59        &mut self,
60        pci_address: u32,
61    ) -> VirtioIOMMUVfioResult {
62        if self.endpoints.remove(&pci_address).is_none() {
63            error!("There is no vfio container of {}", pci_address);
64            return VirtioIOMMUVfioResult::NoSuchDevice;
65        }
66        if let Some(domain) = self.endpoint_map.remove(&pci_address) {
67            self.domain_map.remove(&domain);
68        }
69        VirtioIOMMUVfioResult::Ok
70    }
71
72    pub(in crate::virtio::iommu) fn handle_map_dmabuf(
73        &mut self,
74        region_id: VmMemoryRegionId,
75        gpa: u64,
76        size: u64,
77        dma_buf: File,
78    ) -> VirtioIOMMUVfioResult {
79        if gpa & self.page_mask != 0 {
80            error!("cannot map dmabuf to non-page-aligned guest physical address");
81            return VirtioIOMMUVfioResult::InvalidParam;
82        }
83        let mmap = match MemoryMappingBuilder::new(size as usize)
84            .from_file(&dma_buf)
85            .build()
86        {
87            Ok(v) => v,
88            Err(_) => {
89                error!("failed to mmap dma_buf");
90                return VirtioIOMMUVfioResult::InvalidParam;
91            }
92        };
93        self.dmabuf_mem.insert(
94            gpa,
95            DmabufRegionEntry {
96                mmap,
97                region_id,
98                size,
99            },
100        );
101
102        VirtioIOMMUVfioResult::Ok
103    }
104
105    pub(in crate::virtio::iommu) fn handle_unmap_dmabuf(
106        &mut self,
107        region_id: VmMemoryRegionId,
108    ) -> VirtioIOMMUVfioResult {
109        if let Some(range) = self
110            .dmabuf_mem
111            .iter()
112            .find(|(_, dmabuf_entry)| dmabuf_entry.region_id == region_id)
113            .map(|entry| *entry.0)
114        {
115            self.dmabuf_mem.remove(&range);
116            VirtioIOMMUVfioResult::Ok
117        } else {
118            VirtioIOMMUVfioResult::NoSuchMappedDmabuf
119        }
120    }
121
122    pub(in crate::virtio::iommu) fn handle_vfio(
123        &mut self,
124        vfio_cmd: VirtioIOMMUVfioCommand,
125    ) -> VirtioIOMMUResponse {
126        use VirtioIOMMUVfioCommand::*;
127        let vfio_result = match vfio_cmd {
128            VfioDeviceAdd {
129                wrapper_id,
130                container,
131                endpoint_addr,
132            } => match VfioContainer::new_from_container(container) {
133                Ok(vfio_container) => {
134                    let wrapper =
135                        VfioWrapper::new_with_id(vfio_container, wrapper_id, self.mem.clone());
136                    self.handle_add_vfio_device(endpoint_addr, wrapper)
137                }
138                Err(e) => {
139                    error!("failed to verify the new container: {}", e);
140                    VirtioIOMMUVfioResult::NoAvailableContainer
141                }
142            },
143            VfioDeviceDel { endpoint_addr } => self.handle_del_vfio_device(endpoint_addr),
144            VfioDmabufMap {
145                region_id,
146                gpa,
147                size,
148                dma_buf,
149            } => self.handle_map_dmabuf(region_id, gpa, size, File::from(dma_buf)),
150            VfioDmabufUnmap(region_id) => self.handle_unmap_dmabuf(region_id),
151        };
152        VirtioIOMMUResponse::VfioResponse(vfio_result)
153    }
154}
155
156pub(in crate::virtio::iommu) async fn handle_command_tube(
157    state: &Rc<RefCell<State>>,
158    command_tube: AsyncTube,
159) -> Result<()> {
160    loop {
161        match command_tube.next::<VirtioIOMMURequest>().await {
162            Ok(command) => {
163                let response: VirtioIOMMUResponse = match command {
164                    VirtioIOMMURequest::VfioCommand(vfio_cmd) => {
165                        state.borrow_mut().handle_vfio(vfio_cmd)
166                    }
167                };
168                if let Err(e) = command_tube.send(response).await {
169                    error!("{}", IommuError::VirtioIOMMUResponseError(e));
170                }
171            }
172            Err(e) => {
173                return Err(IommuError::VirtioIOMMUReqError(e));
174            }
175        }
176    }
177}
178
179pub(in crate::virtio::iommu) async fn handle_translate_request(
180    ex: &Executor,
181    state: &Rc<RefCell<State>>,
182    request_tube: Option<AsyncTube>,
183    response_tubes: Option<BTreeMap<u32, AsyncTube>>,
184) -> Result<()> {
185    let request_tube = match request_tube {
186        Some(r) => r,
187        None => {
188            futures::future::pending::<()>().await;
189            return Ok(());
190        }
191    };
192    let response_tubes = response_tubes.unwrap();
193    loop {
194        let req: IommuRequest = match request_tube.next().await {
195            Ok(req) => req,
196            Err(TubeError::Disconnected) => {
197                // This means the process on the other side of the tube went away. That's
198                // not a problem with virtio-iommu itself, so just exit this callback
199                // and wait for crosvm to exit.
200                return Ok(());
201            }
202            Err(e) => {
203                return Err(IommuError::Tube(e));
204            }
205        };
206        let resp = if let Some(mapper) = state.borrow().endpoints.get(&req.get_endpoint_id()) {
207            match req {
208                IommuRequest::Export { iova, size, .. } => {
209                    mapper.lock().export(iova, size).map(IommuResponse::Export)
210                }
211                IommuRequest::Release { iova, size, .. } => mapper
212                    .lock()
213                    .release(iova, size)
214                    .map(|_| IommuResponse::Release),
215                IommuRequest::StartExportSession { .. } => mapper
216                    .lock()
217                    .start_export_session(ex)
218                    .map(IommuResponse::StartExportSession),
219            }
220        } else {
221            error!("endpoint {} not found", req.get_endpoint_id());
222            continue;
223        };
224        let resp: IommuResponse = match resp {
225            Ok(resp) => resp,
226            Err(e) => IommuResponse::Err(format!("{e:?}")),
227        };
228        response_tubes
229            .get(&req.get_endpoint_id())
230            .unwrap()
231            .send(resp)
232            .await
233            .map_err(IommuError::Tube)?;
234    }
235}