use std::{
io::{Cursor, Read},
ops::DerefMut,
};
use ruma::api::client::backup::EncryptedSessionData;
use thiserror::Error;
use vodozemac::{
pk_encryption::{Message, PkDecryption},
Curve25519PublicKey, Curve25519SecretKey,
};
use zeroize::{Zeroize, Zeroizing};
use super::MegolmV1BackupKey;
use crate::{
olm::BackedUpRoomKey,
store::BackupDecryptionKey,
types::{MegolmV1AuthData, RoomKeyBackupInfo},
};
#[derive(Debug, Error)]
pub enum DecodeError {
#[error("The decoded recovery key has an invalid prefix: expected {0:?}, got {1:?}")]
Prefix([u8; 2], [u8; 2]),
#[error("The parity byte of the recovery key doesn't match: expected {0:?}, got {1:?}")]
Parity(u8, u8),
#[error("The decoded recovery key has a invalid length: expected {0}, got {1}")]
Length(usize, usize),
#[error(transparent)]
Base58(#[from] bs58::decode::Error),
#[error(transparent)]
Base64(#[from] vodozemac::Base64DecodeError),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
PublicKey(#[from] vodozemac::KeyError),
}
#[derive(Debug, Error)]
pub enum DecryptionError {
#[error("The MAC of the ciphertext didn't pass validation {0}")]
Encryption(#[from] vodozemac::pk_encryption::Error),
#[error("The message could not been decoded: {0}")]
Decoding(#[from] vodozemac::pk_encryption::MessageDecodeError),
#[error("The decrypted message isn't valid JSON: {0}")]
Json(#[from] serde_json::error::Error),
}
impl TryFrom<String> for BackupDecryptionKey {
type Error = DecodeError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::from_base58(&value)
}
}
impl std::fmt::Display for BackupDecryptionKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let string = Zeroizing::new(self.to_base58());
let string = Zeroizing::new(
string
.chars()
.collect::<Vec<char>>()
.chunks(Self::DISPLAY_CHUNK_SIZE)
.map(|c| c.iter().collect::<String>())
.collect::<Vec<_>>()
.join(" "),
);
write!(f, "{}", string.as_str())
}
}
impl BackupDecryptionKey {
const PREFIX: [u8; 2] = [0x8b, 0x01];
const PREFIX_PARITY: u8 = Self::PREFIX[0] ^ Self::PREFIX[1];
const DISPLAY_CHUNK_SIZE: usize = 4;
fn parity_byte(bytes: &[u8]) -> u8 {
bytes.iter().fold(Self::PREFIX_PARITY, |acc, x| acc ^ x)
}
pub fn from_bytes(key: &[u8; Self::KEY_SIZE]) -> Self {
let mut inner = Box::new([0u8; Self::KEY_SIZE]);
inner.copy_from_slice(key);
Self::from_boxed_bytes(inner)
}
fn from_boxed_bytes(key: Box<[u8; Self::KEY_SIZE]>) -> Self {
Self { inner: key }
}
pub fn as_bytes(&self) -> &[u8; Self::KEY_SIZE] {
&self.inner
}
pub fn from_base64(key: &str) -> Result<Self, DecodeError> {
let decoded = Zeroizing::new(vodozemac::base64_decode(key)?);
if decoded.len() != Self::KEY_SIZE {
Err(DecodeError::Length(Self::KEY_SIZE, decoded.len()))
} else {
let mut key = Box::new([0u8; Self::KEY_SIZE]);
key.copy_from_slice(&decoded);
Ok(Self::from_boxed_bytes(key))
}
}
pub fn from_base58(value: &str) -> Result<Self, DecodeError> {
let value: String = value.chars().filter(|c| !c.is_whitespace()).collect();
let decoded = bs58::decode(value).with_alphabet(bs58::Alphabet::BITCOIN).into_vec()?;
let mut decoded = Cursor::new(decoded);
let mut prefix = [0u8; 2];
let mut key = Box::new([0u8; Self::KEY_SIZE]);
let mut expected_parity = [0u8; 1];
decoded.read_exact(&mut prefix)?;
decoded.read_exact(key.deref_mut())?;
decoded.read_exact(&mut expected_parity)?;
let expected_parity = expected_parity[0];
let parity = Self::parity_byte(key.as_ref());
let _ = Zeroizing::new(decoded.into_inner());
if prefix != Self::PREFIX {
Err(DecodeError::Prefix(Self::PREFIX, prefix))
} else if expected_parity != parity {
Err(DecodeError::Parity(expected_parity, parity))
} else {
Ok(Self::from_boxed_bytes(key))
}
}
pub fn to_base58(&self) -> String {
let bytes = Zeroizing::new(
[
Self::PREFIX.as_ref(),
self.inner.as_ref(),
[Self::parity_byte(self.inner.as_ref())].as_ref(),
]
.concat(),
);
bs58::encode(bytes.as_slice()).with_alphabet(bs58::Alphabet::BITCOIN).into_string()
}
fn get_pk_decryption(&self) -> PkDecryption {
let secret_key = Curve25519SecretKey::from_slice(self.inner.as_ref());
PkDecryption::from_key(secret_key)
}
pub fn megolm_v1_public_key(&self) -> MegolmV1BackupKey {
let pk = self.get_pk_decryption();
MegolmV1BackupKey::new(pk.public_key(), None)
}
pub fn to_backup_info(&self) -> RoomKeyBackupInfo {
let pk = self.get_pk_decryption();
let auth_data = MegolmV1AuthData::new(pk.public_key(), Default::default());
RoomKeyBackupInfo::MegolmBackupV1Curve25519AesSha2(auth_data)
}
pub fn decrypt_v1(
&self,
ephemeral_key: &str,
mac: &str,
ciphertext: &str,
) -> Result<String, DecryptionError> {
let message = Message::from_base64(ciphertext, mac, ephemeral_key)?;
let pk = self.get_pk_decryption();
let decrypted = pk.decrypt(&message)?;
Ok(String::from_utf8_lossy(&decrypted).to_string())
}
pub fn decrypt_session_data(
&self,
session_data: EncryptedSessionData,
) -> Result<BackedUpRoomKey, DecryptionError> {
let message = Message {
ciphertext: session_data.ciphertext.into_inner(),
mac: session_data.mac.into_inner(),
ephemeral_key: Curve25519PublicKey::from_slice(session_data.ephemeral.as_bytes())
.map_err(vodozemac::pk_encryption::MessageDecodeError::from)?,
};
let pk = self.get_pk_decryption();
let mut decrypted = pk.decrypt(&message)?;
let result = serde_json::from_slice(&decrypted);
decrypted.zeroize();
Ok(result?)
}
pub fn backup_key_matches(&self, info: &RoomKeyBackupInfo) -> bool {
match info {
RoomKeyBackupInfo::MegolmBackupV1Curve25519AesSha2(info) => {
let pk = self.get_pk_decryption();
let public_key = pk.public_key();
info.public_key == public_key
}
RoomKeyBackupInfo::Other { .. } => false,
}
}
}
#[cfg(test)]
mod tests {
use matrix_sdk_test::async_test;
use ruma::api::client::backup::KeyBackupData;
use serde_json::json;
use super::{BackupDecryptionKey, DecodeError};
use crate::olm::{BackedUpRoomKey, ExportedRoomKey, InboundGroupSession};
const TEST_KEY: [u8; 32] = [
0x77, 0x07, 0x6D, 0x0A, 0x73, 0x18, 0xA5, 0x7D, 0x3C, 0x16, 0xC1, 0x72, 0x51, 0xB2, 0x66,
0x45, 0xDF, 0x4C, 0x2F, 0x87, 0xEB, 0xC0, 0x99, 0x2A, 0xB1, 0x77, 0xFB, 0xA5, 0x1D, 0xB9,
0x2C, 0x2A,
];
fn room_key() -> ExportedRoomKey {
let json = json!({
"algorithm": "m.megolm.v1.aes-sha2",
"sender_key": "DeHIg4gwhClxzFYcmNntPNF9YtsdZbmMy8+3kzCMXHA",
"session_id": "gM8i47Xhu0q52xLfgUXzanCMpLinoyVyH7R58cBuVBU",
"room_id": "!DovneieKSTkdHKpIXy:morpheus.localhost",
"session_key": "AQAAAABvWMNZjKFtebYIePKieQguozuoLgzeY6wKcyJjLJcJtQgy1dPqTBD12U+XrYLrRHn\
lKmxoozlhFqJl456+9hlHCL+yq+6ScFuBHtJepnY1l2bdLb4T0JMDkNsNErkiLiLnD6yp3J\
DSjIhkdHxmup/huygrmroq6/L5TaThEoqvW4DPIuO14btKudsS34FF82pwjKS4p6Mlch+0e\
fHAblQV",
"sender_claimed_keys":{},
"forwarding_curve25519_key_chain":[]
});
serde_json::from_value(json)
.expect("We should be able to deserialize our backed up room key")
}
#[test]
fn base64_decoding() -> Result<(), DecodeError> {
let key = BackupDecryptionKey::new().expect("Can't create a new recovery key");
let base64 = key.to_base64();
let decoded_key = BackupDecryptionKey::from_base64(&base64)?;
assert_eq!(key.inner, decoded_key.inner, "The decode key doesn't match the original");
BackupDecryptionKey::from_base64("i").expect_err("The recovery key is too short");
Ok(())
}
#[test]
fn base58_decoding() -> Result<(), DecodeError> {
let key = BackupDecryptionKey::new().expect("Can't create a new recovery key");
let base64 = key.to_base58();
let decoded_key = BackupDecryptionKey::from_base58(&base64)?;
assert_eq!(key.inner, decoded_key.inner, "The decode key doesn't match the original");
let test_key =
BackupDecryptionKey::from_base58("EsTcLW2KPGiFwKEA3As5g5c4BXwkqeeJZJV8Q9fugUMNUE4d")?;
assert_eq!(
test_key.as_bytes(),
&TEST_KEY,
"The decoded recovery key doesn't match the test key"
);
let test_key = BackupDecryptionKey::from_base58(
"EsTc LW2K PGiF wKEA 3As5 g5c4 BXwk qeeJ ZJV8 Q9fu gUMN UE4d",
)?;
assert_eq!(
test_key.as_bytes(),
&TEST_KEY,
"The decoded recovery key doesn't match the test key"
);
BackupDecryptionKey::from_base58(
"EsTc LW2K PGiF wKEA 3As5 g5c4 BXwk qeeJ ZJV8 Q9fu gUMN UE4e",
)
.expect_err("Can't create a recovery key if the parity byte is invalid");
Ok(())
}
#[test]
fn test_decrypt_key() {
let decryption_key =
BackupDecryptionKey::from_base64("Ha9cklU/9NqFo9WKdVfGzmqUL/9wlkdxfEitbSIPVXw")
.unwrap();
let data = json!({
"first_message_index": 0,
"forwarded_count": 0,
"is_verified": false,
"session_data": {
"ephemeral": "HlLi76oV6wxHz3PCqE/bxJi6yF1HnYz5Dq3T+d/KpRw",
"ciphertext": "MuM8E3Yc6TSAvhVGb77rQ++jE6p9dRepx63/3YPD2wACKAppkZHeFrnTH6wJ/HSyrmzo\
7HfwqVl6tKNpfooSTHqUf6x1LHz+h4B/Id5ITO1WYt16AaI40LOnZqTkJZCfSPuE2oxa\
lwEHnCS3biWybutcnrBFPR3LMtaeHvvkb+k3ny9l5ZpsU9G7vCm3XoeYkWfLekWXvDhb\
qWrylXD0+CNUuaQJ/S527TzLd4XKctqVjjO/cCH7q+9utt9WJAfK8LGaWT/mZ3AeWjf5\
kiqOpKKf5Cn4n5SSil5p/pvGYmjnURvZSEeQIzHgvunIBEPtzK/MYEPOXe/P5achNGlC\
x+5N19Ftyp9TFaTFlTWCTi0mpD7ePfCNISrwpozAz9HZc0OhA8+1aSc7rhYFIeAYXFU3\
26NuFIFHI5pvpSxjzPQlOA+mavIKmiRAtjlLw11IVKTxgrdT4N8lXeMr4ndCSmvIkAzF\
Mo1uZA4fzjiAdQJE4/2WeXFNNpvdfoYmX8Zl9CAYjpSO5HvpwkAbk4/iLEH3hDfCVUwD\
fMh05PdGLnxeRpiEFWSMSsJNp+OWAA+5JsF41BoRGrxoXXT+VKqlUDONd+O296Psu8Q+\
d8/S618",
"mac": "GtMrurhDTwo"
}
});
let key_backup_data: KeyBackupData = serde_json::from_value(data).unwrap();
let ephemeral = key_backup_data.session_data.ephemeral.encode();
let ciphertext = key_backup_data.session_data.ciphertext.encode();
let mac = key_backup_data.session_data.mac.encode();
let decrypted = decryption_key
.decrypt_v1(&ephemeral, &mac, &ciphertext)
.expect("The backed up key should be decrypted successfully");
let _: BackedUpRoomKey = serde_json::from_str(&decrypted)
.expect("The decrypted payload should contain valid JSON");
let _ = decryption_key
.decrypt_session_data(key_backup_data.session_data)
.expect("The backed up key should be decrypted successfully");
}
#[async_test]
async fn test_encryption_cycle() {
let session = InboundGroupSession::from_export(&room_key()).unwrap();
let decryption_key = BackupDecryptionKey::new().unwrap();
let encryption_key = decryption_key.megolm_v1_public_key();
let encrypted = encryption_key.encrypt(session).await;
let _ = decryption_key
.decrypt_session_data(encrypted.session_data)
.expect("We should be able to decrypt a just encrypted room key");
}
#[test]
fn key_matches() {
let decryption_key = BackupDecryptionKey::new().unwrap();
let key_info = decryption_key.to_backup_info();
assert!(
decryption_key.backup_key_matches(&key_info),
"The backup info should match the decryption key"
);
}
}