cros_async/
tokio_executor.rs1use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::sync::OnceLock;
9
10use base::AsRawDescriptors;
11use base::RawDescriptor;
12use tokio::runtime::Runtime;
13use tokio::task::LocalSet;
14
15use crate::sys::platform::tokio_source::TokioSource;
16use crate::AsyncError;
17use crate::AsyncResult;
18use crate::ExecutorTrait;
19use crate::IntoAsync;
20use crate::IoSource;
21use crate::TaskHandle;
22
23mod send_wrapper {
24 use std::thread;
25
26 #[derive(Clone)]
27 pub(super) struct SendWrapper<T> {
28 instance: T,
29 thread_id: thread::ThreadId,
30 }
31
32 impl<T> SendWrapper<T> {
33 pub(super) fn new(instance: T) -> SendWrapper<T> {
34 SendWrapper {
35 instance,
36 thread_id: thread::current().id(),
37 }
38 }
39 }
40
41 unsafe impl<T> Send for SendWrapper<T> {}
43 unsafe impl<T> Sync for SendWrapper<T> {}
45
46 impl<T> Drop for SendWrapper<T> {
47 fn drop(&mut self) {
48 if self.thread_id != thread::current().id() {
49 panic!("SendWrapper value was dropped on the wrong thread");
50 }
51 }
52 }
53
54 impl<T> std::ops::Deref for SendWrapper<T> {
55 type Target = T;
56
57 fn deref(&self) -> &T {
58 if self.thread_id != thread::current().id() {
59 panic!("SendWrapper value was accessed on the wrong thread");
60 }
61 &self.instance
62 }
63 }
64}
65
66#[derive(Clone)]
67pub struct TokioExecutor {
68 runtime: Arc<Runtime>,
69 local_set: Arc<OnceLock<send_wrapper::SendWrapper<LocalSet>>>,
70}
71
72impl TokioExecutor {
73 pub fn new() -> AsyncResult<Self> {
74 Ok(TokioExecutor {
75 runtime: Arc::new(Runtime::new().map_err(AsyncError::Io)?),
76 local_set: Arc::new(OnceLock::new()),
77 })
78 }
79}
80
81impl ExecutorTrait for TokioExecutor {
82 fn async_from<'a, F: IntoAsync + 'a>(&self, f: F) -> AsyncResult<IoSource<F>> {
83 Ok(IoSource::Tokio(TokioSource::new(
84 f,
85 self.runtime.handle().clone(),
86 )?))
87 }
88
89 fn run_until<F: Future>(&self, f: F) -> AsyncResult<F::Output> {
90 let local_set = self
91 .local_set
92 .get_or_init(|| send_wrapper::SendWrapper::new(LocalSet::new()));
93 Ok(self
94 .runtime
95 .block_on(async { local_set.run_until(f).await }))
96 }
97
98 fn spawn<F>(&self, f: F) -> TaskHandle<F::Output>
99 where
100 F: Future + Send + 'static,
101 F::Output: Send + 'static,
102 {
103 TaskHandle::Tokio(TokioTaskHandle {
104 join_handle: Some(self.runtime.spawn(f)),
105 })
106 }
107
108 fn spawn_blocking<F, R>(&self, f: F) -> TaskHandle<R>
109 where
110 F: FnOnce() -> R + Send + 'static,
111 R: Send + 'static,
112 {
113 TaskHandle::Tokio(TokioTaskHandle {
114 join_handle: Some(self.runtime.spawn_blocking(f)),
115 })
116 }
117
118 fn spawn_local<F>(&self, f: F) -> TaskHandle<F::Output>
119 where
120 F: Future + 'static,
121 F::Output: 'static,
122 {
123 let local_set = self
124 .local_set
125 .get_or_init(|| send_wrapper::SendWrapper::new(LocalSet::new()));
126 TaskHandle::Tokio(TokioTaskHandle {
127 join_handle: Some(local_set.spawn_local(f)),
128 })
129 }
130}
131
132impl AsRawDescriptors for TokioExecutor {
133 fn as_raw_descriptors(&self) -> Vec<RawDescriptor> {
134 todo!();
135 }
136}
137
138pub struct TokioTaskHandle<T> {
139 join_handle: Option<tokio::task::JoinHandle<T>>,
140}
141impl<R> TokioTaskHandle<R> {
142 pub async fn cancel(mut self) -> Option<R> {
143 match self.join_handle.take() {
144 Some(handle) => {
145 handle.abort();
146 handle.await.ok()
147 }
148 None => None,
149 }
150 }
151 pub fn detach(mut self) {
152 self.join_handle.take();
153 }
154}
155impl<R: 'static> Future for TokioTaskHandle<R> {
156 type Output = R;
157 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> std::task::Poll<Self::Output> {
158 let self_mut = self.get_mut();
159 Pin::new(self_mut.join_handle.as_mut().unwrap())
160 .poll(cx)
161 .map(|v| v.unwrap())
162 }
163}
164impl<T> std::ops::Drop for TokioTaskHandle<T> {
165 fn drop(&mut self) {
166 if let Some(handle) = self.join_handle.take() {
167 handle.abort()
168 }
169 }
170}