devices/virtio/vhost_user_backend/
gpu.rs1pub mod sys;
6
7use std::cell::RefCell;
8use std::rc::Rc;
9use std::sync::Arc;
10
11use anyhow::anyhow;
12use anyhow::bail;
13use anyhow::Context;
14use base::error;
15use base::warn;
16use base::Tube;
17use cros_async::EventAsync;
18use cros_async::Executor;
19use cros_async::TaskHandle;
20use futures::FutureExt;
21use futures::StreamExt;
22use snapshot::AnySnapshot;
23use sync::Mutex;
24pub use sys::run_gpu_device;
25pub use sys::Options;
26use vm_memory::GuestMemory;
27use vmm_vhost::message::VhostUserProtocolFeatures;
28use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
29
30use crate::virtio::device_constants::gpu::NUM_QUEUES;
31use crate::virtio::gpu;
32use crate::virtio::gpu::QueueReader;
33use crate::virtio::vhost_user_backend::handler::Error as DeviceError;
34use crate::virtio::vhost_user_backend::handler::VhostBackendReqConnection;
35use crate::virtio::vhost_user_backend::handler::VhostUserDevice;
36use crate::virtio::vhost_user_backend::handler::WorkerState;
37use crate::virtio::DescriptorChain;
38use crate::virtio::Gpu;
39use crate::virtio::Queue;
40use crate::virtio::SharedMemoryMapper;
41use crate::virtio::SharedMemoryRegion;
42use crate::virtio::VirtioDevice;
43
44const MAX_QUEUE_NUM: usize = NUM_QUEUES;
45
46#[derive(Clone)]
47struct SharedReader {
48 queue: Arc<Mutex<Queue>>,
49}
50
51impl gpu::QueueReader for SharedReader {
52 fn pop(&self) -> Option<DescriptorChain> {
53 self.queue.lock().pop()
54 }
55
56 fn add_used(&self, desc_chain: DescriptorChain, len: u32) {
57 self.queue
58 .lock()
59 .add_used_with_bytes_written(desc_chain, len)
60 }
61
62 fn signal_used(&self) {
63 self.queue.lock().trigger_interrupt();
64 }
65}
66
67async fn run_ctrl_queue(
68 reader: SharedReader,
69 mem: GuestMemory,
70 kick_evt: EventAsync,
71 state: Rc<RefCell<gpu::Frontend>>,
72) {
73 loop {
74 if let Err(e) = kick_evt.next_val().await {
75 error!("Failed to read kick event for ctrl queue: {}", e);
76 break;
77 }
78
79 let mut state = state.borrow_mut();
80 let needs_interrupt = state.process_queue(&mem, &reader);
81
82 if needs_interrupt {
83 reader.signal_used();
84 }
85 }
86}
87
88struct GpuBackend {
89 ex: Executor,
90 gpu: Rc<RefCell<Gpu>>,
91 resource_bridges: Arc<Mutex<Vec<Tube>>>,
92 state: Option<Rc<RefCell<gpu::Frontend>>>,
93 fence_state: Arc<Mutex<gpu::FenceState>>,
94 queue_workers: [Option<WorkerState<Arc<Mutex<Queue>>, ()>>; MAX_QUEUE_NUM],
95 platform_worker_tx: futures::channel::mpsc::UnboundedSender<TaskHandle<()>>,
97 platform_worker_rx: futures::channel::mpsc::UnboundedReceiver<TaskHandle<()>>,
98 shmem_mapper: Arc<Mutex<Option<Box<dyn SharedMemoryMapper>>>>,
99}
100
101impl GpuBackend {
102 fn stop_non_queue_workers(&mut self) -> anyhow::Result<()> {
103 self.ex
104 .run_until(async {
105 while let Some(Some(handle)) = self.platform_worker_rx.next().now_or_never() {
106 handle.cancel().await;
107 }
108 })
109 .context("stopping the non-queue workers for GPU")?;
110 Ok(())
111 }
112}
113
114impl VhostUserDevice for GpuBackend {
115 fn max_queue_num(&self) -> usize {
116 MAX_QUEUE_NUM
117 }
118
119 fn features(&self) -> u64 {
120 self.gpu.borrow().features() | 1 << VHOST_USER_F_PROTOCOL_FEATURES
121 }
122
123 fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
124 self.gpu.borrow_mut().ack_features(value);
125 Ok(())
126 }
127
128 fn protocol_features(&self) -> VhostUserProtocolFeatures {
129 VhostUserProtocolFeatures::CONFIG
130 | VhostUserProtocolFeatures::BACKEND_REQ
131 | VhostUserProtocolFeatures::MQ
132 | VhostUserProtocolFeatures::SHARED_MEMORY_REGIONS
133 | VhostUserProtocolFeatures::DEVICE_STATE
134 }
135
136 fn read_config(&self, offset: u64, dst: &mut [u8]) {
137 self.gpu.borrow().read_config(offset, dst)
138 }
139
140 fn write_config(&self, offset: u64, data: &[u8]) {
141 self.gpu.borrow_mut().write_config(offset, data)
142 }
143
144 fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()> {
145 if self.queue_workers[idx].is_some() {
146 warn!("Starting new queue handler without stopping old handler");
147 self.stop_queue(idx)?;
148 }
149
150 let doorbell = queue.interrupt().clone();
151
152 let queue = Arc::new(Mutex::new(queue));
158
159 let queue_task = match idx {
161 0 => {
162 let kick_evt = queue
164 .lock()
165 .event()
166 .try_clone()
167 .context("failed to clone queue event")?;
168 let kick_evt = EventAsync::new(kick_evt, &self.ex)
169 .context("failed to create EventAsync for kick_evt")?;
170 let reader = SharedReader {
171 queue: queue.clone(),
172 };
173
174 let state = if let Some(s) = self.state.as_ref() {
175 s.clone()
176 } else {
177 let fence_handler_resources =
178 Arc::new(Mutex::new(Some(gpu::FenceHandlerActivationResources {
179 mem: mem.clone(),
180 ctrl_queue: reader.clone(),
181 })));
182 let fence_handler = gpu::create_fence_handler(
183 fence_handler_resources,
184 self.fence_state.clone(),
185 );
186
187 let state = Rc::new(RefCell::new(
188 self.gpu
189 .borrow_mut()
190 .initialize_frontend(
191 self.fence_state.clone(),
192 fence_handler,
193 Arc::clone(&self.shmem_mapper),
194 )
195 .ok_or_else(|| anyhow!("failed to initialize gpu frontend"))?,
196 ));
197 self.state = Some(state.clone());
198 state
199 };
200
201 self.start_platform_workers(doorbell)?;
203
204 self.ex
206 .spawn_local(run_ctrl_queue(reader, mem, kick_evt, state))
207 }
208 1 => {
209 self.ex.spawn_local(async {})
213 }
214 _ => bail!("attempted to start unknown queue: {}", idx),
215 };
216
217 self.queue_workers[idx] = Some(WorkerState { queue_task, queue });
218 Ok(())
219 }
220
221 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
222 if let Some(worker) = self.queue_workers.get_mut(idx).and_then(Option::take) {
223 let _ = self.ex.run_until(worker.queue_task.cancel());
225
226 if idx == 0 {
227 self.stop_non_queue_workers()?;
229
230 self.state = None;
235 }
236
237 let queue = match Arc::try_unwrap(worker.queue) {
238 Ok(queue_mutex) => queue_mutex.into_inner(),
239 Err(_) => panic!("failed to recover queue from worker"),
240 };
241
242 Ok(queue)
243 } else {
244 Err(anyhow::Error::new(DeviceError::WorkerNotFound))
245 }
246 }
247
248 fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
249 self.stop_non_queue_workers()?;
250 Ok(())
251 }
252
253 fn reset(&mut self) {
254 self.stop_non_queue_workers()
255 .expect("Failed to stop platform workers.");
256
257 for queue_num in 0..self.max_queue_num() {
258 if self.queue_workers[queue_num].is_some() {
261 if let Err(e) = self.stop_queue(queue_num) {
262 error!("Failed to stop_queue during reset: {}", e);
263 }
264 }
265 }
266 }
267
268 fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
269 self.gpu.borrow().get_shared_memory_region()
270 }
271
272 fn set_backend_req_connection(&mut self, conn: VhostBackendReqConnection) {
273 if self
274 .shmem_mapper
275 .lock()
276 .replace(conn.shmem_mapper().unwrap())
277 .is_some()
278 {
279 warn!("Connection already established. Overwriting shmem_mapper");
280 }
281 }
282
283 fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
284 AnySnapshot::to_any(())
287 }
288
289 fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
290 let () = AnySnapshot::from_any(data)?;
291 Ok(())
292 }
293}
294
295impl Drop for GpuBackend {
296 fn drop(&mut self) {
297 self.reset();
301 }
302}