#![deny(missing_docs)]
use std::num::Wrapping;
use std::sync::atomic::fence;
use std::sync::atomic::AtomicU16;
use std::sync::atomic::Ordering;
use anyhow::bail;
use anyhow::Result;
use base::error;
use base::warn;
use base::Event;
use serde::Deserialize;
use serde::Serialize;
use snapshot::AnySnapshot;
use virtio_sys::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
use vm_memory::GuestAddress;
use vm_memory::GuestMemory;
use crate::virtio::descriptor_chain::DescriptorChain;
use crate::virtio::descriptor_chain::VIRTQ_DESC_F_AVAIL;
use crate::virtio::descriptor_chain::VIRTQ_DESC_F_USED;
use crate::virtio::descriptor_chain::VIRTQ_DESC_F_WRITE;
use crate::virtio::queue::packed_descriptor_chain::PackedDesc;
use crate::virtio::queue::packed_descriptor_chain::PackedDescEvent;
use crate::virtio::queue::packed_descriptor_chain::PackedDescriptorChain;
use crate::virtio::queue::packed_descriptor_chain::PackedNotificationType;
use crate::virtio::queue::packed_descriptor_chain::RING_EVENT_FLAGS_DESC;
use crate::virtio::Interrupt;
use crate::virtio::QueueConfig;
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
struct PackedQueueIndex {
    wrap_counter: bool,
    index: Wrapping<u16>,
}
impl PackedQueueIndex {
    pub fn new(wrap_counter: bool, index: u16) -> Self {
        Self {
            wrap_counter,
            index: Wrapping(index),
        }
    }
    pub fn new_from_desc(desc: u16) -> Self {
        let wrap_counter: bool = (desc >> 15) == 1;
        let mask: u16 = 0x7fff;
        let index = desc & mask;
        Self::new(wrap_counter, index)
    }
    pub fn to_desc(self) -> PackedDescEvent {
        let flag = RING_EVENT_FLAGS_DESC;
        let mut desc = self.index.0;
        if self.wrap_counter {
            desc |= 1 << 15;
        }
        PackedDescEvent {
            desc: desc.into(),
            flag: flag.into(),
        }
    }
    fn add_index(&mut self, index_value: u16, size: u16) {
        let new_index = self.index.0 + index_value;
        if new_index < size {
            self.index = Wrapping(new_index);
        } else {
            self.index = Wrapping(new_index - size);
            self.wrap_counter = !self.wrap_counter;
        }
    }
}
impl Default for PackedQueueIndex {
    fn default() -> Self {
        Self::new(true, 0)
    }
}
#[derive(Debug)]
pub struct PackedQueue {
    mem: GuestMemory,
    event: Event,
    interrupt: Interrupt,
    size: u16,
    vector: u16,
    avail_index: PackedQueueIndex,
    use_index: PackedQueueIndex,
    signalled_used_index: PackedQueueIndex,
    features: u64,
    desc_table: GuestAddress,
    device_event_suppression: GuestAddress,
    driver_event_suppression: GuestAddress,
}
#[derive(Serialize, Deserialize)]
pub struct PackedQueueSnapshot {
    size: u16,
    vector: u16,
    avail_index: PackedQueueIndex,
    use_index: PackedQueueIndex,
    signalled_used_index: PackedQueueIndex,
    features: u64,
    desc_table: GuestAddress,
    device_event_suppression: GuestAddress,
    driver_event_suppression: GuestAddress,
}
impl PackedQueue {
    pub fn new(
        config: &QueueConfig,
        mem: &GuestMemory,
        event: Event,
        interrupt: Interrupt,
    ) -> Result<Self> {
        let size = config.size();
        let desc_table = config.desc_table();
        let driver_area = config.avail_ring();
        let device_area = config.used_ring();
        let ring_sizes = Self::area_sizes(size, desc_table, driver_area, device_area);
        let rings = ring_sizes.iter().zip(vec![
            "descriptor table",
            "driver_event_suppression",
            "device_event_suppression",
        ]);
        for ((addr, size), name) in rings {
            if addr.checked_add(*size as u64).is_none() {
                bail!(
                    "virtio queue {} goes out of bounds: start:0x{:08x} size:0x{:08x}",
                    name,
                    addr.offset(),
                    size,
                );
            }
        }
        Ok(PackedQueue {
            mem: mem.clone(),
            event,
            interrupt,
            size,
            vector: config.vector(),
            desc_table: config.desc_table(),
            driver_event_suppression: config.avail_ring(),
            device_event_suppression: config.used_ring(),
            features: config.acked_features(),
            avail_index: PackedQueueIndex::default(),
            use_index: PackedQueueIndex::default(),
            signalled_used_index: PackedQueueIndex::default(),
        })
    }
    pub fn vhost_user_reclaim(&mut self, _vring_base: u16) {
        unimplemented!()
    }
    pub fn next_avail_to_process(&self) -> u16 {
        self.avail_index.index.0
    }
    pub fn size(&self) -> u16 {
        self.size
    }
    pub fn vector(&self) -> u16 {
        self.vector
    }
    pub fn desc_table(&self) -> GuestAddress {
        self.desc_table
    }
    pub fn avail_ring(&self) -> GuestAddress {
        self.driver_event_suppression
    }
    pub fn used_ring(&self) -> GuestAddress {
        self.device_event_suppression
    }
    pub fn event(&self) -> &Event {
        &self.event
    }
    pub fn interrupt(&self) -> &Interrupt {
        &self.interrupt
    }
    fn area_sizes(
        queue_size: u16,
        desc_table: GuestAddress,
        driver_area: GuestAddress,
        device_area: GuestAddress,
    ) -> Vec<(GuestAddress, usize)> {
        vec![
            (desc_table, 16 * queue_size as usize),
            (driver_area, 4),
            (device_area, 4),
        ]
    }
    fn set_avail_event(&mut self, event: PackedDescEvent) {
        fence(Ordering::SeqCst);
        self.mem
            .write_obj_at_addr_volatile(event, self.device_event_suppression)
            .unwrap();
    }
    fn get_driver_event(&self) -> PackedDescEvent {
        fence(Ordering::SeqCst);
        let desc: PackedDescEvent = self
            .mem
            .read_obj_from_addr_volatile(self.driver_event_suppression)
            .unwrap();
        desc
    }
    pub fn peek(&mut self) -> Option<DescriptorChain> {
        let desc_addr = self
            .desc_table
            .checked_add((self.avail_index.index.0 as u64) * 16)
            .expect("peeked address will not overflow");
        let desc = self
            .mem
            .read_obj_from_addr::<PackedDesc>(desc_addr)
            .inspect_err(|_e| {
                error!("failed to read desc {:#x}", desc_addr.offset());
            })
            .ok()?;
        if !desc.is_available(self.avail_index.wrap_counter as u16) {
            return None;
        }
        fence(Ordering::SeqCst);
        let chain = PackedDescriptorChain::new(
            &self.mem,
            self.desc_table,
            self.size,
            self.avail_index.wrap_counter,
            self.avail_index.index.0,
        );
        match DescriptorChain::new(chain, &self.mem, self.avail_index.index.0) {
            Ok(descriptor_chain) => Some(descriptor_chain),
            Err(e) => {
                error!("{:#}", e);
                None
            }
        }
    }
    pub(super) fn pop_peeked(&mut self, descriptor_chain: &DescriptorChain) {
        self.avail_index
            .add_index(descriptor_chain.count, self.size());
        if self.features & ((1u64) << VIRTIO_RING_F_EVENT_IDX) != 0 {
            self.set_avail_event(self.avail_index.to_desc());
        }
    }
    pub fn add_used_with_bytes_written_batch(
        &mut self,
        desc_chains: impl IntoIterator<Item = (DescriptorChain, u32)>,
    ) {
        let desc_table_size = size_of::<PackedDesc>() * usize::from(self.size);
        let desc_table_vslice = self
            .mem
            .get_slice_at_addr(self.desc_table, desc_table_size)
            .unwrap();
        let desc_table_ptr = desc_table_vslice.as_mut_ptr() as *mut PackedDesc;
        for (desc_chain, len) in desc_chains {
            debug_assert!(desc_chain.index() < self.size);
            let chain_id = desc_chain
                .id
                .expect("Packed descriptor chain should have id");
            let wrap_counter = self.use_index.wrap_counter;
            let mut flags: u16 = 0;
            if wrap_counter {
                flags = flags | VIRTQ_DESC_F_USED | VIRTQ_DESC_F_AVAIL;
            }
            if len > 0 {
                flags |= VIRTQ_DESC_F_WRITE;
            }
            let desc_ptr = unsafe { desc_table_ptr.add(usize::from(self.use_index.index.0)) };
            unsafe {
                std::ptr::write_volatile(std::ptr::addr_of_mut!((*desc_ptr).len), len.into());
                std::ptr::write_volatile(std::ptr::addr_of_mut!((*desc_ptr).id), chain_id.into());
            }
            fence(Ordering::Release);
            let desc_flags_atomic = unsafe {
                AtomicU16::from_ptr(std::ptr::addr_of_mut!((*desc_ptr).flags) as *mut u16)
            };
            desc_flags_atomic.store(u16::to_le(flags), Ordering::Relaxed);
            self.use_index.add_index(desc_chain.count, self.size());
        }
    }
    fn queue_wants_interrupt(&mut self) -> bool {
        let driver_event = self.get_driver_event();
        match driver_event.notification_type() {
            PackedNotificationType::Enable => true,
            PackedNotificationType::Disable => false,
            PackedNotificationType::Desc(desc) => {
                if self.features & ((1u64) << VIRTIO_RING_F_EVENT_IDX) == 0 {
                    warn!("This is undefined behavior. We should actually send error in this case");
                    return true;
                }
                let old = self.signalled_used_index;
                self.signalled_used_index = self.use_index;
                let event_index: PackedQueueIndex = PackedQueueIndex::new_from_desc(desc);
                let event_idx = event_index.index;
                let old_idx = old.index;
                let new_idx = self.use_index.index;
                (new_idx - event_idx - Wrapping(1)) < (new_idx - old_idx)
            }
        };
        true
    }
    pub fn trigger_interrupt(&mut self) -> bool {
        if self.queue_wants_interrupt() {
            self.interrupt.signal_used_queue(self.vector);
            true
        } else {
            false
        }
    }
    pub fn ack_features(&mut self, features: u64) {
        self.features |= features;
    }
    pub fn snapshot(&self) -> Result<AnySnapshot> {
        bail!("Snapshot for packed virtqueue not implemented.");
    }
    pub fn restore(
        _queue_value: AnySnapshot,
        _mem: &GuestMemory,
        _event: Event,
        _interrupt: Interrupt,
    ) -> Result<PackedQueue> {
        bail!("Restore for packed virtqueue not implemented.");
    }
}