cros_async/
io_source.rs

1// Copyright 2023 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::sync::Arc;
6
7use base::AsRawDescriptor;
8
9#[cfg(any(target_os = "android", target_os = "linux"))]
10use crate::sys::linux::PollSource;
11#[cfg(any(target_os = "android", target_os = "linux"))]
12use crate::sys::linux::UringSource;
13#[cfg(feature = "tokio")]
14use crate::sys::platform::tokio_source::TokioSource;
15#[cfg(windows)]
16use crate::sys::windows::HandleSource;
17#[cfg(windows)]
18use crate::sys::windows::OverlappedSource;
19use crate::AsyncResult;
20use crate::BackingMemory;
21use crate::MemRegion;
22
23/// Associates an IO object `F` with cros_async's runtime and exposes an API to perform async IO on
24/// that object's descriptor.
25pub enum IoSource<F: base::AsRawDescriptor> {
26    #[cfg(any(target_os = "android", target_os = "linux"))]
27    Uring(UringSource<F>),
28    #[cfg(any(target_os = "android", target_os = "linux"))]
29    Epoll(PollSource<F>),
30    #[cfg(windows)]
31    Handle(HandleSource<F>),
32    #[cfg(windows)]
33    Overlapped(OverlappedSource<F>),
34    #[cfg(feature = "tokio")]
35    Tokio(TokioSource<F>),
36}
37
38static_assertions::assert_impl_all!(IoSource<std::fs::File>: Send, Sync);
39
40/// Invoke a method on the underlying source type and await the result.
41///
42/// `await_on_inner(io_source, method, ...)` => `inner_source.method(...).await`
43macro_rules! await_on_inner {
44    ($x:ident, $method:ident $(, $args:expr)*) => {
45        match $x {
46            #[cfg(any(target_os = "android", target_os = "linux"))]
47            IoSource::Uring(x) => UringSource::$method(x, $($args),*).await,
48            #[cfg(any(target_os = "android", target_os = "linux"))]
49            IoSource::Epoll(x) => PollSource::$method(x, $($args),*).await,
50            #[cfg(windows)]
51            IoSource::Handle(x) => HandleSource::$method(x, $($args),*).await,
52            #[cfg(windows)]
53            IoSource::Overlapped(x) => OverlappedSource::$method(x, $($args),*).await,
54            #[cfg(feature = "tokio")]
55            IoSource::Tokio(x) => TokioSource::$method(x, $($args),*).await,
56        }
57    };
58}
59
60/// Invoke a method on the underlying source type.
61///
62/// `on_inner(io_source, method, ...)` => `inner_source.method(...)`
63macro_rules! on_inner {
64    ($x:ident, $method:ident $(, $args:expr)*) => {
65        match $x {
66            #[cfg(any(target_os = "android", target_os = "linux"))]
67            IoSource::Uring(x) => UringSource::$method(x, $($args),*),
68            #[cfg(any(target_os = "android", target_os = "linux"))]
69            IoSource::Epoll(x) => PollSource::$method(x, $($args),*),
70            #[cfg(windows)]
71            IoSource::Handle(x) => HandleSource::$method(x, $($args),*),
72            #[cfg(windows)]
73            IoSource::Overlapped(x) => OverlappedSource::$method(x, $($args),*),
74            #[cfg(feature = "tokio")]
75            IoSource::Tokio(x) => TokioSource::$method(x, $($args),*),
76        }
77    };
78}
79
80impl<F: AsRawDescriptor> IoSource<F> {
81    /// Reads at `file_offset` and fills the given `vec`.
82    pub async fn read_to_vec(
83        &self,
84        file_offset: Option<u64>,
85        vec: Vec<u8>,
86    ) -> AsyncResult<(usize, Vec<u8>)> {
87        await_on_inner!(self, read_to_vec, file_offset, vec)
88    }
89
90    /// Reads to the given `mem` at the given offsets from the file starting at `file_offset`.
91    pub async fn read_to_mem(
92        &self,
93        file_offset: Option<u64>,
94        mem: Arc<dyn BackingMemory + Send + Sync>,
95        mem_offsets: impl IntoIterator<Item = MemRegion>,
96    ) -> AsyncResult<usize> {
97        await_on_inner!(self, read_to_mem, file_offset, mem, mem_offsets)
98    }
99
100    /// Waits for the object to be readable.
101    pub async fn wait_readable(&self) -> AsyncResult<()> {
102        await_on_inner!(self, wait_readable)
103    }
104
105    /// Writes from the given `vec` to the file starting at `file_offset`.
106    pub async fn write_from_vec(
107        &self,
108        file_offset: Option<u64>,
109        vec: Vec<u8>,
110    ) -> AsyncResult<(usize, Vec<u8>)> {
111        await_on_inner!(self, write_from_vec, file_offset, vec)
112    }
113
114    /// Writes from the given `mem` at the given offsets to the file starting at `file_offset`.
115    pub async fn write_from_mem(
116        &self,
117        file_offset: Option<u64>,
118        mem: Arc<dyn BackingMemory + Send + Sync>,
119        mem_offsets: impl IntoIterator<Item = MemRegion>,
120    ) -> AsyncResult<usize> {
121        await_on_inner!(self, write_from_mem, file_offset, mem, mem_offsets)
122    }
123
124    /// Deallocates the given range of a file.
125    pub async fn punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
126        await_on_inner!(self, punch_hole, file_offset, len)
127    }
128
129    /// Fills the given range with zeroes.
130    pub async fn write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
131        await_on_inner!(self, write_zeroes_at, file_offset, len)
132    }
133
134    /// Sync all completed write operations to the backing storage.
135    pub async fn fsync(&self) -> AsyncResult<()> {
136        await_on_inner!(self, fsync)
137    }
138
139    /// Sync all data of completed write operations to the backing storage, avoiding updating extra
140    /// metadata. Note that an implementation may simply implement fsync for fdatasync.
141    pub async fn fdatasync(&self) -> AsyncResult<()> {
142        await_on_inner!(self, fdatasync)
143    }
144
145    /// Yields the underlying IO source.
146    pub fn into_source(self) -> F {
147        on_inner!(self, into_source)
148    }
149
150    /// Provides a ref to the underlying IO source.
151    pub fn as_source(&self) -> &F {
152        on_inner!(self, as_source)
153    }
154
155    /// Provides a mutable ref to the underlying IO source.
156    pub fn as_source_mut(&mut self) -> &mut F {
157        on_inner!(self, as_source_mut)
158    }
159
160    /// Waits on a waitable handle.
161    ///
162    /// Needed for Windows currently, and subject to a potential future upstream.
163    #[cfg(windows)]
164    pub async fn wait_for_handle(&self) -> AsyncResult<()> {
165        await_on_inner!(self, wait_for_handle)
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use std::fs::File;
172    use std::io::Read;
173    use std::io::Seek;
174    use std::io::SeekFrom;
175    use std::io::Write;
176    use std::sync::Arc;
177
178    use tempfile::tempfile;
179
180    use super::*;
181    use crate::mem::VecIoWrapper;
182    #[cfg(any(target_os = "android", target_os = "linux"))]
183    use crate::sys::linux::uring_executor::is_uring_stable;
184    use crate::sys::ExecutorKindSys;
185    use crate::Executor;
186    use crate::ExecutorKind;
187    use crate::MemRegion;
188
189    #[cfg(any(target_os = "android", target_os = "linux"))]
190    fn all_kinds() -> Vec<ExecutorKind> {
191        let mut kinds = vec![ExecutorKindSys::Fd.into()];
192        if is_uring_stable() {
193            kinds.push(ExecutorKindSys::Uring.into());
194        }
195        kinds
196    }
197    #[cfg(windows)]
198    fn all_kinds() -> Vec<ExecutorKind> {
199        // TODO: Test OverlappedSource. It requires files to be opened specially, so this test
200        // fixture needs to be refactored first.
201        vec![ExecutorKindSys::Handle.into()]
202    }
203
204    fn tmpfile_with_contents(bytes: &[u8]) -> File {
205        let mut f = tempfile().unwrap();
206        f.write_all(bytes).unwrap();
207        f.flush().unwrap();
208        f.seek(SeekFrom::Start(0)).unwrap();
209        f
210    }
211
212    #[test]
213    fn readvec() {
214        for kind in all_kinds() {
215            async fn go<F: AsRawDescriptor>(async_source: IoSource<F>) {
216                let v = vec![0x55u8; 32];
217                let v_ptr = v.as_ptr();
218                let (n, v) = async_source.read_to_vec(None, v).await.unwrap();
219                assert_eq!(v_ptr, v.as_ptr());
220                assert_eq!(n, 4);
221                assert_eq!(&v[..4], "data".as_bytes());
222            }
223
224            let f = tmpfile_with_contents("data".as_bytes());
225            let ex = Executor::with_executor_kind(kind).unwrap();
226            let source = ex.async_from(f).unwrap();
227            ex.run_until(go(source)).unwrap();
228        }
229    }
230
231    #[test]
232    fn writevec() {
233        for kind in all_kinds() {
234            async fn go<F: AsRawDescriptor>(async_source: IoSource<F>) {
235                let v = "data".as_bytes().to_vec();
236                let v_ptr = v.as_ptr();
237                let (n, v) = async_source.write_from_vec(None, v).await.unwrap();
238                assert_eq!(n, 4);
239                assert_eq!(v_ptr, v.as_ptr());
240            }
241
242            let mut f = tmpfile_with_contents(&[]);
243            let ex = Executor::with_executor_kind(kind).unwrap();
244            let source = ex.async_from(f.try_clone().unwrap()).unwrap();
245            ex.run_until(go(source)).unwrap();
246
247            f.rewind().unwrap();
248            assert_eq!(std::io::read_to_string(f).unwrap(), "data");
249        }
250    }
251
252    #[test]
253    fn readmem() {
254        for kind in all_kinds() {
255            async fn go<F: AsRawDescriptor>(async_source: IoSource<F>) {
256                let mem = Arc::new(VecIoWrapper::from(vec![b' '; 10]));
257                let n = async_source
258                    .read_to_mem(
259                        None,
260                        Arc::<VecIoWrapper>::clone(&mem),
261                        [
262                            MemRegion { offset: 0, len: 2 },
263                            MemRegion { offset: 4, len: 1 },
264                        ],
265                    )
266                    .await
267                    .unwrap();
268                assert_eq!(n, 3);
269                let vec: Vec<u8> = match Arc::try_unwrap(mem) {
270                    Ok(v) => v.into(),
271                    Err(_) => panic!("Too many vec refs"),
272                };
273                assert_eq!(std::str::from_utf8(&vec).unwrap(), "da  t     ");
274            }
275
276            let f = tmpfile_with_contents("data".as_bytes());
277            let ex = Executor::with_executor_kind(kind).unwrap();
278            let source = ex.async_from(f).unwrap();
279            ex.run_until(go(source)).unwrap();
280        }
281    }
282
283    #[test]
284    fn writemem() {
285        for kind in all_kinds() {
286            async fn go<F: AsRawDescriptor>(async_source: IoSource<F>) {
287                let mem = Arc::new(VecIoWrapper::from("data".as_bytes().to_vec()));
288                let ret = async_source
289                    .write_from_mem(
290                        None,
291                        Arc::<VecIoWrapper>::clone(&mem),
292                        [
293                            MemRegion { offset: 0, len: 1 },
294                            MemRegion { offset: 2, len: 2 },
295                        ],
296                    )
297                    .await
298                    .unwrap();
299                assert_eq!(ret, 3);
300            }
301
302            let mut f = tmpfile_with_contents(&[]);
303            let ex = Executor::with_executor_kind(kind).unwrap();
304            let source = ex.async_from(f.try_clone().unwrap()).unwrap();
305            ex.run_until(go(source)).unwrap();
306
307            f.rewind().unwrap();
308            assert_eq!(std::io::read_to_string(f).unwrap(), "dta");
309        }
310    }
311
312    #[test]
313    fn fsync() {
314        for kind in all_kinds() {
315            async fn go<F: AsRawDescriptor>(source: IoSource<F>) {
316                let v = vec![0x55u8; 32];
317                let v_ptr = v.as_ptr();
318                let ret = source.write_from_vec(None, v).await.unwrap();
319                assert_eq!(ret.0, 32);
320                let ret_v = ret.1;
321                assert_eq!(v_ptr, ret_v.as_ptr());
322                source.fsync().await.unwrap();
323            }
324
325            let f = tempfile::tempfile().unwrap();
326            let ex = Executor::with_executor_kind(kind).unwrap();
327            let source = ex.async_from(f).unwrap();
328
329            ex.run_until(go(source)).unwrap();
330        }
331    }
332
333    #[test]
334    fn readmulti() {
335        for kind in all_kinds() {
336            async fn go<F: AsRawDescriptor>(source: IoSource<F>) {
337                let v = vec![0x55u8; 32];
338                let v2 = vec![0x55u8; 32];
339                let (ret, ret2) = futures::future::join(
340                    source.read_to_vec(None, v),
341                    source.read_to_vec(Some(32), v2),
342                )
343                .await;
344
345                let (count, v) = ret.unwrap();
346                let (count2, v2) = ret2.unwrap();
347
348                assert!(v.iter().take(count).all(|&b| b == 0xAA));
349                assert!(v2.iter().take(count2).all(|&b| b == 0xBB));
350            }
351
352            let mut f = tempfile::tempfile().unwrap();
353            f.write_all(&[0xAA; 32]).unwrap();
354            f.write_all(&[0xBB; 32]).unwrap();
355            f.rewind().unwrap();
356
357            let ex = Executor::with_executor_kind(kind).unwrap();
358            let source = ex.async_from(f).unwrap();
359
360            ex.run_until(go(source)).unwrap();
361        }
362    }
363
364    #[test]
365    fn writemulti() {
366        for kind in all_kinds() {
367            async fn go<F: AsRawDescriptor>(source: IoSource<F>) {
368                let v = vec![0x55u8; 32];
369                let v2 = vec![0x55u8; 32];
370                let (r, r2) = futures::future::join(
371                    source.write_from_vec(None, v),
372                    source.write_from_vec(Some(32), v2),
373                )
374                .await;
375                assert_eq!(32, r.unwrap().0);
376                assert_eq!(32, r2.unwrap().0);
377            }
378
379            let f = tempfile::tempfile().unwrap();
380            let ex = Executor::with_executor_kind(kind).unwrap();
381            let source = ex.async_from(f).unwrap();
382
383            ex.run_until(go(source)).unwrap();
384        }
385    }
386
387    #[test]
388    fn read_current_file_position() {
389        for kind in all_kinds() {
390            async fn go<F: AsRawDescriptor>(source: IoSource<F>) {
391                let (count1, verify1) = source.read_to_vec(None, vec![0u8; 32]).await.unwrap();
392                let (count2, verify2) = source.read_to_vec(None, vec![0u8; 32]).await.unwrap();
393                assert_eq!(count1, 32);
394                assert_eq!(count2, 32);
395                assert_eq!(verify1, [0x55u8; 32]);
396                assert_eq!(verify2, [0xffu8; 32]);
397            }
398
399            let mut f = tempfile::tempfile().unwrap();
400            f.write_all(&[0x55u8; 32]).unwrap();
401            f.write_all(&[0xffu8; 32]).unwrap();
402            f.rewind().unwrap();
403
404            let ex = Executor::with_executor_kind(kind).unwrap();
405            let source = ex.async_from(f).unwrap();
406
407            ex.run_until(go(source)).unwrap();
408        }
409    }
410
411    #[test]
412    fn write_current_file_position() {
413        for kind in all_kinds() {
414            async fn go<F: AsRawDescriptor>(source: IoSource<F>) {
415                let count1 = source
416                    .write_from_vec(None, vec![0x55u8; 32])
417                    .await
418                    .unwrap()
419                    .0;
420                assert_eq!(count1, 32);
421                let count2 = source
422                    .write_from_vec(None, vec![0xffu8; 32])
423                    .await
424                    .unwrap()
425                    .0;
426                assert_eq!(count2, 32);
427            }
428
429            let mut f = tempfile::tempfile().unwrap();
430            let ex = Executor::with_executor_kind(kind).unwrap();
431            let source = ex.async_from(f.try_clone().unwrap()).unwrap();
432
433            ex.run_until(go(source)).unwrap();
434
435            f.rewind().unwrap();
436            let mut verify1 = [0u8; 32];
437            let mut verify2 = [0u8; 32];
438            f.read_exact(&mut verify1).unwrap();
439            f.read_exact(&mut verify2).unwrap();
440            assert_eq!(verify1, [0x55u8; 32]);
441            assert_eq!(verify2, [0xffu8; 32]);
442        }
443    }
444}