devices/virtio/fs/
allowlist.rs1use std::collections::HashMap;
6use std::collections::HashSet;
7use std::ffi::OsStr;
8use std::ffi::OsString;
9use std::path::Path;
10use std::path::PathBuf;
11
12fn normalize_lexically(path: &Path) -> Option<PathBuf> {
23 let mut components = Vec::new();
24 for component in path.components() {
25 match component {
26 std::path::Component::RootDir => {
27 components.clear();
28 }
29 std::path::Component::CurDir => {}
30 std::path::Component::ParentDir => {
31 components.pop()?;
33 }
34 std::path::Component::Normal(c) => {
35 components.push(c);
36 }
37 _ => {}
38 }
39 }
40 let mut normalized = PathBuf::from("/");
41 for c in components {
42 normalized.push(c);
43 }
44 Some(normalized)
45}
46
47#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
49pub enum AccessLevel {
50 None,
52 Traverse,
56 Full,
59}
60
61#[derive(Debug, Clone)]
63pub enum ReadDirFilter {
64 AllowAll,
66 AllowOnly(HashSet<OsString>),
68 DenyAll,
70}
71
72#[derive(Debug, Clone)]
73struct TrieNode {
74 access_level: AccessLevel,
75 children: HashMap<OsString, TrieNode>,
76 active_children_count: usize,
77}
78
79impl Default for TrieNode {
80 fn default() -> Self {
81 Self {
82 access_level: AccessLevel::None,
83 children: HashMap::new(),
84 active_children_count: 0,
85 }
86 }
87}
88
89impl TrieNode {
90 fn is_active(&self) -> bool {
92 self.access_level > AccessLevel::None || self.active_children_count > 0
93 }
94
95 fn has_active_descendants(&self) -> bool {
97 self.active_children_count > 0
98 }
99}
100
101#[derive(Debug, Clone, Default)]
155pub struct PathAllowlist {
156 root: TrieNode,
157}
158
159impl PathAllowlist {
160 pub fn new() -> Self {
162 Self {
163 root: TrieNode::default(),
164 }
165 }
166
167 fn parse_components(path: &Path) -> Vec<&OsStr> {
169 path.components()
170 .filter_map(|c| match c {
171 std::path::Component::Normal(s) => Some(s),
172 _ => None,
173 })
174 .collect()
175 }
176
177 pub fn add_path<P: AsRef<Path>>(&mut self, path: P) -> bool {
181 let normalized = match normalize_lexically(path.as_ref()) {
182 Some(p) => p,
183 None => return false,
184 };
185 let components = Self::parse_components(&normalized);
186
187 fn add_rec(node: &mut TrieNode, components: &[&OsStr]) -> bool {
190 let was_active = node.is_active();
191
192 if components.is_empty() {
193 node.access_level = AccessLevel::Full;
194 return !was_active && node.is_active();
195 }
196
197 let first = components[0];
198 if node.access_level == AccessLevel::None {
199 node.access_level = AccessLevel::Traverse;
200 }
201
202 let child = node.children.entry(first.to_os_string()).or_default();
203 let child_active_changed = add_rec(child, &components[1..]);
204
205 if child_active_changed {
206 node.active_children_count += 1;
207 }
208
209 !was_active && node.is_active()
210 }
211
212 add_rec(&mut self.root, &components);
213 true
214 }
215
216 pub fn remove_path<P: AsRef<Path>>(&mut self, path: P) -> bool {
220 let normalized = match normalize_lexically(path.as_ref()) {
221 Some(p) => p,
222 None => return false,
223 };
224 let components = Self::parse_components(&normalized);
225
226 fn remove_rec(node: &mut TrieNode, components: &[&OsStr]) -> (bool, bool) {
231 let was_active = node.is_active();
232
233 if components.is_empty() {
234 if node.access_level != AccessLevel::Full {
235 return (false, false);
236 }
237
238 if node.has_active_descendants() {
239 node.access_level = AccessLevel::Traverse;
242 } else {
243 node.access_level = AccessLevel::None;
244 }
245 let became_inactive = was_active && !node.is_active();
246 return (true, became_inactive);
247 }
248
249 let first = components[0];
250 let mut removed = false;
251
252 if let Some(child) = node.children.get_mut(first) {
253 let (child_removed, child_became_inactive) = remove_rec(child, &components[1..]);
254 removed = child_removed;
255
256 if child_became_inactive {
257 node.active_children_count -= 1;
258 }
259
260 if !child.is_active() {
262 node.children.remove(first);
263 }
264 }
265
266 if node.access_level == AccessLevel::Traverse && !node.has_active_descendants() {
268 node.access_level = AccessLevel::None;
269 }
270
271 let became_inactive = was_active && !node.is_active();
272 (removed, became_inactive)
273 }
274
275 let (removed, _) = remove_rec(&mut self.root, &components);
276 removed
277 }
278
279 fn get_access_level(&self, path: &Path) -> AccessLevel {
281 let normalized = match normalize_lexically(path) {
282 Some(p) => p,
283 None => return AccessLevel::None,
284 };
285 let components = Self::parse_components(&normalized);
286
287 let mut current = &self.root;
288 if current.access_level == AccessLevel::Full {
289 return AccessLevel::Full;
290 }
291
292 for comp in components {
293 if let Some(next) = current.children.get(comp) {
294 current = next;
295 if current.access_level == AccessLevel::Full {
296 return AccessLevel::Full;
297 }
298 } else {
299 return AccessLevel::None;
300 }
301 }
302
303 current.access_level
304 }
305
306 pub fn is_accessible<P: AsRef<Path>>(&self, path: P) -> bool {
310 if !self.root.is_active() {
311 return false;
312 }
313 self.get_access_level(path.as_ref()) >= AccessLevel::Traverse
314 }
315
316 pub fn is_writable<P: AsRef<Path>>(&self, path: P) -> bool {
320 if !self.root.is_active() {
321 return false;
322 }
323 self.get_access_level(path.as_ref()) == AccessLevel::Full
324 }
325
326 pub fn get_read_dir_filter<P: AsRef<Path>>(&self, parent_path: P) -> ReadDirFilter {
332 if !self.root.is_active() {
333 return ReadDirFilter::DenyAll;
334 }
335
336 let normalized = match normalize_lexically(parent_path.as_ref()) {
337 Some(p) => p,
338 None => return ReadDirFilter::DenyAll,
339 };
340 let components = Self::parse_components(&normalized);
341
342 let mut current = &self.root;
343 if current.access_level == AccessLevel::Full {
344 return ReadDirFilter::AllowAll;
345 }
346
347 for comp in components {
348 if let Some(next) = current.children.get(comp) {
349 current = next;
350 if current.access_level == AccessLevel::Full {
351 return ReadDirFilter::AllowAll;
352 }
353 } else {
354 return ReadDirFilter::DenyAll;
355 }
356 }
357
358 match current.access_level {
359 AccessLevel::Full => ReadDirFilter::AllowAll,
360 AccessLevel::Traverse => {
361 let allowed_entries = current
362 .children
363 .iter()
364 .filter(|(_, child)| child.is_active())
365 .map(|(name, _)| name.clone())
366 .collect::<HashSet<_>>();
367 ReadDirFilter::AllowOnly(allowed_entries)
368 }
369 AccessLevel::None => ReadDirFilter::DenyAll,
370 }
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_normalize_lexically() {
380 assert_eq!(
381 normalize_lexically(Path::new("/a/b/c")),
382 Some(PathBuf::from("/a/b/c"))
383 );
384 assert_eq!(
385 normalize_lexically(Path::new("/a/../b")),
386 Some(PathBuf::from("/b"))
387 );
388 assert_eq!(
389 normalize_lexically(Path::new("/a/./b")),
390 Some(PathBuf::from("/a/b"))
391 );
392 assert_eq!(
393 normalize_lexically(Path::new("/")),
394 Some(PathBuf::from("/"))
395 );
396 assert_eq!(normalize_lexically(Path::new("")), Some(PathBuf::from("/")));
397 assert_eq!(
398 normalize_lexically(Path::new("a/b")),
399 Some(PathBuf::from("/a/b"))
400 );
401 assert_eq!(
402 normalize_lexically(Path::new("/a/b/../c/./d")),
403 Some(PathBuf::from("/a/c/d"))
404 );
405
406 assert_eq!(normalize_lexically(Path::new("..")), None);
408 assert_eq!(normalize_lexically(Path::new("/..")), None);
409 assert_eq!(normalize_lexically(Path::new("/a/../..")), None);
410 assert_eq!(normalize_lexically(Path::new("/a/b/../../..")), None);
411 }
412
413 #[test]
414 fn test_path_allowlist_empty() {
415 let allowlist = PathAllowlist::new();
416 assert!(!allowlist.is_accessible("/a/b"));
418 assert!(!allowlist.is_writable("/a/b"));
419 }
420
421 #[test]
422 fn test_path_allowlist_allowed_rules() {
423 let mut allowlist = PathAllowlist::new();
424 allowlist.add_path("/a/b");
425
426 assert!(allowlist.is_accessible("/a/b"));
428 assert!(allowlist.is_accessible("/a/b/c"));
430 assert!(allowlist.is_accessible("/a/b/c/d"));
431 assert!(allowlist.is_accessible("/a"));
433 assert!(allowlist.is_accessible("/"));
434
435 assert!(!allowlist.is_accessible("/d"));
437 assert!(!allowlist.is_accessible("/a/c"));
438 }
439
440 #[test]
441 fn test_path_allowlist_writable_rules() {
442 let mut allowlist = PathAllowlist::new();
443 allowlist.add_path("/a/b");
444
445 assert!(allowlist.is_writable("/a/b"));
447 assert!(allowlist.is_writable("/a/b/c"));
449
450 assert!(!allowlist.is_writable("/a"));
452 assert!(!allowlist.is_writable("/"));
453
454 assert!(!allowlist.is_writable("/d"));
456 }
457
458 #[test]
459 fn test_path_allowlist_multiple_paths() {
460 let mut allowlist = PathAllowlist::new();
461 allowlist.add_path("/a/b");
462 allowlist.add_path("/c/d");
463
464 assert!(allowlist.is_accessible("/a/b"));
465 assert!(allowlist.is_accessible("/c/d"));
466 assert!(allowlist.is_accessible("/a"));
467 assert!(allowlist.is_accessible("/c"));
468
469 assert!(!allowlist.is_accessible("/e"));
470 }
471
472 #[test]
473 fn test_path_allowlist_remove_path() {
474 let mut allowlist = PathAllowlist::new();
475 allowlist.add_path("/a/b");
476 assert!(allowlist.is_accessible("/a/b"));
477
478 assert!(allowlist.remove_path("/a/b"));
479 assert!(!allowlist.is_accessible("/a/b"));
480
481 assert!(!allowlist.remove_path("/a/b"));
483 }
484
485 #[test]
486 fn test_path_allowlist_remove_parent_keeps_child() {
487 let mut allowlist = PathAllowlist::new();
490 allowlist.add_path("/a/b");
491 allowlist.add_path("/a/b/c");
492
493 assert!(allowlist.is_accessible("/a/b"));
494 assert!(allowlist.is_writable("/a/b"));
495 assert!(allowlist.is_accessible("/a/b/c"));
496 assert!(allowlist.is_writable("/a/b/c"));
497
498 assert!(allowlist.remove_path("/a/b"));
499
500 assert!(allowlist.is_accessible("/a/b/c"));
502 assert!(allowlist.is_writable("/a/b/c"));
503
504 assert!(allowlist.is_accessible("/a/b"));
506 assert!(!allowlist.is_writable("/a/b"));
507
508 assert!(allowlist.is_accessible("/a"));
510 assert!(allowlist.is_accessible("/"));
511
512 assert!(!allowlist.remove_path("/a/b"));
514 }
515
516 #[test]
517 fn test_path_allowlist_remove_child_inherited() {
518 let mut allowlist = PathAllowlist::new();
521 allowlist.add_path("/a/b");
522 allowlist.add_path("/a/b/c");
523
524 assert!(allowlist.remove_path("/a/b/c"));
525
526 assert!(allowlist.is_accessible("/a/b"));
528 assert!(allowlist.is_writable("/a/b"));
529
530 assert!(allowlist.is_accessible("/a/b/c"));
532 assert!(allowlist.is_writable("/a/b/c"));
533
534 assert!(!allowlist.remove_path("/a/b/c"));
536 }
537
538 #[test]
539 fn test_path_allowlist_remove_one_of_multiple_children() {
540 let mut allowlist = PathAllowlist::new();
543 allowlist.add_path("/a/b/c");
544 allowlist.add_path("/a/b/d");
545
546 assert!(allowlist.is_accessible("/a/b/c"));
547 assert!(allowlist.is_writable("/a/b/c"));
548 assert!(allowlist.is_accessible("/a/b/d"));
549 assert!(allowlist.is_writable("/a/b/d"));
550 assert!(allowlist.is_accessible("/a/b"));
551 assert!(!allowlist.is_writable("/a/b"));
552
553 assert!(allowlist.remove_path("/a/b/c"));
554
555 assert!(allowlist.is_accessible("/a/b/d"));
557 assert!(allowlist.is_writable("/a/b/d"));
558
559 assert!(!allowlist.is_accessible("/a/b/c"));
561 assert!(!allowlist.is_writable("/a/b/c"));
562
563 assert!(allowlist.is_accessible("/a/b"));
565 assert!(!allowlist.is_writable("/a/b"));
566
567 assert!(!allowlist.remove_path("/a/b/c"));
569 }
570
571 #[test]
572 fn test_path_allowlist_remove_non_existent() {
573 let mut allowlist = PathAllowlist::new();
575 allowlist.add_path("/a/b");
576
577 assert!(!allowlist.remove_path("/a/c"));
578
579 assert!(allowlist.is_accessible("/a/b"));
580 assert!(allowlist.is_writable("/a/b"));
581 }
582
583 #[cfg(unix)]
584 #[test]
585 fn test_path_allowlist_non_utf8() {
586 use std::os::unix::ffi::OsStrExt;
587
588 let mut allowlist = PathAllowlist::new();
589 let path_ff = OsStr::from_bytes(b"/a/b\xff");
591 let path_fe = OsStr::from_bytes(b"/a/b\xfe");
592
593 allowlist.add_path(Path::new(path_ff));
594
595 assert!(allowlist.is_accessible(Path::new(path_ff)));
597 assert!(allowlist.is_writable(Path::new(path_ff)));
598
599 assert!(!allowlist.is_accessible(Path::new(path_fe)));
601 assert!(!allowlist.is_writable(Path::new(path_fe)));
602 }
603
604 #[test]
605 fn test_path_allowlist_invalid_paths() {
606 let mut allowlist = PathAllowlist::new();
607
608 assert!(!allowlist.add_path("/a/../.."));
610 assert!(!allowlist.is_accessible("/"));
611
612 assert!(allowlist.add_path("/a"));
614 assert!(allowlist.is_accessible("/a"));
615 assert!(!allowlist.remove_path("/a/../.."));
616 assert!(allowlist.is_accessible("/a"));
617 }
618
619 #[test]
620 fn test_path_allowlist_get_read_dir_filter() {
621 let mut allowlist = PathAllowlist::new();
622
623 assert!(matches!(
625 allowlist.get_read_dir_filter("/"),
626 ReadDirFilter::DenyAll
627 ));
628 assert!(matches!(
629 allowlist.get_read_dir_filter("/a"),
630 ReadDirFilter::DenyAll
631 ));
632
633 allowlist.add_path("/a/b");
634
635 match allowlist.get_read_dir_filter("/") {
637 ReadDirFilter::AllowOnly(set) => {
638 assert_eq!(set.len(), 1);
639 assert!(set.contains(OsStr::new("a")));
640 }
641 _ => panic!("expected AllowOnly"),
642 }
643
644 match allowlist.get_read_dir_filter("/a") {
646 ReadDirFilter::AllowOnly(set) => {
647 assert_eq!(set.len(), 1);
648 assert!(set.contains(OsStr::new("b")));
649 }
650 _ => panic!("expected AllowOnly"),
651 }
652
653 assert!(matches!(
655 allowlist.get_read_dir_filter("/a/b"),
656 ReadDirFilter::AllowAll
657 ));
658
659 assert!(matches!(
661 allowlist.get_read_dir_filter("/a/b/c"),
662 ReadDirFilter::AllowAll
663 ));
664
665 let mut allowlist2 = PathAllowlist::new();
667 allowlist2.add_path("/a/b/c");
668 allowlist2.add_path("/a");
669 assert!(matches!(
671 allowlist2.get_read_dir_filter("/a/b"),
672 ReadDirFilter::AllowAll
673 ));
674
675 assert!(matches!(
677 allowlist.get_read_dir_filter("/d"),
678 ReadDirFilter::DenyAll
679 ));
680 }
681}