1use std::sync::Arc;
8
9use anyhow::anyhow;
10use anyhow::bail;
11use anyhow::Context;
12use anyhow::Result;
13use base::error;
14use base::AsRawDescriptor;
15use base::AsRawDescriptors;
16use base::Event;
17use base::Protection;
18use base::RawDescriptor;
19use base::Tube;
20use serde::Deserialize;
21use serde::Serialize;
22use smallvec::SmallVec;
23use sync::Mutex;
24use vm_memory::GuestAddress;
25use vm_memory::GuestMemory;
26use zerocopy::FromBytes;
27use zerocopy::FromZeros;
28use zerocopy::Immutable;
29use zerocopy::IntoBytes;
30
31use crate::virtio::memory_mapper::MemRegion;
32
33#[derive(Serialize, Deserialize)]
34pub(super) enum IommuRequest {
35 Export {
36 endpoint_id: u32,
37 iova: u64,
38 size: u64,
39 },
40 Release {
41 endpoint_id: u32,
42 iova: u64,
43 size: u64,
44 },
45 StartExportSession {
46 endpoint_id: u32,
47 },
48}
49
50#[derive(Serialize, Deserialize)]
51pub(super) enum IommuResponse {
52 Export(Vec<MemRegion>),
53 Release,
54 StartExportSession(Event),
55 Err(String),
56}
57
58impl IommuRequest {
59 pub(super) fn get_endpoint_id(&self) -> u32 {
60 match self {
61 Self::Export { endpoint_id, .. } => *endpoint_id,
62 Self::Release { endpoint_id, .. } => *endpoint_id,
63 Self::StartExportSession { endpoint_id } => *endpoint_id,
64 }
65 }
66}
67
68pub struct IpcMemoryMapper {
71 request_tx: Tube,
72 response_rx: Tube,
73 endpoint_id: u32,
74}
75
76impl std::fmt::Debug for IpcMemoryMapper {
77 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
78 f.debug_struct("IpcMemoryMapper")
79 .field("endpoint_id", &self.endpoint_id)
80 .finish()
81 }
82}
83
84fn map_bad_resp(resp: IommuResponse) -> anyhow::Error {
85 match resp {
86 IommuResponse::Err(e) => anyhow!("remote error {}", e),
87 _ => anyhow!("response type mismatch"),
88 }
89}
90
91impl IpcMemoryMapper {
92 pub fn new(request_tx: Tube, response_rx: Tube, endpoint_id: u32) -> Self {
100 Self {
101 request_tx,
102 response_rx,
103 endpoint_id,
104 }
105 }
106
107 fn do_request(&self, req: IommuRequest) -> Result<IommuResponse> {
108 self.request_tx
109 .send(&req)
110 .context("failed to send request")?;
111 self.response_rx
112 .recv::<IommuResponse>()
113 .context("failed to get response")
114 }
115
116 pub fn export(&mut self, iova: u64, size: u64) -> Result<Vec<MemRegion>> {
118 let req = IommuRequest::Export {
119 endpoint_id: self.endpoint_id,
120 iova,
121 size,
122 };
123 match self.do_request(req)? {
124 IommuResponse::Export(vec) => Ok(vec),
125 e => Err(map_bad_resp(e)),
126 }
127 }
128
129 pub fn release(&mut self, iova: u64, size: u64) -> Result<()> {
131 let req = IommuRequest::Release {
132 endpoint_id: self.endpoint_id,
133 iova,
134 size,
135 };
136 match self.do_request(req)? {
137 IommuResponse::Release => Ok(()),
138 e => Err(map_bad_resp(e)),
139 }
140 }
141
142 pub fn start_export_session(&mut self) -> Result<Event> {
144 let req = IommuRequest::StartExportSession {
145 endpoint_id: self.endpoint_id,
146 };
147 match self.do_request(req)? {
148 IommuResponse::StartExportSession(evt) => Ok(evt),
149 e => Err(map_bad_resp(e)),
150 }
151 }
152}
153
154impl AsRawDescriptors for IpcMemoryMapper {
155 fn as_raw_descriptors(&self) -> Vec<RawDescriptor> {
156 vec![
157 self.request_tx.as_raw_descriptor(),
158 self.response_rx.as_raw_descriptor(),
159 ]
160 }
161}
162
163pub struct CreateIpcMapperRet {
164 pub mapper: IpcMemoryMapper,
165 pub response_tx: Tube,
166}
167
168pub fn create_ipc_mapper(endpoint_id: u32, request_tx: Tube) -> CreateIpcMapperRet {
177 let (response_tx, response_rx) = Tube::pair().expect("failed to create tube pair");
178 CreateIpcMapperRet {
179 mapper: IpcMemoryMapper::new(request_tx, response_rx, endpoint_id),
180 response_tx,
181 }
182}
183
184#[derive(Debug)]
185struct ExportedRegionInner {
186 regions: Vec<MemRegion>,
187 iova: u64,
188 size: u64,
189 iommu: Arc<Mutex<IpcMemoryMapper>>,
190}
191
192impl Drop for ExportedRegionInner {
193 fn drop(&mut self) {
194 if let Err(e) = self.iommu.lock().release(self.iova, self.size) {
195 error!("Error releasing region {:?}", e);
196 }
197 }
198}
199
200#[derive(Clone, Debug)]
202pub struct ExportedRegion {
203 inner: Arc<Mutex<ExportedRegionInner>>,
204}
205
206impl ExportedRegion {
207 pub fn new(
209 mem: &GuestMemory,
210 iommu: Arc<Mutex<IpcMemoryMapper>>,
211 iova: u64,
212 size: u64,
213 ) -> Result<Self> {
214 let regions = iommu
215 .lock()
216 .export(iova, size)
217 .context("failed to export")?;
218 for r in ®ions {
219 if !mem.is_valid_range(r.gpa, r.len) {
220 bail!("region not in memory range");
221 }
222 }
223 Ok(Self {
224 inner: Arc::new(Mutex::new(ExportedRegionInner {
225 regions,
226 iova,
227 size,
228 iommu,
229 })),
230 })
231 }
232
233 fn do_copy<C>(
235 &self,
236 iova: u64,
237 mut remaining: usize,
238 prot: Protection,
239 mut copy_fn: C,
240 ) -> Result<()>
241 where
242 C: FnMut(usize , GuestAddress, usize ) -> Result<usize>,
243 {
244 let inner = self.inner.lock();
245 let mut region_offset = iova.checked_sub(inner.iova).with_context(|| {
246 format!(
247 "out of bounds: src_iova={} region_iova={}",
248 iova, inner.iova
249 )
250 })?;
251 let mut offset = 0;
252 for r in &inner.regions {
253 if region_offset >= r.len {
254 region_offset -= r.len;
255 continue;
256 }
257
258 if !r.prot.allows(&prot) {
259 bail!("gpa is not accessible");
260 }
261
262 let len = (r.len as usize).min(remaining);
263 let copy_len = copy_fn(offset, r.gpa.unchecked_add(region_offset), len)?;
264 if len != copy_len {
265 bail!("incomplete copy: expected={}, actual={}", len, copy_len);
266 }
267
268 remaining -= len;
269 offset += len;
270 region_offset = 0;
271
272 if remaining == 0 {
273 return Ok(());
274 }
275 }
276
277 Err(anyhow!("not enough data: remaining={}", remaining))
278 }
279
280 pub fn read_obj_from_addr<T: IntoBytes + FromBytes + FromZeros>(
283 &self,
284 mem: &GuestMemory,
285 iova: u64,
286 ) -> anyhow::Result<T> {
287 let mut val = T::new_zeroed();
288 let buf = val.as_mut_bytes();
289 self.do_copy(iova, buf.len(), Protection::read(), |offset, gpa, len| {
290 mem.read_at_addr(&mut buf[offset..(offset + len)], gpa)
291 .context("failed to read from gpa")
292 })?;
293 Ok(val)
294 }
295
296 pub fn write_obj_at_addr<T: Immutable + IntoBytes>(
299 &self,
300 mem: &GuestMemory,
301 val: T,
302 iova: u64,
303 ) -> anyhow::Result<()> {
304 let buf = val.as_bytes();
305 self.do_copy(iova, buf.len(), Protection::write(), |offset, gpa, len| {
306 mem.write_at_addr(&buf[offset..(offset + len)], gpa)
307 .context("failed to write from gpa")
308 })?;
309 Ok(())
310 }
311
312 pub fn is_valid(&self, mem: &GuestMemory, iova: u64, size: u64) -> bool {
315 let inner = self.inner.lock();
316 let iova_end = iova.checked_add(size);
317 if iova_end.is_none() {
318 return false;
319 }
320 if iova < inner.iova || iova_end.unwrap() > (inner.iova + inner.size) {
321 return false;
322 }
323 self.inner
324 .lock()
325 .regions
326 .iter()
327 .all(|r| mem.range_overlap(r.gpa, r.gpa.unchecked_add(r.len)))
328 }
329
330 pub fn get_mem_regions(&self) -> SmallVec<[MemRegion; 1]> {
332 SmallVec::from_slice(&self.inner.lock().regions)
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use std::thread;
339
340 use base::Protection;
341 use vm_memory::GuestAddress;
342
343 use super::*;
344
345 #[test]
346 fn test() {
347 let (request_tx, request_rx) = Tube::pair().expect("failed to create tube pair");
348 let CreateIpcMapperRet {
349 mut mapper,
350 response_tx,
351 } = create_ipc_mapper(3, request_tx);
352 let user_handle = thread::spawn(move || {
353 assert!(mapper
354 .export(0x555, 1)
355 .unwrap()
356 .iter()
357 .zip(&vec![MemRegion {
358 gpa: GuestAddress(0x777),
359 len: 1,
360 prot: Protection::read_write(),
361 },])
362 .all(|(a, b)| a == b));
363 });
364 let iommu_handle = thread::spawn(move || {
365 let (endpoint_id, iova, size) = match request_rx.recv().unwrap() {
366 IommuRequest::Export {
367 endpoint_id,
368 iova,
369 size,
370 } => (endpoint_id, iova, size),
371 _ => unreachable!(),
372 };
373 assert_eq!(endpoint_id, 3);
374 assert_eq!(iova, 0x555);
375 assert_eq!(size, 1);
376 response_tx
377 .send(&IommuResponse::Export(vec![MemRegion {
378 gpa: GuestAddress(0x777),
379 len: 1,
380 prot: Protection::read_write(),
381 }]))
382 .unwrap();
383 user_handle.join().unwrap();
387 });
388 iommu_handle.join().unwrap();
389 }
390}