devices/virtio/vhost_user_backend/
fs.rs1mod sys;
6
7use std::collections::BTreeMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10use std::sync::RwLock;
11
12use anyhow::bail;
13use argh::FromArgs;
14use base::error;
15use base::info;
16use base::warn;
17use base::AsRawDescriptor;
18use base::FromRawDescriptor;
19use base::RawDescriptor;
20use base::SafeDescriptor;
21use base::Tube;
22use base::UnixSeqpacketListener;
23use base::WorkerThread;
24use data_model::Le32;
25use fuse::Server;
26use hypervisor::ProtectionType;
27use snapshot::AnySnapshot;
28use sync::Mutex;
29pub use sys::start_device as run_fs_device;
30use virtio_sys::virtio_fs::virtio_fs_config;
31use vm_control::FsAllowlistCommand;
32use vm_control::FsAllowlistResponse;
33use vm_memory::GuestMemory;
34use vmm_vhost::message::VhostUserProtocolFeatures;
35use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
36use zerocopy::IntoBytes;
37
38use crate::virtio;
39use crate::virtio::copy_config;
40use crate::virtio::device_constants::fs::FS_MAX_TAG_LEN;
41use crate::virtio::fs::passthrough::PassthroughFs;
42use crate::virtio::fs::Config;
43use crate::virtio::fs::PathAllowlist;
44use crate::virtio::fs::Worker;
45use crate::virtio::vhost_user_backend::handler::Error as DeviceError;
46use crate::virtio::vhost_user_backend::handler::VhostUserDevice;
47use crate::virtio::Queue;
48
49const MAX_QUEUE_NUM: usize = 2; pub(crate) struct FsBackend {
52 server: Arc<fuse::Server<PassthroughFs>>,
53 tag: String,
54 avail_features: u64,
55 workers: BTreeMap<usize, WorkerThread<Queue>>,
56 keep_rds: Vec<RawDescriptor>,
57 unmap_guest_memory_on_fork: bool,
58 allowlist_socket_fd: Option<SafeDescriptor>,
59 allowlist: Option<Arc<RwLock<PathAllowlist>>>,
60}
61
62fn handle_client_session(tube: Tube, allowlist: &Arc<RwLock<PathAllowlist>>) {
76 loop {
77 match tube.recv::<FsAllowlistCommand>() {
78 Ok(cmd) => {
79 let result = match cmd {
80 FsAllowlistCommand::AddPaths { paths } => {
81 info!("Allowlist socket: Add paths {:?}", paths);
82 let mut al_guard = allowlist.write().expect(
83 "Allowlist lock poisoned during write (add_paths). Terminating.",
84 );
85 let mut al_clone = al_guard.clone();
86 let mut success = true;
87 for path in &paths {
88 if !al_clone.add_path(path) {
89 error!("Allowlist socket: Failed to add invalid path: {:?}", path);
90 success = false;
91 break;
92 }
93 }
94 if success {
95 *al_guard = al_clone;
96 FsAllowlistResponse::Ok
97 } else {
98 FsAllowlistResponse::Err("Failed to add one or more paths".to_string())
99 }
100 }
101 FsAllowlistCommand::RemovePaths { paths } => {
102 info!("Allowlist socket: Remove paths {:?}", paths);
103 let mut al_guard = allowlist.write().expect(
104 "Allowlist lock poisoned during write (remove_paths). Terminating.",
105 );
106 let mut al_clone = al_guard.clone();
107 let mut success = true;
108 for path in &paths {
109 if !al_clone.remove_path(path) {
110 error!("Allowlist socket: Failed to remove path: {:?}", path);
111 success = false;
112 break;
113 }
114 }
115 if success {
116 *al_guard = al_clone;
117 FsAllowlistResponse::Ok
118 } else {
119 FsAllowlistResponse::Err(
120 "Failed to remove one or more paths".to_string(),
121 )
122 }
123 }
124 };
125
126 if let Err(e) = tube.send(&result) {
127 error!("Allowlist socket: Failed to send response: {}", e);
128 }
129 }
130 Err(base::TubeError::Disconnected) => {
131 info!("Allowlist socket: Client disconnected");
132 break;
133 }
134 Err(e) => {
135 error!("Allowlist socket: Error reading from control socket: {}", e);
136 break;
137 }
138 }
139 }
140}
141
142fn run_allowlist_listener(fd: SafeDescriptor, allowlist: Arc<RwLock<PathAllowlist>>) {
143 let path = format!("/proc/self/fd/{}", fd.as_raw_descriptor());
144 let listener = match UnixSeqpacketListener::bind(&path) {
145 Ok(l) => l,
146 Err(e) => {
147 error!(
148 "Allowlist socket: Failed to re-create listener from fd: {}",
149 e
150 );
151 return;
152 }
153 };
154
155 loop {
156 match listener.accept() {
157 Ok(seqpacket) => {
158 let tube = match Tube::try_from(seqpacket) {
159 Ok(t) => t,
160 Err(e) => {
161 error!("Allowlist socket: Failed to create Tube: {}", e);
162 continue;
163 }
164 };
165 handle_client_session(tube, &allowlist);
166 }
167 Err(e) => {
168 error!("Allowlist socket: Accept failed: {}", e);
169 break;
170 }
171 }
172 }
173}
174
175impl FsBackend {
176 #[allow(unused_variables)]
177 pub fn new(
178 tag: &str,
179 shared_dir: &str,
180 skip_pivot_root: bool,
181 cfg: Option<Config>,
182 allowlist_socket_fd: Option<RawDescriptor>,
183 ) -> anyhow::Result<Self> {
184 if tag.len() > FS_MAX_TAG_LEN {
185 bail!(
186 "fs tag is too long: {} (max supported: {})",
187 tag.len(),
188 FS_MAX_TAG_LEN
189 );
190 }
191
192 let avail_features = virtio::base_features(ProtectionType::Unprotected)
193 | 1 << VHOST_USER_F_PROTOCOL_FEATURES;
194
195 let cfg = cfg.unwrap_or_default();
196
197 #[cfg(any(target_os = "android", target_os = "linux"))]
198 let unmap_guest_memory_on_fork = cfg.unmap_guest_memory_on_fork;
199 #[cfg(not(any(target_os = "android", target_os = "linux")))]
200 let unmap_guest_memory_on_fork = false;
201
202 #[allow(unused_mut)]
204 let mut fs = PassthroughFs::new(tag, cfg)?;
205 #[cfg(feature = "fs_runtime_ugid_map")]
206 if skip_pivot_root {
207 fs.set_root_dir(shared_dir.to_string())?;
208 }
209
210 let allowlist_socket_fd = allowlist_socket_fd.map(|fd| {
211 unsafe { SafeDescriptor::from_raw_descriptor(fd) }
213 });
214
215 let allowlist = if allowlist_socket_fd.is_some() {
216 let al = Arc::new(RwLock::new(PathAllowlist::new()));
217 fs.set_allowlist(Some(al.clone()));
218 Some(al)
219 } else {
220 None
221 };
222
223 let mut keep_rds: Vec<RawDescriptor> = [0, 1, 2].to_vec();
224 keep_rds.append(&mut fs.keep_rds());
225 if let Some(ref fd) = allowlist_socket_fd {
226 keep_rds.push(fd.as_raw_descriptor());
227 }
228
229 let server = Arc::new(Server::new(fs));
230
231 Ok(FsBackend {
232 server,
233 tag: tag.to_owned(),
234 avail_features,
235 workers: Default::default(),
236 keep_rds,
237 unmap_guest_memory_on_fork,
238 allowlist_socket_fd,
239 allowlist,
240 })
241 }
242
243 pub fn start_allowlist_listener(&mut self) {
244 if let Some(fd) = self.allowlist_socket_fd.take() {
245 if let Some(allowlist) = &self.allowlist {
246 let allowlist = allowlist.clone();
247 let result = std::thread::Builder::new()
248 .name("fs_allowlist_listener".to_string())
249 .spawn(move || {
250 run_allowlist_listener(fd, allowlist);
251 });
252 if let Err(e) = result {
253 error!("Failed to spawn allowlist listener thread: {}", e);
254 }
255 }
256 }
257 }
258}
259
260impl VhostUserDevice for FsBackend {
261 fn max_queue_num(&self) -> usize {
262 MAX_QUEUE_NUM
263 }
264
265 fn features(&self) -> u64 {
266 self.avail_features
267 }
268
269 fn protocol_features(&self) -> VhostUserProtocolFeatures {
270 VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::MQ
271 }
272
273 fn read_config(&self, offset: u64, data: &mut [u8]) {
274 let mut config = virtio_fs_config {
275 tag: [0; FS_MAX_TAG_LEN],
276 num_request_queues: Le32::from(1),
277 };
278 config.tag[..self.tag.len()].copy_from_slice(self.tag.as_bytes());
279 copy_config(data, 0, config.as_bytes(), offset);
280 }
281
282 fn reset(&mut self) {
283 for worker in std::mem::take(&mut self.workers).into_values() {
284 let _ = worker.stop();
285 }
286 }
287
288 fn start_queue(
289 &mut self,
290 idx: usize,
291 queue: virtio::Queue,
292 _mem: GuestMemory,
293 ) -> anyhow::Result<()> {
294 if self.workers.contains_key(&idx) {
295 warn!("Starting new queue handler without stopping old handler");
296 self.stop_queue(idx)?;
297 }
298
299 let (_, fs_device_tube) = Tube::pair()?;
300 let tube = Arc::new(Mutex::new(fs_device_tube));
301
302 let server = self.server.clone();
303
304 let slot: u32 = 0;
306
307 let worker = WorkerThread::start(format!("v_fs:{}:{}", self.tag, idx), move |kill_evt| {
308 let mut worker = Worker::new(queue, server, tube, slot);
309 if let Err(e) = worker.run(kill_evt) {
310 error!("vhost-user-fs worker failed: {e:#}");
311 }
312 worker.queue
313 });
314 self.workers.insert(idx, worker);
315
316 Ok(())
317 }
318
319 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<virtio::Queue> {
320 info!("Stopping vhost-user fs queue [{idx}]");
322 if let Some(worker) = self.workers.remove(&idx) {
323 let queue = worker.stop();
324 Ok(queue)
325 } else {
326 Err(anyhow::Error::new(DeviceError::WorkerNotFound))
327 }
328 }
329
330 fn unmap_guest_memory_on_fork(&self) -> bool {
331 self.unmap_guest_memory_on_fork
332 }
333
334 fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
335 Ok(())
337 }
338
339 fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
340 bail!("snapshot not implemented for vhost-user fs");
341 }
342
343 fn restore(&mut self, _data: AnySnapshot) -> anyhow::Result<()> {
344 bail!("snapshot not implemented for vhost-user fs");
345 }
346}
347
348#[derive(FromArgs)]
349#[argh(subcommand, name = "fs")]
350pub struct Options {
352 #[argh(option, arg_name = "PATH", hidden_help)]
353 socket: Option<String>,
355 #[argh(option, arg_name = "PATH")]
356 socket_path: Option<String>,
359 #[argh(option, arg_name = "FD")]
360 fd: Option<RawDescriptor>,
363 #[argh(option, arg_name = "PATH")]
364 allowlist_socket_path: Option<PathBuf>,
368
369 #[argh(option, arg_name = "TAG")]
370 tag: String,
372 #[argh(option, arg_name = "DIR")]
373 shared_dir: PathBuf,
375 #[argh(option, arg_name = "UIDMAP")]
376 uid_map: Option<String>,
378 #[argh(option, arg_name = "GIDMAP")]
379 gid_map: Option<String>,
381 #[argh(option, arg_name = "CFG")]
382 cfg: Option<Config>,
387 #[argh(option, arg_name = "UID", default = "0")]
388 uid: u32,
401 #[argh(option, arg_name = "GID", default = "0")]
402 gid: u32,
405 #[argh(switch)]
406 disable_sandbox: bool,
412 #[argh(option, arg_name = "skip_pivot_root", default = "false")]
413 #[allow(dead_code)]
424 skip_pivot_root: bool,
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn test_run_allowlist_listener() {
433 let temp_dir = tempfile::TempDir::new().unwrap();
434 let socket_path = temp_dir.path().join("test.sock");
435 let listener = UnixSeqpacketListener::bind(&socket_path).unwrap();
436 let allowlist = Arc::new(RwLock::new(PathAllowlist::default()));
437
438 use std::os::fd::OwnedFd;
439 let fd = SafeDescriptor::from(OwnedFd::from(listener));
440 let fd_clone = fd.try_clone().unwrap();
441
442 let allowlist_clone = allowlist.clone();
443 let handle = std::thread::spawn(move || {
444 run_allowlist_listener(fd, allowlist_clone);
445 });
446
447 use base::UnixSeqpacket;
448 let client_socket = UnixSeqpacket::connect(&socket_path).unwrap();
449 let client_tube = Tube::try_from(client_socket).unwrap();
450
451 client_tube
453 .send(&FsAllowlistCommand::AddPaths {
454 paths: vec!["/allowed_path1".into(), "/allowed_path2".into()],
455 })
456 .unwrap();
457 let resp: FsAllowlistResponse = client_tube.recv().unwrap();
458 assert!(matches!(resp, FsAllowlistResponse::Ok));
459
460 {
462 let al = allowlist.read().unwrap();
463 assert!(al.is_accessible("/allowed_path1"));
464 assert!(al.is_accessible("/allowed_path2"));
465 }
466
467 client_tube
469 .send(&FsAllowlistCommand::RemovePaths {
470 paths: vec!["/allowed_path1".into(), "/allowed_path2".into()],
471 })
472 .unwrap();
473 let resp: FsAllowlistResponse = client_tube.recv().unwrap();
474 assert!(matches!(resp, FsAllowlistResponse::Ok));
475
476 {
478 let al = allowlist.read().unwrap();
479 assert!(!al.is_accessible("/allowed_path1"));
480 assert!(!al.is_accessible("/allowed_path2"));
481 }
482
483 client_tube
485 .send(&FsAllowlistCommand::AddPaths {
486 paths: vec!["/valid_but_rolled_back".into(), "/a/../../..".into()],
487 })
488 .unwrap();
489 let resp: FsAllowlistResponse = client_tube.recv().unwrap();
490 assert!(matches!(resp, FsAllowlistResponse::Err(_)));
491
492 {
494 let al = allowlist.read().unwrap();
495 assert!(!al.is_accessible("/valid_but_rolled_back"));
496 }
497
498 drop(client_tube);
500
501 unsafe {
504 libc::shutdown(fd_clone.as_raw_descriptor(), libc::SHUT_RDWR);
505 }
506
507 let join_res = handle.join();
508 assert!(join_res.is_ok(), "Listener thread panicked!");
509 }
510}