#[cfg(target_arch = "wasm32")]
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
#[cfg(target_arch = "wasm32")]
pub use futures_util::future::Aborted as JoinError;
#[cfg(target_arch = "wasm32")]
use futures_util::{
future::{AbortHandle, Abortable, RemoteHandle},
FutureExt,
};
#[cfg(not(target_arch = "wasm32"))]
pub use tokio::task::{spawn, JoinError, JoinHandle};
#[cfg(target_arch = "wasm32")]
pub fn spawn<F, T>(future: F) -> JoinHandle<T>
where
F: Future<Output = T> + 'static,
{
let (future, remote_handle) = future.remote_handle();
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let future = Abortable::new(future, abort_registration);
wasm_bindgen_futures::spawn_local(async {
let _ = future.await;
});
JoinHandle { remote_handle, abort_handle }
}
#[cfg(target_arch = "wasm32")]
#[derive(Debug)]
pub struct JoinHandle<T> {
remote_handle: RemoteHandle<T>,
abort_handle: AbortHandle,
}
#[cfg(target_arch = "wasm32")]
impl<T> JoinHandle<T> {
pub fn abort(&self) {
self.abort_handle.abort();
}
}
#[cfg(target_arch = "wasm32")]
impl<T: 'static> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.abort_handle.is_aborted() {
Poll::Ready(Err(JoinError))
} else {
Pin::new(&mut self.remote_handle).poll(cx).map(Ok)
}
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use matrix_sdk_test::async_test;
use super::spawn;
#[async_test]
async fn test_spawn() {
let future = async { 42 };
let join_handle = spawn(future);
assert_matches!(join_handle.await, Ok(42));
}
#[async_test]
async fn test_abort() {
let future = async { 42 };
let join_handle = spawn(future);
join_handle.abort();
assert!(join_handle.await.is_err());
}
}