base/sys/linux/
shm.rs

1// Copyright 2017 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::ffi::CStr;
6use std::fs::File;
7use std::io::Seek;
8use std::io::SeekFrom;
9use std::sync::LazyLock;
10
11use libc::c_char;
12use libc::c_int;
13use libc::c_long;
14use libc::c_uint;
15use libc::close;
16use libc::fcntl;
17use libc::ftruncate64;
18use libc::off64_t;
19use libc::syscall;
20use libc::SYS_memfd_create;
21use libc::F_ADD_SEALS;
22use libc::F_GET_SEALS;
23use libc::F_SEAL_FUTURE_WRITE;
24use libc::F_SEAL_GROW;
25use libc::F_SEAL_SEAL;
26use libc::F_SEAL_SHRINK;
27use libc::F_SEAL_WRITE;
28use libc::MFD_ALLOW_SEALING;
29
30use crate::errno_result;
31use crate::shm::PlatformSharedMemory;
32use crate::trace;
33use crate::AsRawDescriptor;
34use crate::FromRawDescriptor;
35use crate::Result;
36use crate::SafeDescriptor;
37use crate::SharedMemory;
38
39// from <sys/memfd.h>
40const MFD_CLOEXEC: c_uint = 0x0001;
41const MFD_NOEXEC_SEAL: c_uint = 0x0008;
42
43// SAFETY: It is caller's responsibility to ensure the args are valid and check the
44// return value of the function.
45unsafe fn memfd_create(name: *const c_char, flags: c_uint) -> c_int {
46    syscall(SYS_memfd_create as c_long, name, flags) as c_int
47}
48
49/// A set of memfd seals.
50///
51/// An enumeration of each bit can be found at `fcntl(2)`.
52#[derive(Copy, Clone, Default)]
53pub struct MemfdSeals(i32);
54
55impl MemfdSeals {
56    /// Returns an empty set of memfd seals.
57    #[inline]
58    pub fn new() -> MemfdSeals {
59        MemfdSeals(0)
60    }
61
62    /// Gets the raw bitmask of seals enumerated in `fcntl(2)`.
63    #[inline]
64    pub fn bitmask(self) -> i32 {
65        self.0
66    }
67
68    /// True if the grow seal bit is present.
69    #[inline]
70    pub fn grow_seal(self) -> bool {
71        self.0 & F_SEAL_GROW != 0
72    }
73
74    /// Sets the grow seal bit.
75    #[inline]
76    pub fn set_grow_seal(&mut self) {
77        self.0 |= F_SEAL_GROW;
78    }
79
80    /// True if the shrink seal bit is present.
81    #[inline]
82    pub fn shrink_seal(self) -> bool {
83        self.0 & F_SEAL_SHRINK != 0
84    }
85
86    /// Sets the shrink seal bit.
87    #[inline]
88    pub fn set_shrink_seal(&mut self) {
89        self.0 |= F_SEAL_SHRINK;
90    }
91
92    /// True if the write seal bit is present.
93    #[inline]
94    pub fn write_seal(self) -> bool {
95        self.0 & F_SEAL_WRITE != 0
96    }
97
98    /// Sets the write seal bit.
99    #[inline]
100    pub fn set_write_seal(&mut self) {
101        self.0 |= F_SEAL_WRITE;
102    }
103
104    /// True if the future write seal bit is present.
105    #[inline]
106    pub fn future_write_seal(self) -> bool {
107        self.0 & F_SEAL_FUTURE_WRITE != 0
108    }
109
110    /// Sets the future write seal bit.
111    #[inline]
112    pub fn set_future_write_seal(&mut self) {
113        self.0 |= F_SEAL_FUTURE_WRITE;
114    }
115
116    /// True of the seal seal bit is present.
117    #[inline]
118    pub fn seal_seal(self) -> bool {
119        self.0 & F_SEAL_SEAL != 0
120    }
121
122    /// Sets the seal seal bit.
123    #[inline]
124    pub fn set_seal_seal(&mut self) {
125        self.0 |= F_SEAL_SEAL;
126    }
127}
128
129static MFD_NOEXEC_SEAL_SUPPORTED: LazyLock<bool> = LazyLock::new(|| {
130    // SAFETY: We pass a valid zero-terminated C string and check the result.
131    let fd = unsafe {
132        // The memfd name used here does not need to be unique, since duplicates are allowed and
133        // will not cause failures.
134        memfd_create(
135            c"MFD_NOEXEC_SEAL_test".as_ptr() as *const c_char,
136            MFD_CLOEXEC | MFD_ALLOW_SEALING | MFD_NOEXEC_SEAL,
137        )
138    };
139    if fd < 0 {
140        trace!("MFD_NOEXEC_SEAL is not supported");
141        false
142    } else {
143        trace!("MFD_NOEXEC_SEAL is supported");
144        // SAFETY: We know `fd` is a valid file descriptor owned by us.
145        unsafe {
146            close(fd);
147        }
148        true
149    }
150});
151
152impl PlatformSharedMemory for SharedMemory {
153    /// Creates a new shared memory file descriptor with the specified `size` in bytes.
154    ///
155    /// `name` will appear in `/proc/self/fd/<shm fd>` for the purposes of debugging. The name does
156    /// not need to be unique.
157    ///
158    /// The file descriptor is opened with the close on exec flag and allows memfd sealing.
159    ///
160    /// If the `MFD_NOEXEC_SEAL` flag is supported, the resulting file will also be created with a
161    /// non-executable file mode (in other words, it cannot be passed to the `exec` family of system
162    /// calls).
163    fn new(debug_name: &CStr, size: u64) -> Result<SharedMemory> {
164        let mut flags = MFD_CLOEXEC | MFD_ALLOW_SEALING;
165        if *MFD_NOEXEC_SEAL_SUPPORTED {
166            flags |= MFD_NOEXEC_SEAL;
167        }
168
169        let shm_name = debug_name.as_ptr() as *const c_char;
170        // SAFETY:
171        // The following are safe because we give a valid C string and check the
172        // results of the memfd_create call.
173        let fd = unsafe { memfd_create(shm_name, flags) };
174        if fd < 0 {
175            return errno_result();
176        }
177        // SAFETY: Safe because fd is valid.
178        let descriptor = unsafe { SafeDescriptor::from_raw_descriptor(fd) };
179
180        // Set the size of the memfd.
181        // SAFETY: Safe because we check the return value to ftruncate64 and all the args to the
182        // function are valid.
183        let ret = unsafe { ftruncate64(descriptor.as_raw_descriptor(), size as off64_t) };
184        if ret < 0 {
185            return errno_result();
186        }
187
188        Ok(SharedMemory { descriptor, size })
189    }
190
191    /// Creates a SharedMemory instance from a SafeDescriptor owning a reference to a
192    /// shared memory descriptor. Ownership of the underlying descriptor is transferred to the
193    /// new SharedMemory object.
194    fn from_safe_descriptor(descriptor: SafeDescriptor, size: u64) -> Result<SharedMemory> {
195        Ok(SharedMemory { descriptor, size })
196    }
197}
198
199pub trait SharedMemoryLinux {
200    /// Constructs a `SharedMemory` instance from a `File` that represents shared memory.
201    ///
202    /// The size of the resulting shared memory will be determined using `File::seek`. If the given
203    /// file's size can not be determined this way, this will return an error.
204    fn from_file(file: File) -> Result<SharedMemory>;
205
206    /// Gets the memfd seals that have already been added to this.
207    ///
208    /// This may fail if this instance was not constructed from a memfd.
209    fn get_seals(&self) -> Result<MemfdSeals>;
210
211    /// Adds the given set of memfd seals.
212    ///
213    /// This may fail if this instance was not constructed from a memfd with sealing allowed or if
214    /// the seal seal (`F_SEAL_SEAL`) bit was already added.
215    fn add_seals(&mut self, seals: MemfdSeals) -> Result<()>;
216}
217
218impl SharedMemoryLinux for SharedMemory {
219    fn from_file(mut file: File) -> Result<SharedMemory> {
220        let file_size = file.seek(SeekFrom::End(0))?;
221        Ok(SharedMemory {
222            descriptor: file.into(),
223            size: file_size,
224        })
225    }
226
227    fn get_seals(&self) -> Result<MemfdSeals> {
228        // SAFETY: Safe because we check the return value to fcntl and all the args to the
229        // function are valid.
230        let ret = unsafe { fcntl(self.descriptor.as_raw_descriptor(), F_GET_SEALS) };
231        if ret < 0 {
232            return errno_result();
233        }
234        Ok(MemfdSeals(ret))
235    }
236
237    fn add_seals(&mut self, seals: MemfdSeals) -> Result<()> {
238        // SAFETY: Safe because we check the return value to fcntl and all the args to the
239        // function are valid.
240        let ret = unsafe { fcntl(self.descriptor.as_raw_descriptor(), F_ADD_SEALS, seals) };
241        if ret < 0 {
242            return errno_result();
243        }
244        Ok(())
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use std::fs::read_link;
251
252    use libc::EINVAL;
253
254    use crate::linux::SharedMemoryLinux;
255    use crate::pagesize;
256    use crate::AsRawDescriptor;
257    use crate::Error;
258    use crate::MemoryMappingBuilder;
259    use crate::Result;
260    use crate::SharedMemory;
261    use crate::VolatileMemory;
262
263    /// Reads the name from the underlying file as a `String`.
264    ///
265    /// If the underlying file was not created with `SharedMemory::new` or with `memfd_create`, the
266    /// results are undefined. Because this returns a `String`, the name's bytes are interpreted as
267    /// utf-8.
268    fn read_name(shm: &SharedMemory) -> Result<String> {
269        let fd_path = format!("/proc/self/fd/{}", shm.as_raw_descriptor());
270        let link_name = read_link(fd_path)?;
271        link_name
272            .to_str()
273            .map(|s| {
274                s.trim_start_matches("/memfd:")
275                    .trim_end_matches(" (deleted)")
276                    .to_owned()
277            })
278            .ok_or_else(|| Error::new(EINVAL))
279    }
280
281    #[test]
282    fn new() {
283        const TEST_NAME: &str = "Name McCool Person";
284        let shm = SharedMemory::new(TEST_NAME, 0).expect("failed to create shared memory");
285        assert_eq!(read_name(&shm), Ok(TEST_NAME.to_owned()));
286    }
287
288    #[test]
289    fn new_huge() {
290        let shm = SharedMemory::new("test", 0x7fff_ffff_ffff_ffff)
291            .expect("failed to create shared memory");
292        assert_eq!(shm.size(), 0x7fff_ffff_ffff_ffff);
293    }
294
295    #[test]
296    fn new_sealed() {
297        let mut shm = SharedMemory::new("test", 0).expect("failed to create shared memory");
298        let mut seals = shm.get_seals().expect("failed to get seals");
299        assert!(!seals.seal_seal());
300        seals.set_seal_seal();
301        shm.add_seals(seals).expect("failed to add seals");
302        seals = shm.get_seals().expect("failed to get seals");
303        assert!(seals.seal_seal());
304        // Adding more seals should be rejected by the kernel.
305        shm.add_seals(seals).unwrap_err();
306    }
307
308    #[test]
309    fn mmap_page() {
310        let shm = SharedMemory::new("test", 4096).expect("failed to create shared memory");
311
312        let mmap1 = MemoryMappingBuilder::new(shm.size() as usize)
313            .from_shared_memory(&shm)
314            .build()
315            .expect("failed to map shared memory");
316        let mmap2 = MemoryMappingBuilder::new(shm.size() as usize)
317            .from_shared_memory(&shm)
318            .build()
319            .expect("failed to map shared memory");
320
321        assert_ne!(
322            mmap1.get_slice(0, 1).unwrap().as_ptr(),
323            mmap2.get_slice(0, 1).unwrap().as_ptr()
324        );
325
326        mmap1
327            .get_slice(0, 4096)
328            .expect("failed to get mmap slice")
329            .write_bytes(0x45);
330
331        for i in 0..4096 {
332            assert_eq!(mmap2.read_obj::<u8>(i).unwrap(), 0x45u8);
333        }
334    }
335
336    #[test]
337    fn mmap_page_offset() {
338        let shm = SharedMemory::new("test", 2 * pagesize() as u64)
339            .expect("failed to create shared memory");
340
341        let mmap1 = MemoryMappingBuilder::new(shm.size() as usize)
342            .from_shared_memory(&shm)
343            .offset(pagesize() as u64)
344            .build()
345            .expect("failed to map shared memory");
346        let mmap2 = MemoryMappingBuilder::new(shm.size() as usize)
347            .from_shared_memory(&shm)
348            .build()
349            .expect("failed to map shared memory");
350
351        mmap1
352            .get_slice(0, pagesize())
353            .expect("failed to get mmap slice")
354            .write_bytes(0x45);
355
356        for i in 0..pagesize() {
357            assert_eq!(mmap2.read_obj::<u8>(i).unwrap(), 0);
358        }
359        for i in pagesize()..(2 * pagesize()) {
360            assert_eq!(mmap2.read_obj::<u8>(i).unwrap(), 0x45u8);
361        }
362    }
363}