cros_async/blocking/cancellable_pool.rs
1// Copyright 2022 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Provides an async blocking pool whose tasks can be cancelled.
6
7use std::collections::HashMap;
8use std::future::Future;
9use std::sync::Arc;
10use std::sync::LazyLock;
11use std::time::Duration;
12use std::time::Instant;
13
14use sync::Condvar;
15use sync::Mutex;
16use thiserror::Error as ThisError;
17
18use crate::BlockingPool;
19
20/// Global executor.
21///
22/// This is convenient, though not preferred. Pros/cons:
23/// + It avoids passing executor all the way to each call sites.
24/// + The call site can assume that executor will never shutdown.
25/// + Provides similar functionality as async_task with a few improvements around ability to cancel.
26/// - Globals are harder to reason about.
27static EXECUTOR: LazyLock<CancellableBlockingPool> =
28 LazyLock::new(|| CancellableBlockingPool::new(256, Duration::from_secs(10)));
29
30const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
31
32#[derive(PartialEq, Eq, PartialOrd, Default)]
33enum WindDownStates {
34 #[default]
35 Armed,
36 Disarmed,
37 ShuttingDown,
38 ShutDown,
39}
40
41#[derive(Default)]
42struct State {
43 wind_down: WindDownStates,
44
45 /// Helps to generate unique id to associate `cancel` with task.
46 current_cancellable_id: u64,
47
48 /// A map of all the `cancel` routines of queued/in-flight tasks.
49 cancellables: HashMap<u64, Box<dyn Fn() + Send + 'static>>,
50}
51
52#[derive(Debug, Clone, Copy)]
53pub enum TimeoutAction {
54 /// Do nothing on timeout.
55 None,
56 /// Panic the thread on timeout.
57 Panic,
58}
59
60#[derive(ThisError, Debug, PartialEq, Eq)]
61pub enum Error {
62 #[error("Timeout occurred while trying to join threads")]
63 Timedout,
64 #[error("Shutdown is in progress")]
65 ShutdownInProgress,
66 #[error("Already shut down")]
67 AlreadyShutdown,
68}
69
70struct Inner {
71 blocking_pool: BlockingPool,
72 state: Mutex<State>,
73
74 /// This condvar gets notified when `cancellables` is empty after removing an
75 /// entry.
76 cancellables_cv: Condvar,
77}
78
79impl Inner {
80 pub fn spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R>
81 where
82 F: FnOnce() -> R + Send + 'static,
83 R: Send + 'static,
84 {
85 self.blocking_pool.spawn(f)
86 }
87
88 /// Adds cancel to a cancellables and returns an `id` with which `cancel` can be
89 /// accessed/removed.
90 fn add_cancellable(&self, cancel: Box<dyn Fn() + Send + 'static>) -> u64 {
91 let mut state = self.state.lock();
92 let id = state.current_cancellable_id;
93 state.current_cancellable_id += 1;
94 state.cancellables.insert(id, cancel);
95 id
96 }
97}
98
99/// A thread pool for running work that may block.
100///
101/// This is a wrapper around `BlockingPool` with an ability to cancel queued tasks.
102/// See [BlockingPool] for more info.
103///
104/// # Examples
105///
106/// Spawn a task to run in the `CancellableBlockingPool` and await on its result.
107///
108/// ```edition2018
109/// use cros_async::CancellableBlockingPool;
110///
111/// # async fn do_it() {
112/// let pool = CancellableBlockingPool::default();
113/// let CANCELLED = 0;
114///
115/// let res = pool.spawn(move || {
116/// // Do some CPU-intensive or blocking work here.
117///
118/// 42
119/// }, move || CANCELLED).await;
120///
121/// assert_eq!(res, 42);
122/// # }
123/// # futures::executor::block_on(do_it());
124/// ```
125#[derive(Clone)]
126pub struct CancellableBlockingPool {
127 inner: Arc<Inner>,
128}
129
130impl CancellableBlockingPool {
131 const RETRY_COUNT: usize = 10;
132 const SLEEP_DURATION: Duration = Duration::from_millis(100);
133
134 /// Create a new `CancellableBlockingPool`.
135 ///
136 /// When we try to shutdown or drop `CancellableBlockingPool`, it may happen that a hung thread
137 /// might prevent `CancellableBlockingPool` pool from getting dropped. On failure to shutdown in
138 /// `watchdog_opts.timeout` duration, `CancellableBlockingPool` can take an action specified by
139 /// `watchdog_opts.action`.
140 ///
141 /// See also: [BlockingPool::new()](BlockingPool::new)
142 pub fn new(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool {
143 CancellableBlockingPool {
144 inner: Arc::new(Inner {
145 blocking_pool: BlockingPool::new(max_threads, keepalive),
146 state: Default::default(),
147 cancellables_cv: Condvar::new(),
148 }),
149 }
150 }
151
152 /// Like [Self::new] but with pre-allocating capacity for up to `max_threads`.
153 pub fn with_capacity(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool {
154 CancellableBlockingPool {
155 inner: Arc::new(Inner {
156 blocking_pool: BlockingPool::with_capacity(max_threads, keepalive),
157 state: Mutex::new(State::default()),
158 cancellables_cv: Condvar::new(),
159 }),
160 }
161 }
162
163 /// Spawn a task to run in the `CancellableBlockingPool`.
164 ///
165 /// Callers may `await` the returned `Task` to be notified when the work is completed.
166 /// Dropping the future will not cancel the task.
167 ///
168 /// `cancel` helps to cancel a queued or in-flight operation `f`.
169 /// `cancel` may be called more than once if `f` doesn't respond to `cancel`.
170 /// `cancel` is not called if `f` completes successfully. For example,
171 /// # Examples
172 ///
173 /// ```edition2018
174 /// use {cros_async::CancellableBlockingPool, std::sync::{Arc, Mutex, Condvar}};
175 ///
176 /// # async fn cancel_it() {
177 /// let pool = CancellableBlockingPool::default();
178 /// let cancelled: i32 = 1;
179 /// let success: i32 = 2;
180 ///
181 /// let shared = Arc::new((Mutex::new(0), Condvar::new()));
182 /// let shared2 = shared.clone();
183 /// let shared3 = shared.clone();
184 ///
185 /// let res = pool
186 /// .spawn(
187 /// move || {
188 /// let guard = shared.0.lock().unwrap();
189 /// let mut guard = shared.1.wait_while(guard, |state| *state == 0).unwrap();
190 /// if *guard != cancelled {
191 /// *guard = success;
192 /// }
193 /// },
194 /// move || {
195 /// *shared2.0.lock().unwrap() = cancelled;
196 /// shared2.1.notify_all();
197 /// },
198 /// )
199 /// .await;
200 /// pool.shutdown();
201 ///
202 /// assert_eq!(*shared3.0.lock().unwrap(), cancelled);
203 /// # }
204 /// ```
205 pub fn spawn<F, R, G>(&self, f: F, cancel: G) -> impl Future<Output = R>
206 where
207 F: FnOnce() -> R + Send + 'static,
208 R: Send + 'static,
209 G: Fn() -> R + Send + 'static,
210 {
211 let inner = self.inner.clone();
212 let cancelled = Arc::new(Mutex::new(None));
213 let cancelled_spawn = cancelled.clone();
214 let id = inner.add_cancellable(Box::new(move || {
215 let mut c = cancelled.lock();
216 *c = Some(cancel());
217 }));
218
219 self.inner.spawn(move || {
220 if let Some(res) = cancelled_spawn.lock().take() {
221 return res;
222 }
223 let ret = f();
224 let mut state = inner.state.lock();
225 state.cancellables.remove(&id);
226 if state.cancellables.is_empty() {
227 inner.cancellables_cv.notify_one();
228 }
229 ret
230 })
231 }
232
233 /// Iterates over all the queued tasks and marks them as cancelled.
234 fn drain_cancellables(&self) {
235 let mut state = self.inner.state.lock();
236 // Iterate a few times to try cancelling all the tasks.
237 for _ in 0..Self::RETRY_COUNT {
238 // Nothing left to do.
239 if state.cancellables.is_empty() {
240 return;
241 }
242
243 // We only cancel the task and do not remove it from the cancellables. It is runner's
244 // job to remove from state.cancellables.
245 for cancel in state.cancellables.values() {
246 cancel();
247 }
248 // Hold the state lock in a block before sleeping so that woken up threads can get to
249 // hold the lock.
250 // Wait for a while so that the threads get a chance complete task in flight.
251 let (state1, _cv_timeout) = self
252 .inner
253 .cancellables_cv
254 .wait_timeout(state, Self::SLEEP_DURATION);
255 state = state1;
256 }
257 }
258
259 /// Marks all the queued and in-flight tasks as cancelled. Any tasks queued after `disarm`ing
260 /// will be cancelled.
261 /// Does not wait for all the tasks to get cancelled.
262 pub fn disarm(&self) {
263 {
264 let mut state = self.inner.state.lock();
265
266 if state.wind_down >= WindDownStates::Disarmed {
267 return;
268 }
269
270 // At this point any new incoming request will be cancelled when run.
271 state.wind_down = WindDownStates::Disarmed;
272 }
273 self.drain_cancellables();
274 }
275
276 /// Shut down the `CancellableBlockingPool`.
277 ///
278 /// This will block until all work that has been started by the worker threads is finished. Any
279 /// work that was added to the `CancellableBlockingPool` but not yet picked up by a worker
280 /// thread will not complete and `await`ing on the `Task` for that work will panic.
281 pub fn shutdown(&self) -> Result<(), Error> {
282 self.shutdown_with_timeout(DEFAULT_SHUTDOWN_TIMEOUT)
283 }
284
285 fn shutdown_with_timeout(&self, timeout: Duration) -> Result<(), Error> {
286 self.disarm();
287 {
288 let mut state = self.inner.state.lock();
289 if state.wind_down == WindDownStates::ShuttingDown {
290 return Err(Error::ShutdownInProgress);
291 }
292 if state.wind_down == WindDownStates::ShutDown {
293 return Err(Error::AlreadyShutdown);
294 }
295 state.wind_down = WindDownStates::ShuttingDown;
296 }
297
298 let res = self
299 .inner
300 .blocking_pool
301 .shutdown(/* deadline: */ Some(Instant::now() + timeout));
302
303 self.inner.state.lock().wind_down = WindDownStates::ShutDown;
304 match res {
305 Ok(_) => Ok(()),
306 Err(_) => Err(Error::Timedout),
307 }
308 }
309}
310
311impl Default for CancellableBlockingPool {
312 fn default() -> CancellableBlockingPool {
313 CancellableBlockingPool::new(256, Duration::from_secs(10))
314 }
315}
316
317impl Drop for CancellableBlockingPool {
318 fn drop(&mut self) {
319 if let Err(e) = self.shutdown() {
320 base::error!("CancellableBlockingPool::shutdown failed: {}", e);
321 }
322 }
323}
324
325/// Spawn a task to run in the `CancellableBlockingPool` static executor.
326///
327/// `cancel` in-flight operation. cancel is called on operation during `disarm` or during
328/// `shutdown`. Cancel may be called multiple times if running task doesn't get cancelled on first
329/// attempt.
330///
331/// Callers may `await` the returned `Task` to be notified when the work is completed.
332///
333/// See also: `spawn`.
334pub fn unblock<F, R, G>(f: F, cancel: G) -> impl Future<Output = R>
335where
336 F: FnOnce() -> R + Send + 'static,
337 R: Send + 'static,
338 G: Fn() -> R + Send + 'static,
339{
340 EXECUTOR.spawn(f, cancel)
341}
342
343/// Marks all the queued and in-flight tasks as cancelled. Any tasks queued after `disarm`ing
344/// will be cancelled.
345/// Doesn't not wait for all the tasks to get cancelled.
346pub fn unblock_disarm() {
347 EXECUTOR.disarm()
348}
349
350#[cfg(test)]
351mod test {
352 use std::sync::Arc;
353 use std::sync::Barrier;
354 use std::thread;
355 use std::time::Duration;
356
357 use futures::executor::block_on;
358 use sync::Condvar;
359 use sync::Mutex;
360
361 use crate::blocking::Error;
362 use crate::CancellableBlockingPool;
363
364 #[test]
365 fn disarm_with_pending_work() {
366 // Create a pool with only one thread.
367 let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
368
369 let mu = Arc::new(Mutex::new(false));
370 let cv = Arc::new(Condvar::new());
371 let blocker_is_running = Arc::new(Barrier::new(2));
372
373 // First spawn a thread that blocks the pool.
374 let task_mu = mu.clone();
375 let task_cv = cv.clone();
376 let task_blocker_is_running = blocker_is_running.clone();
377 let _blocking_task = pool.spawn(
378 move || {
379 task_blocker_is_running.wait();
380 let mut ready = task_mu.lock();
381 while !*ready {
382 ready = task_cv.wait(ready);
383 }
384 },
385 move || {},
386 );
387
388 // Wait for the worker to start running the blocking thread.
389 blocker_is_running.wait();
390
391 // This task will never finish because we will disarm the pool first.
392 let unfinished = pool.spawn(|| 5, || 0);
393
394 // Disarming should cancel the task.
395 pool.disarm();
396
397 // Shutdown the blocking thread. This will allow a worker to pick up the task that has
398 // to be cancelled.
399 *mu.lock() = true;
400 cv.notify_all();
401
402 // We expect the cancelled value to be returned.
403 assert_eq!(block_on(unfinished), 0);
404
405 // Now the pool is empty and can be shutdown without blocking.
406 pool.shutdown().unwrap();
407 }
408
409 #[test]
410 fn shutdown_with_blocked_work_should_timeout() {
411 let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
412
413 let running = Arc::new((Mutex::new(false), Condvar::new()));
414 let running1 = running.clone();
415 let _blocking_task = pool.spawn(
416 move || {
417 *running1.0.lock() = true;
418 running1.1.notify_one();
419 thread::sleep(Duration::from_secs(10000));
420 },
421 move || {},
422 );
423
424 let mut is_running = running.0.lock();
425 while !*is_running {
426 is_running = running.1.wait(is_running);
427 }
428
429 // This shutdown will wait for the full timeout period, so use a short timeout.
430 assert_eq!(
431 pool.shutdown_with_timeout(Duration::from_millis(1)),
432 Err(Error::Timedout)
433 );
434 }
435
436 #[test]
437 fn multiple_shutdown_returns_error() {
438 let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
439 let _ = pool.shutdown();
440 assert_eq!(pool.shutdown(), Err(Error::AlreadyShutdown));
441 }
442
443 #[test]
444 fn shutdown_in_progress() {
445 let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
446
447 let running = Arc::new((Mutex::new(false), Condvar::new()));
448 let running1 = running.clone();
449 let _blocking_task = pool.spawn(
450 move || {
451 *running1.0.lock() = true;
452 running1.1.notify_one();
453 thread::sleep(Duration::from_secs(10000));
454 },
455 move || {},
456 );
457
458 let mut is_running = running.0.lock();
459 while !*is_running {
460 is_running = running.1.wait(is_running);
461 }
462
463 let pool_clone = pool.clone();
464 thread::spawn(move || {
465 while !pool_clone.inner.blocking_pool.shutting_down() {}
466 assert_eq!(pool_clone.shutdown(), Err(Error::ShutdownInProgress));
467 });
468
469 // This shutdown will wait for the full timeout period, so use a short timeout.
470 // However, it also needs to wait long enough for the thread spawned above to observe the
471 // shutting_down state, so don't make it too short.
472 assert_eq!(
473 pool.shutdown_with_timeout(Duration::from_millis(200)),
474 Err(Error::Timedout)
475 );
476 }
477}