pub mod sys;
use std::cell::RefCell;
use std::rc::Rc;
use std::sync::Arc;
use anyhow::anyhow;
use anyhow::bail;
use anyhow::Context;
use base::error;
use base::warn;
use base::Tube;
use cros_async::EventAsync;
use cros_async::Executor;
use cros_async::TaskHandle;
use futures::FutureExt;
use futures::StreamExt;
use sync::Mutex;
pub use sys::run_gpu_device;
pub use sys::Options;
use vm_memory::GuestMemory;
use vmm_vhost::message::VhostUserProtocolFeatures;
use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
use crate::virtio::device_constants::gpu::NUM_QUEUES;
use crate::virtio::gpu;
use crate::virtio::gpu::QueueReader;
use crate::virtio::vhost::user::device::handler::Error as DeviceError;
use crate::virtio::vhost::user::device::handler::VhostBackendReqConnection;
use crate::virtio::vhost::user::device::handler::VhostUserDevice;
use crate::virtio::vhost::user::device::handler::WorkerState;
use crate::virtio::DescriptorChain;
use crate::virtio::Gpu;
use crate::virtio::Queue;
use crate::virtio::SharedMemoryMapper;
use crate::virtio::SharedMemoryRegion;
use crate::virtio::VirtioDevice;
const MAX_QUEUE_NUM: usize = NUM_QUEUES;
#[derive(Clone)]
struct SharedReader {
queue: Arc<Mutex<Queue>>,
}
impl gpu::QueueReader for SharedReader {
fn pop(&self) -> Option<DescriptorChain> {
self.queue.lock().pop()
}
fn add_used(&self, desc_chain: DescriptorChain, len: u32) {
self.queue.lock().add_used(desc_chain, len)
}
fn signal_used(&self) {
self.queue.lock().trigger_interrupt();
}
}
async fn run_ctrl_queue(
reader: SharedReader,
mem: GuestMemory,
kick_evt: EventAsync,
state: Rc<RefCell<gpu::Frontend>>,
) {
loop {
if let Err(e) = kick_evt.next_val().await {
error!("Failed to read kick event for ctrl queue: {}", e);
break;
}
let mut state = state.borrow_mut();
let needs_interrupt = state.process_queue(&mem, &reader);
if needs_interrupt {
reader.signal_used();
}
}
}
struct GpuBackend {
ex: Executor,
gpu: Rc<RefCell<Gpu>>,
resource_bridges: Arc<Mutex<Vec<Tube>>>,
state: Option<Rc<RefCell<gpu::Frontend>>>,
fence_state: Arc<Mutex<gpu::FenceState>>,
queue_workers: [Option<WorkerState<Arc<Mutex<Queue>>, ()>>; MAX_QUEUE_NUM],
platform_worker_tx: futures::channel::mpsc::UnboundedSender<TaskHandle<()>>,
platform_worker_rx: futures::channel::mpsc::UnboundedReceiver<TaskHandle<()>>,
shmem_mapper: Arc<Mutex<Option<Box<dyn SharedMemoryMapper>>>>,
}
impl GpuBackend {
fn stop_non_queue_workers(&mut self) -> anyhow::Result<()> {
self.ex
.run_until(async {
while let Some(Some(handle)) = self.platform_worker_rx.next().now_or_never() {
handle.cancel().await;
}
})
.context("stopping the non-queue workers for GPU")?;
Ok(())
}
}
impl VhostUserDevice for GpuBackend {
fn max_queue_num(&self) -> usize {
MAX_QUEUE_NUM
}
fn features(&self) -> u64 {
self.gpu.borrow().features() | 1 << VHOST_USER_F_PROTOCOL_FEATURES
}
fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
self.gpu.borrow_mut().ack_features(value);
Ok(())
}
fn protocol_features(&self) -> VhostUserProtocolFeatures {
VhostUserProtocolFeatures::CONFIG
| VhostUserProtocolFeatures::BACKEND_REQ
| VhostUserProtocolFeatures::MQ
| VhostUserProtocolFeatures::SHARED_MEMORY_REGIONS
| VhostUserProtocolFeatures::DEVICE_STATE
}
fn read_config(&self, offset: u64, dst: &mut [u8]) {
self.gpu.borrow().read_config(offset, dst)
}
fn write_config(&self, offset: u64, data: &[u8]) {
self.gpu.borrow_mut().write_config(offset, data)
}
fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()> {
if self.queue_workers[idx].is_some() {
warn!("Starting new queue handler without stopping old handler");
self.stop_queue(idx)?;
}
let doorbell = queue.interrupt().clone();
let queue = Arc::new(Mutex::new(queue));
let queue_task = match idx {
0 => {
let kick_evt = queue
.lock()
.event()
.try_clone()
.context("failed to clone queue event")?;
let kick_evt = EventAsync::new(kick_evt, &self.ex)
.context("failed to create EventAsync for kick_evt")?;
let reader = SharedReader {
queue: queue.clone(),
};
let state = if let Some(s) = self.state.as_ref() {
s.clone()
} else {
let fence_handler_resources =
Arc::new(Mutex::new(Some(gpu::FenceHandlerActivationResources {
mem: mem.clone(),
ctrl_queue: reader.clone(),
})));
let fence_handler = gpu::create_fence_handler(
fence_handler_resources,
self.fence_state.clone(),
);
let state = Rc::new(RefCell::new(
self.gpu
.borrow_mut()
.initialize_frontend(
self.fence_state.clone(),
fence_handler,
Arc::clone(&self.shmem_mapper),
)
.ok_or_else(|| anyhow!("failed to initialize gpu frontend"))?,
));
self.state = Some(state.clone());
state
};
self.start_platform_workers(doorbell)?;
self.ex
.spawn_local(run_ctrl_queue(reader, mem, kick_evt, state))
}
1 => {
self.ex.spawn_local(async {})
}
_ => bail!("attempted to start unknown queue: {}", idx),
};
self.queue_workers[idx] = Some(WorkerState { queue_task, queue });
Ok(())
}
fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
if let Some(worker) = self.queue_workers.get_mut(idx).and_then(Option::take) {
let _ = self.ex.run_until(worker.queue_task.cancel());
if idx == 0 {
self.stop_non_queue_workers()?;
self.state = None;
}
let queue = match Arc::try_unwrap(worker.queue) {
Ok(queue_mutex) => queue_mutex.into_inner(),
Err(_) => panic!("failed to recover queue from worker"),
};
Ok(queue)
} else {
Err(anyhow::Error::new(DeviceError::WorkerNotFound))
}
}
fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
self.stop_non_queue_workers()?;
Ok(())
}
fn reset(&mut self) {
self.stop_non_queue_workers()
.expect("Failed to stop platform workers.");
for queue_num in 0..self.max_queue_num() {
if self.queue_workers[queue_num].is_some() {
if let Err(e) = self.stop_queue(queue_num) {
error!("Failed to stop_queue during reset: {}", e);
}
}
}
}
fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
self.gpu.borrow().get_shared_memory_region()
}
fn set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>) {
if self
.shmem_mapper
.lock()
.replace(conn.take_shmem_mapper().unwrap())
.is_some()
{
warn!("Connection already established. Overwriting shmem_mapper");
}
}
fn snapshot(&mut self) -> anyhow::Result<serde_json::Value> {
Ok(serde_json::Value::Null)
}
fn restore(&mut self, data: serde_json::Value) -> anyhow::Result<()> {
anyhow::ensure!(
data.is_null(),
"unexpected snapshot data: should be null, got {}",
data
);
Ok(())
}
}
impl Drop for GpuBackend {
fn drop(&mut self) {
self.reset();
}
}