#![deny(missing_docs)]
use std::collections::VecDeque;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use anyhow::Context;
use base::error;
use base::Event;
use base::EventWaitResult;
use sync::Condvar;
use sync::Mutex;
pub trait Task {
fn execute(self);
}
pub struct Worker<T> {
pub channel: Arc<Channel<T>>,
handles: Vec<thread::JoinHandle<()>>,
}
impl<T: Task + Send + 'static> Worker<T> {
pub fn new(len_channel: usize, n_workers: usize) -> Self {
let channel = Arc::new(Channel::<T>::new(len_channel, n_workers));
let mut handles = Vec::with_capacity(n_workers);
for _ in 0..n_workers {
let context = channel.clone();
let handle = thread::spawn(move || {
Self::worker_thread(context);
});
handles.push(handle);
}
Self { channel, handles }
}
fn worker_thread(context: Arc<Channel<T>>) {
while let Some(task) = context.pop() {
task.execute();
}
}
pub fn close(self) {
self.channel.close();
for handle in self.handles {
match handle.join() {
Ok(()) => {}
Err(e) => {
error!("failed to wait for worker thread: {:?}", e);
}
}
}
}
}
pub struct Channel<T> {
state: Mutex<ChannelState<T>>,
consumer_wait: Condvar,
producer_wait: Condvar,
n_consumers: usize,
}
impl<T> Channel<T> {
fn new(len: usize, n_consumers: usize) -> Self {
Self {
state: Mutex::new(ChannelState::new(len)),
consumer_wait: Condvar::new(),
producer_wait: Condvar::new(),
n_consumers,
}
}
fn close(&self) {
let mut state = self.state.lock();
state.is_closed = true;
self.consumer_wait.notify_all();
self.producer_wait.notify_all();
}
#[inline]
fn pop(&self) -> Option<T> {
let mut state = self.state.lock();
loop {
let was_full = state.queue.len() == state.capacity;
if let Some(item) = state.queue.pop_front() {
if was_full {
self.producer_wait.notify_one();
}
return Some(item);
} else {
if state.is_closed {
return None;
}
state.n_waiting += 1;
if state.n_waiting == self.n_consumers {
self.producer_wait.notify_all();
}
state = self.consumer_wait.wait(state);
state.n_waiting -= 1;
}
}
}
pub fn push(&self, item: T) -> bool {
let mut state = self.state.lock();
while state.queue.len() == state.capacity {
if state.is_closed {
return false;
}
state = self.producer_wait.wait(state);
}
if state.is_closed {
return false;
}
state.queue.push_back(item);
self.consumer_wait.notify_one();
true
}
pub fn wait_complete(&self) {
let mut state = self.state.lock();
while !(state.queue.is_empty() && state.n_waiting == self.n_consumers) {
state = self.producer_wait.wait(state);
}
}
}
struct ChannelState<T> {
queue: VecDeque<T>,
capacity: usize,
n_waiting: usize,
is_closed: bool,
}
impl<T> ChannelState<T> {
fn new(capacity: usize) -> Self {
Self {
queue: VecDeque::with_capacity(capacity),
capacity,
n_waiting: 0,
is_closed: false,
}
}
}
pub struct BackgroundJobControl {
event: Event,
abort_flag: AtomicBool,
}
impl BackgroundJobControl {
pub fn new() -> anyhow::Result<Self> {
Ok(Self {
event: Event::new()?,
abort_flag: AtomicBool::new(false),
})
}
pub fn new_job(&self) -> BackgroundJob<'_> {
BackgroundJob {
event: &self.event,
abort_flag: &self.abort_flag,
}
}
pub fn abort(&self) {
self.abort_flag.store(true, Ordering::Release);
}
pub fn reset(&self) -> anyhow::Result<bool> {
self.abort_flag.store(false, Ordering::Release);
Ok(matches!(
self.event
.wait_timeout(Duration::ZERO)
.context("failed to get job complete event")?,
EventWaitResult::Signaled
))
}
pub fn get_completion_event(&self) -> &Event {
&self.event
}
}
pub struct BackgroundJob<'a> {
event: &'a Event,
abort_flag: &'a AtomicBool,
}
impl BackgroundJob<'_> {
pub fn is_aborted(&self) -> bool {
self.abort_flag.load(Ordering::Acquire)
}
}
impl Drop for BackgroundJob<'_> {
fn drop(&mut self) {
self.event.signal().expect("send job complete event");
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
#[derive(Clone, Copy)]
struct Context {
n_consume: usize,
n_executed: usize,
}
struct FakeTask {
context: Mutex<Context>,
waker: Condvar,
}
impl FakeTask {
fn new() -> Arc<Self> {
Arc::new(Self {
context: Mutex::new(Context {
n_consume: 0,
n_executed: 0,
}),
waker: Condvar::new(),
})
}
fn consume(&self, count: usize) {
let mut context = self.context.lock();
context.n_consume += count;
self.waker.notify_all();
}
fn n_executed(&self) -> usize {
self.context.lock().n_executed
}
}
impl Task for Arc<FakeTask> {
fn execute(self) {
let mut context = self.context.lock();
while context.n_consume == 0 {
context = self.waker.wait(context);
}
context.n_consume -= 1;
context.n_executed += 1;
}
}
fn wait_thread_with_timeout<T>(join_handle: thread::JoinHandle<T>, timeout_millis: u64) -> T {
for _ in 0..timeout_millis {
if join_handle.is_finished() {
return join_handle.join().unwrap();
}
thread::sleep(Duration::from_millis(1));
}
panic!("thread join timeout");
}
fn poll_until_with_timeout<F>(f: F, timeout_millis: u64)
where
F: Fn() -> bool,
{
for _ in 0..timeout_millis {
if f() {
break;
}
thread::sleep(Duration::from_millis(1));
}
}
#[test]
fn test_worker() {
let worker = Worker::new(2, 4);
let task = FakeTask::new();
let channel = worker.channel.clone();
for _ in 0..4 {
assert!(channel.push(task.clone()));
}
assert_eq!(task.n_executed(), 0);
task.consume(4);
worker.channel.wait_complete();
assert_eq!(task.n_executed(), 4);
worker.close();
}
#[test]
fn test_worker_push_after_close() {
let worker = Worker::new(2, 4);
let task = FakeTask::new();
let channel = worker.channel.clone();
worker.close();
assert!(!channel.push(task));
}
#[test]
fn test_worker_push_block() {
let worker = Worker::new(2, 4);
let task = FakeTask::new();
let channel = worker.channel.clone();
let task_cloned = task.clone();
wait_thread_with_timeout(
thread::spawn(move || {
for _ in 0..6 {
assert!(channel.push(task_cloned.clone()));
}
}),
100,
);
let channel = worker.channel.clone();
let task_cloned = task.clone();
let push_thread = thread::spawn(move || {
assert!(channel.push(task_cloned));
});
thread::sleep(Duration::from_millis(10));
assert!(!push_thread.is_finished());
task.consume(1);
wait_thread_with_timeout(push_thread, 100);
task.consume(6);
#[allow(clippy::redundant_clone)]
let task_clone = task.clone();
poll_until_with_timeout(|| task_clone.n_executed() == 7, 100);
assert_eq!(task.n_executed(), 7);
worker.close();
}
#[test]
fn test_worker_close_on_push_blocked() {
let worker = Worker::new(2, 4);
let task = FakeTask::new();
let channel = worker.channel.clone();
let task_cloned = task.clone();
wait_thread_with_timeout(
thread::spawn(move || {
for _ in 0..6 {
assert!(channel.push(task_cloned.clone()));
}
}),
100,
);
let channel = worker.channel.clone();
let task_cloned = task.clone();
let push_thread = thread::spawn(move || channel.push(task_cloned));
thread::sleep(Duration::from_millis(10));
let close_thread = thread::spawn(move || {
worker.close();
});
let push_result = wait_thread_with_timeout(push_thread, 100);
assert!(!push_result);
task.consume(6);
wait_thread_with_timeout(close_thread, 100);
}
#[test]
fn new_background_job_event() {
assert!(BackgroundJobControl::new().is_ok());
}
#[test]
fn background_job_is_not_aborted_default() {
let event = BackgroundJobControl::new().unwrap();
let job = event.new_job();
assert!(!job.is_aborted());
}
#[test]
fn abort_background_job() {
let event = BackgroundJobControl::new().unwrap();
let job = event.new_job();
event.abort();
assert!(job.is_aborted());
}
#[test]
fn reset_background_job() {
let event = BackgroundJobControl::new().unwrap();
event.abort();
event.reset().unwrap();
let job = event.new_job();
assert!(!job.is_aborted());
}
#[test]
fn reset_background_job_event() {
let event = BackgroundJobControl::new().unwrap();
let job = event.new_job();
drop(job);
assert!(event.reset().unwrap());
}
#[test]
fn reset_background_job_event_twice() {
let event = BackgroundJobControl::new().unwrap();
let job = event.new_job();
drop(job);
event.reset().unwrap();
assert!(!event.reset().unwrap());
}
#[test]
fn reset_background_job_event_no_jobs() {
let event = BackgroundJobControl::new().unwrap();
assert!(!event.reset().unwrap());
}
}