#[cfg(any(feature = "anyhow", feature = "eyre"))]
use std::any::TypeId;
use std::{
borrow::Cow,
fmt,
future::Future,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering::SeqCst},
RwLock,
},
};
use anymap2::any::CloneAnySendSync;
use futures_util::stream::{FuturesUnordered, StreamExt};
use matrix_sdk_base::{
deserialized_responses::{EncryptionInfo, SyncTimelineEvent},
SendOutsideWasm, SyncOutsideWasm,
};
use ruma::{events::AnySyncStateEvent, push::Action, serde::Raw, OwnedRoomId};
use serde::{de::DeserializeOwned, Deserialize};
use serde_json::value::RawValue as RawJsonValue;
use tracing::{debug, error, field::debug, instrument, warn};
use self::maps::EventHandlerMaps;
use crate::{Client, Room};
mod context;
mod maps;
mod static_events;
pub use self::context::{Ctx, EventHandlerContext, RawEvent};
#[cfg(not(target_arch = "wasm32"))]
type EventHandlerFut = Pin<Box<dyn Future<Output = ()> + Send>>;
#[cfg(target_arch = "wasm32")]
type EventHandlerFut = Pin<Box<dyn Future<Output = ()>>>;
#[cfg(not(target_arch = "wasm32"))]
type EventHandlerFn = dyn Fn(EventHandlerData<'_>) -> EventHandlerFut + Send + Sync;
#[cfg(target_arch = "wasm32")]
type EventHandlerFn = dyn Fn(EventHandlerData<'_>) -> EventHandlerFut;
type AnyMap = anymap2::Map<dyn CloneAnySendSync + Send + Sync>;
#[derive(Default)]
pub(crate) struct EventHandlerStore {
handlers: RwLock<EventHandlerMaps>,
context: RwLock<AnyMap>,
counter: AtomicU64,
}
impl EventHandlerStore {
pub fn add_handler(&self, handle: EventHandlerHandle, handler_fn: Box<EventHandlerFn>) {
self.handlers.write().unwrap().add(handle, handler_fn);
}
pub fn add_context<T>(&self, ctx: T)
where
T: Clone + Send + Sync + 'static,
{
self.context.write().unwrap().insert(ctx);
}
pub fn remove(&self, handle: EventHandlerHandle) {
self.handlers.write().unwrap().remove(handle);
}
#[cfg(test)]
fn len(&self) -> usize {
self.handlers.read().unwrap().len()
}
}
#[doc(hidden)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum HandlerKind {
GlobalAccountData,
RoomAccountData,
EphemeralRoomData,
Timeline,
MessageLike,
OriginalMessageLike,
RedactedMessageLike,
State,
OriginalState,
RedactedState,
StrippedState,
ToDevice,
Presence,
}
impl HandlerKind {
fn message_like_redacted(redacted: bool) -> Self {
if redacted {
Self::RedactedMessageLike
} else {
Self::OriginalMessageLike
}
}
fn state_redacted(redacted: bool) -> Self {
if redacted {
Self::RedactedState
} else {
Self::OriginalState
}
}
}
pub trait SyncEvent {
#[doc(hidden)]
const KIND: HandlerKind;
#[doc(hidden)]
const TYPE: Option<&'static str>;
}
pub(crate) struct EventHandlerWrapper {
handler_fn: Box<EventHandlerFn>,
pub handler_id: u64,
}
#[derive(Clone, Debug)]
pub struct EventHandlerHandle {
pub(crate) ev_kind: HandlerKind,
pub(crate) ev_type: Option<&'static str>,
pub(crate) room_id: Option<OwnedRoomId>,
pub(crate) handler_id: u64,
}
pub trait EventHandler<Ev, Ctx>: Clone + SendOutsideWasm + SyncOutsideWasm + 'static {
#[doc(hidden)]
type Future: EventHandlerFuture;
#[doc(hidden)]
fn handle_event(self, ev: Ev, data: EventHandlerData<'_>) -> Option<Self::Future>;
}
#[doc(hidden)]
pub trait EventHandlerFuture:
Future<Output = <Self as EventHandlerFuture>::Output> + SendOutsideWasm + 'static
{
type Output: EventHandlerResult;
}
impl<T> EventHandlerFuture for T
where
T: Future + SendOutsideWasm + 'static,
<T as Future>::Output: EventHandlerResult,
{
type Output = <T as Future>::Output;
}
#[doc(hidden)]
#[derive(Debug)]
pub struct EventHandlerData<'a> {
client: Client,
room: Option<Room>,
raw: &'a RawJsonValue,
encryption_info: Option<&'a EncryptionInfo>,
push_actions: &'a [Action],
handle: EventHandlerHandle,
}
pub trait EventHandlerResult: Sized {
#[doc(hidden)]
fn print_error(&self, event_type: Option<&str>);
}
impl EventHandlerResult for () {
fn print_error(&self, _event_type: Option<&str>) {}
}
impl<E: fmt::Debug + fmt::Display + 'static> EventHandlerResult for Result<(), E> {
fn print_error(&self, event_type: Option<&str>) {
let msg_fragment = match event_type {
Some(event_type) => format!(" for `{event_type}`"),
None => "".to_owned(),
};
match self {
#[cfg(feature = "anyhow")]
Err(e) if TypeId::of::<E>() == TypeId::of::<anyhow::Error>() => {
error!("Event handler{msg_fragment} failed: {e:?}");
}
#[cfg(feature = "eyre")]
Err(e) if TypeId::of::<E>() == TypeId::of::<eyre::Report>() => {
error!("Event handler{msg_fragment} failed: {e:?}");
}
Err(e) => {
error!("Event handler{msg_fragment} failed: {e}");
}
Ok(_) => {}
}
}
}
#[derive(Deserialize)]
struct UnsignedDetails {
redacted_because: Option<serde::de::IgnoredAny>,
}
impl Client {
pub(crate) fn add_event_handler_impl<Ev, Ctx, H>(
&self,
handler: H,
room_id: Option<OwnedRoomId>,
) -> EventHandlerHandle
where
Ev: SyncEvent + DeserializeOwned + Send + 'static,
H: EventHandler<Ev, Ctx>,
{
let handler_fn: Box<EventHandlerFn> = Box::new(move |data| {
let maybe_fut = serde_json::from_str(data.raw.get())
.map(|ev| handler.clone().handle_event(ev, data));
Box::pin(async move {
match maybe_fut {
Ok(Some(fut)) => {
fut.await.print_error(Ev::TYPE);
}
Ok(None) => {
error!(
event_type = Ev::TYPE, event_kind = ?Ev::KIND,
"Event handler has an invalid context argument",
);
}
Err(e) => {
warn!(
event_type = Ev::TYPE, event_kind = ?Ev::KIND,
"Failed to deserialize event, skipping event handler.\n
Deserialization error: {e}",
);
}
}
})
});
let handler_id = self.inner.event_handlers.counter.fetch_add(1, SeqCst);
let handle =
EventHandlerHandle { ev_kind: Ev::KIND, ev_type: Ev::TYPE, room_id, handler_id };
self.inner.event_handlers.add_handler(handle.clone(), handler_fn);
handle
}
pub(crate) async fn handle_sync_events<T>(
&self,
kind: HandlerKind,
room: Option<&Room>,
events: &[Raw<T>],
) -> serde_json::Result<()> {
#[derive(Deserialize)]
struct ExtractType<'a> {
#[serde(borrow, rename = "type")]
event_type: Cow<'a, str>,
}
for raw_event in events {
let event_type = raw_event.deserialize_as::<ExtractType<'_>>()?.event_type;
self.call_event_handlers(room, raw_event.json(), kind, &event_type, None, &[]).await;
}
Ok(())
}
pub(crate) async fn handle_sync_state_events(
&self,
room: Option<&Room>,
state_events: &[Raw<AnySyncStateEvent>],
) -> serde_json::Result<()> {
#[derive(Deserialize)]
struct StateEventDetails<'a> {
#[serde(borrow, rename = "type")]
event_type: Cow<'a, str>,
unsigned: Option<UnsignedDetails>,
}
self.handle_sync_events(HandlerKind::State, room, state_events).await?;
for raw_event in state_events {
let StateEventDetails { event_type, unsigned } = raw_event.deserialize_as()?;
let redacted = unsigned.and_then(|u| u.redacted_because).is_some();
let handler_kind = HandlerKind::state_redacted(redacted);
self.call_event_handlers(room, raw_event.json(), handler_kind, &event_type, None, &[])
.await;
}
Ok(())
}
pub(crate) async fn handle_sync_timeline_events(
&self,
room: Option<&Room>,
timeline_events: &[SyncTimelineEvent],
) -> serde_json::Result<()> {
#[derive(Deserialize)]
struct TimelineEventDetails<'a> {
#[serde(borrow, rename = "type")]
event_type: Cow<'a, str>,
state_key: Option<serde::de::IgnoredAny>,
unsigned: Option<UnsignedDetails>,
}
for item in timeline_events {
let TimelineEventDetails { event_type, state_key, unsigned } =
item.raw().deserialize_as()?;
let redacted = unsigned.and_then(|u| u.redacted_because).is_some();
let (handler_kind_g, handler_kind_r) = match state_key {
Some(_) => (HandlerKind::State, HandlerKind::state_redacted(redacted)),
None => (HandlerKind::MessageLike, HandlerKind::message_like_redacted(redacted)),
};
let raw_event = item.raw().json();
let encryption_info = item.encryption_info();
let push_actions = &item.push_actions;
self.call_event_handlers(
room,
raw_event,
handler_kind_g,
&event_type,
encryption_info,
push_actions,
)
.await;
self.call_event_handlers(
room,
raw_event,
handler_kind_r,
&event_type,
encryption_info,
push_actions,
)
.await;
let kind = HandlerKind::Timeline;
self.call_event_handlers(
room,
raw_event,
kind,
&event_type,
encryption_info,
push_actions,
)
.await;
}
Ok(())
}
#[instrument(skip_all, fields(?event_kind, ?event_type, room_id))]
async fn call_event_handlers(
&self,
room: Option<&Room>,
raw: &RawJsonValue,
event_kind: HandlerKind,
event_type: &str,
encryption_info: Option<&EncryptionInfo>,
push_actions: &[Action],
) {
let room_id = room.map(|r| r.room_id());
if let Some(room_id) = room_id {
tracing::Span::current().record("room_id", debug(room_id));
}
let mut futures: FuturesUnordered<_> = self
.inner
.event_handlers
.handlers
.read()
.unwrap()
.get_handlers(event_kind, event_type, room_id)
.map(|(handle, handler_fn)| {
let data = EventHandlerData {
client: self.clone(),
room: room.cloned(),
raw,
encryption_info,
push_actions,
handle,
};
(handler_fn)(data)
})
.collect();
if !futures.is_empty() {
debug!(amount = futures.len(), "Calling event handlers");
while let Some(()) = futures.next().await {}
}
}
}
#[derive(Debug)]
pub struct EventHandlerDropGuard {
handle: EventHandlerHandle,
client: Client,
}
impl EventHandlerDropGuard {
pub(crate) fn new(handle: EventHandlerHandle, client: Client) -> Self {
Self { handle, client }
}
}
impl Drop for EventHandlerDropGuard {
fn drop(&mut self) {
self.client.remove_event_handler(self.handle.clone());
}
}
macro_rules! impl_event_handler {
($($ty:ident),* $(,)?) => {
impl<Ev, Fun, Fut, $($ty),*> EventHandler<Ev, ($($ty,)*)> for Fun
where
Ev: SyncEvent,
Fun: FnOnce(Ev, $($ty),*) -> Fut + Clone + SendOutsideWasm + SyncOutsideWasm + 'static,
Fut: EventHandlerFuture,
$($ty: EventHandlerContext),*
{
type Future = Fut;
fn handle_event(self, ev: Ev, _d: EventHandlerData<'_>) -> Option<Self::Future> {
Some((self)(ev, $($ty::from_data(&_d)?),*))
}
}
};
}
impl_event_handler!();
impl_event_handler!(A);
impl_event_handler!(A, B);
impl_event_handler!(A, B, C);
impl_event_handler!(A, B, C, D);
impl_event_handler!(A, B, C, D, E);
impl_event_handler!(A, B, C, D, E, F);
impl_event_handler!(A, B, C, D, E, F, G);
impl_event_handler!(A, B, C, D, E, F, G, H);
#[cfg(test)]
mod tests {
use matrix_sdk_test::{
async_test, InvitedRoomBuilder, JoinedRoomBuilder, DEFAULT_TEST_ROOM_ID,
};
#[cfg(target_arch = "wasm32")]
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
use std::{
future,
sync::{
atomic::{AtomicU8, Ordering::SeqCst},
Arc,
},
};
use matrix_sdk_test::{
sync_timeline_event, EphemeralTestEvent, StateTestEvent, StrippedStateTestEvent,
SyncResponseBuilder,
};
use once_cell::sync::Lazy;
use ruma::{
events::{
room::{
member::{OriginalSyncRoomMemberEvent, StrippedRoomMemberEvent},
name::OriginalSyncRoomNameEvent,
power_levels::OriginalSyncRoomPowerLevelsEvent,
},
typing::SyncTypingEvent,
AnySyncStateEvent, AnySyncTimelineEvent,
},
room_id,
serde::Raw,
};
use serde_json::json;
use crate::{
event_handler::Ctx,
test_utils::{logged_in_client, no_retry_test_client},
Client, Room,
};
static MEMBER_EVENT: Lazy<Raw<AnySyncTimelineEvent>> = Lazy::new(|| {
sync_timeline_event!({
"content": {
"avatar_url": null,
"displayname": "example",
"membership": "join"
},
"event_id": "$151800140517rfvjc:localhost",
"membership": "join",
"origin_server_ts": 151800140,
"sender": "@example:localhost",
"state_key": "@example:localhost",
"type": "m.room.member",
"prev_content": {
"avatar_url": null,
"displayname": "example",
"membership": "invite"
},
"unsigned": {
"age": 297036,
"replaces_state": "$151800111315tsynI:localhost"
}
})
});
#[async_test]
async fn test_add_event_handler() -> crate::Result<()> {
let client = logged_in_client(None).await;
let member_count = Arc::new(AtomicU8::new(0));
let typing_count = Arc::new(AtomicU8::new(0));
let power_levels_count = Arc::new(AtomicU8::new(0));
let invited_member_count = Arc::new(AtomicU8::new(0));
client.add_event_handler({
let member_count = member_count.clone();
move |_ev: OriginalSyncRoomMemberEvent, _room: Room| async move {
member_count.fetch_add(1, SeqCst);
}
});
client.add_event_handler({
let typing_count = typing_count.clone();
move |_ev: SyncTypingEvent| async move {
typing_count.fetch_add(1, SeqCst);
}
});
client.add_event_handler({
let power_levels_count = power_levels_count.clone();
move |_ev: OriginalSyncRoomPowerLevelsEvent, _client: Client, _room: Room| async move {
power_levels_count.fetch_add(1, SeqCst);
}
});
client.add_event_handler({
let invited_member_count = invited_member_count.clone();
move |_ev: StrippedRoomMemberEvent| async move {
invited_member_count.fetch_add(1, SeqCst);
}
});
let response = SyncResponseBuilder::default()
.add_joined_room(
JoinedRoomBuilder::default()
.add_timeline_event(MEMBER_EVENT.clone())
.add_ephemeral_event(EphemeralTestEvent::Typing)
.add_state_event(StateTestEvent::PowerLevels),
)
.add_invited_room(
InvitedRoomBuilder::new(room_id!("!test_invited:example.org")).add_state_event(
StrippedStateTestEvent::Custom(json!({
"content": {
"avatar_url": "mxc://example.org/SEsfnsuifSDFSSEF",
"displayname": "Alice",
"membership": "invite",
},
"event_id": "$143273582443PhrSn:example.org",
"origin_server_ts": 1432735824653u64,
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
"sender": "@example:example.org",
"state_key": "@alice:example.org",
"type": "m.room.member",
"unsigned": {
"age": 1234,
"invite_room_state": [
{
"content": {
"name": "Example Room"
},
"sender": "@bob:example.org",
"state_key": "",
"type": "m.room.name"
},
{
"content": {
"join_rule": "invite"
},
"sender": "@bob:example.org",
"state_key": "",
"type": "m.room.join_rules"
}
]
}
})),
),
)
.build_sync_response();
client.process_sync(response).await?;
assert_eq!(member_count.load(SeqCst), 1);
assert_eq!(typing_count.load(SeqCst), 1);
assert_eq!(power_levels_count.load(SeqCst), 1);
assert_eq!(invited_member_count.load(SeqCst), 1);
Ok(())
}
#[async_test]
#[allow(dependency_on_unit_never_type_fallback)]
async fn test_add_room_event_handler() -> crate::Result<()> {
let client = logged_in_client(None).await;
let room_id_a = room_id!("!foo:example.org");
let room_id_b = room_id!("!bar:matrix.org");
let member_count = Arc::new(AtomicU8::new(0));
let power_levels_count = Arc::new(AtomicU8::new(0));
client.add_room_event_handler(room_id_a, {
let member_count = member_count.clone();
move |_ev: OriginalSyncRoomMemberEvent, _room: Room| {
member_count.fetch_add(1, SeqCst);
future::ready(())
}
});
client.add_room_event_handler(room_id_b, {
let member_count = member_count.clone();
move |_ev: OriginalSyncRoomMemberEvent, _room: Room| {
member_count.fetch_add(1, SeqCst);
future::ready(())
}
});
client.add_room_event_handler(room_id_a, {
let power_levels_count = power_levels_count.clone();
move |_ev: OriginalSyncRoomPowerLevelsEvent, _client: Client, _room: Room| {
power_levels_count.fetch_add(1, SeqCst);
future::ready(())
}
});
client.add_room_event_handler(room_id_b, move |_ev: OriginalSyncRoomNameEvent| async {
unreachable!("No room event in room B")
});
let response = SyncResponseBuilder::default()
.add_joined_room(
JoinedRoomBuilder::new(room_id_a)
.add_timeline_event(MEMBER_EVENT.clone())
.add_state_event(StateTestEvent::PowerLevels)
.add_state_event(StateTestEvent::RoomName),
)
.add_joined_room(
JoinedRoomBuilder::new(room_id_b)
.add_timeline_event(MEMBER_EVENT.clone())
.add_state_event(StateTestEvent::PowerLevels),
)
.build_sync_response();
client.process_sync(response).await?;
assert_eq!(member_count.load(SeqCst), 2);
assert_eq!(power_levels_count.load(SeqCst), 1);
Ok(())
}
#[async_test]
#[allow(dependency_on_unit_never_type_fallback)]
async fn test_remove_event_handler() -> crate::Result<()> {
let client = logged_in_client(None).await;
let member_count = Arc::new(AtomicU8::new(0));
client.add_event_handler({
let member_count = member_count.clone();
move |_ev: OriginalSyncRoomMemberEvent| async move {
member_count.fetch_add(1, SeqCst);
}
});
let handle_a = client.add_event_handler(move |_ev: OriginalSyncRoomMemberEvent| async {
panic!("handler should have been removed");
});
let handle_b = client.add_room_event_handler(
#[allow(unknown_lints, clippy::explicit_auto_deref)] *DEFAULT_TEST_ROOM_ID,
move |_ev: OriginalSyncRoomMemberEvent| async {
panic!("handler should have been removed");
},
);
client.add_event_handler({
let member_count = member_count.clone();
move |_ev: OriginalSyncRoomMemberEvent| async move {
member_count.fetch_add(1, SeqCst);
}
});
let response = SyncResponseBuilder::default()
.add_joined_room(JoinedRoomBuilder::default().add_timeline_event(MEMBER_EVENT.clone()))
.build_sync_response();
client.remove_event_handler(handle_a);
client.remove_event_handler(handle_b);
client.process_sync(response).await?;
assert_eq!(member_count.load(SeqCst), 2);
Ok(())
}
#[async_test]
async fn test_event_handler_drop_guard() {
let client = no_retry_test_client(None).await;
let handle = client.add_event_handler(|_ev: OriginalSyncRoomMemberEvent| async {});
assert_eq!(client.inner.event_handlers.len(), 1);
{
let _guard = client.event_handler_drop_guard(handle);
assert_eq!(client.inner.event_handlers.len(), 1);
}
assert_eq!(client.inner.event_handlers.len(), 0);
}
#[async_test]
async fn test_use_client_in_handler() {
let client = no_retry_test_client(None).await;
client.add_event_handler(|_ev: OriginalSyncRoomMemberEvent, client: Client| async move {
let _caps = client.get_capabilities().await?;
anyhow::Ok(())
});
}
#[async_test]
async fn test_raw_event_handler() -> crate::Result<()> {
let client = logged_in_client(None).await;
let counter = Arc::new(AtomicU8::new(0));
client.add_event_handler_context(counter.clone());
client.add_event_handler(
|_ev: Raw<OriginalSyncRoomMemberEvent>, counter: Ctx<Arc<AtomicU8>>| async move {
counter.fetch_add(1, SeqCst);
},
);
let response = SyncResponseBuilder::default()
.add_joined_room(JoinedRoomBuilder::default().add_timeline_event(MEMBER_EVENT.clone()))
.build_sync_response();
client.process_sync(response).await?;
assert_eq!(counter.load(SeqCst), 1);
Ok(())
}
#[async_test]
async fn test_enum_event_handler() -> crate::Result<()> {
let client = logged_in_client(None).await;
let counter = Arc::new(AtomicU8::new(0));
client.add_event_handler_context(counter.clone());
client.add_event_handler(
|_ev: AnySyncStateEvent, counter: Ctx<Arc<AtomicU8>>| async move {
counter.fetch_add(1, SeqCst);
},
);
let response = SyncResponseBuilder::default()
.add_joined_room(JoinedRoomBuilder::default().add_timeline_event(MEMBER_EVENT.clone()))
.build_sync_response();
client.process_sync(response).await?;
assert_eq!(counter.load(SeqCst), 1);
Ok(())
}
}