cros_async/sync/
waiter.rs

1// Copyright 2020 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::cell::UnsafeCell;
6use std::future::Future;
7use std::mem;
8use std::pin::Pin;
9use std::ptr::NonNull;
10use std::sync::atomic::AtomicBool;
11use std::sync::atomic::AtomicU8;
12use std::sync::atomic::Ordering;
13use std::sync::Arc;
14use std::task::Context;
15use std::task::Poll;
16use std::task::Waker;
17
18use intrusive_collections::intrusive_adapter;
19use intrusive_collections::linked_list::LinkedList;
20use intrusive_collections::linked_list::LinkedListOps;
21use intrusive_collections::DefaultLinkOps;
22use intrusive_collections::LinkOps;
23
24use super::super::sync::SpinLock;
25
26// An atomic version of a LinkedListLink. See https://github.com/Amanieu/intrusive-rs/issues/47 for
27// more details.
28#[repr(align(128))]
29pub struct AtomicLink {
30    prev: UnsafeCell<Option<NonNull<AtomicLink>>>,
31    next: UnsafeCell<Option<NonNull<AtomicLink>>>,
32    linked: AtomicBool,
33}
34
35impl AtomicLink {
36    fn new() -> AtomicLink {
37        AtomicLink {
38            linked: AtomicBool::new(false),
39            prev: UnsafeCell::new(None),
40            next: UnsafeCell::new(None),
41        }
42    }
43
44    fn is_linked(&self) -> bool {
45        self.linked.load(Ordering::Relaxed)
46    }
47}
48
49impl DefaultLinkOps for AtomicLink {
50    type Ops = AtomicLinkOps;
51
52    const NEW: Self::Ops = AtomicLinkOps;
53}
54
55// SAFETY:
56// Safe because the only way to mutate `AtomicLink` is via the `LinkedListOps` trait whose methods
57// are all unsafe and require that the caller has first called `acquire_link` (and had it return
58// true) to use them safely.
59unsafe impl Send for AtomicLink {}
60// SAFETY: See safety comment for impl Send
61unsafe impl Sync for AtomicLink {}
62
63#[derive(Copy, Clone, Default)]
64pub struct AtomicLinkOps;
65
66// TODO(b/315998194): Add safety comment
67#[allow(clippy::undocumented_unsafe_blocks)]
68unsafe impl LinkOps for AtomicLinkOps {
69    type LinkPtr = NonNull<AtomicLink>;
70
71    unsafe fn acquire_link(&mut self, ptr: Self::LinkPtr) -> bool {
72        !ptr.as_ref().linked.swap(true, Ordering::Acquire)
73    }
74
75    unsafe fn release_link(&mut self, ptr: Self::LinkPtr) {
76        ptr.as_ref().linked.store(false, Ordering::Release)
77    }
78}
79
80// TODO(b/315998194): Add safety comment
81#[allow(clippy::undocumented_unsafe_blocks)]
82unsafe impl LinkedListOps for AtomicLinkOps {
83    unsafe fn next(&self, ptr: Self::LinkPtr) -> Option<Self::LinkPtr> {
84        *ptr.as_ref().next.get()
85    }
86
87    unsafe fn prev(&self, ptr: Self::LinkPtr) -> Option<Self::LinkPtr> {
88        *ptr.as_ref().prev.get()
89    }
90
91    unsafe fn set_next(&mut self, ptr: Self::LinkPtr, next: Option<Self::LinkPtr>) {
92        *ptr.as_ref().next.get() = next;
93    }
94
95    unsafe fn set_prev(&mut self, ptr: Self::LinkPtr, prev: Option<Self::LinkPtr>) {
96        *ptr.as_ref().prev.get() = prev;
97    }
98}
99
100#[derive(Clone, Copy)]
101pub enum Kind {
102    Shared,
103    Exclusive,
104}
105
106enum State {
107    Init,
108    Waiting(Waker),
109    Woken,
110    Finished,
111    Processing,
112}
113
114// Indicates the queue to which the waiter belongs. It is the responsibility of the Mutex and
115// Condvar implementations to update this value when adding/removing a Waiter from their respective
116// waiter lists.
117#[repr(u8)]
118#[derive(Debug, Eq, PartialEq)]
119pub enum WaitingFor {
120    // The waiter is either not linked into  a waiter list or it is linked into a temporary list.
121    None = 0,
122    // The waiter is linked into the Mutex's waiter list.
123    Mutex = 1,
124    // The waiter is linked into the Condvar's waiter list.
125    Condvar = 2,
126}
127
128// Represents a thread currently blocked on a Condvar or on acquiring a Mutex.
129pub struct Waiter {
130    link: AtomicLink,
131    state: SpinLock<State>,
132    cancel: fn(usize, &Waiter, bool),
133    cancel_data: usize,
134    kind: Kind,
135    waiting_for: AtomicU8,
136}
137
138impl Waiter {
139    // Create a new, initialized Waiter.
140    //
141    // `kind` should indicate whether this waiter represent a thread that is waiting for a shared
142    // lock or an exclusive lock.
143    //
144    // `cancel` is the function that is called when a `WaitFuture` (returned by the `wait()`
145    // function) is dropped before it can complete. `cancel_data` is used as the first parameter of
146    // the `cancel` function. The second parameter is the `Waiter` that was canceled and the third
147    // parameter indicates whether the `WaitFuture` was dropped after it was woken (but before it
148    // was polled to completion). A value of `false` for the third parameter may already be stale
149    // by the time the cancel function runs and so does not guarantee that the waiter was not woken.
150    // In this case, implementations should still check if the Waiter was woken. However, a value of
151    // `true` guarantees that the waiter was already woken up so no additional checks are necessary.
152    // In this case, the cancel implementation should wake up the next waiter in its wait list, if
153    // any.
154    //
155    // `waiting_for` indicates the waiter list to which this `Waiter` will be added. See the
156    // documentation of the `WaitingFor` enum for the meaning of the different values.
157    pub fn new(
158        kind: Kind,
159        cancel: fn(usize, &Waiter, bool),
160        cancel_data: usize,
161        waiting_for: WaitingFor,
162    ) -> Waiter {
163        Waiter {
164            link: AtomicLink::new(),
165            state: SpinLock::new(State::Init),
166            cancel,
167            cancel_data,
168            kind,
169            waiting_for: AtomicU8::new(waiting_for as u8),
170        }
171    }
172
173    // The kind of lock that this `Waiter` is waiting to acquire.
174    pub fn kind(&self) -> Kind {
175        self.kind
176    }
177
178    // Returns true if this `Waiter` is currently linked into a waiter list.
179    pub fn is_linked(&self) -> bool {
180        self.link.is_linked()
181    }
182
183    // Indicates the waiter list to which this `Waiter` belongs.
184    pub fn is_waiting_for(&self) -> WaitingFor {
185        match self.waiting_for.load(Ordering::Acquire) {
186            0 => WaitingFor::None,
187            1 => WaitingFor::Mutex,
188            2 => WaitingFor::Condvar,
189            v => panic!("Unknown value for `WaitingFor`: {v}"),
190        }
191    }
192
193    // Change the waiter list to which this `Waiter` belongs. This will panic if called when the
194    // `Waiter` is still linked into a waiter list.
195    pub fn set_waiting_for(&self, waiting_for: WaitingFor) {
196        self.waiting_for.store(waiting_for as u8, Ordering::Release);
197    }
198
199    // Reset the Waiter back to its initial state. Panics if this `Waiter` is still linked into a
200    // waiter list.
201    pub fn reset(&self, waiting_for: WaitingFor) {
202        debug_assert!(!self.is_linked(), "Cannot reset `Waiter` while linked");
203        self.set_waiting_for(waiting_for);
204
205        let mut state = self.state.lock();
206        if let State::Waiting(waker) = mem::replace(&mut *state, State::Init) {
207            mem::drop(state);
208            mem::drop(waker);
209        }
210    }
211
212    // Wait until woken up by another thread.
213    pub fn wait(&self) -> WaitFuture<'_> {
214        WaitFuture { waiter: self }
215    }
216
217    // Wake up the thread associated with this `Waiter`. Panics if `waiting_for()` does not return
218    // `WaitingFor::None` or if `is_linked()` returns true.
219    pub fn wake(&self) {
220        debug_assert!(!self.is_linked(), "Cannot wake `Waiter` while linked");
221        debug_assert_eq!(self.is_waiting_for(), WaitingFor::None);
222
223        let mut state = self.state.lock();
224
225        if let State::Waiting(waker) = mem::replace(&mut *state, State::Woken) {
226            mem::drop(state);
227            waker.wake();
228        }
229    }
230}
231
232pub struct WaitFuture<'w> {
233    waiter: &'w Waiter,
234}
235
236impl Future for WaitFuture<'_> {
237    type Output = ();
238
239    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
240        let mut state = self.waiter.state.lock();
241
242        match mem::replace(&mut *state, State::Processing) {
243            State::Init => {
244                *state = State::Waiting(cx.waker().clone());
245
246                Poll::Pending
247            }
248            State::Waiting(old_waker) => {
249                *state = State::Waiting(cx.waker().clone());
250                mem::drop(state);
251                mem::drop(old_waker);
252
253                Poll::Pending
254            }
255            State::Woken => {
256                *state = State::Finished;
257                Poll::Ready(())
258            }
259            State::Finished => {
260                panic!("Future polled after returning Poll::Ready");
261            }
262            State::Processing => {
263                panic!("Unexpected waker state");
264            }
265        }
266    }
267}
268
269impl Drop for WaitFuture<'_> {
270    fn drop(&mut self) {
271        let state = self.waiter.state.lock();
272
273        match *state {
274            State::Finished => {}
275            State::Processing => panic!("Unexpected waker state"),
276            State::Woken => {
277                mem::drop(state);
278
279                // We were woken but not polled.  Wake up the next waiter.
280                (self.waiter.cancel)(self.waiter.cancel_data, self.waiter, true);
281            }
282            _ => {
283                mem::drop(state);
284
285                // Not woken.  No need to wake up any waiters.
286                (self.waiter.cancel)(self.waiter.cancel_data, self.waiter, false);
287            }
288        }
289    }
290}
291
292intrusive_adapter!(pub WaiterAdapter = Arc<Waiter>: Waiter { link: AtomicLink });
293
294pub type WaiterList = LinkedList<WaiterAdapter>;