cros_async/sys/linux/
tokio_source.rs

1// Copyright 2024 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::io;
6use std::os::fd::AsRawFd;
7use std::os::fd::OwnedFd;
8use std::os::fd::RawFd;
9use std::sync::Arc;
10
11use base::add_fd_flags;
12use base::clone_descriptor;
13use base::linux::fallocate;
14use base::linux::FallocateMode;
15use base::AsRawDescriptor;
16use base::VolatileSlice;
17use tokio::io::unix::AsyncFd;
18
19use crate::mem::MemRegion;
20use crate::AsyncError;
21use crate::AsyncResult;
22use crate::BackingMemory;
23
24#[derive(Debug, thiserror::Error)]
25pub enum Error {
26    #[error("Failed to copy the FD for the polling context: '{0}'")]
27    DuplicatingFd(base::Error),
28    #[error("Failed to punch hole in file: '{0}'.")]
29    Fallocate(base::Error),
30    #[error("Failed to fdatasync: '{0}'")]
31    Fdatasync(io::Error),
32    #[error("Failed to fsync: '{0}'")]
33    Fsync(io::Error),
34    #[error("Failed to join task: '{0}'")]
35    Join(tokio::task::JoinError),
36    #[error("Cannot wait on file descriptor")]
37    NonWaitable,
38    #[error("Failed to read: '{0}'")]
39    Read(io::Error),
40    #[error("Failed to set nonblocking: '{0}'")]
41    SettingNonBlocking(base::Error),
42    #[error("Tokio Async FD error: '{0}'")]
43    TokioAsyncFd(io::Error),
44    #[error("Failed to write: '{0}'")]
45    Write(io::Error),
46}
47
48impl From<Error> for io::Error {
49    fn from(e: Error) -> Self {
50        use Error::*;
51        match e {
52            DuplicatingFd(e) => e.into(),
53            Fallocate(e) => e.into(),
54            Fdatasync(e) => e,
55            Fsync(e) => e,
56            Join(e) => io::Error::other(e),
57            NonWaitable => io::Error::other(e),
58            Read(e) => e,
59            SettingNonBlocking(e) => e.into(),
60            TokioAsyncFd(e) => e,
61            Write(e) => e,
62        }
63    }
64}
65
66enum FdType {
67    Async(AsyncFd<Arc<OwnedFd>>),
68    Blocking(Arc<OwnedFd>),
69}
70
71impl AsRawFd for FdType {
72    fn as_raw_fd(&self) -> RawFd {
73        match self {
74            FdType::Async(async_fd) => async_fd.as_raw_fd(),
75            FdType::Blocking(blocking) => blocking.as_raw_fd(),
76        }
77    }
78}
79
80impl From<Error> for AsyncError {
81    fn from(e: Error) -> AsyncError {
82        AsyncError::SysVariants(e.into())
83    }
84}
85
86fn do_fdatasync(raw: Arc<OwnedFd>) -> io::Result<()> {
87    let fd = raw.as_raw_fd();
88    // SAFETY: we partially own `raw`
89    match unsafe { libc::fdatasync(fd) } {
90        0 => Ok(()),
91        _ => Err(io::Error::last_os_error()),
92    }
93}
94
95fn do_fsync(raw: Arc<OwnedFd>) -> io::Result<()> {
96    let fd = raw.as_raw_fd();
97    // SAFETY: we partially own `raw`
98    match unsafe { libc::fsync(fd) } {
99        0 => Ok(()),
100        _ => Err(io::Error::last_os_error()),
101    }
102}
103
104fn do_read_vectored(
105    raw: Arc<OwnedFd>,
106    file_offset: Option<u64>,
107    io_vecs: &[VolatileSlice],
108) -> io::Result<usize> {
109    let ptr = io_vecs.as_ptr() as *const libc::iovec;
110    let len = io_vecs.len() as i32;
111    let fd = raw.as_raw_fd();
112    let res = match file_offset {
113        // SAFETY: we partially own `raw`, `io_vecs` is validated
114        Some(off) => unsafe { libc::preadv64(fd, ptr, len, off as libc::off64_t) },
115        // SAFETY: we partially own `raw`, `io_vecs` is validated
116        None => unsafe { libc::readv(fd, ptr, len) },
117    };
118    match res {
119        r if r >= 0 => Ok(res as usize),
120        _ => Err(io::Error::last_os_error()),
121    }
122}
123fn do_read(raw: Arc<OwnedFd>, file_offset: Option<u64>, buf: &mut [u8]) -> io::Result<usize> {
124    let fd = raw.as_raw_fd();
125    let ptr = buf.as_mut_ptr() as *mut libc::c_void;
126    let res = match file_offset {
127        // SAFETY: we partially own `raw`, `ptr` has space up to vec.len()
128        Some(off) => unsafe { libc::pread64(fd, ptr, buf.len(), off as libc::off64_t) },
129        // SAFETY: we partially own `raw`, `ptr` has space up to vec.len()
130        None => unsafe { libc::read(fd, ptr, buf.len()) },
131    };
132    match res {
133        r if r >= 0 => Ok(res as usize),
134        _ => Err(io::Error::last_os_error()),
135    }
136}
137
138fn do_write(raw: Arc<OwnedFd>, file_offset: Option<u64>, buf: &[u8]) -> io::Result<usize> {
139    let fd = raw.as_raw_fd();
140    let ptr = buf.as_ptr() as *const libc::c_void;
141    let res = match file_offset {
142        // SAFETY: we partially own `raw`, `ptr` has data up to vec.len()
143        Some(off) => unsafe { libc::pwrite64(fd, ptr, buf.len(), off as libc::off64_t) },
144        // SAFETY: we partially own `raw`, `ptr` has data up to vec.len()
145        None => unsafe { libc::write(fd, ptr, buf.len()) },
146    };
147    match res {
148        r if r >= 0 => Ok(res as usize),
149        _ => Err(io::Error::last_os_error()),
150    }
151}
152
153fn do_write_vectored(
154    raw: Arc<OwnedFd>,
155    file_offset: Option<u64>,
156    io_vecs: &[VolatileSlice],
157) -> io::Result<usize> {
158    let ptr = io_vecs.as_ptr() as *const libc::iovec;
159    let len = io_vecs.len() as i32;
160    let fd = raw.as_raw_fd();
161    let res = match file_offset {
162        // SAFETY: we partially own `raw`, `io_vecs` is validated
163        Some(off) => unsafe { libc::pwritev64(fd, ptr, len, off as libc::off64_t) },
164        // SAFETY: we partially own `raw`, `io_vecs` is validated
165        None => unsafe { libc::writev(fd, ptr, len) },
166    };
167    match res {
168        r if r >= 0 => Ok(res as usize),
169        _ => Err(io::Error::last_os_error()),
170    }
171}
172
173pub struct TokioSource<T> {
174    fd: FdType,
175    inner: T,
176    runtime: tokio::runtime::Handle,
177}
178impl<T: AsRawDescriptor> TokioSource<T> {
179    pub fn new(inner: T, runtime: tokio::runtime::Handle) -> Result<TokioSource<T>, Error> {
180        let _guard = runtime.enter(); // Required for AsyncFd
181        let safe_fd = clone_descriptor(&inner).map_err(Error::DuplicatingFd)?;
182        let fd_arc: Arc<OwnedFd> = Arc::new(safe_fd.into());
183        let fd = match AsyncFd::new(fd_arc.clone()) {
184            Ok(async_fd) => {
185                add_fd_flags(async_fd.get_ref().as_raw_descriptor(), libc::O_NONBLOCK)
186                    .map_err(Error::SettingNonBlocking)?;
187                FdType::Async(async_fd)
188            }
189            Err(e) if e.kind() == io::ErrorKind::PermissionDenied => FdType::Blocking(fd_arc),
190            Err(e) => return Err(Error::TokioAsyncFd(e)),
191        };
192        Ok(TokioSource { fd, inner, runtime })
193    }
194
195    pub fn as_source(&self) -> &T {
196        &self.inner
197    }
198
199    pub fn as_source_mut(&mut self) -> &mut T {
200        &mut self.inner
201    }
202
203    fn clone_fd(&self) -> Arc<OwnedFd> {
204        match &self.fd {
205            FdType::Async(async_fd) => async_fd.get_ref().clone(),
206            FdType::Blocking(blocking) => blocking.clone(),
207        }
208    }
209
210    pub async fn fdatasync(&self) -> AsyncResult<()> {
211        let fd = self.clone_fd();
212        Ok(self
213            .runtime
214            .spawn_blocking(move || do_fdatasync(fd))
215            .await
216            .map_err(Error::Join)?
217            .map_err(Error::Fdatasync)?)
218    }
219
220    pub async fn fsync(&self) -> AsyncResult<()> {
221        let fd = self.clone_fd();
222        Ok(self
223            .runtime
224            .spawn_blocking(move || do_fsync(fd))
225            .await
226            .map_err(Error::Join)?
227            .map_err(Error::Fsync)?)
228    }
229
230    pub fn into_source(self) -> T {
231        self.inner
232    }
233
234    pub async fn read_to_vec(
235        &self,
236        file_offset: Option<u64>,
237        mut vec: Vec<u8>,
238    ) -> AsyncResult<(usize, Vec<u8>)> {
239        Ok(match &self.fd {
240            FdType::Async(async_fd) => {
241                let res = async_fd
242                    .async_io(tokio::io::Interest::READABLE, |fd| {
243                        do_read(fd.clone(), file_offset, &mut vec)
244                    })
245                    .await
246                    .map_err(AsyncError::Io)?;
247                (res, vec)
248            }
249            FdType::Blocking(blocking) => {
250                let fd = blocking.clone();
251                self.runtime
252                    .spawn_blocking(move || {
253                        let size = do_read(fd, file_offset, &mut vec)?;
254                        Ok((size, vec))
255                    })
256                    .await
257                    .map_err(Error::Join)?
258                    .map_err(Error::Read)?
259            }
260        })
261    }
262
263    pub async fn read_to_mem(
264        &self,
265        file_offset: Option<u64>,
266        mem: Arc<dyn BackingMemory + Send + Sync>,
267        mem_offsets: impl IntoIterator<Item = MemRegion>,
268    ) -> AsyncResult<usize> {
269        let mem_offsets_vec: Vec<MemRegion> = mem_offsets.into_iter().collect();
270        Ok(match &self.fd {
271            FdType::Async(async_fd) => {
272                let iovecs = mem_offsets_vec
273                    .into_iter()
274                    .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
275                    .collect::<Vec<VolatileSlice>>();
276                async_fd
277                    .async_io(tokio::io::Interest::READABLE, |fd| {
278                        do_read_vectored(fd.clone(), file_offset, &iovecs)
279                    })
280                    .await
281                    .map_err(AsyncError::Io)?
282            }
283            FdType::Blocking(blocking) => {
284                let fd = blocking.clone();
285                self.runtime
286                    .spawn_blocking(move || {
287                        let iovecs = mem_offsets_vec
288                            .into_iter()
289                            .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
290                            .collect::<Vec<VolatileSlice>>();
291                        do_read_vectored(fd, file_offset, &iovecs)
292                    })
293                    .await
294                    .map_err(Error::Join)?
295                    .map_err(Error::Read)?
296            }
297        })
298    }
299
300    pub async fn punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
301        let fd = self.clone_fd();
302        Ok(self
303            .runtime
304            .spawn_blocking(move || fallocate(&*fd, FallocateMode::PunchHole, file_offset, len))
305            .await
306            .map_err(Error::Join)?
307            .map_err(Error::Fallocate)?)
308    }
309
310    pub async fn wait_readable(&self) -> AsyncResult<()> {
311        match &self.fd {
312            FdType::Async(async_fd) => async_fd
313                .readable()
314                .await
315                .map_err(crate::AsyncError::Io)?
316                .retain_ready(),
317            FdType::Blocking(_) => return Err(Error::NonWaitable.into()),
318        }
319        Ok(())
320    }
321
322    pub async fn write_from_mem(
323        &self,
324        file_offset: Option<u64>,
325        mem: Arc<dyn BackingMemory + Send + Sync>,
326        mem_offsets: impl IntoIterator<Item = MemRegion>,
327    ) -> AsyncResult<usize> {
328        let mem_offsets_vec: Vec<MemRegion> = mem_offsets.into_iter().collect();
329        Ok(match &self.fd {
330            FdType::Async(async_fd) => {
331                let iovecs = mem_offsets_vec
332                    .into_iter()
333                    .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
334                    .collect::<Vec<VolatileSlice>>();
335                async_fd
336                    .async_io(tokio::io::Interest::WRITABLE, |fd| {
337                        do_write_vectored(fd.clone(), file_offset, &iovecs)
338                    })
339                    .await
340                    .map_err(AsyncError::Io)?
341            }
342            FdType::Blocking(blocking) => {
343                let fd = blocking.clone();
344                self.runtime
345                    .spawn_blocking(move || {
346                        let iovecs = mem_offsets_vec
347                            .into_iter()
348                            .filter_map(|mem_range| mem.get_volatile_slice(mem_range).ok())
349                            .collect::<Vec<VolatileSlice>>();
350                        do_write_vectored(fd, file_offset, &iovecs)
351                    })
352                    .await
353                    .map_err(Error::Join)?
354                    .map_err(Error::Read)?
355            }
356        })
357    }
358
359    pub async fn write_from_vec(
360        &self,
361        file_offset: Option<u64>,
362        vec: Vec<u8>,
363    ) -> AsyncResult<(usize, Vec<u8>)> {
364        Ok(match &self.fd {
365            FdType::Async(async_fd) => {
366                let res = async_fd
367                    .async_io(tokio::io::Interest::WRITABLE, |fd| {
368                        do_write(fd.clone(), file_offset, &vec)
369                    })
370                    .await
371                    .map_err(AsyncError::Io)?;
372                (res, vec)
373            }
374            FdType::Blocking(blocking) => {
375                let fd = blocking.clone();
376                self.runtime
377                    .spawn_blocking(move || {
378                        let size = do_write(fd.clone(), file_offset, &vec)?;
379                        Ok((size, vec))
380                    })
381                    .await
382                    .map_err(Error::Join)?
383                    .map_err(Error::Read)?
384            }
385        })
386    }
387
388    pub async fn write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
389        let fd = self.clone_fd();
390        Ok(self
391            .runtime
392            .spawn_blocking(move || fallocate(&*fd, FallocateMode::ZeroRange, file_offset, len))
393            .await
394            .map_err(Error::Join)?
395            .map_err(Error::Fallocate)?)
396    }
397}