#![deny(missing_docs)]
use anyhow::bail;
use anyhow::Context;
use anyhow::Result;
use base::trace;
use data_model::Le16;
use data_model::Le32;
use data_model::Le64;
use vm_memory::GuestAddress;
use vm_memory::GuestMemory;
use zerocopy::FromBytes;
use zerocopy::Immutable;
use zerocopy::IntoBytes;
use zerocopy::KnownLayout;
use crate::virtio::descriptor_chain::Descriptor;
use crate::virtio::descriptor_chain::DescriptorAccess;
use crate::virtio::descriptor_chain::DescriptorChainIter;
use crate::virtio::descriptor_chain::VIRTQ_DESC_F_NEXT;
use crate::virtio::descriptor_chain::VIRTQ_DESC_F_WRITE;
#[derive(Copy, Clone, Debug, FromBytes, Immutable, IntoBytes, KnownLayout)]
#[repr(C)]
pub struct Desc {
    pub addr: Le64,
    pub len: Le32,
    pub flags: Le16,
    pub next: Le16,
}
pub struct SplitDescriptorChain<'m> {
    index: Option<u16>,
    count: u16,
    queue_size: u16,
    mem: &'m GuestMemory,
    desc_table: GuestAddress,
}
impl<'m> SplitDescriptorChain<'m> {
    pub fn new(
        mem: &'m GuestMemory,
        desc_table: GuestAddress,
        queue_size: u16,
        index: u16,
    ) -> SplitDescriptorChain<'m> {
        trace!("starting split descriptor chain head={index}");
        SplitDescriptorChain {
            index: Some(index),
            count: 0,
            queue_size,
            mem,
            desc_table,
        }
    }
}
impl DescriptorChainIter for SplitDescriptorChain<'_> {
    fn next(&mut self) -> Result<Option<Descriptor>> {
        let index = match self.index {
            Some(index) => index,
            None => return Ok(None),
        };
        if index >= self.queue_size {
            bail!(
                "out of bounds descriptor index {} for queue size {}",
                index,
                self.queue_size
            );
        }
        if self.count >= self.queue_size {
            bail!("descriptor chain loop detected");
        }
        self.count += 1;
        let desc_addr = self
            .desc_table
            .checked_add((index as u64) * 16)
            .context("integer overflow")?;
        let desc = self
            .mem
            .read_obj_from_addr::<Desc>(desc_addr)
            .with_context(|| format!("failed to read desc {:#x}", desc_addr.offset()))?;
        let address: u64 = desc.addr.into();
        let len: u32 = desc.len.into();
        let flags: u16 = desc.flags.into();
        let next: u16 = desc.next.into();
        trace!("{index:5}: addr={address:#016x} len={len:#08x} flags={flags:#x}");
        let unexpected_flags = flags & !(VIRTQ_DESC_F_WRITE | VIRTQ_DESC_F_NEXT);
        if unexpected_flags != 0 {
            bail!("unexpected flags in descriptor {index}: {unexpected_flags:#x}")
        }
        let access = if flags & VIRTQ_DESC_F_WRITE != 0 {
            DescriptorAccess::DeviceWrite
        } else {
            DescriptorAccess::DeviceRead
        };
        self.index = if flags & VIRTQ_DESC_F_NEXT != 0 {
            Some(next)
        } else {
            None
        };
        Ok(Some(Descriptor {
            address,
            len,
            access,
        }))
    }
    fn count(&self) -> u16 {
        self.count
    }
    fn id(&self) -> Option<u16> {
        None
    }
}