use std::sync::Arc;
use anyhow::anyhow;
use anyhow::bail;
use anyhow::Context;
use anyhow::Result;
use base::error;
use base::AsRawDescriptor;
use base::AsRawDescriptors;
use base::Event;
use base::Protection;
use base::RawDescriptor;
use base::Tube;
use serde::Deserialize;
use serde::Serialize;
use smallvec::SmallVec;
use sync::Mutex;
use vm_memory::GuestAddress;
use vm_memory::GuestMemory;
use zerocopy::AsBytes;
use zerocopy::FromBytes;
use crate::virtio::memory_mapper::MemRegion;
#[derive(Serialize, Deserialize)]
pub(super) enum IommuRequest {
Export {
endpoint_id: u32,
iova: u64,
size: u64,
},
Release {
endpoint_id: u32,
iova: u64,
size: u64,
},
StartExportSession {
endpoint_id: u32,
},
}
#[derive(Serialize, Deserialize)]
pub(super) enum IommuResponse {
Export(Vec<MemRegion>),
Release,
StartExportSession(Event),
Err(String),
}
impl IommuRequest {
pub(super) fn get_endpoint_id(&self) -> u32 {
match self {
Self::Export { endpoint_id, .. } => *endpoint_id,
Self::Release { endpoint_id, .. } => *endpoint_id,
Self::StartExportSession { endpoint_id } => *endpoint_id,
}
}
}
pub struct IpcMemoryMapper {
request_tx: Tube,
response_rx: Tube,
endpoint_id: u32,
}
impl std::fmt::Debug for IpcMemoryMapper {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("IpcMemoryMapper")
.field("endpoint_id", &self.endpoint_id)
.finish()
}
}
fn map_bad_resp(resp: IommuResponse) -> anyhow::Error {
match resp {
IommuResponse::Err(e) => anyhow!("remote error {}", e),
_ => anyhow!("response type mismatch"),
}
}
impl IpcMemoryMapper {
pub fn new(request_tx: Tube, response_rx: Tube, endpoint_id: u32) -> Self {
Self {
request_tx,
response_rx,
endpoint_id,
}
}
fn do_request(&self, req: IommuRequest) -> Result<IommuResponse> {
self.request_tx
.send(&req)
.context("failed to send request")?;
self.response_rx
.recv::<IommuResponse>()
.context("failed to get response")
}
pub fn export(&mut self, iova: u64, size: u64) -> Result<Vec<MemRegion>> {
let req = IommuRequest::Export {
endpoint_id: self.endpoint_id,
iova,
size,
};
match self.do_request(req)? {
IommuResponse::Export(vec) => Ok(vec),
e => Err(map_bad_resp(e)),
}
}
pub fn release(&mut self, iova: u64, size: u64) -> Result<()> {
let req = IommuRequest::Release {
endpoint_id: self.endpoint_id,
iova,
size,
};
match self.do_request(req)? {
IommuResponse::Release => Ok(()),
e => Err(map_bad_resp(e)),
}
}
pub fn start_export_session(&mut self) -> Result<Event> {
let req = IommuRequest::StartExportSession {
endpoint_id: self.endpoint_id,
};
match self.do_request(req)? {
IommuResponse::StartExportSession(evt) => Ok(evt),
e => Err(map_bad_resp(e)),
}
}
}
impl AsRawDescriptors for IpcMemoryMapper {
fn as_raw_descriptors(&self) -> Vec<RawDescriptor> {
vec![
self.request_tx.as_raw_descriptor(),
self.response_rx.as_raw_descriptor(),
]
}
}
pub struct CreateIpcMapperRet {
pub mapper: IpcMemoryMapper,
pub response_tx: Tube,
}
pub fn create_ipc_mapper(endpoint_id: u32, request_tx: Tube) -> CreateIpcMapperRet {
let (response_tx, response_rx) = Tube::pair().expect("failed to create tube pair");
CreateIpcMapperRet {
mapper: IpcMemoryMapper::new(request_tx, response_rx, endpoint_id),
response_tx,
}
}
#[derive(Debug)]
struct ExportedRegionInner {
regions: Vec<MemRegion>,
iova: u64,
size: u64,
iommu: Arc<Mutex<IpcMemoryMapper>>,
}
impl Drop for ExportedRegionInner {
fn drop(&mut self) {
if let Err(e) = self.iommu.lock().release(self.iova, self.size) {
error!("Error releasing region {:?}", e);
}
}
}
#[derive(Clone, Debug)]
pub struct ExportedRegion {
inner: Arc<Mutex<ExportedRegionInner>>,
}
impl ExportedRegion {
pub fn new(
mem: &GuestMemory,
iommu: Arc<Mutex<IpcMemoryMapper>>,
iova: u64,
size: u64,
) -> Result<Self> {
let regions = iommu
.lock()
.export(iova, size)
.context("failed to export")?;
for r in ®ions {
if !mem.is_valid_range(r.gpa, r.len) {
bail!("region not in memory range");
}
}
Ok(Self {
inner: Arc::new(Mutex::new(ExportedRegionInner {
regions,
iova,
size,
iommu,
})),
})
}
fn do_copy<C>(
&self,
iova: u64,
mut remaining: usize,
prot: Protection,
mut copy_fn: C,
) -> Result<()>
where
C: FnMut(usize , GuestAddress, usize ) -> Result<usize>,
{
let inner = self.inner.lock();
let mut region_offset = iova.checked_sub(inner.iova).with_context(|| {
format!(
"out of bounds: src_iova={} region_iova={}",
iova, inner.iova
)
})?;
let mut offset = 0;
for r in &inner.regions {
if region_offset >= r.len {
region_offset -= r.len;
continue;
}
if !r.prot.allows(&prot) {
bail!("gpa is not accessible");
}
let len = (r.len as usize).min(remaining);
let copy_len = copy_fn(offset, r.gpa.unchecked_add(region_offset), len)?;
if len != copy_len {
bail!("incomplete copy: expected={}, actual={}", len, copy_len);
}
remaining -= len;
offset += len;
region_offset = 0;
if remaining == 0 {
return Ok(());
}
}
Err(anyhow!("not enough data: remaining={}", remaining))
}
pub fn read_obj_from_addr<T: FromBytes>(
&self,
mem: &GuestMemory,
iova: u64,
) -> anyhow::Result<T> {
let mut buf = vec![0u8; std::mem::size_of::<T>()];
self.do_copy(iova, buf.len(), Protection::read(), |offset, gpa, len| {
mem.read_at_addr(&mut buf[offset..(offset + len)], gpa)
.context("failed to read from gpa")
})?;
T::read_from(buf.as_bytes()).context("failed to construct obj")
}
pub fn write_obj_at_addr<T: AsBytes>(
&self,
mem: &GuestMemory,
val: T,
iova: u64,
) -> anyhow::Result<()> {
let buf = val.as_bytes();
self.do_copy(iova, buf.len(), Protection::write(), |offset, gpa, len| {
mem.write_at_addr(&buf[offset..(offset + len)], gpa)
.context("failed to write from gpa")
})?;
Ok(())
}
pub fn is_valid(&self, mem: &GuestMemory, iova: u64, size: u64) -> bool {
let inner = self.inner.lock();
let iova_end = iova.checked_add(size);
if iova_end.is_none() {
return false;
}
if iova < inner.iova || iova_end.unwrap() > (inner.iova + inner.size) {
return false;
}
self.inner
.lock()
.regions
.iter()
.all(|r| mem.range_overlap(r.gpa, r.gpa.unchecked_add(r.len)))
}
pub fn get_mem_regions(&self) -> SmallVec<[MemRegion; 1]> {
SmallVec::from_slice(&self.inner.lock().regions)
}
}
#[cfg(test)]
mod tests {
use std::thread;
use base::Protection;
use vm_memory::GuestAddress;
use super::*;
#[test]
fn test() {
let (request_tx, request_rx) = Tube::pair().expect("failed to create tube pair");
let CreateIpcMapperRet {
mut mapper,
response_tx,
} = create_ipc_mapper(3, request_tx);
let user_handle = thread::spawn(move || {
assert!(mapper
.export(0x555, 1)
.unwrap()
.iter()
.zip(&vec![MemRegion {
gpa: GuestAddress(0x777),
len: 1,
prot: Protection::read_write(),
},])
.all(|(a, b)| a == b));
});
let iommu_handle = thread::spawn(move || {
let (endpoint_id, iova, size) = match request_rx.recv().unwrap() {
IommuRequest::Export {
endpoint_id,
iova,
size,
} => (endpoint_id, iova, size),
_ => unreachable!(),
};
assert_eq!(endpoint_id, 3);
assert_eq!(iova, 0x555);
assert_eq!(size, 1);
response_tx
.send(&IommuResponse::Export(vec![MemRegion {
gpa: GuestAddress(0x777),
len: 1,
prot: Protection::read_write(),
}]))
.unwrap();
user_handle.join().unwrap();
});
iommu_handle.join().unwrap();
}
}