use std::cmp::Ordering;
use aes::cipher::block_padding::UnpadError;
use hmac::digest::MacError;
use serde::{Deserialize, Serialize};
use subtle::ConstantTimeEq;
use thiserror::Error;
use super::{
default_config,
message::MegolmMessage,
ratchet::Ratchet,
session_config::Version,
session_keys::{ExportedSessionKey, SessionKey},
GroupSession, SessionConfig,
};
use crate::{
cipher::{Cipher, Mac, MessageMac},
types::{Ed25519PublicKey, SignatureError},
utilities::{base64_encode, pickle, unpickle},
PickleError,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SessionOrdering {
Equal,
Better,
Worse,
Unconnected,
}
#[derive(Debug, Error)]
pub enum DecryptionError {
#[error("The signature on the message was invalid: {0}")]
Signature(#[from] SignatureError),
#[error("Failed decrypting Megolm message, invalid MAC: {0}")]
InvalidMAC(#[from] MacError),
#[error("Failed decrypting Olm message, invalid MAC length: expected {0}, got {1}")]
InvalidMACLength(usize, usize),
#[error("Failed decrypting Megolm message, invalid padding")]
InvalidPadding(#[from] UnpadError),
#[error(
"The message was encrypted using an unknown message index, \
first known index {0}, index of the message {1}"
)]
UnknownMessageIndex(u32, u32),
}
#[derive(Deserialize)]
#[serde(try_from = "InboundGroupSessionPickle")]
pub struct InboundGroupSession {
initial_ratchet: Ratchet,
latest_ratchet: Ratchet,
signing_key: Ed25519PublicKey,
signing_key_verified: bool,
config: SessionConfig,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DecryptedMessage {
pub plaintext: Vec<u8>,
pub message_index: u32,
}
impl InboundGroupSession {
pub fn new(key: &SessionKey, session_config: SessionConfig) -> Self {
let initial_ratchet =
Ratchet::from_bytes(key.session_key.ratchet.clone(), key.session_key.ratchet_index);
let latest_ratchet = initial_ratchet.clone();
Self {
initial_ratchet,
latest_ratchet,
signing_key: key.session_key.signing_key,
signing_key_verified: true,
config: session_config,
}
}
pub fn import(session_key: &ExportedSessionKey, session_config: SessionConfig) -> Self {
let initial_ratchet =
Ratchet::from_bytes(session_key.ratchet.clone(), session_key.ratchet_index);
let latest_ratchet = initial_ratchet.clone();
Self {
initial_ratchet,
latest_ratchet,
signing_key: session_key.signing_key,
signing_key_verified: false,
config: session_config,
}
}
pub fn session_id(&self) -> String {
base64_encode(self.signing_key.as_bytes())
}
pub fn connected(&mut self, other: &mut InboundGroupSession) -> bool {
if self.config != other.config || self.signing_key != other.signing_key {
false
} else if let Some(ratchet) = self.find_ratchet(other.first_known_index()) {
ratchet.ct_eq(&other.initial_ratchet).into()
} else if let Some(ratchet) = other.find_ratchet(self.first_known_index()) {
self.initial_ratchet.ct_eq(ratchet).into()
} else {
unreachable!("Either index A >= index B, or vice versa. There is no third option.")
}
}
pub fn compare(&mut self, other: &mut InboundGroupSession) -> SessionOrdering {
if self.connected(other) {
match self.first_known_index().cmp(&other.first_known_index()) {
Ordering::Less => SessionOrdering::Better,
Ordering::Equal => SessionOrdering::Equal,
Ordering::Greater => SessionOrdering::Worse,
}
} else {
SessionOrdering::Unconnected
}
}
pub fn merge(&mut self, other: &mut InboundGroupSession) -> Option<InboundGroupSession> {
let best_ratchet = match self.compare(other) {
SessionOrdering::Equal | SessionOrdering::Better => Some(self.initial_ratchet.clone()),
SessionOrdering::Worse => Some(other.initial_ratchet.clone()),
SessionOrdering::Unconnected => None,
}?;
Some(InboundGroupSession {
initial_ratchet: best_ratchet.clone(),
latest_ratchet: best_ratchet,
signing_key: self.signing_key,
signing_key_verified: self.signing_key_verified || other.signing_key_verified,
config: self.config,
})
}
pub fn first_known_index(&self) -> u32 {
self.initial_ratchet.index()
}
pub fn advance_to(&mut self, index: u32) -> bool {
if self.first_known_index() < index {
self.initial_ratchet.advance_to(index);
if self.latest_ratchet.index() < index {
self.latest_ratchet = self.initial_ratchet.clone();
}
true
} else {
false
}
}
#[cfg(feature = "low-level-api")]
pub fn get_cipher_at(&self, message_index: u32) -> Option<Cipher> {
if self.initial_ratchet.index() <= message_index {
let mut ratchet = self.initial_ratchet.clone();
if self.initial_ratchet.index() < message_index {
ratchet.advance_to(message_index);
}
Some(Cipher::new_megolm(ratchet.as_bytes()))
} else {
None
}
}
fn find_ratchet(&mut self, message_index: u32) -> Option<&Ratchet> {
if self.initial_ratchet.index() == message_index {
Some(&self.initial_ratchet)
} else if self.latest_ratchet.index() == message_index {
Some(&self.latest_ratchet)
} else if self.latest_ratchet.index() < message_index {
self.latest_ratchet.advance_to(message_index);
Some(&self.latest_ratchet)
} else if self.initial_ratchet.index() < message_index {
self.latest_ratchet = self.initial_ratchet.clone();
self.latest_ratchet.advance_to(message_index);
Some(&self.latest_ratchet)
} else {
None
}
}
fn verify_mac(&self, cipher: &Cipher, message: &MegolmMessage) -> Result<(), DecryptionError> {
match self.config.version {
Version::V1 => {
if let MessageMac::Truncated(m) = &message.mac {
Ok(cipher.verify_truncated_mac(&message.to_mac_bytes(), m)?)
} else {
Err(DecryptionError::InvalidMACLength(Mac::TRUNCATED_LEN, Mac::LENGTH))
}
}
Version::V2 => {
if let MessageMac::Full(m) = &message.mac {
Ok(cipher.verify_mac(&message.to_mac_bytes(), m)?)
} else {
Err(DecryptionError::InvalidMACLength(Mac::LENGTH, Mac::TRUNCATED_LEN))
}
}
}
}
pub fn decrypt(
&mut self,
message: &MegolmMessage,
) -> Result<DecryptedMessage, DecryptionError> {
self.signing_key.verify(&message.to_signature_bytes(), &message.signature)?;
if let Some(ratchet) = self.find_ratchet(message.message_index) {
let cipher = Cipher::new_megolm(ratchet.as_bytes());
self.verify_mac(&cipher, message)?;
let plaintext = cipher.decrypt(&message.ciphertext)?;
Ok(DecryptedMessage { plaintext, message_index: message.message_index })
} else {
Err(DecryptionError::UnknownMessageIndex(
self.initial_ratchet.index(),
message.message_index,
))
}
}
pub fn export_at(&mut self, index: u32) -> Option<ExportedSessionKey> {
let signing_key = self.signing_key;
self.find_ratchet(index).map(|ratchet| ExportedSessionKey::new(ratchet, signing_key))
}
pub fn export_at_first_known_index(&self) -> ExportedSessionKey {
ExportedSessionKey::new(&self.initial_ratchet, self.signing_key)
}
pub fn pickle(&self) -> InboundGroupSessionPickle {
InboundGroupSessionPickle {
initial_ratchet: self.initial_ratchet.clone(),
signing_key: self.signing_key,
signing_key_verified: self.signing_key_verified,
config: self.config,
}
}
pub fn from_pickle(pickle: InboundGroupSessionPickle) -> Self {
Self::from(pickle)
}
#[cfg(feature = "libolm-compat")]
pub fn from_libolm_pickle(
pickle: &str,
pickle_key: &[u8],
) -> Result<Self, crate::LibolmPickleError> {
use crate::{
megolm::inbound_group_session::libolm_compat::Pickle, utilities::unpickle_libolm,
};
const PICKLE_VERSION: u32 = 2;
unpickle_libolm::<Pickle, _>(pickle, pickle_key, PICKLE_VERSION)
}
}
#[cfg(feature = "libolm-compat")]
mod libolm_compat {
use matrix_pickle::Decode;
use zeroize::{Zeroize, ZeroizeOnDrop};
use super::InboundGroupSession;
use crate::{
megolm::{libolm::LibolmRatchetPickle, SessionConfig},
Ed25519PublicKey,
};
#[derive(Zeroize, ZeroizeOnDrop, Decode)]
pub(super) struct Pickle {
version: u32,
initial_ratchet: LibolmRatchetPickle,
latest_ratchet: LibolmRatchetPickle,
signing_key: [u8; 32],
signing_key_verified: bool,
}
impl TryFrom<Pickle> for InboundGroupSession {
type Error = crate::LibolmPickleError;
fn try_from(pickle: Pickle) -> Result<Self, Self::Error> {
#[allow(clippy::needless_borrow)]
let initial_ratchet = (&pickle.initial_ratchet).into();
#[allow(clippy::needless_borrow)]
let latest_ratchet = (&pickle.latest_ratchet).into();
let signing_key = Ed25519PublicKey::from_slice(&pickle.signing_key)?;
let signing_key_verified = pickle.signing_key_verified;
Ok(Self {
initial_ratchet,
latest_ratchet,
signing_key,
signing_key_verified,
config: SessionConfig::version_1(),
})
}
}
}
#[derive(Serialize, Deserialize)]
pub struct InboundGroupSessionPickle {
initial_ratchet: Ratchet,
signing_key: Ed25519PublicKey,
#[allow(dead_code)]
signing_key_verified: bool,
#[serde(default = "default_config")]
config: SessionConfig,
}
impl InboundGroupSessionPickle {
pub fn encrypt(self, pickle_key: &[u8; 32]) -> String {
pickle(&self, pickle_key)
}
pub fn from_encrypted(ciphertext: &str, pickle_key: &[u8; 32]) -> Result<Self, PickleError> {
unpickle(ciphertext, pickle_key)
}
}
impl From<&InboundGroupSession> for InboundGroupSessionPickle {
fn from(session: &InboundGroupSession) -> Self {
session.pickle()
}
}
impl From<InboundGroupSessionPickle> for InboundGroupSession {
fn from(pickle: InboundGroupSessionPickle) -> Self {
Self {
initial_ratchet: pickle.initial_ratchet.clone(),
latest_ratchet: pickle.initial_ratchet,
signing_key: pickle.signing_key,
signing_key_verified: pickle.signing_key_verified,
config: pickle.config,
}
}
}
impl From<&GroupSession> for InboundGroupSession {
fn from(session: &GroupSession) -> Self {
Self::new(&session.session_key(), session.session_config())
}
}
#[cfg(test)]
mod test {
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
use super::InboundGroupSession;
use crate::{
cipher::Cipher,
megolm::{GroupSession, SessionConfig, SessionKey, SessionOrdering},
};
#[test]
fn advance_inbound_session() {
let mut session = InboundGroupSession::from(&GroupSession::new(Default::default()));
assert_eq!(session.first_known_index(), 0);
assert_eq!(session.latest_ratchet.index(), 0);
assert!(session.advance_to(10));
assert_eq!(session.first_known_index(), 10);
assert_eq!(session.latest_ratchet.index(), 10);
assert!(!session.advance_to(10));
assert!(session.advance_to(20));
assert_eq!(session.first_known_index(), 20);
assert_eq!(session.latest_ratchet.index(), 20);
}
#[test]
fn connecting() {
let outbound = GroupSession::new(Default::default());
let mut session = InboundGroupSession::from(&outbound);
let mut clone = InboundGroupSession::from(&outbound);
assert!(session.connected(&mut clone));
assert!(clone.connected(&mut session));
clone.advance_to(10);
assert!(session.connected(&mut clone));
assert!(clone.connected(&mut session));
let mut other = InboundGroupSession::from(&GroupSession::new(Default::default()));
assert!(!session.connected(&mut other));
assert!(!clone.connected(&mut other));
other.signing_key = session.signing_key;
assert!(!session.connected(&mut other));
assert!(!clone.connected(&mut other));
let session_key = session.export_at_first_known_index();
let mut different_config =
InboundGroupSession::import(&session_key, SessionConfig::version_1());
assert!(!session.connected(&mut different_config));
assert!(!different_config.connected(&mut session));
}
#[test]
fn comparison() {
let outbound = GroupSession::new(Default::default());
let mut session = InboundGroupSession::from(&outbound);
let mut clone = InboundGroupSession::from(&outbound);
assert_eq!(session.compare(&mut clone), SessionOrdering::Equal);
assert_eq!(clone.compare(&mut session), SessionOrdering::Equal);
clone.advance_to(10);
assert_eq!(session.compare(&mut clone), SessionOrdering::Better);
assert_eq!(clone.compare(&mut session), SessionOrdering::Worse);
let mut other = InboundGroupSession::from(&GroupSession::new(Default::default()));
assert_eq!(session.compare(&mut other), SessionOrdering::Unconnected);
assert_eq!(clone.compare(&mut other), SessionOrdering::Unconnected);
other.signing_key = session.signing_key;
assert_eq!(session.compare(&mut other), SessionOrdering::Unconnected);
assert_eq!(clone.compare(&mut other), SessionOrdering::Unconnected);
}
#[test]
fn upgrade() {
let session = GroupSession::new(Default::default());
let session_key = session.session_key();
let mut first_session = InboundGroupSession::new(&session_key, Default::default());
let mut second_session =
InboundGroupSession::import(&first_session.export_at(10).unwrap(), Default::default());
assert!(!second_session.signing_key_verified);
assert_eq!(first_session.compare(&mut second_session), SessionOrdering::Better);
let mut merged = second_session.merge(&mut first_session).unwrap();
assert!(merged.signing_key_verified);
assert_eq!(merged.compare(&mut second_session), SessionOrdering::Better);
assert_eq!(merged.compare(&mut first_session), SessionOrdering::Equal);
}
#[test]
fn verify_mac() {
let olm_session = OlmOutboundGroupSession::new();
let session_key = SessionKey::from_base64(&olm_session.session_key()).unwrap();
let message = olm_session.encrypt("Hello").as_str().try_into().unwrap();
let mut session = InboundGroupSession::new(&session_key, SessionConfig::version_1());
let ratchet = session.find_ratchet(0).unwrap();
let cipher = Cipher::new_megolm(ratchet.as_bytes());
session
.verify_mac(&cipher, &message)
.expect("Should verify MAC from matching outbound session");
let olm_session = OlmOutboundGroupSession::new();
let session_key = SessionKey::from_base64(&olm_session.session_key()).unwrap();
let mut session = InboundGroupSession::new(&session_key, SessionConfig::version_1());
let ratchet = session.find_ratchet(0).unwrap();
let cipher = Cipher::new_megolm(ratchet.as_bytes());
session
.verify_mac(&cipher, &message)
.expect_err("Should not verify MAC from different outbound session");
}
#[cfg(feature = "low-level-api")]
#[test]
fn get_cipher_at() {
let mut group_session = GroupSession::new(Default::default());
group_session.encrypt("test1");
group_session.encrypt("test2");
let session = InboundGroupSession::from(&group_session);
assert!(session.get_cipher_at(0).is_none());
assert!(session.get_cipher_at(1).is_none());
assert!(session.get_cipher_at(2).is_some());
assert!(session.get_cipher_at(1000).is_some());
assert_ne!(
session.get_cipher_at(2).unwrap().encrypt(b""),
session.get_cipher_at(3).unwrap().encrypt(b"")
);
assert_ne!(
session.get_cipher_at(3).unwrap().encrypt(b""),
session.get_cipher_at(1000).unwrap().encrypt(b"")
);
}
}