use std::fs::read;
use std::fs::write;
use std::fs::File;
use std::fs::OpenOptions;
use std::os::unix::fs::FileExt;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::thread;
use anyhow::anyhow;
use anyhow::bail;
use anyhow::Context;
use anyhow::Result;
use base::error;
use base::Tube;
use sync::Mutex;
use vm_control::HotPlugDeviceInfo;
use vm_control::HotPlugDeviceType;
use vm_control::VmRequest;
use vm_control::VmResponse;
use zerocopy::FromBytes;
use zerocopy::IntoBytes;
use crate::pci::pci_configuration::PciBridgeSubclass;
use crate::pci::pci_configuration::CAPABILITY_LIST_HEAD_OFFSET;
use crate::pci::pci_configuration::HEADER_TYPE_REG;
use crate::pci::pci_configuration::PCI_CAP_NEXT_POINTER;
use crate::pci::pcie::pci_bridge::PciBridgeBusRange;
use crate::pci::pcie::pci_bridge::BR_BUS_NUMBER_REG;
use crate::pci::pcie::pci_bridge::BR_MEM_BASE_MASK;
use crate::pci::pcie::pci_bridge::BR_MEM_BASE_SHIFT;
use crate::pci::pcie::pci_bridge::BR_MEM_LIMIT_MASK;
use crate::pci::pcie::pci_bridge::BR_MEM_MINIMUM;
use crate::pci::pcie::pci_bridge::BR_MEM_REG;
use crate::pci::pcie::pci_bridge::BR_PREF_MEM_64BIT;
use crate::pci::pcie::pci_bridge::BR_PREF_MEM_BASE_HIGH_REG;
use crate::pci::pcie::pci_bridge::BR_PREF_MEM_LIMIT_HIGH_REG;
use crate::pci::pcie::pci_bridge::BR_PREF_MEM_LOW_REG;
use crate::pci::pcie::pci_bridge::BR_WINDOW_ALIGNMENT;
use crate::pci::pcie::PcieDevicePortType;
use crate::pci::PciCapabilityID;
use crate::pci::PciClassCode;
struct PciHostConfig {
    config_file: File,
}
impl PciHostConfig {
    fn new(host_sysfs_path: &Path) -> Result<Self> {
        let mut config_path = PathBuf::new();
        config_path.push(host_sysfs_path);
        config_path.push("config");
        let f = OpenOptions::new()
            .write(true)
            .read(true)
            .open(config_path.as_path())
            .with_context(|| format!("failed to open: {}", config_path.display()))?;
        Ok(PciHostConfig { config_file: f })
    }
    fn read_config<T: IntoBytes + FromBytes + Copy + Default>(&self, offset: u64) -> T {
        let length = std::mem::size_of::<T>();
        let mut val = T::default();
        if offset % length as u64 != 0 {
            error!(
                "read_config, offset {} isn't aligned to length {}",
                offset, length
            );
        } else if let Err(e) = self.config_file.read_exact_at(val.as_mut_bytes(), offset) {
            error!("failed to read host sysfs config: {}", e);
        }
        val
    }
    #[allow(dead_code)]
    fn write_config(&self, offset: u64, data: &[u8]) {
        if offset % data.len() as u64 != 0 {
            error!(
                "write_config, offset {} isn't aligned to length {}",
                offset,
                data.len()
            );
            return;
        }
        if let Err(e) = self.config_file.write_all_at(data, offset) {
            error!("failed to write host sysfs config: {}", e);
        }
    }
}
fn visit_children(dir: &Path, children: &mut Vec<HotPlugDeviceInfo>) -> Result<()> {
    if !dir.is_dir() {
        bail!("{} isn't directory", dir.display());
    }
    let entries = dir
        .read_dir()
        .with_context(|| format!("failed to read dir {}", dir.display()))?;
    let mut devices = Vec::new();
    for entry in entries {
        let sub_dir = match entry {
            Ok(sub) => sub,
            _ => continue,
        };
        if !sub_dir.path().is_dir() {
            continue;
        }
        let name = sub_dir
            .file_name()
            .into_string()
            .map_err(|_| anyhow!("failed to get dir name"))?;
        if name.len() != 12 || !name.starts_with("0000:") {
            continue;
        }
        let child_path = dir.join(name);
        devices.push(child_path);
    }
    devices.reverse();
    let mut iter = devices.iter().peekable();
    while let Some(device) = iter.next() {
        let class_path = device.join("class");
        let class_id = read(class_path.as_path())
            .with_context(|| format!("failed to read {}", class_path.display()))?;
        let hp_interrupt = iter.peek().is_none();
        if !class_id.starts_with("0x0604".as_bytes()) {
            children.push(HotPlugDeviceInfo {
                device_type: HotPlugDeviceType::EndPoint,
                path: device.to_path_buf(),
                hp_interrupt,
            });
            return Ok(());
        } else {
            let host_config = PciHostConfig::new(device)?;
            let mut cap_pointer: u8 = host_config.read_config(CAPABILITY_LIST_HEAD_OFFSET as u64);
            while cap_pointer != 0x0 {
                let cap_id: u8 = host_config.read_config(cap_pointer as u64);
                if cap_id == PciCapabilityID::PciExpress as u8 {
                    break;
                }
                cap_pointer = host_config.read_config(cap_pointer as u64 + 0x1);
            }
            if cap_pointer == 0x0 {
                bail!(
                    "Failed to get pcie express capability for {}",
                    device.display()
                );
            }
            let express_cap_reg: u16 = host_config.read_config(cap_pointer as u64 + 0x2);
            match (express_cap_reg & 0xf0) >> 4 {
                x if x == PcieDevicePortType::UpstreamPort as u16 => {
                    children.push(HotPlugDeviceInfo {
                        device_type: HotPlugDeviceType::UpstreamPort,
                        path: device.to_path_buf(),
                        hp_interrupt,
                    })
                }
                x if x == PcieDevicePortType::DownstreamPort as u16 => {
                    children.push(HotPlugDeviceInfo {
                        device_type: HotPlugDeviceType::DownstreamPort,
                        path: device.to_path_buf(),
                        hp_interrupt,
                    })
                }
                _ => (),
            }
        }
    }
    for device in devices.iter() {
        visit_children(device.as_path(), children)?;
    }
    Ok(())
}
struct HotplugWorker {
    host_name: String,
}
impl HotplugWorker {
    fn run(&self, vm_socket: Arc<Mutex<Tube>>, child_exist: Arc<Mutex<bool>>) -> Result<()> {
        let mut host_sysfs = PathBuf::new();
        host_sysfs.push("/sys/bus/pci/devices/");
        host_sysfs.push(self.host_name.clone());
        let rescan_path = host_sysfs.join("rescan");
        write(rescan_path.as_path(), "1")
            .with_context(|| format!("failed to write {}", rescan_path.display()))?;
        let mut child_exist = child_exist.lock();
        if *child_exist {
            return Ok(());
        }
        let mut children: Vec<HotPlugDeviceInfo> = Vec::new();
        visit_children(host_sysfs.as_path(), &mut children)?;
        children.reverse();
        while let Some(child) = children.pop() {
            if let HotPlugDeviceType::EndPoint = child.device_type {
                let vendor_path = child.path.join("vendor");
                let vendor_id = read(vendor_path.as_path())
                    .with_context(|| format!("failed to read {}", vendor_path.display()))?;
                let prefix: &str = "0x";
                let vendor = match vendor_id.strip_prefix(prefix.as_bytes()) {
                    Some(v) => v.to_vec(),
                    None => vendor_id,
                };
                let device_path = child.path.join("device");
                let device_id = read(device_path.as_path())
                    .with_context(|| format!("failed to read {}", device_path.display()))?;
                let device = match device_id.strip_prefix(prefix.as_bytes()) {
                    Some(d) => d.to_vec(),
                    None => device_id,
                };
                let new_id = [
                    String::from_utf8_lossy(&vendor),
                    String::from_utf8_lossy(&device),
                ]
                .join(" ");
                if Path::new("/sys/bus/pci/drivers/vfio-pci-pm/new_id").exists() {
                    let _ = write("/sys/bus/pci/drivers/vfio-pci-pm/new_id", &new_id);
                }
                if !child.path.join("driver/unbind").exists() {
                    write("/sys/bus/pci/drivers/vfio-pci/new_id", &new_id).with_context(|| {
                        format!("failed to write {} into vfio-pci/new_id", new_id)
                    })?;
                }
            }
            let request = VmRequest::HotPlugVfioCommand {
                device: child.clone(),
                add: true,
            };
            let vm_socket = vm_socket.lock();
            vm_socket
                .send(&request)
                .with_context(|| format!("failed to send hotplug request for {:?}", child))?;
            let response = vm_socket
                .recv::<VmResponse>()
                .with_context(|| format!("failed to receive hotplug response for {:?}", child))?;
            match response {
                VmResponse::Ok => {}
                _ => bail!("unexpected hotplug response: {response}"),
            };
            if !*child_exist {
                *child_exist = true;
            }
        }
        Ok(())
    }
}
const PCI_CONFIG_DEVICE_ID: u64 = 0x02;
const PCI_BASE_CLASS_CODE: u64 = 0x0B;
const PCI_SUB_CLASS_CODE: u64 = 0x0A;
pub struct PcieHostPort {
    host_config: PciHostConfig,
    host_name: String,
    hotplug_in_process: Arc<Mutex<bool>>,
    hotplug_child_exist: Arc<Mutex<bool>>,
    vm_socket: Arc<Mutex<Tube>>,
}
impl PcieHostPort {
    pub fn new(host_sysfs_path: &Path, socket: Tube) -> Result<Self> {
        let host_config = PciHostConfig::new(host_sysfs_path)?;
        let host_name = host_sysfs_path
            .file_name()
            .unwrap()
            .to_str()
            .unwrap()
            .to_owned();
        let base_class: u8 = host_config.read_config(PCI_BASE_CLASS_CODE);
        if base_class != PciClassCode::BridgeDevice.get_register_value() {
            return Err(anyhow!("host {} isn't bridge", host_name));
        }
        let sub_class: u8 = host_config.read_config(PCI_SUB_CLASS_CODE);
        if sub_class != PciBridgeSubclass::PciToPciBridge as u8 {
            return Err(anyhow!("host {} isn't pci to pci bridge", host_name));
        }
        let mut pcie_cap_reg: u8 = 0;
        let mut cap_next: u8 = host_config.read_config(CAPABILITY_LIST_HEAD_OFFSET as u64);
        let mut counter: u16 = 0;
        while cap_next != 0 && counter < 256 {
            let cap_id: u8 = host_config.read_config(cap_next.into());
            if cap_id == PciCapabilityID::PciExpress as u8 {
                pcie_cap_reg = cap_next;
                break;
            }
            let offset = cap_next as u64 + PCI_CAP_NEXT_POINTER as u64;
            cap_next = host_config.read_config(offset);
            counter += 1;
        }
        if pcie_cap_reg == 0 {
            return Err(anyhow!("host {} isn't pcie device", host_name));
        }
        Ok(PcieHostPort {
            host_config,
            host_name,
            hotplug_in_process: Arc::new(Mutex::new(false)),
            hotplug_child_exist: Arc::new(Mutex::new(false)),
            vm_socket: Arc::new(Mutex::new(socket)),
        })
    }
    pub fn get_bus_range(&self) -> PciBridgeBusRange {
        let bus_num: u32 = self.host_config.read_config((BR_BUS_NUMBER_REG * 4) as u64);
        let primary = (bus_num & 0xFF) as u8;
        let secondary = ((bus_num >> 8) & 0xFF) as u8;
        let subordinate = ((bus_num >> 16) & 0xFF) as u8;
        PciBridgeBusRange {
            primary,
            secondary,
            subordinate,
        }
    }
    pub fn read_device_id(&self) -> u16 {
        self.host_config.read_config::<u16>(PCI_CONFIG_DEVICE_ID)
    }
    pub fn host_name(&self) -> String {
        self.host_name.clone()
    }
    pub fn read_config(&self, reg_idx: usize, data: &mut u32) {
        if reg_idx == HEADER_TYPE_REG {
            *data = self.host_config.read_config((HEADER_TYPE_REG as u64) * 4)
        }
    }
    pub fn write_config(&mut self, _reg_idx: usize, _offset: u64, _data: &[u8]) {}
    pub fn get_bridge_window_size(&self) -> (u64, u64) {
        let br_memory: u32 = self.host_config.read_config(BR_MEM_REG as u64 * 4);
        let mem_base = (br_memory & BR_MEM_BASE_MASK) << BR_MEM_BASE_SHIFT;
        let mem_limit = br_memory & BR_MEM_LIMIT_MASK;
        let mem_size = if mem_limit > mem_base {
            (mem_limit - mem_base) as u64 + BR_WINDOW_ALIGNMENT
        } else {
            BR_MEM_MINIMUM
        };
        let br_pref_mem_low: u32 = self.host_config.read_config(BR_PREF_MEM_LOW_REG as u64 * 4);
        let pref_mem_base_low = (br_pref_mem_low & BR_MEM_BASE_MASK) << BR_MEM_BASE_SHIFT;
        let pref_mem_limit_low = br_pref_mem_low & BR_MEM_LIMIT_MASK;
        let mut pref_mem_base: u64 = pref_mem_base_low as u64;
        let mut pref_mem_limit: u64 = pref_mem_limit_low as u64;
        if br_pref_mem_low & BR_PREF_MEM_64BIT == BR_PREF_MEM_64BIT {
            let pref_mem_base_high: u32 = self
                .host_config
                .read_config(BR_PREF_MEM_BASE_HIGH_REG as u64 * 4);
            let pref_mem_limit_high: u32 = self
                .host_config
                .read_config(BR_PREF_MEM_LIMIT_HIGH_REG as u64 * 4);
            pref_mem_base = ((pref_mem_base_high as u64) << 32) | (pref_mem_base_low as u64);
            pref_mem_limit = ((pref_mem_limit_high as u64) << 32) | (pref_mem_limit_low as u64);
        }
        let pref_mem_size = if pref_mem_limit > pref_mem_base {
            pref_mem_limit - pref_mem_base + BR_WINDOW_ALIGNMENT
        } else {
            BR_MEM_MINIMUM
        };
        (mem_size, pref_mem_size)
    }
    pub fn hotplug_probe(&mut self) {
        if *self.hotplug_in_process.lock() {
            return;
        }
        let hotplug_process = self.hotplug_in_process.clone();
        let child_exist = self.hotplug_child_exist.clone();
        let socket = self.vm_socket.clone();
        let name = self.host_name.clone();
        let _ = thread::Builder::new()
            .name("pcie_hotplug".to_string())
            .spawn(move || {
                let mut hotplug = hotplug_process.lock();
                *hotplug = true;
                let hotplug_worker = HotplugWorker { host_name: name };
                let _ = hotplug_worker.run(socket, child_exist);
                *hotplug = false;
            });
    }
    pub fn hot_unplug(&mut self) {
        *self.hotplug_child_exist.lock() = false;
    }
}