use std::{
collections::{BTreeMap, BTreeSet, HashMap},
default::Default,
ops::Deref,
};
use itertools::{Either, Itertools};
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use tracing::{debug, instrument, trace};
use super::OutboundGroupSession;
use crate::{
error::{OlmResult, SessionRecipientCollectionError},
store::Store,
types::events::room_key_withheld::WithheldCode,
DeviceData, EncryptionSettings, LocalTrust, OlmError, OwnUserIdentityData, UserIdentityData,
};
#[cfg(doc)]
use crate::{Device, UserIdentity};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
pub enum CollectStrategy {
DeviceBasedStrategy {
only_allow_trusted_devices: bool,
#[serde(default)]
error_on_verified_user_problem: bool,
},
IdentityBasedStrategy,
}
impl CollectStrategy {
pub const fn new_identity_based() -> Self {
CollectStrategy::IdentityBasedStrategy
}
}
impl Default for CollectStrategy {
fn default() -> Self {
CollectStrategy::DeviceBasedStrategy {
only_allow_trusted_devices: false,
error_on_verified_user_problem: false,
}
}
}
#[derive(Debug)]
pub(crate) struct CollectRecipientsResult {
pub should_rotate: bool,
pub devices: BTreeMap<OwnedUserId, Vec<DeviceData>>,
pub withheld_devices: Vec<(DeviceData, WithheldCode)>,
}
#[instrument(skip_all)]
pub(crate) async fn collect_session_recipients(
store: &Store,
users: impl Iterator<Item = &UserId>,
settings: &EncryptionSettings,
outbound: &OutboundGroupSession,
) -> OlmResult<CollectRecipientsResult> {
let users: BTreeSet<&UserId> = users.collect();
let mut devices: BTreeMap<OwnedUserId, Vec<DeviceData>> = Default::default();
let mut withheld_devices: Vec<(DeviceData, WithheldCode)> = Default::default();
let mut verified_users_with_new_identities: Vec<OwnedUserId> = Default::default();
trace!(?users, ?settings, "Calculating group session recipients");
let users_shared_with: BTreeSet<OwnedUserId> =
outbound.shared_with_set.read().unwrap().keys().cloned().collect();
let users_shared_with: BTreeSet<&UserId> = users_shared_with.iter().map(Deref::deref).collect();
let user_left = !users_shared_with.difference(&users).collect::<BTreeSet<_>>().is_empty();
let visibility_changed = outbound.settings().history_visibility != settings.history_visibility;
let algorithm_changed = outbound.settings().algorithm != settings.algorithm;
let mut should_rotate = user_left || visibility_changed || algorithm_changed;
let own_identity = store.get_user_identity(store.user_id()).await?.and_then(|i| i.into_own());
match settings.sharing_strategy {
CollectStrategy::DeviceBasedStrategy {
only_allow_trusted_devices,
error_on_verified_user_problem,
} => {
let mut unsigned_devices_of_verified_users: BTreeMap<OwnedUserId, Vec<OwnedDeviceId>> =
Default::default();
for user_id in users {
trace!("Considering recipient devices for user {}", user_id);
let user_devices = store.get_device_data_for_user_filtered(user_id).await?;
let device_owner_identity =
if only_allow_trusted_devices || error_on_verified_user_problem {
store.get_user_identity(user_id).await?
} else {
None
};
if error_on_verified_user_problem
&& has_identity_verification_violation(
own_identity.as_ref(),
device_owner_identity.as_ref(),
)
{
verified_users_with_new_identities.push(user_id.to_owned());
continue;
}
let recipient_devices = split_devices_for_user(
user_devices,
&own_identity,
&device_owner_identity,
only_allow_trusted_devices,
error_on_verified_user_problem,
);
if !recipient_devices.unsigned_of_verified_user.is_empty() {
unsigned_devices_of_verified_users.insert(
user_id.to_owned(),
recipient_devices
.unsigned_of_verified_user
.into_iter()
.map(|d| d.device_id().to_owned())
.collect(),
);
}
if !should_rotate {
should_rotate = is_session_overshared_for_user(
outbound,
user_id,
&recipient_devices.allowed_devices,
)
}
devices
.entry(user_id.to_owned())
.or_default()
.extend(recipient_devices.allowed_devices);
withheld_devices.extend(recipient_devices.denied_devices_with_code);
}
if !unsigned_devices_of_verified_users.is_empty() {
return Err(OlmError::SessionRecipientCollectionError(
SessionRecipientCollectionError::VerifiedUserHasUnsignedDevice(
unsigned_devices_of_verified_users,
),
));
}
}
CollectStrategy::IdentityBasedStrategy => {
match &own_identity {
None => {
return Err(OlmError::SessionRecipientCollectionError(
SessionRecipientCollectionError::CrossSigningNotSetup,
))
}
Some(identity) if !identity.is_verified() => {
return Err(OlmError::SessionRecipientCollectionError(
SessionRecipientCollectionError::SendingFromUnverifiedDevice,
))
}
Some(_) => (),
}
for user_id in users {
trace!("Considering recipient devices for user {}", user_id);
let user_devices = store.get_device_data_for_user_filtered(user_id).await?;
let device_owner_identity = store.get_user_identity(user_id).await?;
if has_identity_verification_violation(
own_identity.as_ref(),
device_owner_identity.as_ref(),
) {
verified_users_with_new_identities.push(user_id.to_owned());
continue;
}
let recipient_devices = split_recipients_withhelds_for_user_based_on_identity(
user_devices,
&device_owner_identity,
);
if !should_rotate {
should_rotate = is_session_overshared_for_user(
outbound,
user_id,
&recipient_devices.allowed_devices,
)
}
devices
.entry(user_id.to_owned())
.or_default()
.extend(recipient_devices.allowed_devices);
withheld_devices.extend(recipient_devices.denied_devices_with_code);
}
}
}
if !verified_users_with_new_identities.is_empty() {
return Err(OlmError::SessionRecipientCollectionError(
SessionRecipientCollectionError::VerifiedUserChangedIdentity(
verified_users_with_new_identities,
),
));
}
if should_rotate {
debug!(
should_rotate,
user_left,
visibility_changed,
algorithm_changed,
"Rotating room key to protect room history",
);
}
trace!(should_rotate, "Done calculating group session recipients");
Ok(CollectRecipientsResult { should_rotate, devices, withheld_devices })
}
fn is_session_overshared_for_user(
outbound_session: &OutboundGroupSession,
user_id: &UserId,
recipient_devices: &[DeviceData],
) -> bool {
let recipient_device_ids: BTreeSet<&DeviceId> =
recipient_devices.iter().map(|d| d.device_id()).collect();
let guard = outbound_session.shared_with_set.read().unwrap();
let Some(shared) = guard.get(user_id) else {
return false;
};
let shared: BTreeSet<&DeviceId> = shared.keys().map(|d| d.as_ref()).collect();
let newly_deleted_or_blacklisted =
shared.difference(&recipient_device_ids).collect::<BTreeSet<_>>();
let should_rotate = !newly_deleted_or_blacklisted.is_empty();
if should_rotate {
debug!(
"Rotating a room key due to these devices being deleted/blacklisted {:?}",
newly_deleted_or_blacklisted,
);
}
should_rotate
}
#[derive(Default)]
struct DeviceBasedRecipientDevices {
allowed_devices: Vec<DeviceData>,
denied_devices_with_code: Vec<(DeviceData, WithheldCode)>,
unsigned_of_verified_user: Vec<DeviceData>,
}
fn split_devices_for_user(
user_devices: HashMap<OwnedDeviceId, DeviceData>,
own_identity: &Option<OwnUserIdentityData>,
device_owner_identity: &Option<UserIdentityData>,
only_allow_trusted_devices: bool,
error_on_verified_user_problem: bool,
) -> DeviceBasedRecipientDevices {
let mut recipient_devices: DeviceBasedRecipientDevices = Default::default();
for d in user_devices.into_values() {
if d.is_blacklisted() {
recipient_devices.denied_devices_with_code.push((d, WithheldCode::Blacklisted));
} else if d.local_trust_state() == LocalTrust::Ignored {
recipient_devices.allowed_devices.push(d);
} else if only_allow_trusted_devices && !d.is_verified(own_identity, device_owner_identity)
{
recipient_devices.denied_devices_with_code.push((d, WithheldCode::Unverified));
} else if error_on_verified_user_problem
&& is_unsigned_device_of_verified_user(
own_identity.as_ref(),
device_owner_identity.as_ref(),
&d,
)
{
recipient_devices.unsigned_of_verified_user.push(d)
} else {
recipient_devices.allowed_devices.push(d);
}
}
recipient_devices
}
#[derive(Default)]
struct IdentityBasedRecipientDevices {
allowed_devices: Vec<DeviceData>,
denied_devices_with_code: Vec<(DeviceData, WithheldCode)>,
}
fn split_recipients_withhelds_for_user_based_on_identity(
user_devices: HashMap<OwnedDeviceId, DeviceData>,
device_owner_identity: &Option<UserIdentityData>,
) -> IdentityBasedRecipientDevices {
match device_owner_identity {
None => {
IdentityBasedRecipientDevices {
allowed_devices: Vec::default(),
denied_devices_with_code: user_devices
.into_values()
.map(|d| (d, WithheldCode::Unverified))
.collect(),
}
}
Some(device_owner_identity) => {
let (recipients, withheld_recipients): (
Vec<DeviceData>,
Vec<(DeviceData, WithheldCode)>,
) = user_devices.into_values().partition_map(|d| {
if d.is_cross_signed_by_owner(device_owner_identity) {
Either::Left(d)
} else {
Either::Right((d, WithheldCode::Unverified))
}
});
IdentityBasedRecipientDevices {
allowed_devices: recipients,
denied_devices_with_code: withheld_recipients,
}
}
}
}
fn is_unsigned_device_of_verified_user(
own_identity: Option<&OwnUserIdentityData>,
device_owner_identity: Option<&UserIdentityData>,
device_data: &DeviceData,
) -> bool {
device_owner_identity.is_some_and(|device_owner_identity| {
is_user_verified(own_identity, device_owner_identity)
&& !device_data.is_cross_signed_by_owner(device_owner_identity)
})
}
fn has_identity_verification_violation(
own_identity: Option<&OwnUserIdentityData>,
device_owner_identity: Option<&UserIdentityData>,
) -> bool {
device_owner_identity.is_some_and(|device_owner_identity| {
device_owner_identity.was_previously_verified()
&& !is_user_verified(own_identity, device_owner_identity)
})
}
fn is_user_verified(
own_identity: Option<&OwnUserIdentityData>,
user_identity: &UserIdentityData,
) -> bool {
match user_identity {
UserIdentityData::Own(own_identity) => own_identity.is_verified(),
UserIdentityData::Other(other_identity) => {
own_identity.is_some_and(|oi| oi.is_identity_verified(other_identity))
}
}
}
#[cfg(test)]
mod tests {
use std::{collections::BTreeMap, iter, sync::Arc};
use assert_matches::assert_matches;
use assert_matches2::assert_let;
use matrix_sdk_test::{
async_test, test_json,
test_json::keys_query_sets::{
IdentityChangeDataSet, KeyDistributionTestData, MaloIdentityChangeDataSet,
VerificationViolationTestData,
},
};
use ruma::{
device_id, events::room::history_visibility::HistoryVisibility, room_id, TransactionId,
};
use serde_json::json;
use crate::{
error::SessionRecipientCollectionError,
olm::OutboundGroupSession,
session_manager::{
group_sessions::share_strategy::collect_session_recipients, CollectStrategy,
},
testing::simulate_key_query_response_for_verification,
types::events::room_key_withheld::WithheldCode,
CrossSigningKeyExport, EncryptionSettings, LocalTrust, OlmError, OlmMachine,
};
async fn set_up_test_machine() -> OlmMachine {
let machine = OlmMachine::new(
KeyDistributionTestData::me_id(),
KeyDistributionTestData::me_device_id(),
)
.await;
let keys_query = KeyDistributionTestData::me_keys_query_response();
let txn_id = TransactionId::new();
machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
machine
.import_cross_signing_keys(CrossSigningKeyExport {
master_key: KeyDistributionTestData::MASTER_KEY_PRIVATE_EXPORT.to_owned().into(),
self_signing_key: KeyDistributionTestData::SELF_SIGNING_KEY_PRIVATE_EXPORT
.to_owned()
.into(),
user_signing_key: KeyDistributionTestData::USER_SIGNING_KEY_PRIVATE_EXPORT
.to_owned()
.into(),
})
.await
.unwrap();
let keys_query = KeyDistributionTestData::dan_keys_query_response();
let txn_id = TransactionId::new();
machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
let txn_id_dave = TransactionId::new();
let keys_query_dave = KeyDistributionTestData::dave_keys_query_response();
machine.mark_request_as_sent(&txn_id_dave, &keys_query_dave).await.unwrap();
let txn_id_good = TransactionId::new();
let keys_query_good = KeyDistributionTestData::good_keys_query_response();
machine.mark_request_as_sent(&txn_id_good, &keys_query_good).await.unwrap();
machine
}
#[async_test]
async fn test_share_with_per_device_strategy_to_all() {
let machine = set_up_test_machine().await;
let encryption_settings = EncryptionSettings {
sharing_strategy: CollectStrategy::DeviceBasedStrategy {
only_allow_trusted_devices: false,
error_on_verified_user_problem: false,
},
..Default::default()
};
let group_session = create_test_outbound_group_session(&machine, &encryption_settings);
let share_result = collect_session_recipients(
machine.store(),
vec![
KeyDistributionTestData::dan_id(),
KeyDistributionTestData::dave_id(),
KeyDistributionTestData::good_id(),
]
.into_iter(),
&encryption_settings,
&group_session,
)
.await
.unwrap();
assert!(!share_result.should_rotate);
let dan_devices_shared =
share_result.devices.get(KeyDistributionTestData::dan_id()).unwrap();
let dave_devices_shared =
share_result.devices.get(KeyDistributionTestData::dave_id()).unwrap();
let good_devices_shared =
share_result.devices.get(KeyDistributionTestData::good_id()).unwrap();
assert_eq!(dan_devices_shared.len(), 2);
assert_eq!(dave_devices_shared.len(), 1);
assert_eq!(good_devices_shared.len(), 2);
}
#[async_test]
async fn test_share_with_per_device_strategy_only_trusted() {
test_share_only_trusted_helper(false).await;
}
#[async_test]
async fn test_share_with_per_device_strategy_only_trusted_error_on_unsigned_of_verified() {
test_share_only_trusted_helper(true).await;
}
async fn test_share_only_trusted_helper(error_on_verified_user_problem: bool) {
let machine = set_up_test_machine().await;
let encryption_settings = EncryptionSettings {
sharing_strategy: CollectStrategy::DeviceBasedStrategy {
only_allow_trusted_devices: true,
error_on_verified_user_problem,
},
..Default::default()
};
let group_session = create_test_outbound_group_session(&machine, &encryption_settings);
let share_result = collect_session_recipients(
machine.store(),
vec![
KeyDistributionTestData::dan_id(),
KeyDistributionTestData::dave_id(),
KeyDistributionTestData::good_id(),
]
.into_iter(),
&encryption_settings,
&group_session,
)
.await
.unwrap();
assert!(!share_result.should_rotate);
let dave_devices_shared = share_result.devices.get(KeyDistributionTestData::dave_id());
let good_devices_shared = share_result.devices.get(KeyDistributionTestData::good_id());
assert!(dave_devices_shared.unwrap().is_empty());
assert!(good_devices_shared.unwrap().is_empty());
let dan_devices_shared =
share_result.devices.get(KeyDistributionTestData::dan_id()).unwrap();
assert_eq!(dan_devices_shared.len(), 1);
let dan_device_that_will_get_the_key = &dan_devices_shared[0];
assert_eq!(
dan_device_that_will_get_the_key.device_id().as_str(),
KeyDistributionTestData::dan_signed_device_id()
);
let (_, code) = share_result
.withheld_devices
.iter()
.find(|(d, _)| d.device_id() == KeyDistributionTestData::dan_unsigned_device_id())
.expect("This dan's device should receive a withheld code");
assert_eq!(code, &WithheldCode::Unverified);
let (_, code) = share_result
.withheld_devices
.iter()
.find(|(d, _)| d.device_id() == KeyDistributionTestData::dave_device_id())
.expect("This daves's device should receive a withheld code");
assert_eq!(code, &WithheldCode::Unverified);
}
#[async_test]
async fn test_error_on_unsigned_of_verified_users() {
use VerificationViolationTestData as DataSet;
let machine = unsigned_of_verified_setup().await;
let carol_keys = DataSet::carol_keys_query_response_signed();
machine.mark_request_as_sent(&TransactionId::new(), &carol_keys).await.unwrap();
let carol_identity =
machine.get_identity(DataSet::carol_id(), None).await.unwrap().unwrap();
assert!(carol_identity.other().unwrap().is_verified());
let carol_unsigned_device = machine
.get_device(DataSet::carol_id(), DataSet::carol_unsigned_device_id(), None)
.await
.unwrap()
.unwrap();
assert!(!carol_unsigned_device.is_verified());
let encryption_settings = error_on_verification_problem_encryption_settings();
let group_session = create_test_outbound_group_session(&machine, &encryption_settings);
let share_result = collect_session_recipients(
machine.store(),
vec![DataSet::bob_id(), DataSet::carol_id()].into_iter(),
&encryption_settings,
&group_session,
)
.await;
assert_let!(
Err(OlmError::SessionRecipientCollectionError(
SessionRecipientCollectionError::VerifiedUserHasUnsignedDevice(unverified_devices)
)) = share_result
);
assert_eq!(
unverified_devices,
BTreeMap::from([
(DataSet::bob_id().to_owned(), vec![DataSet::bob_device_2_id().to_owned()]),
(
DataSet::carol_id().to_owned(),
vec![DataSet::carol_unsigned_device_id().to_owned()]
),
])
);
}
#[async_test]
async fn test_error_on_unsigned_of_verified_resolve_by_whitelisting() {
use VerificationViolationTestData as DataSet;
let machine = unsigned_of_verified_setup().await;
machine
.get_device(DataSet::bob_id(), DataSet::bob_device_2_id(), None)
.await
.unwrap()
.unwrap()
.set_local_trust(LocalTrust::Ignored)
.await
.unwrap();
let encryption_settings = error_on_verification_problem_encryption_settings();
let group_session = create_test_outbound_group_session(&machine, &encryption_settings);
let share_result = collect_session_recipients(
machine.store(),
iter::once(DataSet::bob_id()),
&encryption_settings,
&group_session,
)
.await
.unwrap();
assert_eq!(2, share_result.devices.get(DataSet::bob_id()).unwrap().len());
assert_eq!(0, share_result.withheld_devices.len());
}
#[async_test]
async fn test_error_on_unsigned_of_verified_resolve_by_blacklisting() {
use VerificationViolationTestData as DataSet;
let machine = unsigned_of_verified_setup().await;
machine
.get_device(DataSet::bob_id(), DataSet::bob_device_2_id(), None)
.await
.unwrap()
.unwrap()
.set_local_trust(LocalTrust::BlackListed)
.await
.unwrap();
let encryption_settings = error_on_verification_problem_encryption_settings();
let group_session = create_test_outbound_group_session(&machine, &encryption_settings);
let share_result = collect_session_recipients(
machine.store(),
iter::once(DataSet::bob_id()),
&encryption_settings,
&group_session,
)
.await
.unwrap();
assert_eq!(1, share_result.devices.get(DataSet::bob_id()).unwrap().len());
let withheld_list: Vec<_> = share_result
.withheld_devices
.iter()
.map(|(d, code)| (d.device_id().to_owned(), code.clone()))
.collect();
assert_eq!(
withheld_list,
vec![(DataSet::bob_device_2_id().to_owned(), WithheldCode::Blacklisted)]
);
}
#[async_test]
async fn test_error_on_unsigned_of_verified_owner_is_us() {
use VerificationViolationTestData as DataSet;
let machine = unsigned_of_verified_setup().await;
let mut own_keys = DataSet::own_keys_query_response_1().clone();
own_keys.device_keys.insert(
DataSet::own_id().to_owned(),
BTreeMap::from([
DataSet::own_signed_device_keys(),
DataSet::own_unsigned_device_keys(),
]),
);
machine.mark_request_as_sent(&TransactionId::new(), &own_keys).await.unwrap();
let encryption_settings = error_on_verification_problem_encryption_settings();
let group_session = create_test_outbound_group_session(&machine, &encryption_settings);
let share_result = collect_session_recipients(
machine.store(),
iter::once(DataSet::own_id()),
&encryption_settings,
&group_session,
)
.await;
assert_let!(
Err(OlmError::SessionRecipientCollectionError(
SessionRecipientCollectionError::VerifiedUserHasUnsignedDevice(unverified_devices)
)) = share_result
);
assert_eq!(
unverified_devices,
BTreeMap::from([(
DataSet::own_id().to_owned(),
vec![DataSet::own_unsigned_device_id()]
),])
);
}
#[async_test]
async fn test_should_not_error_on_unsigned_of_unverified() {
use VerificationViolationTestData as DataSet;
let machine = OlmMachine::new(DataSet::own_id(), device_id!("LOCAL")).await;
let own_keys = DataSet::own_keys_query_response_1();
machine.mark_request_as_sent(&TransactionId::new(), &own_keys).await.unwrap();
machine
.import_cross_signing_keys(CrossSigningKeyExport {
master_key: DataSet::MASTER_KEY_PRIVATE_EXPORT.to_owned().into(),
self_signing_key: DataSet::SELF_SIGNING_KEY_PRIVATE_EXPORT.to_owned().into(),
user_signing_key: DataSet::USER_SIGNING_KEY_PRIVATE_EXPORT.to_owned().into(),
})
.await
.unwrap();
let bob_keys = DataSet::bob_keys_query_response_rotated();
machine.mark_request_as_sent(&TransactionId::new(), &bob_keys).await.unwrap();
let bob_identity = machine.get_identity(DataSet::bob_id(), None).await.unwrap().unwrap();
assert!(!bob_identity.other().unwrap().is_verified());
let bob_unsigned_device = machine
.get_device(DataSet::bob_id(), DataSet::bob_device_1_id(), None)
.await
.unwrap()
.unwrap();
assert!(!bob_unsigned_device.is_cross_signed_by_owner());
let encryption_settings = error_on_verification_problem_encryption_settings();
let group_session = create_test_outbound_group_session(&machine, &encryption_settings);
collect_session_recipients(
machine.store(),
iter::once(DataSet::bob_id()),
&encryption_settings,
&group_session,
)
.await
.unwrap();
}
#[async_test]
async fn test_should_not_error_on_unsigned_of_signed_but_unverified() {
use VerificationViolationTestData as DataSet;
let machine = OlmMachine::new(DataSet::own_id(), device_id!("LOCAL")).await;
let keys_query = DataSet::own_keys_query_response_1();
machine.mark_request_as_sent(&TransactionId::new(), &keys_query).await.unwrap();
let keys_query = DataSet::bob_keys_query_response_signed();
machine.mark_request_as_sent(&TransactionId::new(), &keys_query).await.unwrap();
let bob_identity =
machine.get_identity(DataSet::bob_id(), None).await.unwrap().unwrap().other().unwrap();
assert!(bob_identity
.own_identity
.as_ref()
.unwrap()
.is_identity_signed(&bob_identity.inner));
assert!(!bob_identity.is_verified());
let bob_unsigned_device = machine
.get_device(DataSet::bob_id(), DataSet::bob_device_2_id(), None)
.await
.unwrap()
.unwrap();
assert!(!bob_unsigned_device.is_cross_signed_by_owner());
let encryption_settings = error_on_verification_problem_encryption_settings();
let group_session = create_test_outbound_group_session(&machine, &encryption_settings);
collect_session_recipients(
machine.store(),
iter::once(DataSet::bob_id()),
&encryption_settings,
&group_session,
)
.await
.unwrap();
}
#[async_test]
async fn test_verified_user_changed_identity() {
use test_json::keys_query_sets::VerificationViolationTestData as DataSet;
let machine = unsigned_of_verified_setup().await;
let bob_keys = DataSet::bob_keys_query_response_rotated();
machine.mark_request_as_sent(&TransactionId::new(), &bob_keys).await.unwrap();
let bob_identity = machine.get_identity(DataSet::bob_id(), None).await.unwrap().unwrap();
assert!(bob_identity.has_verification_violation());
let encryption_settings = error_on_verification_problem_encryption_settings();
let group_session = create_test_outbound_group_session(&machine, &encryption_settings);
let share_result = collect_session_recipients(
machine.store(),
iter::once(DataSet::bob_id()),
&encryption_settings,
&group_session,
)
.await;
assert_let!(
Err(OlmError::SessionRecipientCollectionError(
SessionRecipientCollectionError::VerifiedUserChangedIdentity(violating_users)
)) = share_result
);
assert_eq!(violating_users, vec![DataSet::bob_id()]);
bob_identity.withdraw_verification().await.unwrap();
collect_session_recipients(
machine.store(),
iter::once(DataSet::bob_id()),
&encryption_settings,
&group_session,
)
.await
.unwrap();
}
#[async_test]
async fn test_own_verified_identity_changed() {
use test_json::keys_query_sets::VerificationViolationTestData as DataSet;
let machine = unsigned_of_verified_setup().await;
let own_identity = machine.get_identity(DataSet::own_id(), None).await.unwrap().unwrap();
assert!(own_identity.own().unwrap().is_verified());
let own_keys = DataSet::own_keys_query_response_2();
machine.mark_request_as_sent(&TransactionId::new(), &own_keys).await.unwrap();
let own_identity = machine.get_identity(DataSet::own_id(), None).await.unwrap().unwrap();
assert!(!own_identity.is_verified());
let encryption_settings = error_on_verification_problem_encryption_settings();
let group_session = create_test_outbound_group_session(&machine, &encryption_settings);
let share_result = collect_session_recipients(
machine.store(),
iter::once(DataSet::own_id()),
&encryption_settings,
&group_session,
)
.await;
assert_let!(
Err(OlmError::SessionRecipientCollectionError(
SessionRecipientCollectionError::VerifiedUserChangedIdentity(violating_users)
)) = share_result
);
assert_eq!(violating_users, vec![DataSet::own_id()]);
own_identity.withdraw_verification().await.unwrap();
collect_session_recipients(
machine.store(),
iter::once(DataSet::own_id()),
&encryption_settings,
&group_session,
)
.await
.unwrap();
}
#[async_test]
async fn test_share_with_identity_strategy() {
let machine = set_up_test_machine().await;
let strategy = CollectStrategy::new_identity_based();
let encryption_settings =
EncryptionSettings { sharing_strategy: strategy.clone(), ..Default::default() };
let group_session = create_test_outbound_group_session(&machine, &encryption_settings);
let share_result = collect_session_recipients(
machine.store(),
vec![
KeyDistributionTestData::dan_id(),
KeyDistributionTestData::dave_id(),
KeyDistributionTestData::good_id(),
]
.into_iter(),
&encryption_settings,
&group_session,
)
.await
.unwrap();
assert!(!share_result.should_rotate);
let dave_devices_shared = share_result.devices.get(KeyDistributionTestData::dave_id());
let good_devices_shared = share_result.devices.get(KeyDistributionTestData::good_id());
assert!(dave_devices_shared.unwrap().is_empty());
assert_eq!(good_devices_shared.unwrap().len(), 2);
let dan_devices_shared =
share_result.devices.get(KeyDistributionTestData::dan_id()).unwrap();
assert_eq!(dan_devices_shared.len(), 1);
let dan_device_that_will_get_the_key = &dan_devices_shared[0];
assert_eq!(
dan_device_that_will_get_the_key.device_id().as_str(),
KeyDistributionTestData::dan_signed_device_id()
);
let (_, code) = share_result
.withheld_devices
.iter()
.find(|(d, _)| d.device_id() == KeyDistributionTestData::dan_unsigned_device_id())
.expect("This dan's device should receive a withheld code");
assert_eq!(code, &WithheldCode::Unverified);
let (_, code) = share_result
.withheld_devices
.iter()
.find(|(d, _)| d.device_id() == KeyDistributionTestData::dave_device_id())
.expect("This dave device should receive a withheld code");
assert_eq!(code, &WithheldCode::Unverified);
}
#[async_test]
async fn test_share_identity_strategy_no_cross_signing() {
let machine: OlmMachine = OlmMachine::new(
KeyDistributionTestData::me_id(),
KeyDistributionTestData::me_device_id(),
)
.await;
let keys_query = KeyDistributionTestData::dan_keys_query_response();
machine.mark_request_as_sent(&TransactionId::new(), &keys_query).await.unwrap();
let fake_room_id = room_id!("!roomid:localhost");
let encryption_settings = EncryptionSettings {
sharing_strategy: CollectStrategy::new_identity_based(),
..Default::default()
};
let request_result = machine
.share_room_key(
fake_room_id,
iter::once(KeyDistributionTestData::dan_id()),
encryption_settings.clone(),
)
.await;
assert_matches!(
request_result,
Err(OlmError::SessionRecipientCollectionError(
SessionRecipientCollectionError::CrossSigningNotSetup
))
);
let keys_query = KeyDistributionTestData::me_keys_query_response();
machine.mark_request_as_sent(&TransactionId::new(), &keys_query).await.unwrap();
let request_result = machine
.share_room_key(
fake_room_id,
iter::once(KeyDistributionTestData::dan_id()),
encryption_settings.clone(),
)
.await;
assert_matches!(
request_result,
Err(OlmError::SessionRecipientCollectionError(
SessionRecipientCollectionError::SendingFromUnverifiedDevice
))
);
machine
.import_cross_signing_keys(CrossSigningKeyExport {
master_key: KeyDistributionTestData::MASTER_KEY_PRIVATE_EXPORT.to_owned().into(),
self_signing_key: KeyDistributionTestData::SELF_SIGNING_KEY_PRIVATE_EXPORT
.to_owned()
.into(),
user_signing_key: KeyDistributionTestData::USER_SIGNING_KEY_PRIVATE_EXPORT
.to_owned()
.into(),
})
.await
.unwrap();
let requests = machine
.share_room_key(
fake_room_id,
iter::once(KeyDistributionTestData::dan_id()),
encryption_settings.clone(),
)
.await
.unwrap();
assert_eq!(requests.len(), 1);
}
#[async_test]
async fn test_share_identity_strategy_report_verification_violation() {
let machine: OlmMachine = OlmMachine::new(
KeyDistributionTestData::me_id(),
KeyDistributionTestData::me_device_id(),
)
.await;
machine.bootstrap_cross_signing(false).await.unwrap();
let user1 = IdentityChangeDataSet::user_id();
let user2 = MaloIdentityChangeDataSet::user_id();
let keys_query = IdentityChangeDataSet::key_query_with_identity_a();
machine.mark_request_as_sent(&TransactionId::new(), &keys_query).await.unwrap();
let keys_query = MaloIdentityChangeDataSet::initial_key_query();
machine.mark_request_as_sent(&TransactionId::new(), &keys_query).await.unwrap();
let keys_query = IdentityChangeDataSet::key_query_with_identity_b();
machine.mark_request_as_sent(&TransactionId::new(), &keys_query).await.unwrap();
machine
.get_identity(user1, None)
.await
.unwrap()
.unwrap()
.other()
.unwrap()
.mark_as_previously_verified()
.await
.unwrap();
let keys_query = MaloIdentityChangeDataSet::updated_key_query();
machine.mark_request_as_sent(&TransactionId::new(), &keys_query).await.unwrap();
machine
.get_identity(user2, None)
.await
.unwrap()
.unwrap()
.other()
.unwrap()
.mark_as_previously_verified()
.await
.unwrap();
let fake_room_id = room_id!("!roomid:localhost");
let encryption_settings = EncryptionSettings {
sharing_strategy: CollectStrategy::new_identity_based(),
..Default::default()
};
let request_result = machine
.share_room_key(
fake_room_id,
vec![user1, user2].into_iter(),
encryption_settings.clone(),
)
.await;
assert_let!(
Err(OlmError::SessionRecipientCollectionError(
SessionRecipientCollectionError::VerifiedUserChangedIdentity(affected_users)
)) = request_result
);
assert_eq!(2, affected_users.len());
machine
.get_identity(user1, None)
.await
.unwrap()
.unwrap()
.withdraw_verification()
.await
.unwrap();
let verification_request = machine
.get_identity(user2, None)
.await
.unwrap()
.unwrap()
.other()
.unwrap()
.verify()
.await
.unwrap();
let master_key =
&machine.get_identity(user2, None).await.unwrap().unwrap().other().unwrap().master_key;
let my_identity = machine
.get_identity(KeyDistributionTestData::me_id(), None)
.await
.expect("Should not fail to find own identity")
.expect("Our own identity should not be missing")
.own()
.expect("Our own identity should be of type Own");
let msk = json!({ user2: serde_json::to_value(master_key).expect("Should not fail to serialize")});
let ssk =
serde_json::to_value(&MaloIdentityChangeDataSet::updated_key_query().self_signing_keys)
.expect("Should not fail to serialize");
let kq_response = simulate_key_query_response_for_verification(
verification_request,
my_identity,
KeyDistributionTestData::me_id(),
user2,
msk,
ssk,
);
machine
.mark_request_as_sent(
&TransactionId::new(),
crate::IncomingResponse::KeysQuery(&kq_response),
)
.await
.unwrap();
assert!(machine.get_identity(user2, None).await.unwrap().unwrap().is_verified());
machine
.share_room_key(
fake_room_id,
vec![user1, user2].into_iter(),
encryption_settings.clone(),
)
.await
.unwrap();
}
#[async_test]
async fn test_should_rotate_based_on_visibility() {
let machine = set_up_test_machine().await;
let strategy = CollectStrategy::DeviceBasedStrategy {
only_allow_trusted_devices: false,
error_on_verified_user_problem: false,
};
let encryption_settings = EncryptionSettings {
sharing_strategy: strategy.clone(),
history_visibility: HistoryVisibility::Invited,
..Default::default()
};
let group_session = create_test_outbound_group_session(&machine, &encryption_settings);
let _ = collect_session_recipients(
machine.store(),
vec![KeyDistributionTestData::dan_id()].into_iter(),
&encryption_settings,
&group_session,
)
.await
.unwrap();
let encryption_settings = EncryptionSettings {
sharing_strategy: strategy.clone(),
history_visibility: HistoryVisibility::Shared,
..Default::default()
};
let share_result = collect_session_recipients(
machine.store(),
vec![KeyDistributionTestData::dan_id()].into_iter(),
&encryption_settings,
&group_session,
)
.await
.unwrap();
assert!(share_result.should_rotate);
}
#[async_test]
async fn test_should_rotate_based_on_device_excluded() {
let machine = set_up_test_machine().await;
let fake_room_id = room_id!("!roomid:localhost");
let strategy = CollectStrategy::DeviceBasedStrategy {
only_allow_trusted_devices: false,
error_on_verified_user_problem: false,
};
let encryption_settings =
EncryptionSettings { sharing_strategy: strategy.clone(), ..Default::default() };
let requests = machine
.share_room_key(
fake_room_id,
vec![KeyDistributionTestData::dan_id()].into_iter(),
encryption_settings.clone(),
)
.await
.unwrap();
for r in requests {
machine
.inner
.group_session_manager
.mark_request_as_sent(r.as_ref().txn_id.as_ref())
.await
.unwrap();
}
let keys_query = KeyDistributionTestData::dan_keys_query_response_device_loggedout();
let txn_id = TransactionId::new();
machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
let group_session =
machine.store().get_outbound_group_session(fake_room_id).await.unwrap().unwrap();
let share_result = collect_session_recipients(
machine.store(),
vec![KeyDistributionTestData::dan_id()].into_iter(),
&encryption_settings,
&group_session,
)
.await
.unwrap();
assert!(share_result.should_rotate);
}
async fn unsigned_of_verified_setup() -> OlmMachine {
use test_json::keys_query_sets::VerificationViolationTestData as DataSet;
let machine = OlmMachine::new(DataSet::own_id(), device_id!("LOCAL")).await;
let own_keys = DataSet::own_keys_query_response_1();
machine.mark_request_as_sent(&TransactionId::new(), &own_keys).await.unwrap();
machine
.import_cross_signing_keys(CrossSigningKeyExport {
master_key: DataSet::MASTER_KEY_PRIVATE_EXPORT.to_owned().into(),
self_signing_key: DataSet::SELF_SIGNING_KEY_PRIVATE_EXPORT.to_owned().into(),
user_signing_key: DataSet::USER_SIGNING_KEY_PRIVATE_EXPORT.to_owned().into(),
})
.await
.unwrap();
let bob_keys = DataSet::bob_keys_query_response_signed();
machine.mark_request_as_sent(&TransactionId::new(), &bob_keys).await.unwrap();
let bob_identity = machine.get_identity(DataSet::bob_id(), None).await.unwrap().unwrap();
assert!(bob_identity.other().unwrap().is_verified());
let bob_signed_device = machine
.get_device(DataSet::bob_id(), DataSet::bob_device_1_id(), None)
.await
.unwrap()
.unwrap();
assert!(bob_signed_device.is_verified());
assert!(bob_signed_device.device_owner_identity.is_some());
let bob_unsigned_device = machine
.get_device(DataSet::bob_id(), DataSet::bob_device_2_id(), None)
.await
.unwrap()
.unwrap();
assert!(!bob_unsigned_device.is_verified());
machine
}
fn error_on_verification_problem_encryption_settings() -> EncryptionSettings {
EncryptionSettings {
sharing_strategy: CollectStrategy::DeviceBasedStrategy {
only_allow_trusted_devices: false,
error_on_verified_user_problem: true,
},
..Default::default()
}
}
fn create_test_outbound_group_session(
machine: &OlmMachine,
encryption_settings: &EncryptionSettings,
) -> OutboundGroupSession {
OutboundGroupSession::new(
machine.device_id().into(),
Arc::new(machine.identity_keys()),
room_id!("!roomid:localhost"),
encryption_settings.clone(),
)
.expect("creating an outbound group session should not fail")
}
}