1#![deny(missing_docs)]
8
9use std::collections::VecDeque;
10use std::sync::atomic::AtomicBool;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::thread;
14use std::time::Duration;
15
16use anyhow::Context;
17use base::error;
18use base::Event;
19use base::EventWaitResult;
20use sync::Condvar;
21use sync::Mutex;
22
23pub trait Task {
25 fn execute(self);
27}
28
29pub struct Worker<T> {
33 pub channel: Arc<Channel<T>>,
35 handles: Vec<thread::JoinHandle<()>>,
36}
37
38impl<T: Task + Send + 'static> Worker<T> {
39 pub fn new(len_channel: usize, n_workers: usize) -> Self {
41 let channel = Arc::new(Channel::<T>::new(len_channel, n_workers));
42 let mut handles = Vec::with_capacity(n_workers);
43 for _ in 0..n_workers {
44 let context = channel.clone();
45 let handle = thread::spawn(move || {
46 Self::worker_thread(context);
47 });
48 handles.push(handle);
49 }
50 Self { channel, handles }
51 }
52
53 fn worker_thread(context: Arc<Channel<T>>) {
54 while let Some(task) = context.pop() {
55 task.execute();
56 }
57 }
58
59 pub fn close(self) {
63 self.channel.close();
64 for handle in self.handles {
65 match handle.join() {
66 Ok(()) => {}
67 Err(e) => {
68 error!("failed to wait for worker thread: {:?}", e);
69 }
70 }
71 }
72 }
73}
74
75pub struct Channel<T> {
81 state: Mutex<ChannelState<T>>,
82 consumer_wait: Condvar,
83 producer_wait: Condvar,
84 n_consumers: usize,
85}
86
87impl<T> Channel<T> {
88 fn new(len: usize, n_consumers: usize) -> Self {
89 Self {
90 state: Mutex::new(ChannelState::new(len)),
91 consumer_wait: Condvar::new(),
92 producer_wait: Condvar::new(),
93 n_consumers,
94 }
95 }
96
97 fn close(&self) {
98 let mut state = self.state.lock();
99 state.is_closed = true;
100 self.consumer_wait.notify_all();
101 self.producer_wait.notify_all();
102 }
103
104 #[inline]
109 fn pop(&self) -> Option<T> {
110 let mut state = self.state.lock();
111 loop {
112 let was_full = state.queue.len() == state.capacity;
113 if let Some(item) = state.queue.pop_front() {
114 if was_full {
115 self.producer_wait.notify_one();
117 }
118 return Some(item);
119 } else {
120 if state.is_closed {
121 return None;
122 }
123 state.n_waiting += 1;
124 if state.n_waiting == self.n_consumers {
125 self.producer_wait.notify_all();
127 }
128 state = self.consumer_wait.wait(state);
129 state.n_waiting -= 1;
130 }
131 }
132 }
133
134 pub fn push(&self, item: T) -> bool {
140 let mut state = self.state.lock();
141 while state.queue.len() == state.capacity {
143 if state.is_closed {
144 return false;
145 }
146 state = self.producer_wait.wait(state);
147 }
148 if state.is_closed {
149 return false;
150 }
151 state.queue.push_back(item);
152 self.consumer_wait.notify_one();
153 true
154 }
155
156 pub fn wait_complete(&self) {
160 let mut state = self.state.lock();
161 while !(state.queue.is_empty() && state.n_waiting == self.n_consumers) {
162 state = self.producer_wait.wait(state);
163 }
164 }
165}
166
167struct ChannelState<T> {
168 queue: VecDeque<T>,
169 capacity: usize,
170 n_waiting: usize,
171 is_closed: bool,
172}
173
174impl<T> ChannelState<T> {
175 fn new(capacity: usize) -> Self {
176 Self {
177 queue: VecDeque::with_capacity(capacity),
178 capacity,
179 n_waiting: 0,
180 is_closed: false,
181 }
182 }
183}
184
185pub struct BackgroundJobControl {
191 event: Event,
192 abort_flag: AtomicBool,
193}
194
195impl BackgroundJobControl {
196 pub fn new() -> anyhow::Result<Self> {
198 Ok(Self {
199 event: Event::new()?,
200 abort_flag: AtomicBool::new(false),
201 })
202 }
203
204 pub fn new_job(&self) -> BackgroundJob<'_> {
206 BackgroundJob {
207 event: &self.event,
208 abort_flag: &self.abort_flag,
209 }
210 }
211
212 pub fn abort(&self) {
214 self.abort_flag.store(true, Ordering::Release);
215 }
216
217 pub fn reset(&self) -> anyhow::Result<bool> {
221 self.abort_flag.store(false, Ordering::Release);
222 Ok(matches!(
223 self.event
224 .wait_timeout(Duration::ZERO)
225 .context("failed to get job complete event")?,
226 EventWaitResult::Signaled
227 ))
228 }
229
230 pub fn get_completion_event(&self) -> &Event {
232 &self.event
233 }
234}
235
236pub struct BackgroundJob<'a> {
240 event: &'a Event,
241 abort_flag: &'a AtomicBool,
242}
243
244impl BackgroundJob<'_> {
245 pub fn is_aborted(&self) -> bool {
247 self.abort_flag.load(Ordering::Acquire)
248 }
249}
250
251impl Drop for BackgroundJob<'_> {
252 fn drop(&mut self) {
253 self.event.signal().expect("send job complete event");
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use std::time::Duration;
260
261 use super::*;
262
263 #[derive(Clone, Copy)]
264 struct Context {
265 n_consume: usize,
266 n_executed: usize,
267 }
268
269 struct FakeTask {
270 context: Mutex<Context>,
271 waker: Condvar,
272 }
273
274 impl FakeTask {
275 fn new() -> Arc<Self> {
276 Arc::new(Self {
277 context: Mutex::new(Context {
278 n_consume: 0,
279 n_executed: 0,
280 }),
281 waker: Condvar::new(),
282 })
283 }
284
285 fn consume(&self, count: usize) {
286 let mut context = self.context.lock();
287 context.n_consume += count;
288 self.waker.notify_all();
289 }
290
291 fn n_executed(&self) -> usize {
292 self.context.lock().n_executed
293 }
294 }
295
296 impl Task for Arc<FakeTask> {
297 fn execute(self) {
298 let mut context = self.context.lock();
299 while context.n_consume == 0 {
300 context = self.waker.wait(context);
301 }
302 context.n_consume -= 1;
303 context.n_executed += 1;
304 }
305 }
306
307 fn wait_thread_with_timeout<T>(join_handle: thread::JoinHandle<T>, timeout_millis: u64) -> T {
308 for _ in 0..timeout_millis {
309 if join_handle.is_finished() {
310 return join_handle.join().unwrap();
311 }
312 thread::sleep(Duration::from_millis(1));
313 }
314 panic!("thread join timeout");
315 }
316
317 fn poll_until_with_timeout<F>(f: F, timeout_millis: u64)
318 where
319 F: Fn() -> bool,
320 {
321 for _ in 0..timeout_millis {
322 if f() {
323 break;
324 }
325 thread::sleep(Duration::from_millis(1));
326 }
327 }
328
329 #[test]
330 fn test_worker() {
331 let worker = Worker::new(2, 4);
332 let task = FakeTask::new();
333 let channel = worker.channel.clone();
334
335 for _ in 0..4 {
336 assert!(channel.push(task.clone()));
337 }
338
339 assert_eq!(task.n_executed(), 0);
340 task.consume(4);
341 worker.channel.wait_complete();
342 assert_eq!(task.n_executed(), 4);
343 worker.close();
344 }
345
346 #[test]
347 fn test_worker_push_after_close() {
348 let worker = Worker::new(2, 4);
349 let task = FakeTask::new();
350 let channel = worker.channel.clone();
351
352 worker.close();
353
354 assert!(!channel.push(task));
355 }
356
357 #[test]
358 fn test_worker_push_block() {
359 let worker = Worker::new(2, 4);
360 let task = FakeTask::new();
361 let channel = worker.channel.clone();
362
363 let task_cloned = task.clone();
364 wait_thread_with_timeout(
366 thread::spawn(move || {
367 for _ in 0..6 {
368 assert!(channel.push(task_cloned.clone()));
369 }
370 }),
371 100,
372 );
373 let channel = worker.channel.clone();
374 let task_cloned = task.clone();
375 let push_thread = thread::spawn(move || {
376 assert!(channel.push(task_cloned));
377 });
378 thread::sleep(Duration::from_millis(10));
379 assert!(!push_thread.is_finished());
380
381 task.consume(1);
382 wait_thread_with_timeout(push_thread, 100);
383
384 task.consume(6);
385 #[allow(clippy::redundant_clone)]
386 let task_clone = task.clone();
387 poll_until_with_timeout(|| task_clone.n_executed() == 7, 100);
388 assert_eq!(task.n_executed(), 7);
389 worker.close();
390 }
391
392 #[test]
393 fn test_worker_close_on_push_blocked() {
394 let worker = Worker::new(2, 4);
395 let task = FakeTask::new();
396 let channel = worker.channel.clone();
397
398 let task_cloned = task.clone();
399 wait_thread_with_timeout(
401 thread::spawn(move || {
402 for _ in 0..6 {
403 assert!(channel.push(task_cloned.clone()));
404 }
405 }),
406 100,
407 );
408 let channel = worker.channel.clone();
409 let task_cloned = task.clone();
410 let push_thread = thread::spawn(move || channel.push(task_cloned));
411 thread::sleep(Duration::from_millis(10));
413 let close_thread = thread::spawn(move || {
415 worker.close();
416 });
417 let push_result = wait_thread_with_timeout(push_thread, 100);
418 assert!(!push_result);
420
421 task.consume(6);
423 wait_thread_with_timeout(close_thread, 100);
424 }
425
426 #[test]
427 fn new_background_job_event() {
428 assert!(BackgroundJobControl::new().is_ok());
429 }
430
431 #[test]
432 fn background_job_is_not_aborted_default() {
433 let event = BackgroundJobControl::new().unwrap();
434
435 let job = event.new_job();
436
437 assert!(!job.is_aborted());
438 }
439
440 #[test]
441 fn abort_background_job() {
442 let event = BackgroundJobControl::new().unwrap();
443
444 let job = event.new_job();
445 event.abort();
446
447 assert!(job.is_aborted());
448 }
449
450 #[test]
451 fn reset_background_job() {
452 let event = BackgroundJobControl::new().unwrap();
453
454 event.abort();
455 event.reset().unwrap();
456 let job = event.new_job();
457
458 assert!(!job.is_aborted());
459 }
460
461 #[test]
462 fn reset_background_job_event() {
463 let event = BackgroundJobControl::new().unwrap();
464
465 let job = event.new_job();
466 drop(job);
467
468 assert!(event.reset().unwrap());
469 }
470
471 #[test]
472 fn reset_background_job_event_twice() {
473 let event = BackgroundJobControl::new().unwrap();
474
475 let job = event.new_job();
476 drop(job);
477
478 event.reset().unwrap();
479 assert!(!event.reset().unwrap());
480 }
481
482 #[test]
483 fn reset_background_job_event_no_jobs() {
484 let event = BackgroundJobControl::new().unwrap();
485
486 assert!(!event.reset().unwrap());
487 }
488}