1use std::cell::UnsafeCell;
6use std::hint;
7use std::ops::Deref;
8use std::ops::DerefMut;
9use std::sync::atomic::AtomicBool;
10use std::sync::atomic::Ordering;
11
12const UNLOCKED: bool = false;
13const LOCKED: bool = true;
14
15#[repr(align(128))]
30pub struct SpinLock<T: ?Sized> {
31 lock: AtomicBool,
32 value: UnsafeCell<T>,
33}
34
35impl<T> SpinLock<T> {
36 pub fn new(value: T) -> SpinLock<T> {
38 SpinLock {
39 lock: AtomicBool::new(UNLOCKED),
40 value: UnsafeCell::new(value),
41 }
42 }
43
44 pub fn into_inner(self) -> T {
47 self.value.into_inner()
50 }
51}
52
53impl<T: ?Sized> SpinLock<T> {
54 pub fn lock(&self) -> SpinLockGuard<T> {
60 loop {
61 let state = self.lock.load(Ordering::Relaxed);
62 if state == UNLOCKED
63 && self
64 .lock
65 .compare_exchange_weak(UNLOCKED, LOCKED, Ordering::Acquire, Ordering::Relaxed)
66 .is_ok()
67 {
68 break;
69 }
70 hint::spin_loop();
71 }
72
73 #[allow(clippy::undocumented_unsafe_blocks)]
75 SpinLockGuard {
76 lock: self,
77 value: unsafe { &mut *self.value.get() },
78 }
79 }
80
81 fn unlock(&self) {
82 self.lock.store(UNLOCKED, Ordering::Release);
84 }
85
86 pub fn get_mut(&mut self) -> &mut T {
89 unsafe { &mut *self.value.get() }
93 }
94}
95
96#[allow(clippy::undocumented_unsafe_blocks)]
98unsafe impl<T: ?Sized + Send> Send for SpinLock<T> {}
99#[allow(clippy::undocumented_unsafe_blocks)]
101unsafe impl<T: ?Sized + Send> Sync for SpinLock<T> {}
102
103impl<T: Default> Default for SpinLock<T> {
104 fn default() -> Self {
105 Self::new(Default::default())
106 }
107}
108
109impl<T> From<T> for SpinLock<T> {
110 fn from(source: T) -> Self {
111 Self::new(source)
112 }
113}
114
115pub struct SpinLockGuard<'a, T: 'a + ?Sized> {
119 lock: &'a SpinLock<T>,
120 value: &'a mut T,
121}
122
123impl<T: ?Sized> Deref for SpinLockGuard<'_, T> {
124 type Target = T;
125 fn deref(&self) -> &T {
126 self.value
127 }
128}
129
130impl<T: ?Sized> DerefMut for SpinLockGuard<'_, T> {
131 fn deref_mut(&mut self) -> &mut T {
132 self.value
133 }
134}
135
136impl<T: ?Sized> Drop for SpinLockGuard<'_, T> {
137 fn drop(&mut self) {
138 self.lock.unlock();
139 }
140}
141
142#[cfg(test)]
143mod test {
144 use std::mem;
145 use std::sync::atomic::AtomicUsize;
146 use std::sync::atomic::Ordering;
147 use std::sync::Arc;
148 use std::thread;
149
150 use super::*;
151
152 #[derive(PartialEq, Eq, Debug)]
153 struct NonCopy(u32);
154
155 #[test]
156 fn it_works() {
157 let sl = SpinLock::new(NonCopy(13));
158
159 assert_eq!(*sl.lock(), NonCopy(13));
160 }
161
162 #[test]
163 fn smoke() {
164 let sl = SpinLock::new(NonCopy(7));
165
166 mem::drop(sl.lock());
167 mem::drop(sl.lock());
168 }
169
170 #[test]
171 fn send() {
172 let sl = SpinLock::new(NonCopy(19));
173
174 thread::spawn(move || {
175 let value = sl.lock();
176 assert_eq!(*value, NonCopy(19));
177 })
178 .join()
179 .unwrap();
180 }
181
182 #[test]
183 fn high_contention() {
184 const THREADS: usize = 23;
185 const ITERATIONS: usize = 101;
186
187 let mut threads = Vec::with_capacity(THREADS);
188
189 let sl = Arc::new(SpinLock::new(0usize));
190 for _ in 0..THREADS {
191 let sl2 = sl.clone();
192 threads.push(thread::spawn(move || {
193 for _ in 0..ITERATIONS {
194 *sl2.lock() += 1;
195 }
196 }));
197 }
198
199 for t in threads.into_iter() {
200 t.join().unwrap();
201 }
202
203 assert_eq!(*sl.lock(), THREADS * ITERATIONS);
204 }
205
206 #[test]
207 fn get_mut() {
208 let mut sl = SpinLock::new(NonCopy(13));
209 *sl.get_mut() = NonCopy(17);
210
211 assert_eq!(sl.into_inner(), NonCopy(17));
212 }
213
214 #[test]
215 fn into_inner() {
216 let sl = SpinLock::new(NonCopy(29));
217 assert_eq!(sl.into_inner(), NonCopy(29));
218 }
219
220 #[test]
221 fn into_inner_drop() {
222 struct NeedsDrop(Arc<AtomicUsize>);
223 impl Drop for NeedsDrop {
224 fn drop(&mut self) {
225 self.0.fetch_add(1, Ordering::AcqRel);
226 }
227 }
228
229 let value = Arc::new(AtomicUsize::new(0));
230 let needs_drop = SpinLock::new(NeedsDrop(value.clone()));
231 assert_eq!(value.load(Ordering::Acquire), 0);
232
233 {
234 let inner = needs_drop.into_inner();
235 assert_eq!(inner.0.load(Ordering::Acquire), 0);
236 }
237
238 assert_eq!(value.load(Ordering::Acquire), 1);
239 }
240
241 #[test]
242 fn arc_nested() {
243 let sl = SpinLock::new(1);
245 let arc = Arc::new(SpinLock::new(sl));
246 thread::spawn(move || {
247 let nested = arc.lock();
248 let lock2 = nested.lock();
249 assert_eq!(*lock2, 1);
250 })
251 .join()
252 .unwrap();
253 }
254
255 #[test]
256 fn arc_access_in_unwind() {
257 let arc = Arc::new(SpinLock::new(1));
258 let arc2 = arc.clone();
259 thread::spawn(move || {
260 struct Unwinder {
261 i: Arc<SpinLock<i32>>,
262 }
263 impl Drop for Unwinder {
264 fn drop(&mut self) {
265 *self.i.lock() += 1;
266 }
267 }
268 let _u = Unwinder { i: arc2 };
269 panic!();
270 })
271 .join()
272 .expect_err("thread did not panic");
273 let lock = arc.lock();
274 assert_eq!(*lock, 2);
275 }
276
277 #[test]
278 fn unsized_value() {
279 let sltex: &SpinLock<[i32]> = &SpinLock::new([1, 2, 3]);
280 {
281 let b = &mut *sltex.lock();
282 b[0] = 4;
283 b[2] = 5;
284 }
285 let expected: &[i32] = &[4, 2, 5];
286 assert_eq!(&*sltex.lock(), expected);
287 }
288}