vodozemac/olm/account/
fallback_keys.rsuse serde::{Deserialize, Serialize};
use crate::{
types::{Curve25519SecretKey, KeyId},
Curve25519PublicKey,
};
#[derive(Serialize, Deserialize, Clone)]
pub(super) struct FallbackKey {
pub key_id: KeyId,
pub key: Curve25519SecretKey,
pub published: bool,
}
impl FallbackKey {
fn new(key_id: KeyId) -> Self {
let key = Curve25519SecretKey::new();
Self { key_id, key, published: false }
}
pub fn public_key(&self) -> Curve25519PublicKey {
Curve25519PublicKey::from(&self.key)
}
pub const fn secret_key(&self) -> &Curve25519SecretKey {
&self.key
}
pub const fn key_id(&self) -> KeyId {
self.key_id
}
pub fn mark_as_published(&mut self) {
self.published = true;
}
pub const fn published(&self) -> bool {
self.published
}
}
#[derive(Serialize, Deserialize, Clone)]
pub(super) struct FallbackKeys {
pub key_id: u64,
pub fallback_key: Option<FallbackKey>,
pub previous_fallback_key: Option<FallbackKey>,
}
impl FallbackKeys {
pub const fn new() -> Self {
Self { key_id: 0, fallback_key: None, previous_fallback_key: None }
}
pub fn mark_as_published(&mut self) {
if let Some(f) = self.fallback_key.as_mut() {
f.mark_as_published()
}
}
pub fn generate_fallback_key(&mut self) -> Option<Curve25519PublicKey> {
let key_id = KeyId(self.key_id);
self.key_id += 1;
let ret = self.previous_fallback_key.take().map(|f| f.public_key());
self.previous_fallback_key = self.fallback_key.take();
self.fallback_key = Some(FallbackKey::new(key_id));
ret
}
pub fn get_secret_key(&self, public_key: &Curve25519PublicKey) -> Option<&Curve25519SecretKey> {
self.fallback_key
.as_ref()
.filter(|f| f.public_key() == *public_key)
.or_else(|| {
self.previous_fallback_key.as_ref().filter(|f| f.public_key() == *public_key)
})
.map(|f| f.secret_key())
}
pub fn forget_previous_fallback_key(&mut self) -> Option<FallbackKey> {
self.previous_fallback_key.take()
}
pub fn unpublished_fallback_key(&self) -> Option<&FallbackKey> {
self.fallback_key.as_ref().filter(|f| !f.published())
}
}
#[cfg(test)]
mod test {
use super::FallbackKeys;
#[test]
fn fallback_key_fetching() {
let err = "Missing fallback key";
let mut fallback_keys = FallbackKeys::new();
fallback_keys.generate_fallback_key();
let public_key = fallback_keys.fallback_key.as_ref().expect(err).public_key();
let secret_bytes = fallback_keys.fallback_key.as_ref().expect(err).key.to_bytes();
let fetched_key = fallback_keys.get_secret_key(&public_key).expect(err);
assert_eq!(secret_bytes, fetched_key.to_bytes());
fallback_keys.generate_fallback_key();
let fetched_key = fallback_keys.get_secret_key(&public_key).expect(err);
assert_eq!(secret_bytes, fetched_key.to_bytes());
let public_key = fallback_keys.fallback_key.as_ref().expect(err).public_key();
let secret_bytes = fallback_keys.fallback_key.as_ref().expect(err).key.to_bytes();
let fetched_key = fallback_keys.get_secret_key(&public_key).expect(err);
assert_eq!(secret_bytes, fetched_key.to_bytes());
}
#[test]
fn fallback_key_publishing() {
let mut fallback_keys = FallbackKeys::new();
assert_eq!(fallback_keys.key_id, 0);
fallback_keys.generate_fallback_key();
assert_eq!(fallback_keys.key_id, 1);
assert!(fallback_keys.unpublished_fallback_key().is_some());
fallback_keys.mark_as_published();
assert!(fallback_keys.unpublished_fallback_key().is_none());
}
}