use std::sync::Arc;
use anyhow::anyhow;
use anyhow::bail;
use anyhow::Context;
use anyhow::Result;
use base::error;
use base::AsRawDescriptor;
use base::AsRawDescriptors;
use base::Event;
use base::Protection;
use base::RawDescriptor;
use base::Tube;
use serde::Deserialize;
use serde::Serialize;
use smallvec::SmallVec;
use sync::Mutex;
use vm_memory::GuestAddress;
use vm_memory::GuestMemory;
use zerocopy::FromBytes;
use zerocopy::FromZeros;
use zerocopy::Immutable;
use zerocopy::IntoBytes;
use crate::virtio::memory_mapper::MemRegion;
#[derive(Serialize, Deserialize)]
pub(super) enum IommuRequest {
    Export {
        endpoint_id: u32,
        iova: u64,
        size: u64,
    },
    Release {
        endpoint_id: u32,
        iova: u64,
        size: u64,
    },
    StartExportSession {
        endpoint_id: u32,
    },
}
#[derive(Serialize, Deserialize)]
pub(super) enum IommuResponse {
    Export(Vec<MemRegion>),
    Release,
    StartExportSession(Event),
    Err(String),
}
impl IommuRequest {
    pub(super) fn get_endpoint_id(&self) -> u32 {
        match self {
            Self::Export { endpoint_id, .. } => *endpoint_id,
            Self::Release { endpoint_id, .. } => *endpoint_id,
            Self::StartExportSession { endpoint_id } => *endpoint_id,
        }
    }
}
pub struct IpcMemoryMapper {
    request_tx: Tube,
    response_rx: Tube,
    endpoint_id: u32,
}
impl std::fmt::Debug for IpcMemoryMapper {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        f.debug_struct("IpcMemoryMapper")
            .field("endpoint_id", &self.endpoint_id)
            .finish()
    }
}
fn map_bad_resp(resp: IommuResponse) -> anyhow::Error {
    match resp {
        IommuResponse::Err(e) => anyhow!("remote error {}", e),
        _ => anyhow!("response type mismatch"),
    }
}
impl IpcMemoryMapper {
    pub fn new(request_tx: Tube, response_rx: Tube, endpoint_id: u32) -> Self {
        Self {
            request_tx,
            response_rx,
            endpoint_id,
        }
    }
    fn do_request(&self, req: IommuRequest) -> Result<IommuResponse> {
        self.request_tx
            .send(&req)
            .context("failed to send request")?;
        self.response_rx
            .recv::<IommuResponse>()
            .context("failed to get response")
    }
    pub fn export(&mut self, iova: u64, size: u64) -> Result<Vec<MemRegion>> {
        let req = IommuRequest::Export {
            endpoint_id: self.endpoint_id,
            iova,
            size,
        };
        match self.do_request(req)? {
            IommuResponse::Export(vec) => Ok(vec),
            e => Err(map_bad_resp(e)),
        }
    }
    pub fn release(&mut self, iova: u64, size: u64) -> Result<()> {
        let req = IommuRequest::Release {
            endpoint_id: self.endpoint_id,
            iova,
            size,
        };
        match self.do_request(req)? {
            IommuResponse::Release => Ok(()),
            e => Err(map_bad_resp(e)),
        }
    }
    pub fn start_export_session(&mut self) -> Result<Event> {
        let req = IommuRequest::StartExportSession {
            endpoint_id: self.endpoint_id,
        };
        match self.do_request(req)? {
            IommuResponse::StartExportSession(evt) => Ok(evt),
            e => Err(map_bad_resp(e)),
        }
    }
}
impl AsRawDescriptors for IpcMemoryMapper {
    fn as_raw_descriptors(&self) -> Vec<RawDescriptor> {
        vec![
            self.request_tx.as_raw_descriptor(),
            self.response_rx.as_raw_descriptor(),
        ]
    }
}
pub struct CreateIpcMapperRet {
    pub mapper: IpcMemoryMapper,
    pub response_tx: Tube,
}
pub fn create_ipc_mapper(endpoint_id: u32, request_tx: Tube) -> CreateIpcMapperRet {
    let (response_tx, response_rx) = Tube::pair().expect("failed to create tube pair");
    CreateIpcMapperRet {
        mapper: IpcMemoryMapper::new(request_tx, response_rx, endpoint_id),
        response_tx,
    }
}
#[derive(Debug)]
struct ExportedRegionInner {
    regions: Vec<MemRegion>,
    iova: u64,
    size: u64,
    iommu: Arc<Mutex<IpcMemoryMapper>>,
}
impl Drop for ExportedRegionInner {
    fn drop(&mut self) {
        if let Err(e) = self.iommu.lock().release(self.iova, self.size) {
            error!("Error releasing region {:?}", e);
        }
    }
}
#[derive(Clone, Debug)]
pub struct ExportedRegion {
    inner: Arc<Mutex<ExportedRegionInner>>,
}
impl ExportedRegion {
    pub fn new(
        mem: &GuestMemory,
        iommu: Arc<Mutex<IpcMemoryMapper>>,
        iova: u64,
        size: u64,
    ) -> Result<Self> {
        let regions = iommu
            .lock()
            .export(iova, size)
            .context("failed to export")?;
        for r in ®ions {
            if !mem.is_valid_range(r.gpa, r.len) {
                bail!("region not in memory range");
            }
        }
        Ok(Self {
            inner: Arc::new(Mutex::new(ExportedRegionInner {
                regions,
                iova,
                size,
                iommu,
            })),
        })
    }
    fn do_copy<C>(
        &self,
        iova: u64,
        mut remaining: usize,
        prot: Protection,
        mut copy_fn: C,
    ) -> Result<()>
    where
        C: FnMut(usize , GuestAddress, usize ) -> Result<usize>,
    {
        let inner = self.inner.lock();
        let mut region_offset = iova.checked_sub(inner.iova).with_context(|| {
            format!(
                "out of bounds: src_iova={} region_iova={}",
                iova, inner.iova
            )
        })?;
        let mut offset = 0;
        for r in &inner.regions {
            if region_offset >= r.len {
                region_offset -= r.len;
                continue;
            }
            if !r.prot.allows(&prot) {
                bail!("gpa is not accessible");
            }
            let len = (r.len as usize).min(remaining);
            let copy_len = copy_fn(offset, r.gpa.unchecked_add(region_offset), len)?;
            if len != copy_len {
                bail!("incomplete copy: expected={}, actual={}", len, copy_len);
            }
            remaining -= len;
            offset += len;
            region_offset = 0;
            if remaining == 0 {
                return Ok(());
            }
        }
        Err(anyhow!("not enough data: remaining={}", remaining))
    }
    pub fn read_obj_from_addr<T: IntoBytes + FromBytes + FromZeros>(
        &self,
        mem: &GuestMemory,
        iova: u64,
    ) -> anyhow::Result<T> {
        let mut val = T::new_zeroed();
        let buf = val.as_mut_bytes();
        self.do_copy(iova, buf.len(), Protection::read(), |offset, gpa, len| {
            mem.read_at_addr(&mut buf[offset..(offset + len)], gpa)
                .context("failed to read from gpa")
        })?;
        Ok(val)
    }
    pub fn write_obj_at_addr<T: Immutable + IntoBytes>(
        &self,
        mem: &GuestMemory,
        val: T,
        iova: u64,
    ) -> anyhow::Result<()> {
        let buf = val.as_bytes();
        self.do_copy(iova, buf.len(), Protection::write(), |offset, gpa, len| {
            mem.write_at_addr(&buf[offset..(offset + len)], gpa)
                .context("failed to write from gpa")
        })?;
        Ok(())
    }
    pub fn is_valid(&self, mem: &GuestMemory, iova: u64, size: u64) -> bool {
        let inner = self.inner.lock();
        let iova_end = iova.checked_add(size);
        if iova_end.is_none() {
            return false;
        }
        if iova < inner.iova || iova_end.unwrap() > (inner.iova + inner.size) {
            return false;
        }
        self.inner
            .lock()
            .regions
            .iter()
            .all(|r| mem.range_overlap(r.gpa, r.gpa.unchecked_add(r.len)))
    }
    pub fn get_mem_regions(&self) -> SmallVec<[MemRegion; 1]> {
        SmallVec::from_slice(&self.inner.lock().regions)
    }
}
#[cfg(test)]
mod tests {
    use std::thread;
    use base::Protection;
    use vm_memory::GuestAddress;
    use super::*;
    #[test]
    fn test() {
        let (request_tx, request_rx) = Tube::pair().expect("failed to create tube pair");
        let CreateIpcMapperRet {
            mut mapper,
            response_tx,
        } = create_ipc_mapper(3, request_tx);
        let user_handle = thread::spawn(move || {
            assert!(mapper
                .export(0x555, 1)
                .unwrap()
                .iter()
                .zip(&vec![MemRegion {
                    gpa: GuestAddress(0x777),
                    len: 1,
                    prot: Protection::read_write(),
                },])
                .all(|(a, b)| a == b));
        });
        let iommu_handle = thread::spawn(move || {
            let (endpoint_id, iova, size) = match request_rx.recv().unwrap() {
                IommuRequest::Export {
                    endpoint_id,
                    iova,
                    size,
                } => (endpoint_id, iova, size),
                _ => unreachable!(),
            };
            assert_eq!(endpoint_id, 3);
            assert_eq!(iova, 0x555);
            assert_eq!(size, 1);
            response_tx
                .send(&IommuResponse::Export(vec![MemRegion {
                    gpa: GuestAddress(0x777),
                    len: 1,
                    prot: Protection::read_write(),
                }]))
                .unwrap();
            user_handle.join().unwrap();
        });
        iommu_handle.join().unwrap();
    }
}