use std::{future, ops::Deref, sync::Arc};
use futures_core::Stream;
use futures_util::StreamExt;
use matrix_sdk_common::store_locks::CrossProcessStoreLock;
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
use tokio::sync::{broadcast, Mutex};
use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
use tracing::{debug, trace, warn};
use super::{caches::SessionStore, DeviceChanges, IdentityChanges, LockableCryptoStore};
use crate::{
olm::InboundGroupSession,
store,
store::{Changes, DynCryptoStore, IntoCryptoStore, RoomKeyInfo, RoomKeyWithheldInfo},
CryptoStoreError, GossippedSecret, OwnUserIdentityData, Session, UserIdentityData,
};
#[derive(Debug)]
pub(crate) struct CryptoStoreWrapper {
user_id: OwnedUserId,
device_id: OwnedDeviceId,
store: Arc<DynCryptoStore>,
sessions: SessionStore,
room_keys_received_sender: broadcast::Sender<Vec<RoomKeyInfo>>,
room_keys_withheld_received_sender: broadcast::Sender<Vec<RoomKeyWithheldInfo>>,
secrets_broadcaster: broadcast::Sender<GossippedSecret>,
identities_broadcaster:
broadcast::Sender<(Option<OwnUserIdentityData>, IdentityChanges, DeviceChanges)>,
}
impl CryptoStoreWrapper {
pub(crate) fn new(user_id: &UserId, device_id: &DeviceId, store: impl IntoCryptoStore) -> Self {
let room_keys_received_sender = broadcast::Sender::new(10);
let room_keys_withheld_received_sender = broadcast::Sender::new(10);
let secrets_broadcaster = broadcast::Sender::new(10);
let identities_broadcaster = broadcast::Sender::new(20);
Self {
user_id: user_id.to_owned(),
device_id: device_id.to_owned(),
store: store.into_crypto_store(),
sessions: SessionStore::new(),
room_keys_received_sender,
room_keys_withheld_received_sender,
secrets_broadcaster,
identities_broadcaster,
}
}
pub async fn save_changes(&self, changes: Changes) -> store::Result<()> {
let room_key_updates: Vec<_> =
changes.inbound_group_sessions.iter().map(RoomKeyInfo::from).collect();
let withheld_session_updates: Vec<_> = changes
.withheld_session_info
.iter()
.flat_map(|(room_id, session_map)| {
session_map.iter().map(|(session_id, withheld_event)| RoomKeyWithheldInfo {
room_id: room_id.to_owned(),
session_id: session_id.to_owned(),
withheld_event: withheld_event.clone(),
})
})
.collect();
let own_identity_was_verified_before_change = self
.store
.get_user_identity(self.user_id.as_ref())
.await?
.as_ref()
.and_then(|i| i.own())
.is_some_and(|own| own.is_verified());
let secrets = changes.secrets.to_owned();
let devices = changes.devices.to_owned();
let identities = changes.identities.to_owned();
if devices
.changed
.iter()
.any(|d| d.user_id() == self.user_id && d.device_id() == self.device_id)
{
self.sessions.clear().await;
} else {
for session in &changes.sessions {
self.sessions.add(session.clone()).await;
}
}
self.store.save_changes(changes).await?;
if tracing::level_enabled!(tracing::Level::DEBUG) {
for updated_identity in
identities.new.iter().chain(identities.changed.iter()).filter_map(|id| id.own())
{
let master_key = updated_identity.master_key().get_first_key();
let user_signing_key = updated_identity.user_signing_key().get_first_key();
let self_signing_key = updated_identity.self_signing_key().get_first_key();
debug!(
?master_key,
?user_signing_key,
?self_signing_key,
previously_verified = updated_identity.was_previously_verified(),
verified = updated_identity.is_verified(),
"Stored our own identity"
);
}
}
if !room_key_updates.is_empty() {
let _ = self.room_keys_received_sender.send(room_key_updates);
}
if !withheld_session_updates.is_empty() {
let _ = self.room_keys_withheld_received_sender.send(withheld_session_updates);
}
for secret in secrets {
let _ = self.secrets_broadcaster.send(secret);
}
if !devices.is_empty() || !identities.is_empty() {
let maybe_own_identity =
self.store.get_user_identity(&self.user_id).await?.and_then(|i| i.into_own());
if let Some(own_identity_after) = maybe_own_identity.as_ref() {
if !own_identity_was_verified_before_change && own_identity_after.is_verified() {
debug!("Own identity is now verified, check all known identities for verification status changes");
self.check_all_identities_and_update_was_previously_verified_flag_if_needed(
own_identity_after,
)
.await?;
}
}
let _ = self.identities_broadcaster.send((maybe_own_identity, identities, devices));
}
Ok(())
}
async fn check_all_identities_and_update_was_previously_verified_flag_if_needed(
&self,
own_identity_after: &OwnUserIdentityData,
) -> Result<(), CryptoStoreError> {
let tracked_users = self.store.load_tracked_users().await?;
let mut updated_identities: Vec<UserIdentityData> = Default::default();
for tracked_user in tracked_users {
if let Some(other_identity) = self
.store
.get_user_identity(tracked_user.user_id.as_ref())
.await?
.as_ref()
.and_then(|i| i.other())
{
if !other_identity.was_previously_verified()
&& own_identity_after.is_identity_signed(other_identity)
{
trace!(?tracked_user.user_id, "Marking set verified_latch to true.");
other_identity.mark_as_previously_verified();
updated_identities.push(other_identity.clone().into());
}
}
}
if !updated_identities.is_empty() {
let identity_changes =
IdentityChanges { changed: updated_identities, ..Default::default() };
self.store
.save_changes(Changes {
identities: identity_changes.clone(),
..Default::default()
})
.await?;
let _ = self.identities_broadcaster.send((
Some(own_identity_after.clone()),
identity_changes,
DeviceChanges::default(),
));
}
Ok(())
}
pub async fn get_sessions(
&self,
sender_key: &str,
) -> store::Result<Option<Arc<Mutex<Vec<Session>>>>> {
let sessions = self.sessions.get(sender_key).await;
let sessions = if sessions.is_none() {
let mut entries = self.sessions.entries.write().await;
let sessions = entries.get(sender_key);
if sessions.is_some() {
sessions.cloned()
} else {
let sessions = self.store.get_sessions(sender_key).await?;
let sessions = Arc::new(Mutex::new(sessions.unwrap_or_default()));
entries.insert(sender_key.to_owned(), sessions.clone());
Some(sessions)
}
} else {
sessions
};
Ok(sessions)
}
pub async fn save_inbound_group_sessions(
&self,
sessions: Vec<InboundGroupSession>,
backed_up_to_version: Option<&str>,
) -> store::Result<()> {
let room_key_updates: Vec<_> = sessions.iter().map(RoomKeyInfo::from).collect();
self.store.save_inbound_group_sessions(sessions, backed_up_to_version).await?;
if !room_key_updates.is_empty() {
let _ = self.room_keys_received_sender.send(room_key_updates);
}
Ok(())
}
pub fn room_keys_received_stream(&self) -> impl Stream<Item = Vec<RoomKeyInfo>> {
let stream = BroadcastStream::new(self.room_keys_received_sender.subscribe());
Self::filter_errors_out_of_stream(stream, "room_keys_received_stream")
}
pub fn room_keys_withheld_received_stream(
&self,
) -> impl Stream<Item = Vec<RoomKeyWithheldInfo>> {
let stream = BroadcastStream::new(self.room_keys_withheld_received_sender.subscribe());
Self::filter_errors_out_of_stream(stream, "room_keys_withheld_received_stream")
}
pub fn secrets_stream(&self) -> impl Stream<Item = GossippedSecret> {
let stream = BroadcastStream::new(self.secrets_broadcaster.subscribe());
Self::filter_errors_out_of_stream(stream, "secrets_stream")
}
pub(super) fn identities_stream(
&self,
) -> impl Stream<Item = (Option<OwnUserIdentityData>, IdentityChanges, DeviceChanges)> {
let stream = BroadcastStream::new(self.identities_broadcaster.subscribe());
Self::filter_errors_out_of_stream(stream, "identities_stream")
}
fn filter_errors_out_of_stream<ItemType>(
stream: BroadcastStream<ItemType>,
stream_name: &str,
) -> impl Stream<Item = ItemType>
where
ItemType: 'static + Clone + Send,
{
let stream_name = stream_name.to_owned();
stream.filter_map(move |result| {
future::ready(match result {
Ok(r) => Some(r),
Err(BroadcastStreamRecvError::Lagged(lag)) => {
warn!("{stream_name} missed {lag} updates");
None
}
})
})
}
pub(crate) fn create_store_lock(
&self,
lock_key: String,
lock_value: String,
) -> CrossProcessStoreLock<LockableCryptoStore> {
CrossProcessStoreLock::new(LockableCryptoStore(self.store.clone()), lock_key, lock_value)
}
}
impl Deref for CryptoStoreWrapper {
type Target = DynCryptoStore;
fn deref(&self) -> &Self::Target {
self.store.deref()
}
}
#[cfg(test)]
mod test {
use matrix_sdk_test::async_test;
use ruma::user_id;
use super::*;
use crate::machine::test_helpers::get_machine_pair_with_setup_sessions_test_helper;
#[async_test]
async fn test_cache_cleared_after_device_update() {
let user_id = user_id!("@alice:example.com");
let (first, second) =
get_machine_pair_with_setup_sessions_test_helper(user_id, user_id, false).await;
let sender_key = second.identity_keys().curve25519.to_base64();
first
.store()
.inner
.store
.sessions
.get(&sender_key)
.await
.expect("We should have a session in the cache.");
let device_data = first
.get_device(user_id, first.device_id(), None)
.await
.unwrap()
.expect("We should have access to our own device.")
.inner;
first
.store()
.save_changes(Changes {
devices: DeviceChanges { changed: vec![device_data], ..Default::default() },
..Default::default()
})
.await
.unwrap();
assert!(
first.store().inner.store.sessions.get(&sender_key).await.is_none(),
"The session should no longer be in the cache after our own device keys changed"
);
}
}