use std::fmt;
use std::io;
use std::mem;
use std::mem::size_of;
use std::num::ParseIntError;
use std::os::raw::c_uchar;
use std::os::raw::c_uint;
use std::os::raw::c_ushort;
use std::os::unix::io::AsRawFd;
use std::os::unix::io::IntoRawFd;
use std::os::unix::io::RawFd;
use std::result;
use std::str::FromStr;
use libc::c_void;
use libc::sa_family_t;
use libc::size_t;
use libc::sockaddr;
use libc::socklen_t;
use libc::F_GETFL;
use libc::F_SETFL;
use libc::O_NONBLOCK;
use libc::VMADDR_CID_ANY;
use libc::VMADDR_CID_HOST;
use libc::VMADDR_CID_HYPERVISOR;
use thiserror::Error;
const AF_VSOCK: sa_family_t = 40;
const VMADDR_CID_LOCAL: c_uint = 1;
pub const VMADDR_PORT_ANY: c_uint = c_uint::MAX;
const PADDING: usize = size_of::<sockaddr>()
    - size_of::<sa_family_t>()
    - size_of::<c_ushort>()
    - (2 * size_of::<c_uint>());
#[repr(C)]
#[derive(Default)]
struct sockaddr_vm {
    svm_family: sa_family_t,
    svm_reserved1: c_ushort,
    svm_port: c_uint,
    svm_cid: c_uint,
    svm_zero: [c_uchar; PADDING],
}
#[derive(Error, Debug)]
#[error("failed to parse vsock address")]
pub struct AddrParseError;
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub enum VsockCid {
    Any,
    Hypervisor,
    Local,
    Host,
    Cid(c_uint),
}
impl fmt::Display for VsockCid {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        match &self {
            VsockCid::Any => write!(fmt, "Any"),
            VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
            VsockCid::Local => write!(fmt, "Local"),
            VsockCid::Host => write!(fmt, "Host"),
            VsockCid::Cid(c) => write!(fmt, "'{}'", c),
        }
    }
}
impl From<c_uint> for VsockCid {
    fn from(c: c_uint) -> Self {
        match c {
            VMADDR_CID_ANY => VsockCid::Any,
            VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
            VMADDR_CID_LOCAL => VsockCid::Local,
            VMADDR_CID_HOST => VsockCid::Host,
            _ => VsockCid::Cid(c),
        }
    }
}
impl FromStr for VsockCid {
    type Err = ParseIntError;
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let c: c_uint = s.parse()?;
        Ok(c.into())
    }
}
impl From<VsockCid> for c_uint {
    fn from(cid: VsockCid) -> c_uint {
        match cid {
            VsockCid::Any => VMADDR_CID_ANY,
            VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
            VsockCid::Local => VMADDR_CID_LOCAL,
            VsockCid::Host => VMADDR_CID_HOST,
            VsockCid::Cid(c) => c,
        }
    }
}
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct SocketAddr {
    pub cid: VsockCid,
    pub port: c_uint,
}
pub trait ToSocketAddr {
    fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
}
impl ToSocketAddr for SocketAddr {
    fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
        Ok(*self)
    }
}
impl ToSocketAddr for str {
    fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
        self.parse()
    }
}
impl ToSocketAddr for (VsockCid, c_uint) {
    fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
        let (cid, port) = *self;
        Ok(SocketAddr { cid, port })
    }
}
impl<T: ToSocketAddr + ?Sized> ToSocketAddr for &T {
    fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
        (**self).to_socket_addr()
    }
}
impl FromStr for SocketAddr {
    type Err = AddrParseError;
    fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
        let components: Vec<&str> = s.split(':').collect();
        if components.len() != 3 || components[0] != "vsock" {
            return Err(AddrParseError);
        }
        Ok(SocketAddr {
            cid: components[1].parse().map_err(|_| AddrParseError)?,
            port: components[2].parse().map_err(|_| AddrParseError)?,
        })
    }
}
impl fmt::Display for SocketAddr {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        write!(fmt, "{}:{}", self.cid, self.port)
    }
}
unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
    let flags = libc::fcntl(fd, F_GETFL, 0);
    if flags < 0 {
        return Err(io::Error::last_os_error());
    }
    let flags = if nonblocking {
        flags | O_NONBLOCK
    } else {
        flags & !O_NONBLOCK
    };
    let ret = libc::fcntl(fd, F_SETFL, flags);
    if ret < 0 {
        return Err(io::Error::last_os_error());
    }
    Ok(())
}
#[derive(Debug)]
pub struct VsockSocket {
    fd: RawFd,
}
impl VsockSocket {
    pub fn new() -> io::Result<Self> {
        let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
        if fd < 0 {
            Err(io::Error::last_os_error())
        } else {
            Ok(VsockSocket { fd })
        }
    }
    pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
        let sockaddr = addr
            .to_socket_addr()
            .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
        assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
        let svm = sockaddr_vm {
            svm_family: AF_VSOCK,
            svm_cid: sockaddr.cid.into(),
            svm_port: sockaddr.port,
            ..Default::default()
        };
        let ret = unsafe {
            libc::bind(
                self.fd,
                &svm as *const sockaddr_vm as *const sockaddr,
                size_of::<sockaddr_vm>() as socklen_t,
            )
        };
        if ret < 0 {
            let bind_err = io::Error::last_os_error();
            Err(bind_err)
        } else {
            Ok(())
        }
    }
    pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
        let sockaddr = addr
            .to_socket_addr()
            .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
        let svm = sockaddr_vm {
            svm_family: AF_VSOCK,
            svm_cid: sockaddr.cid.into(),
            svm_port: sockaddr.port,
            ..Default::default()
        };
        let ret = unsafe {
            libc::connect(
                self.fd,
                &svm as *const sockaddr_vm as *const sockaddr,
                size_of::<sockaddr_vm>() as socklen_t,
            )
        };
        if ret < 0 {
            let connect_err = io::Error::last_os_error();
            Err(connect_err)
        } else {
            Ok(VsockStream { sock: self })
        }
    }
    pub fn listen(self) -> io::Result<VsockListener> {
        let ret = unsafe { libc::listen(self.fd, 1) };
        if ret < 0 {
            let listen_err = io::Error::last_os_error();
            return Err(listen_err);
        }
        Ok(VsockListener { sock: self })
    }
    pub fn local_port(&self) -> io::Result<u32> {
        let mut svm: sockaddr_vm = Default::default();
        let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
        let ret = unsafe {
            libc::getsockname(
                self.fd,
                &mut svm as *mut sockaddr_vm as *mut sockaddr,
                &mut addrlen as *mut socklen_t,
            )
        };
        if ret < 0 {
            let getsockname_err = io::Error::last_os_error();
            Err(getsockname_err)
        } else {
            assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
            Ok(svm.svm_port)
        }
    }
    pub fn try_clone(&self) -> io::Result<Self> {
        let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
        if dup_fd < 0 {
            Err(io::Error::last_os_error())
        } else {
            Ok(Self { fd: dup_fd })
        }
    }
    pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
        unsafe { set_nonblocking(self.fd, nonblocking) }
    }
}
impl IntoRawFd for VsockSocket {
    fn into_raw_fd(self) -> RawFd {
        let fd = self.fd;
        mem::forget(self);
        fd
    }
}
impl AsRawFd for VsockSocket {
    fn as_raw_fd(&self) -> RawFd {
        self.fd
    }
}
impl Drop for VsockSocket {
    fn drop(&mut self) {
        unsafe { libc::close(self.fd) };
    }
}
#[derive(Debug)]
pub struct VsockStream {
    sock: VsockSocket,
}
impl VsockStream {
    pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
        let sock = VsockSocket::new()?;
        sock.connect(addr)
    }
    pub fn local_port(&self) -> io::Result<u32> {
        self.sock.local_port()
    }
    pub fn try_clone(&self) -> io::Result<VsockStream> {
        self.sock.try_clone().map(|f| VsockStream { sock: f })
    }
    pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
        self.sock.set_nonblocking(nonblocking)
    }
}
impl io::Read for VsockStream {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        let ret = unsafe {
            libc::read(
                self.sock.as_raw_fd(),
                buf as *mut [u8] as *mut c_void,
                buf.len() as size_t,
            )
        };
        if ret < 0 {
            return Err(io::Error::last_os_error());
        }
        Ok(ret as usize)
    }
}
impl io::Write for VsockStream {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let ret = unsafe {
            libc::write(
                self.sock.as_raw_fd(),
                buf as *const [u8] as *const c_void,
                buf.len() as size_t,
            )
        };
        if ret < 0 {
            return Err(io::Error::last_os_error());
        }
        Ok(ret as usize)
    }
    fn flush(&mut self) -> io::Result<()> {
        Ok(())
    }
}
impl AsRawFd for VsockStream {
    fn as_raw_fd(&self) -> RawFd {
        self.sock.as_raw_fd()
    }
}
impl IntoRawFd for VsockStream {
    fn into_raw_fd(self) -> RawFd {
        self.sock.into_raw_fd()
    }
}
#[derive(Debug)]
pub struct VsockListener {
    sock: VsockSocket,
}
impl VsockListener {
    pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
        let mut sock = VsockSocket::new()?;
        sock.bind(addr)?;
        sock.listen()
    }
    pub fn local_port(&self) -> io::Result<u32> {
        self.sock.local_port()
    }
    pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
        let mut svm: sockaddr_vm = Default::default();
        let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
        let fd = unsafe {
            libc::accept4(
                self.sock.as_raw_fd(),
                &mut svm as *mut sockaddr_vm as *mut sockaddr,
                &mut socklen as *mut socklen_t,
                libc::SOCK_CLOEXEC,
            )
        };
        if fd < 0 {
            return Err(io::Error::last_os_error());
        }
        if svm.svm_family != AF_VSOCK {
            return Err(io::Error::new(
                io::ErrorKind::InvalidData,
                format!("unexpected address family: {}", svm.svm_family),
            ));
        }
        Ok((
            VsockStream {
                sock: VsockSocket { fd },
            },
            SocketAddr {
                cid: svm.svm_cid.into(),
                port: svm.svm_port,
            },
        ))
    }
    pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
        self.sock.set_nonblocking(nonblocking)
    }
}
impl AsRawFd for VsockListener {
    fn as_raw_fd(&self) -> RawFd {
        self.sock.as_raw_fd()
    }
}