swap/
worker.rs

1// Copyright 2022 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Multi-thread worker.
6
7#![deny(missing_docs)]
8
9use std::collections::VecDeque;
10use std::sync::atomic::AtomicBool;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::thread;
14use std::time::Duration;
15
16use anyhow::Context;
17use base::error;
18use base::Event;
19use base::EventWaitResult;
20use sync::Condvar;
21use sync::Mutex;
22
23/// Task to run on the worker threads.
24pub trait Task {
25    /// Executes the task.
26    fn execute(self);
27}
28
29/// Multi thread based worker executing a single type [Task].
30///
31/// See the doc of [Channel] as well for the behaviors of it.
32pub struct Worker<T> {
33    /// Shared [Channel] with the worker threads.
34    pub channel: Arc<Channel<T>>,
35    handles: Vec<thread::JoinHandle<()>>,
36}
37
38impl<T: Task + Send + 'static> Worker<T> {
39    /// Spawns the numbers of worker threads.
40    pub fn new(len_channel: usize, n_workers: usize) -> Self {
41        let channel = Arc::new(Channel::<T>::new(len_channel, n_workers));
42        let mut handles = Vec::with_capacity(n_workers);
43        for _ in 0..n_workers {
44            let context = channel.clone();
45            let handle = thread::spawn(move || {
46                Self::worker_thread(context);
47            });
48            handles.push(handle);
49        }
50        Self { channel, handles }
51    }
52
53    fn worker_thread(context: Arc<Channel<T>>) {
54        while let Some(task) = context.pop() {
55            task.execute();
56        }
57    }
58
59    /// Closes the channel and wait for worker threads shutdown.
60    ///
61    /// This also waits for all the tasks in the channel to be executed.
62    pub fn close(self) {
63        self.channel.close();
64        for handle in self.handles {
65            match handle.join() {
66                Ok(()) => {}
67                Err(e) => {
68                    error!("failed to wait for worker thread: {:?}", e);
69                }
70            }
71        }
72    }
73}
74
75/// MPMC (Multi Producers Multi Consumers) queue integrated with [Worker].
76///
77/// [Channel] offers [Channel::wait_complete()] to guarantee all the tasks are executed.
78///
79/// This only exposes methods for producers.
80pub struct Channel<T> {
81    state: Mutex<ChannelState<T>>,
82    consumer_wait: Condvar,
83    producer_wait: Condvar,
84    n_consumers: usize,
85}
86
87impl<T> Channel<T> {
88    fn new(len: usize, n_consumers: usize) -> Self {
89        Self {
90            state: Mutex::new(ChannelState::new(len)),
91            consumer_wait: Condvar::new(),
92            producer_wait: Condvar::new(),
93            n_consumers,
94        }
95    }
96
97    fn close(&self) {
98        let mut state = self.state.lock();
99        state.is_closed = true;
100        self.consumer_wait.notify_all();
101        self.producer_wait.notify_all();
102    }
103
104    /// Pops a task from the channel.
105    ///
106    /// If the queue is closed and also **empty**, this returns [None]. This returns all the tasks
107    /// in the queue even while this is closed.
108    #[inline]
109    fn pop(&self) -> Option<T> {
110        let mut state = self.state.lock();
111        loop {
112            let was_full = state.queue.len() == state.capacity;
113            if let Some(item) = state.queue.pop_front() {
114                if was_full {
115                    // notification for a producer waiting for `push()`.
116                    self.producer_wait.notify_one();
117                }
118                return Some(item);
119            } else {
120                if state.is_closed {
121                    return None;
122                }
123                state.n_waiting += 1;
124                if state.n_waiting == self.n_consumers {
125                    // notification for producers waiting for `wait_complete()`.
126                    self.producer_wait.notify_all();
127                }
128                state = self.consumer_wait.wait(state);
129                state.n_waiting -= 1;
130            }
131        }
132    }
133
134    /// Push a task.
135    ///
136    /// This blocks if the channel is full.
137    ///
138    /// If the channel is closed, this returns `false`.
139    pub fn push(&self, item: T) -> bool {
140        let mut state = self.state.lock();
141        // Wait until the queue has room to push a task.
142        while state.queue.len() == state.capacity {
143            if state.is_closed {
144                return false;
145            }
146            state = self.producer_wait.wait(state);
147        }
148        if state.is_closed {
149            return false;
150        }
151        state.queue.push_back(item);
152        self.consumer_wait.notify_one();
153        true
154    }
155
156    /// Wait until all the tasks have been executed.
157    ///
158    /// This guarantees that all the tasks in this channel are not only consumed but also executed.
159    pub fn wait_complete(&self) {
160        let mut state = self.state.lock();
161        while !(state.queue.is_empty() && state.n_waiting == self.n_consumers) {
162            state = self.producer_wait.wait(state);
163        }
164    }
165}
166
167struct ChannelState<T> {
168    queue: VecDeque<T>,
169    capacity: usize,
170    n_waiting: usize,
171    is_closed: bool,
172}
173
174impl<T> ChannelState<T> {
175    fn new(capacity: usize) -> Self {
176        Self {
177            queue: VecDeque::with_capacity(capacity),
178            capacity,
179            n_waiting: 0,
180            is_closed: false,
181        }
182    }
183}
184
185/// The event channel for background jobs.
186///
187/// This sends an abort request from the main thread to the job thread via atomic boolean flag.
188///
189/// This notifies the main thread that the job thread is completed via [Event].
190pub struct BackgroundJobControl {
191    event: Event,
192    abort_flag: AtomicBool,
193}
194
195impl BackgroundJobControl {
196    /// Creates [BackgroundJobControl].
197    pub fn new() -> anyhow::Result<Self> {
198        Ok(Self {
199            event: Event::new()?,
200            abort_flag: AtomicBool::new(false),
201        })
202    }
203
204    /// Creates [BackgroundJob].
205    pub fn new_job(&self) -> BackgroundJob<'_> {
206        BackgroundJob {
207            event: &self.event,
208            abort_flag: &self.abort_flag,
209        }
210    }
211
212    /// Abort the background job.
213    pub fn abort(&self) {
214        self.abort_flag.store(true, Ordering::Release);
215    }
216
217    /// Reset the internal state for a next job.
218    ///
219    /// Returns false, if the event is already reset and no event exists.
220    pub fn reset(&self) -> anyhow::Result<bool> {
221        self.abort_flag.store(false, Ordering::Release);
222        Ok(matches!(
223            self.event
224                .wait_timeout(Duration::ZERO)
225                .context("failed to get job complete event")?,
226            EventWaitResult::Signaled
227        ))
228    }
229
230    /// Returns the event to notify the completion of background job.
231    pub fn get_completion_event(&self) -> &Event {
232        &self.event
233    }
234}
235
236/// Background job context.
237///
238/// When dropped, this sends an event to the main thread via [Event].
239pub struct BackgroundJob<'a> {
240    event: &'a Event,
241    abort_flag: &'a AtomicBool,
242}
243
244impl BackgroundJob<'_> {
245    /// Returns whether the background job is aborted or not.
246    pub fn is_aborted(&self) -> bool {
247        self.abort_flag.load(Ordering::Acquire)
248    }
249}
250
251impl Drop for BackgroundJob<'_> {
252    fn drop(&mut self) {
253        self.event.signal().expect("send job complete event");
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use std::time::Duration;
260
261    use super::*;
262
263    #[derive(Clone, Copy)]
264    struct Context {
265        n_consume: usize,
266        n_executed: usize,
267    }
268
269    struct FakeTask {
270        context: Mutex<Context>,
271        waker: Condvar,
272    }
273
274    impl FakeTask {
275        fn new() -> Arc<Self> {
276            Arc::new(Self {
277                context: Mutex::new(Context {
278                    n_consume: 0,
279                    n_executed: 0,
280                }),
281                waker: Condvar::new(),
282            })
283        }
284
285        fn consume(&self, count: usize) {
286            let mut context = self.context.lock();
287            context.n_consume += count;
288            self.waker.notify_all();
289        }
290
291        fn n_executed(&self) -> usize {
292            self.context.lock().n_executed
293        }
294    }
295
296    impl Task for Arc<FakeTask> {
297        fn execute(self) {
298            let mut context = self.context.lock();
299            while context.n_consume == 0 {
300                context = self.waker.wait(context);
301            }
302            context.n_consume -= 1;
303            context.n_executed += 1;
304        }
305    }
306
307    fn wait_thread_with_timeout<T>(join_handle: thread::JoinHandle<T>, timeout_millis: u64) -> T {
308        for _ in 0..timeout_millis {
309            if join_handle.is_finished() {
310                return join_handle.join().unwrap();
311            }
312            thread::sleep(Duration::from_millis(1));
313        }
314        panic!("thread join timeout");
315    }
316
317    fn poll_until_with_timeout<F>(f: F, timeout_millis: u64)
318    where
319        F: Fn() -> bool,
320    {
321        for _ in 0..timeout_millis {
322            if f() {
323                break;
324            }
325            thread::sleep(Duration::from_millis(1));
326        }
327    }
328
329    #[test]
330    fn test_worker() {
331        let worker = Worker::new(2, 4);
332        let task = FakeTask::new();
333        let channel = worker.channel.clone();
334
335        for _ in 0..4 {
336            assert!(channel.push(task.clone()));
337        }
338
339        assert_eq!(task.n_executed(), 0);
340        task.consume(4);
341        worker.channel.wait_complete();
342        assert_eq!(task.n_executed(), 4);
343        worker.close();
344    }
345
346    #[test]
347    fn test_worker_push_after_close() {
348        let worker = Worker::new(2, 4);
349        let task = FakeTask::new();
350        let channel = worker.channel.clone();
351
352        worker.close();
353
354        assert!(!channel.push(task));
355    }
356
357    #[test]
358    fn test_worker_push_block() {
359        let worker = Worker::new(2, 4);
360        let task = FakeTask::new();
361        let channel = worker.channel.clone();
362
363        let task_cloned = task.clone();
364        // push tasks on another thread to avoid blocking forever
365        wait_thread_with_timeout(
366            thread::spawn(move || {
367                for _ in 0..6 {
368                    assert!(channel.push(task_cloned.clone()));
369                }
370            }),
371            100,
372        );
373        let channel = worker.channel.clone();
374        let task_cloned = task.clone();
375        let push_thread = thread::spawn(move || {
376            assert!(channel.push(task_cloned));
377        });
378        thread::sleep(Duration::from_millis(10));
379        assert!(!push_thread.is_finished());
380
381        task.consume(1);
382        wait_thread_with_timeout(push_thread, 100);
383
384        task.consume(6);
385        #[allow(clippy::redundant_clone)]
386        let task_clone = task.clone();
387        poll_until_with_timeout(|| task_clone.n_executed() == 7, 100);
388        assert_eq!(task.n_executed(), 7);
389        worker.close();
390    }
391
392    #[test]
393    fn test_worker_close_on_push_blocked() {
394        let worker = Worker::new(2, 4);
395        let task = FakeTask::new();
396        let channel = worker.channel.clone();
397
398        let task_cloned = task.clone();
399        // push tasks on another thread to avoid blocking forever
400        wait_thread_with_timeout(
401            thread::spawn(move || {
402                for _ in 0..6 {
403                    assert!(channel.push(task_cloned.clone()));
404                }
405            }),
406            100,
407        );
408        let channel = worker.channel.clone();
409        let task_cloned = task.clone();
410        let push_thread = thread::spawn(move || channel.push(task_cloned));
411        // sleep to run push_thread.
412        thread::sleep(Duration::from_millis(10));
413        // close blocks until all the task are executed.
414        let close_thread = thread::spawn(move || {
415            worker.close();
416        });
417        let push_result = wait_thread_with_timeout(push_thread, 100);
418        // push fails.
419        assert!(!push_result);
420
421        // cleanup
422        task.consume(6);
423        wait_thread_with_timeout(close_thread, 100);
424    }
425
426    #[test]
427    fn new_background_job_event() {
428        assert!(BackgroundJobControl::new().is_ok());
429    }
430
431    #[test]
432    fn background_job_is_not_aborted_default() {
433        let event = BackgroundJobControl::new().unwrap();
434
435        let job = event.new_job();
436
437        assert!(!job.is_aborted());
438    }
439
440    #[test]
441    fn abort_background_job() {
442        let event = BackgroundJobControl::new().unwrap();
443
444        let job = event.new_job();
445        event.abort();
446
447        assert!(job.is_aborted());
448    }
449
450    #[test]
451    fn reset_background_job() {
452        let event = BackgroundJobControl::new().unwrap();
453
454        event.abort();
455        event.reset().unwrap();
456        let job = event.new_job();
457
458        assert!(!job.is_aborted());
459    }
460
461    #[test]
462    fn reset_background_job_event() {
463        let event = BackgroundJobControl::new().unwrap();
464
465        let job = event.new_job();
466        drop(job);
467
468        assert!(event.reset().unwrap());
469    }
470
471    #[test]
472    fn reset_background_job_event_twice() {
473        let event = BackgroundJobControl::new().unwrap();
474
475        let job = event.new_job();
476        drop(job);
477
478        event.reset().unwrap();
479        assert!(!event.reset().unwrap());
480    }
481
482    #[test]
483    fn reset_background_job_event_no_jobs() {
484        let event = BackgroundJobControl::new().unwrap();
485
486        assert!(!event.reset().unwrap());
487    }
488}