cros_async/blocking/
pool.rs1use std::collections::VecDeque;
6use std::future::Future;
7use std::mem;
8use std::sync::mpsc::channel;
9use std::sync::mpsc::Receiver;
10use std::sync::mpsc::Sender;
11use std::sync::Arc;
12use std::thread;
13use std::thread::JoinHandle;
14use std::time::Duration;
15use std::time::Instant;
16
17use base::error;
18use base::warn;
19use futures::channel::oneshot;
20use slab::Slab;
21use sync::Condvar;
22use sync::Mutex;
23
24const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
25
26struct State {
27 tasks: VecDeque<Box<dyn FnOnce() + Send>>,
28 num_threads: usize,
29 num_idle: usize,
30 num_notified: usize,
31 worker_threads: Slab<JoinHandle<()>>,
32 exited_threads: Option<Receiver<usize>>,
33 exit: Sender<usize>,
34 shutting_down: bool,
35}
36
37fn run_blocking_thread(idx: usize, inner: Arc<Inner>, exit: Sender<usize>) {
38 let mut state = inner.state.lock();
39 while !state.shutting_down {
40 if let Some(f) = state.tasks.pop_front() {
41 drop(state);
42 f();
43 state = inner.state.lock();
44 continue;
45 }
46
47 state.num_idle += 1;
49
50 let (guard, result) = inner
51 .condvar
52 .wait_timeout_while(state, inner.keepalive, |s| {
53 !s.shutting_down && s.num_notified == 0
54 });
55 state = guard;
56
57 if state.num_notified > 0 {
59 state.num_notified -= 1;
60 continue;
61 }
62
63 if result.timed_out() {
66 state.num_idle = state
67 .num_idle
68 .checked_sub(1)
69 .expect("`num_idle` underflow on timeout");
70 break;
71 }
72 }
73
74 state.num_threads -= 1;
75
76 let last_exited_thread = if let Some(exited_threads) = state.exited_threads.as_mut() {
79 exited_threads
80 .try_recv()
81 .map(|idx| state.worker_threads.remove(idx))
82 .ok()
83 } else {
84 None
85 };
86
87 drop(state);
89
90 if let Some(handle) = last_exited_thread {
91 let _ = handle.join();
92 }
93
94 if let Err(e) = exit.send(idx) {
95 error!("Failed to send thread exit event on channel: {}", e);
96 }
97}
98
99struct Inner {
100 state: Mutex<State>,
101 condvar: Condvar,
102 max_threads: usize,
103 keepalive: Duration,
104}
105
106impl Inner {
107 pub fn spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R>
108 where
109 F: FnOnce() -> R + Send + 'static,
110 R: Send + 'static,
111 {
112 let mut state = self.state.lock();
113
114 if state.shutting_down {
116 error!("spawn called after shutdown");
117 return futures::future::Either::Left(async {
118 panic!("tried to poll BlockingPool task after shutdown")
119 });
120 }
121
122 let (send_chan, recv_chan) = oneshot::channel();
123 state.tasks.push_back(Box::new(|| {
124 let _ = send_chan.send(f());
125 }));
126
127 if state.num_idle == 0 {
128 if state.num_threads < self.max_threads {
130 state.num_threads += 1;
131 let exit = state.exit.clone();
132 let entry = state.worker_threads.vacant_entry();
133 let idx = entry.key();
134 let inner = self.clone();
135 entry.insert(
136 thread::Builder::new()
137 .name(format!("blockingPool{idx}"))
138 .spawn(move || run_blocking_thread(idx, inner, exit))
139 .unwrap(),
140 );
141 }
142 } else {
143 state.num_idle -= 1;
145 state.num_notified += 1;
146 self.condvar.notify_one();
147 }
148
149 futures::future::Either::Right(async {
150 recv_chan
151 .await
152 .expect("BlockingThread task unexpectedly cancelled")
153 })
154 }
155}
156
157#[derive(Debug, thiserror::Error)]
158#[error("{0} BlockingPool threads did not exit in time and will be detached")]
159pub struct ShutdownTimedOut(usize);
160
161pub struct BlockingPool {
200 inner: Arc<Inner>,
201}
202
203impl BlockingPool {
204 pub fn new(max_threads: usize, keepalive: Duration) -> BlockingPool {
218 let (exit, exited_threads) = channel();
219 BlockingPool {
220 inner: Arc::new(Inner {
221 state: Mutex::new(State {
222 tasks: VecDeque::new(),
223 num_threads: 0,
224 num_idle: 0,
225 num_notified: 0,
226 worker_threads: Slab::new(),
227 exited_threads: Some(exited_threads),
228 exit,
229 shutting_down: false,
230 }),
231 condvar: Condvar::new(),
232 max_threads,
233 keepalive,
234 }),
235 }
236 }
237
238 pub fn with_capacity(max_threads: usize, keepalive: Duration) -> BlockingPool {
240 let (exit, exited_threads) = channel();
241 BlockingPool {
242 inner: Arc::new(Inner {
243 state: Mutex::new(State {
244 tasks: VecDeque::new(),
245 num_threads: 0,
246 num_idle: 0,
247 num_notified: 0,
248 worker_threads: Slab::with_capacity(max_threads),
249 exited_threads: Some(exited_threads),
250 exit,
251 shutting_down: false,
252 }),
253 condvar: Condvar::new(),
254 max_threads,
255 keepalive,
256 }),
257 }
258 }
259
260 pub fn spawn<F, R>(&self, f: F) -> impl Future<Output = R>
270 where
271 F: FnOnce() -> R + Send + 'static,
272 R: Send + 'static,
273 {
274 self.inner.spawn(f)
275 }
276
277 pub fn shutdown(&self, deadline: Option<Instant>) -> Result<(), ShutdownTimedOut> {
284 let mut state = self.inner.state.lock();
285
286 if state.shutting_down {
287 return Ok(());
289 }
290
291 state.shutting_down = true;
292 let exited_threads = state.exited_threads.take().expect("exited_threads missing");
293 let unfinished_tasks = std::mem::take(&mut state.tasks);
294 let mut worker_threads = mem::replace(&mut state.worker_threads, Slab::new());
295 drop(state);
296
297 self.inner.condvar.notify_all();
298
299 drop(unfinished_tasks);
301
302 if let Some(deadline) = deadline {
304 let mut now = Instant::now();
305 while now < deadline && !worker_threads.is_empty() {
306 if let Ok(idx) = exited_threads.recv_timeout(deadline - now) {
307 let _ = worker_threads.remove(idx).join();
308 }
309 now = Instant::now();
310 }
311
312 if !worker_threads.is_empty() {
314 return Err(ShutdownTimedOut(worker_threads.len()));
315 }
316
317 Ok(())
318 } else {
319 for handle in worker_threads.drain() {
321 let _ = handle.join();
322 }
323
324 Ok(())
325 }
326 }
327
328 #[cfg(test)]
329 pub(crate) fn shutting_down(&self) -> bool {
330 self.inner.state.lock().shutting_down
331 }
332}
333
334impl Default for BlockingPool {
335 fn default() -> BlockingPool {
336 BlockingPool::new(256, Duration::from_secs(10))
337 }
338}
339
340impl Drop for BlockingPool {
341 fn drop(&mut self) {
342 if let Err(e) = self.shutdown(Some(Instant::now() + DEFAULT_SHUTDOWN_TIMEOUT)) {
343 warn!("{}", e);
344 }
345 }
346}
347
348#[cfg(test)]
349mod test {
350 use std::sync::Arc;
351 use std::sync::Barrier;
352 use std::thread;
353 use std::time::Duration;
354 use std::time::Instant;
355
356 use futures::executor::block_on;
357 use futures::stream::FuturesUnordered;
358 use futures::StreamExt;
359 use sync::Condvar;
360 use sync::Mutex;
361
362 use super::super::super::BlockingPool;
363
364 #[test]
365 fn blocking_sleep() {
366 let pool = BlockingPool::default();
367
368 let res = block_on(pool.spawn(|| 42));
369 assert_eq!(res, 42);
370 }
371
372 #[test]
373 fn drop_doesnt_block() {
374 let pool = BlockingPool::default();
375 let (tx, rx) = std::sync::mpsc::sync_channel(0);
376 std::mem::drop(pool.spawn(move || tx.send(()).unwrap()));
382 rx.recv().unwrap();
383 }
384
385 #[test]
386 fn fast_tasks_with_short_keepalive() {
387 let pool = BlockingPool::new(256, Duration::from_millis(1));
388
389 let streams = FuturesUnordered::new();
390 for _ in 0..2 {
391 for _ in 0..256 {
392 let task = pool.spawn(|| ());
393 streams.push(task);
394 }
395
396 thread::sleep(Duration::from_millis(1));
397 }
398
399 block_on(streams.collect::<Vec<_>>());
400
401 }
404
405 #[test]
406 fn more_tasks_than_threads() {
407 let pool = BlockingPool::new(4, Duration::from_secs(10));
408
409 let stream = (0..19)
410 .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
411 .collect::<FuturesUnordered<_>>();
412
413 let results = block_on(stream.collect::<Vec<_>>());
414 assert_eq!(results.len(), 19);
415 }
416
417 #[test]
418 fn shutdown() {
419 let pool = BlockingPool::default();
420
421 let stream = (0..19)
422 .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
423 .collect::<FuturesUnordered<_>>();
424
425 let results = block_on(stream.collect::<Vec<_>>());
426 assert_eq!(results.len(), 19);
427
428 pool.shutdown(Some(Instant::now() + Duration::from_secs(10)))
429 .unwrap();
430 let state = pool.inner.state.lock();
431 assert_eq!(state.num_threads, 0);
432 }
433
434 #[test]
435 fn keepalive_timeout() {
436 let pool = BlockingPool::new(7, Duration::from_millis(1));
439
440 let stream = (0..19)
441 .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
442 .collect::<FuturesUnordered<_>>();
443
444 let results = block_on(stream.collect::<Vec<_>>());
445 assert_eq!(results.len(), 19);
446
447 let deadline = Instant::now() + Duration::from_secs(10);
449 while Instant::now() < deadline {
450 thread::sleep(Duration::from_millis(100));
451 let state = pool.inner.state.lock();
452 if state.num_threads == 0 {
453 break;
454 }
455 }
456
457 {
458 let state = pool.inner.state.lock();
459 assert_eq!(state.num_threads, 0);
460 assert_eq!(state.num_idle, 0);
461 }
462 }
463
464 #[test]
465 #[should_panic]
466 fn shutdown_with_pending_work() {
467 let pool = BlockingPool::new(1, Duration::from_secs(10));
468
469 let mu = Arc::new(Mutex::new(false));
470 let cv = Arc::new(Condvar::new());
471
472 let task_mu = mu.clone();
474 let task_cv = cv.clone();
475 let _blocking_task = pool.spawn(move || {
476 let mut ready = task_mu.lock();
477 while !*ready {
478 ready = task_cv.wait(ready);
479 }
480 });
481
482 let unfinished = pool.spawn(|| 5);
484
485 let inner = pool.inner.clone();
488 thread::spawn(move || {
489 let mut state = inner.state.lock();
490 while !state.shutting_down {
491 state = inner.condvar.wait(state);
492 }
493
494 *mu.lock() = true;
495 cv.notify_all();
496 });
497 pool.shutdown(None).unwrap();
498
499 assert_eq!(block_on(unfinished), 5);
501 }
502
503 #[test]
504 fn unfinished_worker_thread() {
505 let pool = BlockingPool::default();
506
507 let ready = Arc::new(Mutex::new(false));
508 let cv = Arc::new(Condvar::new());
509 let barrier = Arc::new(Barrier::new(2));
510
511 let thread_ready = ready.clone();
512 let thread_barrier = barrier.clone();
513 let thread_cv = cv.clone();
514
515 let task = pool.spawn(move || {
516 thread_barrier.wait();
517 let mut ready = thread_ready.lock();
518 while !*ready {
519 ready = thread_cv.wait(ready);
520 }
521 });
522
523 barrier.wait();
525 pool.shutdown(Some(Instant::now() + Duration::from_millis(5)))
526 .unwrap_err();
527
528 let num_threads = pool.inner.state.lock().num_threads;
529 assert_eq!(num_threads, 1);
530
531 *ready.lock() = true;
533 cv.notify_all();
534
535 block_on(task);
536
537 let deadline = Instant::now() + Duration::from_secs(10);
538 while Instant::now() < deadline {
539 thread::sleep(Duration::from_millis(100));
540 let state = pool.inner.state.lock();
541 if state.num_threads == 0 {
542 break;
543 }
544 }
545
546 {
547 let state = pool.inner.state.lock();
548 assert_eq!(state.num_threads, 0);
549 assert_eq!(state.num_idle, 0);
550 }
551 }
552}