cros_async/blocking/
cancellable_pool.rs

1// Copyright 2022 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Provides an async blocking pool whose tasks can be cancelled.
6
7use std::collections::HashMap;
8use std::future::Future;
9use std::sync::Arc;
10use std::sync::LazyLock;
11use std::time::Duration;
12use std::time::Instant;
13
14use sync::Condvar;
15use sync::Mutex;
16use thiserror::Error as ThisError;
17
18use crate::BlockingPool;
19
20/// Global executor.
21///
22/// This is convenient, though not preferred. Pros/cons:
23/// + It avoids passing executor all the way to each call sites.
24/// + The call site can assume that executor will never shutdown.
25/// + Provides similar functionality as async_task with a few improvements around ability to cancel.
26/// - Globals are harder to reason about.
27static EXECUTOR: LazyLock<CancellableBlockingPool> =
28    LazyLock::new(|| CancellableBlockingPool::new(256, Duration::from_secs(10)));
29
30const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
31
32#[derive(PartialEq, Eq, PartialOrd, Default)]
33enum WindDownStates {
34    #[default]
35    Armed,
36    Disarmed,
37    ShuttingDown,
38    ShutDown,
39}
40
41#[derive(Default)]
42struct State {
43    wind_down: WindDownStates,
44
45    /// Helps to generate unique id to associate `cancel` with task.
46    current_cancellable_id: u64,
47
48    /// A map of all the `cancel` routines of queued/in-flight tasks.
49    cancellables: HashMap<u64, Box<dyn Fn() + Send + 'static>>,
50}
51
52#[derive(Debug, Clone, Copy)]
53pub enum TimeoutAction {
54    /// Do nothing on timeout.
55    None,
56    /// Panic the thread on timeout.
57    Panic,
58}
59
60#[derive(ThisError, Debug, PartialEq, Eq)]
61pub enum Error {
62    #[error("Timeout occurred while trying to join threads")]
63    Timedout,
64    #[error("Shutdown is in progress")]
65    ShutdownInProgress,
66    #[error("Already shut down")]
67    AlreadyShutdown,
68}
69
70struct Inner {
71    blocking_pool: BlockingPool,
72    state: Mutex<State>,
73
74    /// This condvar gets notified when `cancellables` is empty after removing an
75    /// entry.
76    cancellables_cv: Condvar,
77}
78
79impl Inner {
80    pub fn spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R>
81    where
82        F: FnOnce() -> R + Send + 'static,
83        R: Send + 'static,
84    {
85        self.blocking_pool.spawn(f)
86    }
87
88    /// Adds cancel to a cancellables and returns an `id` with which `cancel` can be
89    /// accessed/removed.
90    fn add_cancellable(&self, cancel: Box<dyn Fn() + Send + 'static>) -> u64 {
91        let mut state = self.state.lock();
92        let id = state.current_cancellable_id;
93        state.current_cancellable_id += 1;
94        state.cancellables.insert(id, cancel);
95        id
96    }
97}
98
99/// A thread pool for running work that may block.
100///
101/// This is a wrapper around `BlockingPool` with an ability to cancel queued tasks.
102/// See [BlockingPool] for more info.
103///
104/// # Examples
105///
106/// Spawn a task to run in the `CancellableBlockingPool` and await on its result.
107///
108/// ```edition2018
109/// use cros_async::CancellableBlockingPool;
110///
111/// # async fn do_it() {
112///     let pool = CancellableBlockingPool::default();
113///     let CANCELLED = 0;
114///
115///     let res = pool.spawn(move || {
116///         // Do some CPU-intensive or blocking work here.
117///
118///         42
119///     }, move || CANCELLED).await;
120///
121///     assert_eq!(res, 42);
122/// # }
123/// # futures::executor::block_on(do_it());
124/// ```
125#[derive(Clone)]
126pub struct CancellableBlockingPool {
127    inner: Arc<Inner>,
128}
129
130impl CancellableBlockingPool {
131    const RETRY_COUNT: usize = 10;
132    const SLEEP_DURATION: Duration = Duration::from_millis(100);
133
134    /// Create a new `CancellableBlockingPool`.
135    ///
136    /// When we try to shutdown or drop `CancellableBlockingPool`, it may happen that a hung thread
137    /// might prevent `CancellableBlockingPool` pool from getting dropped. On failure to shutdown in
138    /// `watchdog_opts.timeout` duration, `CancellableBlockingPool` can take an action specified by
139    /// `watchdog_opts.action`.
140    ///
141    /// See also: [BlockingPool::new()](BlockingPool::new)
142    pub fn new(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool {
143        CancellableBlockingPool {
144            inner: Arc::new(Inner {
145                blocking_pool: BlockingPool::new(max_threads, keepalive),
146                state: Default::default(),
147                cancellables_cv: Condvar::new(),
148            }),
149        }
150    }
151
152    /// Like [Self::new] but with pre-allocating capacity for up to `max_threads`.
153    pub fn with_capacity(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool {
154        CancellableBlockingPool {
155            inner: Arc::new(Inner {
156                blocking_pool: BlockingPool::with_capacity(max_threads, keepalive),
157                state: Mutex::new(State::default()),
158                cancellables_cv: Condvar::new(),
159            }),
160        }
161    }
162
163    /// Spawn a task to run in the `CancellableBlockingPool`.
164    ///
165    /// Callers may `await` the returned `Task` to be notified when the work is completed.
166    /// Dropping the future will not cancel the task.
167    ///
168    /// `cancel` helps to cancel a queued or in-flight operation `f`.
169    /// `cancel` may be called more than once if `f` doesn't respond to `cancel`.
170    /// `cancel` is not called if `f` completes successfully. For example,
171    /// # Examples
172    ///
173    /// ```edition2018
174    /// use {cros_async::CancellableBlockingPool, std::sync::{Arc, Mutex, Condvar}};
175    ///
176    /// # async fn cancel_it() {
177    ///    let pool = CancellableBlockingPool::default();
178    ///    let cancelled: i32 = 1;
179    ///    let success: i32 = 2;
180    ///
181    ///    let shared = Arc::new((Mutex::new(0), Condvar::new()));
182    ///    let shared2 = shared.clone();
183    ///    let shared3 = shared.clone();
184    ///
185    ///    let res = pool
186    ///        .spawn(
187    ///            move || {
188    ///                let guard = shared.0.lock().unwrap();
189    ///                let mut guard = shared.1.wait_while(guard, |state| *state == 0).unwrap();
190    ///                if *guard != cancelled {
191    ///                    *guard = success;
192    ///                }
193    ///            },
194    ///            move || {
195    ///                *shared2.0.lock().unwrap() = cancelled;
196    ///                shared2.1.notify_all();
197    ///            },
198    ///        )
199    ///        .await;
200    ///    pool.shutdown();
201    ///
202    ///    assert_eq!(*shared3.0.lock().unwrap(), cancelled);
203    /// # }
204    /// ```
205    pub fn spawn<F, R, G>(&self, f: F, cancel: G) -> impl Future<Output = R>
206    where
207        F: FnOnce() -> R + Send + 'static,
208        R: Send + 'static,
209        G: Fn() -> R + Send + 'static,
210    {
211        let inner = self.inner.clone();
212        let cancelled = Arc::new(Mutex::new(None));
213        let cancelled_spawn = cancelled.clone();
214        let id = inner.add_cancellable(Box::new(move || {
215            let mut c = cancelled.lock();
216            *c = Some(cancel());
217        }));
218
219        self.inner.spawn(move || {
220            if let Some(res) = cancelled_spawn.lock().take() {
221                return res;
222            }
223            let ret = f();
224            let mut state = inner.state.lock();
225            state.cancellables.remove(&id);
226            if state.cancellables.is_empty() {
227                inner.cancellables_cv.notify_one();
228            }
229            ret
230        })
231    }
232
233    /// Iterates over all the queued tasks and marks them as cancelled.
234    fn drain_cancellables(&self) {
235        let mut state = self.inner.state.lock();
236        // Iterate a few times to try cancelling all the tasks.
237        for _ in 0..Self::RETRY_COUNT {
238            // Nothing left to do.
239            if state.cancellables.is_empty() {
240                return;
241            }
242
243            // We only cancel the task and do not remove it from the cancellables. It is runner's
244            // job to remove from state.cancellables.
245            for cancel in state.cancellables.values() {
246                cancel();
247            }
248            // Hold the state lock in a block before sleeping so that woken up threads can get to
249            // hold the lock.
250            // Wait for a while so that the threads get a chance complete task in flight.
251            let (state1, _cv_timeout) = self
252                .inner
253                .cancellables_cv
254                .wait_timeout(state, Self::SLEEP_DURATION);
255            state = state1;
256        }
257    }
258
259    /// Marks all the queued and in-flight tasks as cancelled. Any tasks queued after `disarm`ing
260    /// will be cancelled.
261    /// Does not wait for all the tasks to get cancelled.
262    pub fn disarm(&self) {
263        {
264            let mut state = self.inner.state.lock();
265
266            if state.wind_down >= WindDownStates::Disarmed {
267                return;
268            }
269
270            // At this point any new incoming request will be cancelled when run.
271            state.wind_down = WindDownStates::Disarmed;
272        }
273        self.drain_cancellables();
274    }
275
276    /// Shut down the `CancellableBlockingPool`.
277    ///
278    /// This will block until all work that has been started by the worker threads is finished. Any
279    /// work that was added to the `CancellableBlockingPool` but not yet picked up by a worker
280    /// thread will not complete and `await`ing on the `Task` for that work will panic.
281    pub fn shutdown(&self) -> Result<(), Error> {
282        self.shutdown_with_timeout(DEFAULT_SHUTDOWN_TIMEOUT)
283    }
284
285    fn shutdown_with_timeout(&self, timeout: Duration) -> Result<(), Error> {
286        self.disarm();
287        {
288            let mut state = self.inner.state.lock();
289            if state.wind_down == WindDownStates::ShuttingDown {
290                return Err(Error::ShutdownInProgress);
291            }
292            if state.wind_down == WindDownStates::ShutDown {
293                return Err(Error::AlreadyShutdown);
294            }
295            state.wind_down = WindDownStates::ShuttingDown;
296        }
297
298        let res = self
299            .inner
300            .blocking_pool
301            .shutdown(/* deadline: */ Some(Instant::now() + timeout));
302
303        self.inner.state.lock().wind_down = WindDownStates::ShutDown;
304        match res {
305            Ok(_) => Ok(()),
306            Err(_) => Err(Error::Timedout),
307        }
308    }
309}
310
311impl Default for CancellableBlockingPool {
312    fn default() -> CancellableBlockingPool {
313        CancellableBlockingPool::new(256, Duration::from_secs(10))
314    }
315}
316
317impl Drop for CancellableBlockingPool {
318    fn drop(&mut self) {
319        if let Err(e) = self.shutdown() {
320            base::error!("CancellableBlockingPool::shutdown failed: {}", e);
321        }
322    }
323}
324
325/// Spawn a task to run in the `CancellableBlockingPool` static executor.
326///
327/// `cancel` in-flight operation. cancel is called on operation during `disarm` or during
328/// `shutdown`.  Cancel may be called multiple times if running task doesn't get cancelled on first
329/// attempt.
330///
331/// Callers may `await` the returned `Task` to be notified when the work is completed.
332///
333/// See also: `spawn`.
334pub fn unblock<F, R, G>(f: F, cancel: G) -> impl Future<Output = R>
335where
336    F: FnOnce() -> R + Send + 'static,
337    R: Send + 'static,
338    G: Fn() -> R + Send + 'static,
339{
340    EXECUTOR.spawn(f, cancel)
341}
342
343/// Marks all the queued and in-flight tasks as cancelled. Any tasks queued after `disarm`ing
344/// will be cancelled.
345/// Doesn't not wait for all the tasks to get cancelled.
346pub fn unblock_disarm() {
347    EXECUTOR.disarm()
348}
349
350#[cfg(test)]
351mod test {
352    use std::sync::Arc;
353    use std::sync::Barrier;
354    use std::thread;
355    use std::time::Duration;
356
357    use futures::executor::block_on;
358    use sync::Condvar;
359    use sync::Mutex;
360
361    use crate::blocking::Error;
362    use crate::CancellableBlockingPool;
363
364    #[test]
365    fn disarm_with_pending_work() {
366        // Create a pool with only one thread.
367        let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
368
369        let mu = Arc::new(Mutex::new(false));
370        let cv = Arc::new(Condvar::new());
371        let blocker_is_running = Arc::new(Barrier::new(2));
372
373        // First spawn a thread that blocks the pool.
374        let task_mu = mu.clone();
375        let task_cv = cv.clone();
376        let task_blocker_is_running = blocker_is_running.clone();
377        let _blocking_task = pool.spawn(
378            move || {
379                task_blocker_is_running.wait();
380                let mut ready = task_mu.lock();
381                while !*ready {
382                    ready = task_cv.wait(ready);
383                }
384            },
385            move || {},
386        );
387
388        // Wait for the worker to start running the blocking thread.
389        blocker_is_running.wait();
390
391        // This task will never finish because we will disarm the pool first.
392        let unfinished = pool.spawn(|| 5, || 0);
393
394        // Disarming should cancel the task.
395        pool.disarm();
396
397        // Shutdown the blocking thread. This will allow a worker to pick up the task that has
398        // to be cancelled.
399        *mu.lock() = true;
400        cv.notify_all();
401
402        // We expect the cancelled value to be returned.
403        assert_eq!(block_on(unfinished), 0);
404
405        // Now the pool is empty and can be shutdown without blocking.
406        pool.shutdown().unwrap();
407    }
408
409    #[test]
410    fn shutdown_with_blocked_work_should_timeout() {
411        let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
412
413        let running = Arc::new((Mutex::new(false), Condvar::new()));
414        let running1 = running.clone();
415        let _blocking_task = pool.spawn(
416            move || {
417                *running1.0.lock() = true;
418                running1.1.notify_one();
419                thread::sleep(Duration::from_secs(10000));
420            },
421            move || {},
422        );
423
424        let mut is_running = running.0.lock();
425        while !*is_running {
426            is_running = running.1.wait(is_running);
427        }
428
429        // This shutdown will wait for the full timeout period, so use a short timeout.
430        assert_eq!(
431            pool.shutdown_with_timeout(Duration::from_millis(1)),
432            Err(Error::Timedout)
433        );
434    }
435
436    #[test]
437    fn multiple_shutdown_returns_error() {
438        let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
439        let _ = pool.shutdown();
440        assert_eq!(pool.shutdown(), Err(Error::AlreadyShutdown));
441    }
442
443    #[test]
444    fn shutdown_in_progress() {
445        let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
446
447        let running = Arc::new((Mutex::new(false), Condvar::new()));
448        let running1 = running.clone();
449        let _blocking_task = pool.spawn(
450            move || {
451                *running1.0.lock() = true;
452                running1.1.notify_one();
453                thread::sleep(Duration::from_secs(10000));
454            },
455            move || {},
456        );
457
458        let mut is_running = running.0.lock();
459        while !*is_running {
460            is_running = running.1.wait(is_running);
461        }
462
463        let pool_clone = pool.clone();
464        thread::spawn(move || {
465            while !pool_clone.inner.blocking_pool.shutting_down() {}
466            assert_eq!(pool_clone.shutdown(), Err(Error::ShutdownInProgress));
467        });
468
469        // This shutdown will wait for the full timeout period, so use a short timeout.
470        // However, it also needs to wait long enough for the thread spawned above to observe the
471        // shutting_down state, so don't make it too short.
472        assert_eq!(
473            pool.shutdown_with_timeout(Duration::from_millis(200)),
474            Err(Error::Timedout)
475        );
476    }
477}