cros_async/blocking/
pool.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::collections::VecDeque;
6use std::future::Future;
7use std::mem;
8use std::sync::mpsc::channel;
9use std::sync::mpsc::Receiver;
10use std::sync::mpsc::Sender;
11use std::sync::Arc;
12use std::thread;
13use std::thread::JoinHandle;
14use std::time::Duration;
15use std::time::Instant;
16
17use base::error;
18use base::warn;
19use futures::channel::oneshot;
20use slab::Slab;
21use sync::Condvar;
22use sync::Mutex;
23
24const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
25
26struct State {
27    tasks: VecDeque<Box<dyn FnOnce() + Send>>,
28    num_threads: usize,
29    num_idle: usize,
30    num_notified: usize,
31    worker_threads: Slab<JoinHandle<()>>,
32    exited_threads: Option<Receiver<usize>>,
33    exit: Sender<usize>,
34    shutting_down: bool,
35}
36
37fn run_blocking_thread(idx: usize, inner: Arc<Inner>, exit: Sender<usize>) {
38    let mut state = inner.state.lock();
39    while !state.shutting_down {
40        if let Some(f) = state.tasks.pop_front() {
41            drop(state);
42            f();
43            state = inner.state.lock();
44            continue;
45        }
46
47        // No more tasks so wait for more work.
48        state.num_idle += 1;
49
50        let (guard, result) = inner
51            .condvar
52            .wait_timeout_while(state, inner.keepalive, |s| {
53                !s.shutting_down && s.num_notified == 0
54            });
55        state = guard;
56
57        // If `state.num_notified > 0` then this was a real wakeup.
58        if state.num_notified > 0 {
59            state.num_notified -= 1;
60            continue;
61        }
62
63        // Only decrement the idle count if we timed out. Otherwise, it was decremented when new
64        // work was added to `state.tasks`.
65        if result.timed_out() {
66            state.num_idle = state
67                .num_idle
68                .checked_sub(1)
69                .expect("`num_idle` underflow on timeout");
70            break;
71        }
72    }
73
74    state.num_threads -= 1;
75
76    // If we're shutting down then the BlockingPool will take care of joining all the threads.
77    // Otherwise, we need to join the last worker thread that exited here.
78    let last_exited_thread = if let Some(exited_threads) = state.exited_threads.as_mut() {
79        exited_threads
80            .try_recv()
81            .map(|idx| state.worker_threads.remove(idx))
82            .ok()
83    } else {
84        None
85    };
86
87    // Drop the lock before trying to join the last exited thread.
88    drop(state);
89
90    if let Some(handle) = last_exited_thread {
91        let _ = handle.join();
92    }
93
94    if let Err(e) = exit.send(idx) {
95        error!("Failed to send thread exit event on channel: {}", e);
96    }
97}
98
99struct Inner {
100    state: Mutex<State>,
101    condvar: Condvar,
102    max_threads: usize,
103    keepalive: Duration,
104}
105
106impl Inner {
107    pub fn spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R>
108    where
109        F: FnOnce() -> R + Send + 'static,
110        R: Send + 'static,
111    {
112        let mut state = self.state.lock();
113
114        // If we're shutting down then nothing is going to run this task.
115        if state.shutting_down {
116            error!("spawn called after shutdown");
117            return futures::future::Either::Left(async {
118                panic!("tried to poll BlockingPool task after shutdown")
119            });
120        }
121
122        let (send_chan, recv_chan) = oneshot::channel();
123        state.tasks.push_back(Box::new(|| {
124            let _ = send_chan.send(f());
125        }));
126
127        if state.num_idle == 0 {
128            // There are no idle threads.  Spawn a new one if possible.
129            if state.num_threads < self.max_threads {
130                state.num_threads += 1;
131                let exit = state.exit.clone();
132                let entry = state.worker_threads.vacant_entry();
133                let idx = entry.key();
134                let inner = self.clone();
135                entry.insert(
136                    thread::Builder::new()
137                        .name(format!("blockingPool{idx}"))
138                        .spawn(move || run_blocking_thread(idx, inner, exit))
139                        .unwrap(),
140                );
141            }
142        } else {
143            // We have idle threads, wake one up.
144            state.num_idle -= 1;
145            state.num_notified += 1;
146            self.condvar.notify_one();
147        }
148
149        futures::future::Either::Right(async {
150            recv_chan
151                .await
152                .expect("BlockingThread task unexpectedly cancelled")
153        })
154    }
155}
156
157#[derive(Debug, thiserror::Error)]
158#[error("{0} BlockingPool threads did not exit in time and will be detached")]
159pub struct ShutdownTimedOut(usize);
160
161/// A thread pool for running work that may block.
162///
163/// It is generally discouraged to do any blocking work inside an async function. However, this is
164/// sometimes unavoidable when dealing with interfaces that don't provide async variants. In this
165/// case callers may use the `BlockingPool` to run the blocking work on a different thread and
166/// `await` for its result to finish, which will prevent blocking the main thread of the
167/// application.
168///
169/// Since the blocking work is sent to another thread, users should be careful when using the
170/// `BlockingPool` for latency-sensitive operations. Additionally, the `BlockingPool` is intended to
171/// be used for work that will eventually complete on its own. Users who want to spawn a thread
172/// should just use `thread::spawn` directly.
173///
174/// There is no way to cancel work once it has been picked up by one of the worker threads in the
175/// `BlockingPool`. Dropping or shutting down the pool will block up to a timeout (default 10
176/// seconds) to wait for any active blocking work to finish. Any threads running tasks that have not
177/// completed by that time will be detached.
178///
179/// # Examples
180///
181/// Spawn a task to run in the `BlockingPool` and await on its result.
182///
183/// ```edition2018
184/// use cros_async::BlockingPool;
185///
186/// # async fn do_it() {
187///     let pool = BlockingPool::default();
188///
189///     let res = pool.spawn(move || {
190///         // Do some CPU-intensive or blocking work here.
191///
192///         42
193///     }).await;
194///
195///     assert_eq!(res, 42);
196/// # }
197/// # cros_async::block_on(do_it());
198/// ```
199pub struct BlockingPool {
200    inner: Arc<Inner>,
201}
202
203impl BlockingPool {
204    /// Create a new `BlockingPool`.
205    ///
206    /// The `BlockingPool` will never spawn more than `max_threads` threads to do work, regardless
207    /// of the number of tasks that are added to it. This value should be set relatively low (for
208    /// example, the number of CPUs on the machine) if the pool is intended to run CPU intensive
209    /// work or it should be set relatively high (128 or more) if the pool is intended to be used
210    /// for various IO operations that cannot be completed asynchronously. The default value is 256.
211    ///
212    /// Worker threads are spawned on demand when new work is added to the pool and will
213    /// automatically exit after being idle for some time so there is no overhead for setting
214    /// `max_threads` to a large value when there is little to no work assigned to the
215    /// `BlockingPool`. `keepalive` determines the idle duration after which the worker thread will
216    /// exit. The default value is 10 seconds.
217    pub fn new(max_threads: usize, keepalive: Duration) -> BlockingPool {
218        let (exit, exited_threads) = channel();
219        BlockingPool {
220            inner: Arc::new(Inner {
221                state: Mutex::new(State {
222                    tasks: VecDeque::new(),
223                    num_threads: 0,
224                    num_idle: 0,
225                    num_notified: 0,
226                    worker_threads: Slab::new(),
227                    exited_threads: Some(exited_threads),
228                    exit,
229                    shutting_down: false,
230                }),
231                condvar: Condvar::new(),
232                max_threads,
233                keepalive,
234            }),
235        }
236    }
237
238    /// Like new but with pre-allocating capacity for up to `max_threads`.
239    pub fn with_capacity(max_threads: usize, keepalive: Duration) -> BlockingPool {
240        let (exit, exited_threads) = channel();
241        BlockingPool {
242            inner: Arc::new(Inner {
243                state: Mutex::new(State {
244                    tasks: VecDeque::new(),
245                    num_threads: 0,
246                    num_idle: 0,
247                    num_notified: 0,
248                    worker_threads: Slab::with_capacity(max_threads),
249                    exited_threads: Some(exited_threads),
250                    exit,
251                    shutting_down: false,
252                }),
253                condvar: Condvar::new(),
254                max_threads,
255                keepalive,
256            }),
257        }
258    }
259
260    /// Spawn a task to run in the `BlockingPool`.
261    ///
262    /// Callers may `await` the returned `Future` to be notified when the work is completed.
263    /// Dropping the future will not cancel the task.
264    ///
265    /// # Panics
266    ///
267    /// `await`ing a `Task` after dropping the `BlockingPool` or calling `BlockingPool::shutdown`
268    /// will panic if the work was not completed before the pool was shut down.
269    pub fn spawn<F, R>(&self, f: F) -> impl Future<Output = R>
270    where
271        F: FnOnce() -> R + Send + 'static,
272        R: Send + 'static,
273    {
274        self.inner.spawn(f)
275    }
276
277    /// Shut down the `BlockingPool`.
278    ///
279    /// If `deadline` is provided then this will block until either all worker threads exit or the
280    /// deadline is exceeded. If `deadline` is not given then this will block indefinitely until all
281    /// worker threads exit. Any work that was added to the `BlockingPool` but not yet picked up by
282    /// a worker thread will not complete and `await`ing on the `Task` for that work will panic.
283    pub fn shutdown(&self, deadline: Option<Instant>) -> Result<(), ShutdownTimedOut> {
284        let mut state = self.inner.state.lock();
285
286        if state.shutting_down {
287            // We've already shut down this BlockingPool.
288            return Ok(());
289        }
290
291        state.shutting_down = true;
292        let exited_threads = state.exited_threads.take().expect("exited_threads missing");
293        let unfinished_tasks = std::mem::take(&mut state.tasks);
294        let mut worker_threads = mem::replace(&mut state.worker_threads, Slab::new());
295        drop(state);
296
297        self.inner.condvar.notify_all();
298
299        // Cancel any unfinished work after releasing the lock.
300        drop(unfinished_tasks);
301
302        // Now wait for all worker threads to exit.
303        if let Some(deadline) = deadline {
304            let mut now = Instant::now();
305            while now < deadline && !worker_threads.is_empty() {
306                if let Ok(idx) = exited_threads.recv_timeout(deadline - now) {
307                    let _ = worker_threads.remove(idx).join();
308                }
309                now = Instant::now();
310            }
311
312            // Any threads that have not yet joined will just be detached.
313            if !worker_threads.is_empty() {
314                return Err(ShutdownTimedOut(worker_threads.len()));
315            }
316
317            Ok(())
318        } else {
319            // Block indefinitely until all worker threads exit.
320            for handle in worker_threads.drain() {
321                let _ = handle.join();
322            }
323
324            Ok(())
325        }
326    }
327
328    #[cfg(test)]
329    pub(crate) fn shutting_down(&self) -> bool {
330        self.inner.state.lock().shutting_down
331    }
332}
333
334impl Default for BlockingPool {
335    fn default() -> BlockingPool {
336        BlockingPool::new(256, Duration::from_secs(10))
337    }
338}
339
340impl Drop for BlockingPool {
341    fn drop(&mut self) {
342        if let Err(e) = self.shutdown(Some(Instant::now() + DEFAULT_SHUTDOWN_TIMEOUT)) {
343            warn!("{}", e);
344        }
345    }
346}
347
348#[cfg(test)]
349mod test {
350    use std::sync::Arc;
351    use std::sync::Barrier;
352    use std::thread;
353    use std::time::Duration;
354    use std::time::Instant;
355
356    use futures::executor::block_on;
357    use futures::stream::FuturesUnordered;
358    use futures::StreamExt;
359    use sync::Condvar;
360    use sync::Mutex;
361
362    use super::super::super::BlockingPool;
363
364    #[test]
365    fn blocking_sleep() {
366        let pool = BlockingPool::default();
367
368        let res = block_on(pool.spawn(|| 42));
369        assert_eq!(res, 42);
370    }
371
372    #[test]
373    fn drop_doesnt_block() {
374        let pool = BlockingPool::default();
375        let (tx, rx) = std::sync::mpsc::sync_channel(0);
376        // The blocking work should continue even though we drop the future.
377        //
378        // If we cancelled the work, then the recv call would fail. If we blocked on the work, then
379        // the send would never complete because the channel is size zero and so waits for a
380        // matching recv call.
381        std::mem::drop(pool.spawn(move || tx.send(()).unwrap()));
382        rx.recv().unwrap();
383    }
384
385    #[test]
386    fn fast_tasks_with_short_keepalive() {
387        let pool = BlockingPool::new(256, Duration::from_millis(1));
388
389        let streams = FuturesUnordered::new();
390        for _ in 0..2 {
391            for _ in 0..256 {
392                let task = pool.spawn(|| ());
393                streams.push(task);
394            }
395
396            thread::sleep(Duration::from_millis(1));
397        }
398
399        block_on(streams.collect::<Vec<_>>());
400
401        // The test passes if there are no panics, which would happen if one of the worker threads
402        // triggered an underflow on `pool.inner.state.num_idle`.
403    }
404
405    #[test]
406    fn more_tasks_than_threads() {
407        let pool = BlockingPool::new(4, Duration::from_secs(10));
408
409        let stream = (0..19)
410            .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
411            .collect::<FuturesUnordered<_>>();
412
413        let results = block_on(stream.collect::<Vec<_>>());
414        assert_eq!(results.len(), 19);
415    }
416
417    #[test]
418    fn shutdown() {
419        let pool = BlockingPool::default();
420
421        let stream = (0..19)
422            .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
423            .collect::<FuturesUnordered<_>>();
424
425        let results = block_on(stream.collect::<Vec<_>>());
426        assert_eq!(results.len(), 19);
427
428        pool.shutdown(Some(Instant::now() + Duration::from_secs(10)))
429            .unwrap();
430        let state = pool.inner.state.lock();
431        assert_eq!(state.num_threads, 0);
432    }
433
434    #[test]
435    fn keepalive_timeout() {
436        // Set the keepalive to a very low value so that threads will exit soon after they run out
437        // of work.
438        let pool = BlockingPool::new(7, Duration::from_millis(1));
439
440        let stream = (0..19)
441            .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
442            .collect::<FuturesUnordered<_>>();
443
444        let results = block_on(stream.collect::<Vec<_>>());
445        assert_eq!(results.len(), 19);
446
447        // Wait for all threads to exit.
448        let deadline = Instant::now() + Duration::from_secs(10);
449        while Instant::now() < deadline {
450            thread::sleep(Duration::from_millis(100));
451            let state = pool.inner.state.lock();
452            if state.num_threads == 0 {
453                break;
454            }
455        }
456
457        {
458            let state = pool.inner.state.lock();
459            assert_eq!(state.num_threads, 0);
460            assert_eq!(state.num_idle, 0);
461        }
462    }
463
464    #[test]
465    #[should_panic]
466    fn shutdown_with_pending_work() {
467        let pool = BlockingPool::new(1, Duration::from_secs(10));
468
469        let mu = Arc::new(Mutex::new(false));
470        let cv = Arc::new(Condvar::new());
471
472        // First spawn a thread that blocks the pool.
473        let task_mu = mu.clone();
474        let task_cv = cv.clone();
475        let _blocking_task = pool.spawn(move || {
476            let mut ready = task_mu.lock();
477            while !*ready {
478                ready = task_cv.wait(ready);
479            }
480        });
481
482        // This task will never finish because we will shut down the pool first.
483        let unfinished = pool.spawn(|| 5);
484
485        // Spawn a thread to unblock the work we started earlier once it sees that the pool is
486        // shutting down.
487        let inner = pool.inner.clone();
488        thread::spawn(move || {
489            let mut state = inner.state.lock();
490            while !state.shutting_down {
491                state = inner.condvar.wait(state);
492            }
493
494            *mu.lock() = true;
495            cv.notify_all();
496        });
497        pool.shutdown(None).unwrap();
498
499        // This should panic.
500        assert_eq!(block_on(unfinished), 5);
501    }
502
503    #[test]
504    fn unfinished_worker_thread() {
505        let pool = BlockingPool::default();
506
507        let ready = Arc::new(Mutex::new(false));
508        let cv = Arc::new(Condvar::new());
509        let barrier = Arc::new(Barrier::new(2));
510
511        let thread_ready = ready.clone();
512        let thread_barrier = barrier.clone();
513        let thread_cv = cv.clone();
514
515        let task = pool.spawn(move || {
516            thread_barrier.wait();
517            let mut ready = thread_ready.lock();
518            while !*ready {
519                ready = thread_cv.wait(ready);
520            }
521        });
522
523        // Wait to shut down the pool until after the worker thread has started.
524        barrier.wait();
525        pool.shutdown(Some(Instant::now() + Duration::from_millis(5)))
526            .unwrap_err();
527
528        let num_threads = pool.inner.state.lock().num_threads;
529        assert_eq!(num_threads, 1);
530
531        // Now wake up the blocked task so we don't leak the thread.
532        *ready.lock() = true;
533        cv.notify_all();
534
535        block_on(task);
536
537        let deadline = Instant::now() + Duration::from_secs(10);
538        while Instant::now() < deadline {
539            thread::sleep(Duration::from_millis(100));
540            let state = pool.inner.state.lock();
541            if state.num_threads == 0 {
542                break;
543            }
544        }
545
546        {
547            let state = pool.inner.state.lock();
548            assert_eq!(state.num_threads, 0);
549            assert_eq!(state.num_idle, 0);
550        }
551    }
552}