cros_async/
tokio_executor.rs

1// Copyright 2023 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::sync::OnceLock;
9
10use base::AsRawDescriptors;
11use base::RawDescriptor;
12use tokio::runtime::Runtime;
13use tokio::task::LocalSet;
14
15use crate::sys::platform::tokio_source::TokioSource;
16use crate::AsyncError;
17use crate::AsyncResult;
18use crate::ExecutorTrait;
19use crate::IntoAsync;
20use crate::IoSource;
21use crate::TaskHandle;
22
23mod send_wrapper {
24    use std::thread;
25
26    #[derive(Clone)]
27    pub(super) struct SendWrapper<T> {
28        instance: T,
29        thread_id: thread::ThreadId,
30    }
31
32    impl<T> SendWrapper<T> {
33        pub(super) fn new(instance: T) -> SendWrapper<T> {
34            SendWrapper {
35                instance,
36                thread_id: thread::current().id(),
37            }
38        }
39    }
40
41    // SAFETY: panics when the value is accessed on the wrong thread.
42    unsafe impl<T> Send for SendWrapper<T> {}
43    // SAFETY: panics when the value is accessed on the wrong thread.
44    unsafe impl<T> Sync for SendWrapper<T> {}
45
46    impl<T> Drop for SendWrapper<T> {
47        fn drop(&mut self) {
48            if self.thread_id != thread::current().id() {
49                panic!("SendWrapper value was dropped on the wrong thread");
50            }
51        }
52    }
53
54    impl<T> std::ops::Deref for SendWrapper<T> {
55        type Target = T;
56
57        fn deref(&self) -> &T {
58            if self.thread_id != thread::current().id() {
59                panic!("SendWrapper value was accessed on the wrong thread");
60            }
61            &self.instance
62        }
63    }
64}
65
66#[derive(Clone)]
67pub struct TokioExecutor {
68    runtime: Arc<Runtime>,
69    local_set: Arc<OnceLock<send_wrapper::SendWrapper<LocalSet>>>,
70}
71
72impl TokioExecutor {
73    pub fn new() -> AsyncResult<Self> {
74        Ok(TokioExecutor {
75            runtime: Arc::new(Runtime::new().map_err(AsyncError::Io)?),
76            local_set: Arc::new(OnceLock::new()),
77        })
78    }
79}
80
81impl ExecutorTrait for TokioExecutor {
82    fn async_from<'a, F: IntoAsync + 'a>(&self, f: F) -> AsyncResult<IoSource<F>> {
83        Ok(IoSource::Tokio(TokioSource::new(
84            f,
85            self.runtime.handle().clone(),
86        )?))
87    }
88
89    fn run_until<F: Future>(&self, f: F) -> AsyncResult<F::Output> {
90        let local_set = self
91            .local_set
92            .get_or_init(|| send_wrapper::SendWrapper::new(LocalSet::new()));
93        Ok(self
94            .runtime
95            .block_on(async { local_set.run_until(f).await }))
96    }
97
98    fn spawn<F>(&self, f: F) -> TaskHandle<F::Output>
99    where
100        F: Future + Send + 'static,
101        F::Output: Send + 'static,
102    {
103        TaskHandle::Tokio(TokioTaskHandle {
104            join_handle: Some(self.runtime.spawn(f)),
105        })
106    }
107
108    fn spawn_blocking<F, R>(&self, f: F) -> TaskHandle<R>
109    where
110        F: FnOnce() -> R + Send + 'static,
111        R: Send + 'static,
112    {
113        TaskHandle::Tokio(TokioTaskHandle {
114            join_handle: Some(self.runtime.spawn_blocking(f)),
115        })
116    }
117
118    fn spawn_local<F>(&self, f: F) -> TaskHandle<F::Output>
119    where
120        F: Future + 'static,
121        F::Output: 'static,
122    {
123        let local_set = self
124            .local_set
125            .get_or_init(|| send_wrapper::SendWrapper::new(LocalSet::new()));
126        TaskHandle::Tokio(TokioTaskHandle {
127            join_handle: Some(local_set.spawn_local(f)),
128        })
129    }
130}
131
132impl AsRawDescriptors for TokioExecutor {
133    fn as_raw_descriptors(&self) -> Vec<RawDescriptor> {
134        todo!();
135    }
136}
137
138pub struct TokioTaskHandle<T> {
139    join_handle: Option<tokio::task::JoinHandle<T>>,
140}
141impl<R> TokioTaskHandle<R> {
142    pub async fn cancel(mut self) -> Option<R> {
143        match self.join_handle.take() {
144            Some(handle) => {
145                handle.abort();
146                handle.await.ok()
147            }
148            None => None,
149        }
150    }
151    pub fn detach(mut self) {
152        self.join_handle.take();
153    }
154}
155impl<R: 'static> Future for TokioTaskHandle<R> {
156    type Output = R;
157    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> std::task::Poll<Self::Output> {
158        let self_mut = self.get_mut();
159        Pin::new(self_mut.join_handle.as_mut().unwrap())
160            .poll(cx)
161            .map(|v| v.unwrap())
162    }
163}
164impl<T> std::ops::Drop for TokioTaskHandle<T> {
165    fn drop(&mut self) {
166        if let Some(handle) = self.join_handle.take() {
167            handle.abort()
168        }
169    }
170}