cros_async/blocking/sys/linux/
block_on.rs

1// Copyright 2020 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::future::Future;
6use std::ptr;
7use std::sync::atomic::AtomicI32;
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10use std::task::Context;
11use std::task::Poll;
12
13use futures::pin_mut;
14use futures::task::waker_ref;
15use futures::task::ArcWake;
16
17// Randomly generated values to indicate the state of the current thread.
18const WAITING: i32 = 0x25de_74d1;
19const WOKEN: i32 = 0x72d3_2c9f;
20
21const FUTEX_WAIT_PRIVATE: libc::c_int = libc::FUTEX_WAIT | libc::FUTEX_PRIVATE_FLAG;
22const FUTEX_WAKE_PRIVATE: libc::c_int = libc::FUTEX_WAKE | libc::FUTEX_PRIVATE_FLAG;
23
24thread_local!(static PER_THREAD_WAKER: Arc<Waker> = Arc::new(Waker(AtomicI32::new(WAITING))));
25
26#[repr(transparent)]
27struct Waker(AtomicI32);
28
29impl ArcWake for Waker {
30    fn wake_by_ref(arc_self: &Arc<Self>) {
31        let state = arc_self.0.swap(WOKEN, Ordering::Release);
32        if state == WAITING {
33            // SAFETY:
34            // The thread hasn't already been woken up so wake it up now. Safe because this doesn't
35            // modify any memory and we check the return value.
36            let res = unsafe {
37                libc::syscall(
38                    libc::SYS_futex,
39                    &arc_self.0,
40                    FUTEX_WAKE_PRIVATE,
41                    libc::INT_MAX,                        // val
42                    ptr::null::<*const libc::timespec>(), // timeout
43                    ptr::null::<*const libc::c_int>(),    // uaddr2
44                    0_i32,                                // val3
45                )
46            };
47            if res < 0 {
48                panic!(
49                    "unexpected error from FUTEX_WAKE_PRIVATE: {}",
50                    std::io::Error::last_os_error()
51                );
52            }
53        }
54    }
55}
56
57/// Run a future to completion on the current thread.
58///
59/// This method will block the current thread until `f` completes. Useful when you need to call an
60/// async fn from a non-async context.
61pub fn block_on<F: Future>(f: F) -> F::Output {
62    pin_mut!(f);
63
64    PER_THREAD_WAKER.with(|thread_waker| {
65        let waker = waker_ref(thread_waker);
66        let mut cx = Context::from_waker(&waker);
67
68        loop {
69            if let Poll::Ready(t) = f.as_mut().poll(&mut cx) {
70                return t;
71            }
72
73            let state = thread_waker.0.swap(WAITING, Ordering::Acquire);
74            if state == WAITING {
75                // SAFETY:
76                // If we weren't already woken up then wait until we are. Safe because this doesn't
77                // modify any memory and we check the return value.
78                let res = unsafe {
79                    libc::syscall(
80                        libc::SYS_futex,
81                        &thread_waker.0,
82                        FUTEX_WAIT_PRIVATE,
83                        state,
84                        ptr::null::<*const libc::timespec>(), // timeout
85                        ptr::null::<*const libc::c_int>(),    // uaddr2
86                        0_i32,                                // val3
87                    )
88                };
89
90                if res < 0 {
91                    let e = std::io::Error::last_os_error();
92                    match e.raw_os_error() {
93                        Some(libc::EAGAIN) | Some(libc::EINTR) => {}
94                        _ => panic!("unexpected error from FUTEX_WAIT_PRIVATE: {e}"),
95                    }
96                }
97
98                // Clear the state to prevent unnecessary extra loop iterations and also to allow
99                // nested usage of `block_on`.
100                thread_waker.0.store(WAITING, Ordering::Release);
101            }
102        }
103    })
104}
105
106#[cfg(test)]
107mod test {
108    use std::future::Future;
109    use std::pin::Pin;
110    use std::sync::mpsc::channel;
111    use std::sync::mpsc::Sender;
112    use std::sync::Arc;
113    use std::task::Context;
114    use std::task::Poll;
115    use std::task::Waker;
116    use std::thread;
117    use std::time::Duration;
118
119    use super::*;
120    use crate::sync::SpinLock;
121
122    struct TimerState {
123        fired: bool,
124        waker: Option<Waker>,
125    }
126    struct Timer {
127        state: Arc<SpinLock<TimerState>>,
128    }
129
130    impl Future for Timer {
131        type Output = ();
132
133        fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
134            let mut state = self.state.lock();
135            if state.fired {
136                return Poll::Ready(());
137            }
138
139            state.waker = Some(cx.waker().clone());
140            Poll::Pending
141        }
142    }
143
144    fn start_timer(dur: Duration, notify: Option<Sender<()>>) -> Timer {
145        let state = Arc::new(SpinLock::new(TimerState {
146            fired: false,
147            waker: None,
148        }));
149
150        let thread_state = Arc::clone(&state);
151        thread::spawn(move || {
152            thread::sleep(dur);
153            let mut ts = thread_state.lock();
154            ts.fired = true;
155            if let Some(waker) = ts.waker.take() {
156                waker.wake();
157            }
158            drop(ts);
159
160            if let Some(tx) = notify {
161                tx.send(()).expect("Failed to send completion notification");
162            }
163        });
164
165        Timer { state }
166    }
167
168    #[test]
169    fn it_works() {
170        block_on(start_timer(Duration::from_millis(100), None));
171    }
172
173    #[test]
174    fn nested() {
175        async fn inner() {
176            block_on(start_timer(Duration::from_millis(100), None));
177        }
178
179        block_on(inner());
180    }
181
182    #[test]
183    fn ready_before_poll() {
184        let (tx, rx) = channel();
185
186        let timer = start_timer(Duration::from_millis(50), Some(tx));
187
188        rx.recv()
189            .expect("Failed to receive completion notification");
190
191        // We know the timer has already fired so the poll should complete immediately.
192        block_on(timer);
193    }
194}