cros_async/sync/
cv.rs

1// Copyright 2020 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::cell::UnsafeCell;
6use std::hint;
7use std::mem;
8use std::sync::atomic::AtomicUsize;
9use std::sync::atomic::Ordering;
10use std::sync::Arc;
11
12use super::super::sync::mu::RawRwLock;
13use super::super::sync::mu::RwLockReadGuard;
14use super::super::sync::mu::RwLockWriteGuard;
15use super::super::sync::waiter::Kind as WaiterKind;
16use super::super::sync::waiter::Waiter;
17use super::super::sync::waiter::WaiterAdapter;
18use super::super::sync::waiter::WaiterList;
19use super::super::sync::waiter::WaitingFor;
20
21const SPINLOCK: usize = 1 << 0;
22const HAS_WAITERS: usize = 1 << 1;
23
24/// A primitive to wait for an event to occur without consuming CPU time.
25///
26/// Condition variables are used in combination with a `RwLock` when a thread wants to wait for some
27/// condition to become true. The condition must always be verified while holding the `RwLock` lock.
28/// It is an error to use a `Condvar` with more than one `RwLock` while there are threads waiting on
29/// the `Condvar`.
30///
31/// # Examples
32///
33/// ```edition2018
34/// use std::sync::Arc;
35/// use std::thread;
36/// use std::sync::mpsc::channel;
37///
38/// use cros_async::{
39///     block_on,
40///     sync::{Condvar, RwLock},
41/// };
42///
43/// const N: usize = 13;
44///
45/// // Spawn a few threads to increment a shared variable (non-atomically), and
46/// // let all threads waiting on the Condvar know once the increments are done.
47/// let data = Arc::new(RwLock::new(0));
48/// let cv = Arc::new(Condvar::new());
49///
50/// for _ in 0..N {
51///     let (data, cv) = (data.clone(), cv.clone());
52///     thread::spawn(move || {
53///         let mut data = block_on(data.lock());
54///         *data += 1;
55///         if *data == N {
56///             cv.notify_all();
57///         }
58///     });
59/// }
60///
61/// let mut val = block_on(data.lock());
62/// while *val != N {
63///     val = block_on(cv.wait(val));
64/// }
65/// ```
66#[repr(align(128))]
67pub struct Condvar {
68    state: AtomicUsize,
69    waiters: UnsafeCell<WaiterList>,
70    mu: UnsafeCell<usize>,
71}
72
73impl Condvar {
74    /// Creates a new condition variable ready to be waited on and notified.
75    pub fn new() -> Condvar {
76        Condvar {
77            state: AtomicUsize::new(0),
78            waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())),
79            mu: UnsafeCell::new(0),
80        }
81    }
82
83    /// Block the current thread until this `Condvar` is notified by another thread.
84    ///
85    /// This method will atomically unlock the `RwLock` held by `guard` and then block the current
86    /// thread. Any call to `notify_one` or `notify_all` after the `RwLock` is unlocked may wake up
87    /// the thread.
88    ///
89    /// To allow for more efficient scheduling, this call may return even when the programmer
90    /// doesn't expect the thread to be woken. Therefore, calls to `wait()` should be used inside a
91    /// loop that checks the predicate before continuing.
92    ///
93    /// Callers that are not in an async context may wish to use the `block_on` method to block the
94    /// thread until the `Condvar` is notified.
95    ///
96    /// # Panics
97    ///
98    /// This method will panic if used with more than one `RwLock` at the same time.
99    ///
100    /// # Examples
101    ///
102    /// ```
103    /// # use std::sync::Arc;
104    /// # use std::thread;
105    ///
106    /// # use cros_async::{
107    /// #     block_on,
108    /// #     sync::{Condvar, RwLock},
109    /// # };
110    ///
111    /// # let mu = Arc::new(RwLock::new(false));
112    /// # let cv = Arc::new(Condvar::new());
113    /// # let (mu2, cv2) = (mu.clone(), cv.clone());
114    ///
115    /// # let t = thread::spawn(move || {
116    /// #     *block_on(mu2.lock()) = true;
117    /// #     cv2.notify_all();
118    /// # });
119    ///
120    /// let mut ready = block_on(mu.lock());
121    /// while !*ready {
122    ///     ready = block_on(cv.wait(ready));
123    /// }
124    ///
125    /// # t.join().expect("failed to join thread");
126    /// ```
127    // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
128    // that doesn't compile.
129    #[allow(clippy::needless_lifetimes)]
130    pub async fn wait<'g, T>(&self, guard: RwLockWriteGuard<'g, T>) -> RwLockWriteGuard<'g, T> {
131        let waiter = Arc::new(Waiter::new(
132            WaiterKind::Exclusive,
133            cancel_waiter,
134            self as *const Condvar as usize,
135            WaitingFor::Condvar,
136        ));
137
138        self.add_waiter(waiter.clone(), guard.as_raw_rwlock());
139
140        // Get a reference to the rwlock and then drop the lock.
141        let mu = guard.into_inner();
142
143        // Wait to be woken up.
144        waiter.wait().await;
145
146        // Now re-acquire the lock.
147        mu.lock_from_cv().await
148    }
149
150    /// Like `wait()` but takes and returns a `RwLockReadGuard` instead.
151    // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
152    // that doesn't compile.
153    #[allow(clippy::needless_lifetimes)]
154    pub async fn wait_read<'g, T>(&self, guard: RwLockReadGuard<'g, T>) -> RwLockReadGuard<'g, T> {
155        let waiter = Arc::new(Waiter::new(
156            WaiterKind::Shared,
157            cancel_waiter,
158            self as *const Condvar as usize,
159            WaitingFor::Condvar,
160        ));
161
162        self.add_waiter(waiter.clone(), guard.as_raw_rwlock());
163
164        // Get a reference to the rwlock and then drop the lock.
165        let mu = guard.into_inner();
166
167        // Wait to be woken up.
168        waiter.wait().await;
169
170        // Now re-acquire the lock.
171        mu.read_lock_from_cv().await
172    }
173
174    fn add_waiter(&self, waiter: Arc<Waiter>, raw_rwlock: &RawRwLock) {
175        // Acquire the spin lock.
176        let mut oldstate = self.state.load(Ordering::Relaxed);
177        while (oldstate & SPINLOCK) != 0
178            || self
179                .state
180                .compare_exchange_weak(
181                    oldstate,
182                    oldstate | SPINLOCK | HAS_WAITERS,
183                    Ordering::Acquire,
184                    Ordering::Relaxed,
185                )
186                .is_err()
187        {
188            hint::spin_loop();
189            oldstate = self.state.load(Ordering::Relaxed);
190        }
191
192        // SAFETY:
193        // Safe because the spin lock guarantees exclusive access and the reference does not escape
194        // this function.
195        let mu = unsafe { &mut *self.mu.get() };
196        let muptr = raw_rwlock as *const RawRwLock as usize;
197
198        match *mu {
199            0 => *mu = muptr,
200            p if p == muptr => {}
201            _ => panic!("Attempting to use Condvar with more than one RwLock at the same time"),
202        }
203
204        // SAFETY:
205        // Safe because the spin lock guarantees exclusive access.
206        unsafe { (*self.waiters.get()).push_back(waiter) };
207
208        // Release the spin lock. Use a direct store here because no other thread can modify
209        // `self.state` while we hold the spin lock. Keep the `HAS_WAITERS` bit that we set earlier
210        // because we just added a waiter.
211        self.state.store(HAS_WAITERS, Ordering::Release);
212    }
213
214    /// Notify at most one thread currently waiting on the `Condvar`.
215    ///
216    /// If there is a thread currently waiting on the `Condvar` it will be woken up from its call to
217    /// `wait`.
218    ///
219    /// Unlike more traditional condition variable interfaces, this method requires a reference to
220    /// the `RwLock` associated with this `Condvar`. This is because it is inherently racy to call
221    /// `notify_one` or `notify_all` without first acquiring the `RwLock` lock. Additionally, taking
222    /// a reference to the `RwLock` here allows us to make some optimizations that can improve
223    /// performance by reducing unnecessary wakeups.
224    pub fn notify_one(&self) {
225        let mut oldstate = self.state.load(Ordering::Relaxed);
226        if (oldstate & HAS_WAITERS) == 0 {
227            // No waiters.
228            return;
229        }
230
231        while (oldstate & SPINLOCK) != 0
232            || self
233                .state
234                .compare_exchange_weak(
235                    oldstate,
236                    oldstate | SPINLOCK,
237                    Ordering::Acquire,
238                    Ordering::Relaxed,
239                )
240                .is_err()
241        {
242            hint::spin_loop();
243            oldstate = self.state.load(Ordering::Relaxed);
244        }
245
246        // SAFETY:
247        // Safe because the spin lock guarantees exclusive access and the reference does not escape
248        // this function.
249        let waiters = unsafe { &mut *self.waiters.get() };
250        let wake_list = get_wake_list(waiters);
251
252        let newstate = if waiters.is_empty() {
253            // SAFETY:
254            // Also clear the rwlock associated with this Condvar since there are no longer any
255            // waiters.  Safe because the spin lock guarantees exclusive access.
256            unsafe { *self.mu.get() = 0 };
257
258            // We are releasing the spin lock and there are no more waiters so we can clear all bits
259            // in `self.state`.
260            0
261        } else {
262            // There are still waiters so we need to keep the HAS_WAITERS bit in the state.
263            HAS_WAITERS
264        };
265
266        // Release the spin lock.
267        self.state.store(newstate, Ordering::Release);
268
269        // Now wake any waiters in the wake list.
270        for w in wake_list {
271            w.wake();
272        }
273    }
274
275    /// Notify all threads currently waiting on the `Condvar`.
276    ///
277    /// All threads currently waiting on the `Condvar` will be woken up from their call to `wait`.
278    ///
279    /// Unlike more traditional condition variable interfaces, this method requires a reference to
280    /// the `RwLock` associated with this `Condvar`. This is because it is inherently racy to call
281    /// `notify_one` or `notify_all` without first acquiring the `RwLock` lock. Additionally, taking
282    /// a reference to the `RwLock` here allows us to make some optimizations that can improve
283    /// performance by reducing unnecessary wakeups.
284    pub fn notify_all(&self) {
285        let mut oldstate = self.state.load(Ordering::Relaxed);
286        if (oldstate & HAS_WAITERS) == 0 {
287            // No waiters.
288            return;
289        }
290
291        while (oldstate & SPINLOCK) != 0
292            || self
293                .state
294                .compare_exchange_weak(
295                    oldstate,
296                    oldstate | SPINLOCK,
297                    Ordering::Acquire,
298                    Ordering::Relaxed,
299                )
300                .is_err()
301        {
302            hint::spin_loop();
303            oldstate = self.state.load(Ordering::Relaxed);
304        }
305
306        // SAFETY:
307        // Safe because the spin lock guarantees exclusive access to `self.waiters`.
308        let wake_list = unsafe { (*self.waiters.get()).take() };
309
310        // SAFETY:
311        // Clear the rwlock associated with this Condvar since there are no longer any waiters. Safe
312        // because we the spin lock guarantees exclusive access.
313        unsafe { *self.mu.get() = 0 };
314
315        // Mark any waiters left as no longer waiting for the Condvar.
316        for w in &wake_list {
317            w.set_waiting_for(WaitingFor::None);
318        }
319
320        // Release the spin lock.  We can clear all bits in the state since we took all the waiters.
321        self.state.store(0, Ordering::Release);
322
323        // Now wake any waiters in the wake list.
324        for w in wake_list {
325            w.wake();
326        }
327    }
328
329    fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) {
330        let mut oldstate = self.state.load(Ordering::Relaxed);
331        while oldstate & SPINLOCK != 0
332            || self
333                .state
334                .compare_exchange_weak(
335                    oldstate,
336                    oldstate | SPINLOCK,
337                    Ordering::Acquire,
338                    Ordering::Relaxed,
339                )
340                .is_err()
341        {
342            hint::spin_loop();
343            oldstate = self.state.load(Ordering::Relaxed);
344        }
345
346        // SAFETY:
347        // Safe because the spin lock provides exclusive access and the reference does not escape
348        // this function.
349        let waiters = unsafe { &mut *self.waiters.get() };
350
351        let waiting_for = waiter.is_waiting_for();
352        // Don't drop the old waiter now as we're still holding the spin lock.
353        let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Condvar {
354            // SAFETY:
355            // Safe because we know that the waiter is still linked and is waiting for the Condvar,
356            // which guarantees that it is still in `self.waiters`.
357            let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) };
358            cursor.remove()
359        } else {
360            None
361        };
362
363        let wake_list = if wake_next || waiting_for == WaitingFor::None {
364            // Either the waiter was already woken or it's been removed from the condvar's waiter
365            // list and is going to be woken. Either way, we need to wake up another thread.
366            get_wake_list(waiters)
367        } else {
368            WaiterList::new(WaiterAdapter::new())
369        };
370
371        let set_on_release = if waiters.is_empty() {
372            // SAFETY:
373            // Clear the rwlock associated with this Condvar since there are no longer any waiters.
374            // Safe because we the spin lock guarantees exclusive access.
375            unsafe { *self.mu.get() = 0 };
376
377            0
378        } else {
379            HAS_WAITERS
380        };
381
382        self.state.store(set_on_release, Ordering::Release);
383
384        // Now wake any waiters still left in the wake list.
385        for w in wake_list {
386            w.wake();
387        }
388
389        mem::drop(old_waiter);
390    }
391}
392
393// TODO(b/315998194): Add safety comment
394#[allow(clippy::undocumented_unsafe_blocks)]
395unsafe impl Send for Condvar {}
396// TODO(b/315998194): Add safety comment
397#[allow(clippy::undocumented_unsafe_blocks)]
398unsafe impl Sync for Condvar {}
399
400impl Default for Condvar {
401    fn default() -> Self {
402        Self::new()
403    }
404}
405
406// Scan `waiters` and return all waiters that should be woken up.
407//
408// If the first waiter is trying to acquire a shared lock, then all waiters in the list that are
409// waiting for a shared lock are also woken up. In addition one writer is woken up, if possible.
410//
411// If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and
412// the rest of the list is not scanned.
413fn get_wake_list(waiters: &mut WaiterList) -> WaiterList {
414    let mut to_wake = WaiterList::new(WaiterAdapter::new());
415    let mut cursor = waiters.front_mut();
416
417    let mut waking_readers = false;
418    let mut all_readers = true;
419    while let Some(w) = cursor.get() {
420        match w.kind() {
421            WaiterKind::Exclusive if !waking_readers => {
422                // This is the first waiter and it's a writer. No need to check the other waiters.
423                // Also mark the waiter as having been removed from the Condvar's waiter list.
424                let waiter = cursor.remove().unwrap();
425                waiter.set_waiting_for(WaitingFor::None);
426                to_wake.push_back(waiter);
427                break;
428            }
429
430            WaiterKind::Shared => {
431                // This is a reader and the first waiter in the list was not a writer so wake up all
432                // the readers in the wait list.
433                let waiter = cursor.remove().unwrap();
434                waiter.set_waiting_for(WaitingFor::None);
435                to_wake.push_back(waiter);
436                waking_readers = true;
437            }
438
439            WaiterKind::Exclusive => {
440                debug_assert!(waking_readers);
441                if all_readers {
442                    // We are waking readers but we need to ensure that at least one writer is woken
443                    // up. Since we haven't yet woken up a writer, wake up this one.
444                    let waiter = cursor.remove().unwrap();
445                    waiter.set_waiting_for(WaitingFor::None);
446                    to_wake.push_back(waiter);
447                    all_readers = false;
448                } else {
449                    // We are waking readers and have already woken one writer. Skip this one.
450                    cursor.move_next();
451                }
452            }
453        }
454    }
455
456    to_wake
457}
458
459fn cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool) {
460    let condvar = cv as *const Condvar;
461
462    // SAFETY:
463    // Safe because the thread that owns the waiter being canceled must also own a reference to the
464    // Condvar, which guarantees that this pointer is valid.
465    unsafe { (*condvar).cancel_waiter(waiter, wake_next) }
466}
467
468// TODO(b/194338842): Fix tests for windows
469#[cfg(any(target_os = "android", target_os = "linux"))]
470#[cfg(test)]
471mod test {
472    use std::future::Future;
473    use std::mem;
474    use std::ptr;
475    use std::rc::Rc;
476    use std::sync::mpsc::channel;
477    use std::sync::mpsc::Sender;
478    use std::sync::Arc;
479    use std::task::Context;
480    use std::task::Poll;
481    use std::thread;
482    use std::thread::JoinHandle;
483    use std::time::Duration;
484
485    use futures::channel::oneshot;
486    use futures::select;
487    use futures::task::waker_ref;
488    use futures::task::ArcWake;
489    use futures::FutureExt;
490    use futures_executor::LocalPool;
491    use futures_executor::LocalSpawner;
492    use futures_executor::ThreadPool;
493    use futures_util::task::LocalSpawnExt;
494
495    use super::super::super::block_on;
496    use super::super::super::sync::RwLock;
497    use super::*;
498
499    // Dummy waker used when we want to manually drive futures.
500    struct TestWaker;
501    impl ArcWake for TestWaker {
502        fn wake_by_ref(_arc_self: &Arc<Self>) {}
503    }
504
505    #[test]
506    fn smoke() {
507        let cv = Condvar::new();
508        cv.notify_one();
509        cv.notify_all();
510    }
511
512    #[test]
513    fn notify_one() {
514        let mu = Arc::new(RwLock::new(()));
515        let cv = Arc::new(Condvar::new());
516
517        let mu2 = mu.clone();
518        let cv2 = cv.clone();
519
520        let guard = block_on(mu.lock());
521        thread::spawn(move || {
522            let _g = block_on(mu2.lock());
523            cv2.notify_one();
524        });
525
526        let guard = block_on(cv.wait(guard));
527        mem::drop(guard);
528    }
529
530    #[test]
531    fn multi_rwlock() {
532        const NUM_THREADS: usize = 5;
533
534        let mu = Arc::new(RwLock::new(false));
535        let cv = Arc::new(Condvar::new());
536
537        let mut threads = Vec::with_capacity(NUM_THREADS);
538        for _ in 0..NUM_THREADS {
539            let mu = mu.clone();
540            let cv = cv.clone();
541
542            threads.push(thread::spawn(move || {
543                let mut ready = block_on(mu.lock());
544                while !*ready {
545                    ready = block_on(cv.wait(ready));
546                }
547            }));
548        }
549
550        let mut g = block_on(mu.lock());
551        *g = true;
552        mem::drop(g);
553        cv.notify_all();
554
555        threads
556            .into_iter()
557            .try_for_each(JoinHandle::join)
558            .expect("Failed to join threads");
559
560        // Now use the Condvar with a different rwlock.
561        let alt_mu = Arc::new(RwLock::new(None));
562        let alt_mu2 = alt_mu.clone();
563        let cv2 = cv.clone();
564        let handle = thread::spawn(move || {
565            let mut g = block_on(alt_mu2.lock());
566            while g.is_none() {
567                g = block_on(cv2.wait(g));
568            }
569        });
570
571        let mut alt_g = block_on(alt_mu.lock());
572        *alt_g = Some(());
573        mem::drop(alt_g);
574        cv.notify_all();
575
576        handle
577            .join()
578            .expect("Failed to join thread alternate rwlock");
579    }
580
581    #[test]
582    fn notify_one_single_thread_async() {
583        async fn notify(mu: Rc<RwLock<()>>, cv: Rc<Condvar>) {
584            let _g = mu.lock().await;
585            cv.notify_one();
586        }
587
588        async fn wait(mu: Rc<RwLock<()>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
589            let mu2 = Rc::clone(&mu);
590            let cv2 = Rc::clone(&cv);
591
592            let g = mu.lock().await;
593            // Has to be spawned _after_ acquiring the lock to prevent a race
594            // where the notify happens before the waiter has acquired the lock.
595            spawner
596                .spawn_local(notify(mu2, cv2))
597                .expect("Failed to spawn `notify` task");
598            let _g = cv.wait(g).await;
599        }
600
601        let mut ex = LocalPool::new();
602        let spawner = ex.spawner();
603
604        let mu = Rc::new(RwLock::new(()));
605        let cv = Rc::new(Condvar::new());
606
607        spawner
608            .spawn_local(wait(mu, cv, spawner.clone()))
609            .expect("Failed to spawn `wait` task");
610
611        ex.run();
612    }
613
614    #[test]
615    fn notify_one_multi_thread_async() {
616        async fn notify(mu: Arc<RwLock<()>>, cv: Arc<Condvar>) {
617            let _g = mu.lock().await;
618            cv.notify_one();
619        }
620
621        async fn wait(mu: Arc<RwLock<()>>, cv: Arc<Condvar>, tx: Sender<()>, pool: ThreadPool) {
622            let mu2 = Arc::clone(&mu);
623            let cv2 = Arc::clone(&cv);
624
625            let g = mu.lock().await;
626            // Has to be spawned _after_ acquiring the lock to prevent a race
627            // where the notify happens before the waiter has acquired the lock.
628            pool.spawn_ok(notify(mu2, cv2));
629            let _g = cv.wait(g).await;
630
631            tx.send(()).expect("Failed to send completion notification");
632        }
633
634        let ex = ThreadPool::new().expect("Failed to create ThreadPool");
635
636        let mu = Arc::new(RwLock::new(()));
637        let cv = Arc::new(Condvar::new());
638
639        let (tx, rx) = channel();
640        ex.spawn_ok(wait(mu, cv, tx, ex.clone()));
641
642        rx.recv_timeout(Duration::from_secs(5))
643            .expect("Failed to receive completion notification");
644    }
645
646    #[test]
647    fn notify_one_with_cancel() {
648        const TASKS: usize = 17;
649        const OBSERVERS: usize = 7;
650        const ITERATIONS: usize = 103;
651
652        async fn observe(mu: &Arc<RwLock<usize>>, cv: &Arc<Condvar>) {
653            let mut count = mu.read_lock().await;
654            while *count == 0 {
655                count = cv.wait_read(count).await;
656            }
657            // SAFETY: Safe because count is valid and is byte aligned.
658            let _ = unsafe { ptr::read_volatile(&*count as *const usize) };
659        }
660
661        async fn decrement(mu: &Arc<RwLock<usize>>, cv: &Arc<Condvar>) {
662            let mut count = mu.lock().await;
663            while *count == 0 {
664                count = cv.wait(count).await;
665            }
666            *count -= 1;
667        }
668
669        async fn increment(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
670            for _ in 0..TASKS * OBSERVERS * ITERATIONS {
671                *mu.lock().await += 1;
672                cv.notify_one();
673            }
674
675            done.send(()).expect("Failed to send completion message");
676        }
677
678        async fn observe_either(
679            mu: Arc<RwLock<usize>>,
680            cv: Arc<Condvar>,
681            alt_mu: Arc<RwLock<usize>>,
682            alt_cv: Arc<Condvar>,
683            done: Sender<()>,
684        ) {
685            for _ in 0..ITERATIONS {
686                select! {
687                    () = observe(&mu, &cv).fuse() => {},
688                    () = observe(&alt_mu, &alt_cv).fuse() => {},
689                }
690            }
691
692            done.send(()).expect("Failed to send completion message");
693        }
694
695        async fn decrement_either(
696            mu: Arc<RwLock<usize>>,
697            cv: Arc<Condvar>,
698            alt_mu: Arc<RwLock<usize>>,
699            alt_cv: Arc<Condvar>,
700            done: Sender<()>,
701        ) {
702            for _ in 0..ITERATIONS {
703                select! {
704                    () = decrement(&mu, &cv).fuse() => {},
705                    () = decrement(&alt_mu, &alt_cv).fuse() => {},
706                }
707            }
708
709            done.send(()).expect("Failed to send completion message");
710        }
711
712        let ex = ThreadPool::new().expect("Failed to create ThreadPool");
713
714        let mu = Arc::new(RwLock::new(0usize));
715        let alt_mu = Arc::new(RwLock::new(0usize));
716
717        let cv = Arc::new(Condvar::new());
718        let alt_cv = Arc::new(Condvar::new());
719
720        let (tx, rx) = channel();
721        for _ in 0..TASKS {
722            ex.spawn_ok(decrement_either(
723                Arc::clone(&mu),
724                Arc::clone(&cv),
725                Arc::clone(&alt_mu),
726                Arc::clone(&alt_cv),
727                tx.clone(),
728            ));
729        }
730
731        for _ in 0..OBSERVERS {
732            ex.spawn_ok(observe_either(
733                Arc::clone(&mu),
734                Arc::clone(&cv),
735                Arc::clone(&alt_mu),
736                Arc::clone(&alt_cv),
737                tx.clone(),
738            ));
739        }
740
741        ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
742        ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
743
744        for _ in 0..TASKS + OBSERVERS + 2 {
745            if let Err(e) = rx.recv_timeout(Duration::from_secs(20)) {
746                panic!("Error while waiting for threads to complete: {e}");
747            }
748        }
749
750        assert_eq!(
751            *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
752            (TASKS * OBSERVERS * ITERATIONS * 2) - (TASKS * ITERATIONS)
753        );
754        assert_eq!(cv.state.load(Ordering::Relaxed), 0);
755        assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
756    }
757
758    #[test]
759    fn notify_all_with_cancel() {
760        const TASKS: usize = 17;
761        const ITERATIONS: usize = 103;
762
763        async fn decrement(mu: &Arc<RwLock<usize>>, cv: &Arc<Condvar>) {
764            let mut count = mu.lock().await;
765            while *count == 0 {
766                count = cv.wait(count).await;
767            }
768            *count -= 1;
769        }
770
771        async fn increment(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
772            for _ in 0..TASKS * ITERATIONS {
773                *mu.lock().await += 1;
774                cv.notify_all();
775            }
776
777            done.send(()).expect("Failed to send completion message");
778        }
779
780        async fn decrement_either(
781            mu: Arc<RwLock<usize>>,
782            cv: Arc<Condvar>,
783            alt_mu: Arc<RwLock<usize>>,
784            alt_cv: Arc<Condvar>,
785            done: Sender<()>,
786        ) {
787            for _ in 0..ITERATIONS {
788                select! {
789                    () = decrement(&mu, &cv).fuse() => {},
790                    () = decrement(&alt_mu, &alt_cv).fuse() => {},
791                }
792            }
793
794            done.send(()).expect("Failed to send completion message");
795        }
796
797        let ex = ThreadPool::new().expect("Failed to create ThreadPool");
798
799        let mu = Arc::new(RwLock::new(0usize));
800        let alt_mu = Arc::new(RwLock::new(0usize));
801
802        let cv = Arc::new(Condvar::new());
803        let alt_cv = Arc::new(Condvar::new());
804
805        let (tx, rx) = channel();
806        for _ in 0..TASKS {
807            ex.spawn_ok(decrement_either(
808                Arc::clone(&mu),
809                Arc::clone(&cv),
810                Arc::clone(&alt_mu),
811                Arc::clone(&alt_cv),
812                tx.clone(),
813            ));
814        }
815
816        ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
817        ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
818
819        for _ in 0..TASKS + 2 {
820            if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) {
821                panic!("Error while waiting for threads to complete: {e}");
822            }
823        }
824
825        assert_eq!(
826            *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
827            TASKS * ITERATIONS
828        );
829        assert_eq!(cv.state.load(Ordering::Relaxed), 0);
830        assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
831    }
832    #[test]
833    fn notify_all() {
834        const THREADS: usize = 13;
835
836        let mu = Arc::new(RwLock::new(0));
837        let cv = Arc::new(Condvar::new());
838        let (tx, rx) = channel();
839
840        let mut threads = Vec::with_capacity(THREADS);
841        for _ in 0..THREADS {
842            let mu2 = mu.clone();
843            let cv2 = cv.clone();
844            let tx2 = tx.clone();
845
846            threads.push(thread::spawn(move || {
847                let mut count = block_on(mu2.lock());
848                *count += 1;
849                if *count == THREADS {
850                    tx2.send(()).unwrap();
851                }
852
853                while *count != 0 {
854                    count = block_on(cv2.wait(count));
855                }
856            }));
857        }
858
859        mem::drop(tx);
860
861        // Wait till all threads have started.
862        rx.recv_timeout(Duration::from_secs(5)).unwrap();
863
864        let mut count = block_on(mu.lock());
865        *count = 0;
866        mem::drop(count);
867        cv.notify_all();
868
869        for t in threads {
870            t.join().unwrap();
871        }
872    }
873
874    #[test]
875    fn notify_all_single_thread_async() {
876        const TASKS: usize = 13;
877
878        async fn reset(mu: Rc<RwLock<usize>>, cv: Rc<Condvar>) {
879            let mut count = mu.lock().await;
880            *count = 0;
881            cv.notify_all();
882        }
883
884        async fn watcher(mu: Rc<RwLock<usize>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
885            let mut count = mu.lock().await;
886            *count += 1;
887            if *count == TASKS {
888                spawner
889                    .spawn_local(reset(mu.clone(), cv.clone()))
890                    .expect("Failed to spawn reset task");
891            }
892
893            while *count != 0 {
894                count = cv.wait(count).await;
895            }
896        }
897
898        let mut ex = LocalPool::new();
899        let spawner = ex.spawner();
900
901        let mu = Rc::new(RwLock::new(0));
902        let cv = Rc::new(Condvar::new());
903
904        for _ in 0..TASKS {
905            spawner
906                .spawn_local(watcher(mu.clone(), cv.clone(), spawner.clone()))
907                .expect("Failed to spawn watcher task");
908        }
909
910        ex.run();
911    }
912
913    #[test]
914    fn notify_all_multi_thread_async() {
915        const TASKS: usize = 13;
916
917        async fn reset(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>) {
918            let mut count = mu.lock().await;
919            *count = 0;
920            cv.notify_all();
921        }
922
923        async fn watcher(
924            mu: Arc<RwLock<usize>>,
925            cv: Arc<Condvar>,
926            pool: ThreadPool,
927            tx: Sender<()>,
928        ) {
929            let mut count = mu.lock().await;
930            *count += 1;
931            if *count == TASKS {
932                pool.spawn_ok(reset(mu.clone(), cv.clone()));
933            }
934
935            while *count != 0 {
936                count = cv.wait(count).await;
937            }
938
939            tx.send(()).expect("Failed to send completion notification");
940        }
941
942        let pool = ThreadPool::new().expect("Failed to create ThreadPool");
943
944        let mu = Arc::new(RwLock::new(0));
945        let cv = Arc::new(Condvar::new());
946
947        let (tx, rx) = channel();
948        for _ in 0..TASKS {
949            pool.spawn_ok(watcher(mu.clone(), cv.clone(), pool.clone(), tx.clone()));
950        }
951
952        for _ in 0..TASKS {
953            rx.recv_timeout(Duration::from_secs(5))
954                .expect("Failed to receive completion notification");
955        }
956    }
957
958    #[test]
959    fn wake_all_readers() {
960        async fn read(mu: Arc<RwLock<bool>>, cv: Arc<Condvar>) {
961            let mut ready = mu.read_lock().await;
962            while !*ready {
963                ready = cv.wait_read(ready).await;
964            }
965        }
966
967        let mu = Arc::new(RwLock::new(false));
968        let cv = Arc::new(Condvar::new());
969        let mut readers = [
970            Box::pin(read(mu.clone(), cv.clone())),
971            Box::pin(read(mu.clone(), cv.clone())),
972            Box::pin(read(mu.clone(), cv.clone())),
973            Box::pin(read(mu.clone(), cv.clone())),
974        ];
975
976        let arc_waker = Arc::new(TestWaker);
977        let waker = waker_ref(&arc_waker);
978        let mut cx = Context::from_waker(&waker);
979
980        // First have all the readers wait on the Condvar.
981        for r in &mut readers {
982            if let Poll::Ready(()) = r.as_mut().poll(&mut cx) {
983                panic!("reader unexpectedly ready");
984            }
985        }
986
987        assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
988
989        // Now make the condition true and notify the condvar. Even though we will call notify_one,
990        // all the readers should be woken up.
991        *block_on(mu.lock()) = true;
992        cv.notify_one();
993
994        assert_eq!(cv.state.load(Ordering::Relaxed), 0);
995
996        // All readers should now be able to complete.
997        for r in &mut readers {
998            if r.as_mut().poll(&mut cx).is_pending() {
999                panic!("reader unable to complete");
1000            }
1001        }
1002    }
1003
1004    #[test]
1005    fn cancel_before_notify() {
1006        async fn dec(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>) {
1007            let mut count = mu.lock().await;
1008
1009            while *count == 0 {
1010                count = cv.wait(count).await;
1011            }
1012
1013            *count -= 1;
1014        }
1015
1016        let mu = Arc::new(RwLock::new(0));
1017        let cv = Arc::new(Condvar::new());
1018
1019        let arc_waker = Arc::new(TestWaker);
1020        let waker = waker_ref(&arc_waker);
1021        let mut cx = Context::from_waker(&waker);
1022
1023        let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1024        let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1025
1026        if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1027            panic!("future unexpectedly ready");
1028        }
1029        if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1030            panic!("future unexpectedly ready");
1031        }
1032        assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1033
1034        *block_on(mu.lock()) = 2;
1035        // Drop fut1 before notifying the cv.
1036        mem::drop(fut1);
1037        cv.notify_one();
1038
1039        // fut2 should now be ready to complete.
1040        assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1041
1042        if fut2.as_mut().poll(&mut cx).is_pending() {
1043            panic!("future unable to complete");
1044        }
1045
1046        assert_eq!(*block_on(mu.lock()), 1);
1047    }
1048
1049    #[test]
1050    fn cancel_after_notify_one() {
1051        async fn dec(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>) {
1052            let mut count = mu.lock().await;
1053
1054            while *count == 0 {
1055                count = cv.wait(count).await;
1056            }
1057
1058            *count -= 1;
1059        }
1060
1061        let mu = Arc::new(RwLock::new(0));
1062        let cv = Arc::new(Condvar::new());
1063
1064        let arc_waker = Arc::new(TestWaker);
1065        let waker = waker_ref(&arc_waker);
1066        let mut cx = Context::from_waker(&waker);
1067
1068        let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1069        let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1070
1071        if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1072            panic!("future unexpectedly ready");
1073        }
1074        if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1075            panic!("future unexpectedly ready");
1076        }
1077        assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1078
1079        *block_on(mu.lock()) = 2;
1080        cv.notify_one();
1081
1082        // fut1 should now be ready to complete. Drop it before polling. This should wake up fut2.
1083        mem::drop(fut1);
1084        assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1085
1086        if fut2.as_mut().poll(&mut cx).is_pending() {
1087            panic!("future unable to complete");
1088        }
1089
1090        assert_eq!(*block_on(mu.lock()), 1);
1091    }
1092
1093    #[test]
1094    fn cancel_after_notify_all() {
1095        async fn dec(mu: Arc<RwLock<usize>>, cv: Arc<Condvar>) {
1096            let mut count = mu.lock().await;
1097
1098            while *count == 0 {
1099                count = cv.wait(count).await;
1100            }
1101
1102            *count -= 1;
1103        }
1104
1105        let mu = Arc::new(RwLock::new(0));
1106        let cv = Arc::new(Condvar::new());
1107
1108        let arc_waker = Arc::new(TestWaker);
1109        let waker = waker_ref(&arc_waker);
1110        let mut cx = Context::from_waker(&waker);
1111
1112        let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1113        let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1114
1115        if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1116            panic!("future unexpectedly ready");
1117        }
1118        if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1119            panic!("future unexpectedly ready");
1120        }
1121        assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1122
1123        let mut count = block_on(mu.lock());
1124        *count = 2;
1125
1126        // Notify the cv while holding the lock. This should wake up both waiters.
1127        cv.notify_all();
1128        assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1129
1130        mem::drop(count);
1131
1132        mem::drop(fut1);
1133
1134        if fut2.as_mut().poll(&mut cx).is_pending() {
1135            panic!("future unable to complete");
1136        }
1137
1138        assert_eq!(*block_on(mu.lock()), 1);
1139    }
1140
1141    #[test]
1142    fn timed_wait() {
1143        async fn wait_deadline(
1144            mu: Arc<RwLock<usize>>,
1145            cv: Arc<Condvar>,
1146            timeout: oneshot::Receiver<()>,
1147        ) {
1148            let mut count = mu.lock().await;
1149
1150            if *count == 0 {
1151                let mut rx = timeout.fuse();
1152
1153                while *count == 0 {
1154                    select! {
1155                        res = rx => {
1156                            if let Err(e) = res {
1157                                panic!("Error while receiving timeout notification: {e}");
1158                            }
1159
1160                            return;
1161                        },
1162                        c = cv.wait(count).fuse() => count = c,
1163                    }
1164                }
1165            }
1166
1167            *count += 1;
1168        }
1169
1170        let mu = Arc::new(RwLock::new(0));
1171        let cv = Arc::new(Condvar::new());
1172
1173        let arc_waker = Arc::new(TestWaker);
1174        let waker = waker_ref(&arc_waker);
1175        let mut cx = Context::from_waker(&waker);
1176
1177        let (tx, rx) = oneshot::channel();
1178        let mut wait = Box::pin(wait_deadline(mu.clone(), cv.clone(), rx));
1179
1180        if let Poll::Ready(()) = wait.as_mut().poll(&mut cx) {
1181            panic!("wait_deadline unexpectedly ready");
1182        }
1183
1184        assert_eq!(cv.state.load(Ordering::Relaxed), HAS_WAITERS);
1185
1186        // Signal the channel, which should cancel the wait.
1187        tx.send(()).expect("Failed to send wakeup");
1188
1189        // Wait for the timer to run out.
1190        if wait.as_mut().poll(&mut cx).is_pending() {
1191            panic!("wait_deadline unable to complete in time");
1192        }
1193
1194        assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1195        assert_eq!(*block_on(mu.lock()), 0);
1196    }
1197}