use std::fmt::Debug;
use prost::Message;
use serde::{Deserialize, Serialize};
use crate::{
cipher::{Cipher, Mac, MessageMac},
types::{Ed25519Keypair, Ed25519Signature},
utilities::{base64_decode, base64_encode, extract_mac, VarInt},
DecodeError,
};
#[cfg(feature = "low-level-api")]
use crate::{Ed25519PublicKey, SignatureError};
const MAC_TRUNCATED_VERSION: u8 = 3;
const VERSION: u8 = 4;
#[derive(Clone, PartialEq, Eq)]
pub struct MegolmMessage {
pub(super) version: u8,
pub(super) ciphertext: Vec<u8>,
pub(super) message_index: u32,
pub(super) mac: MessageMac,
pub(super) signature: Ed25519Signature,
}
const MESSAGE_TRUNCATED_SUFFIX_LENGTH: usize = Mac::TRUNCATED_LEN + Ed25519Signature::LENGTH;
const MESSAGE_SUFFIX_LENGTH: usize = Mac::LENGTH + Ed25519Signature::LENGTH;
impl MegolmMessage {
pub fn ciphertext(&self) -> &[u8] {
&self.ciphertext
}
pub const fn message_index(&self) -> u32 {
self.message_index
}
pub fn mac(&self) -> &[u8] {
self.mac.as_bytes()
}
pub const fn signature(&self) -> &Ed25519Signature {
&self.signature
}
pub fn from_bytes(message: &[u8]) -> Result<Self, DecodeError> {
Self::try_from(message)
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut message = self.encode_message();
message.extend(self.mac.as_bytes());
message.extend(self.signature.to_bytes());
message
}
pub fn from_base64(message: &str) -> Result<Self, DecodeError> {
Self::try_from(message)
}
pub fn to_base64(&self) -> String {
base64_encode(self.to_bytes())
}
#[cfg(feature = "low-level-api")]
pub fn add_signature(
&mut self,
signature: Ed25519Signature,
signing_key: Ed25519PublicKey,
) -> Result<(), SignatureError> {
signing_key.verify(&self.to_signature_bytes(), &signature)?;
self.signature = signature;
Ok(())
}
fn encode_message(&self) -> Vec<u8> {
let message = ProtobufMegolmMessage {
message_index: self.message_index,
ciphertext: self.ciphertext.clone(),
};
message.encode_manual(self.version)
}
fn set_mac(&mut self, mac: Mac) {
match self.mac {
MessageMac::Truncated(_) => self.mac = mac.truncate().into(),
MessageMac::Full(_) => self.mac = mac.into(),
}
}
#[cfg(feature = "low-level-api")]
pub fn encrypt(
message_index: u32,
cipher: &Cipher,
signing_key: &Ed25519Keypair,
plaintext: &[u8],
) -> Self {
MegolmMessage::encrypt_truncated_mac(message_index, cipher, signing_key, plaintext)
}
pub(super) fn encrypt_full_mac(
message_index: u32,
cipher: &Cipher,
signing_key: &Ed25519Keypair,
plaintext: &[u8],
) -> Self {
let ciphertext = cipher.encrypt(plaintext);
let message = Self {
version: VERSION,
ciphertext,
message_index,
mac: Mac([0u8; Mac::LENGTH]).into(),
signature: Ed25519Signature::from_slice(&[0; Ed25519Signature::LENGTH])
.expect("Can't create an empty signature"),
};
Self::encrypt_helper(cipher, signing_key, message)
}
pub(super) fn encrypt_truncated_mac(
message_index: u32,
cipher: &Cipher,
signing_key: &Ed25519Keypair,
plaintext: &[u8],
) -> Self {
let ciphertext = cipher.encrypt(plaintext);
let message = Self {
version: MAC_TRUNCATED_VERSION,
ciphertext,
message_index,
mac: [0u8; Mac::TRUNCATED_LEN].into(),
signature: Ed25519Signature::from_slice(&[0; Ed25519Signature::LENGTH])
.expect("Can't create an empty signature"),
};
Self::encrypt_helper(cipher, signing_key, message)
}
fn encrypt_helper(
cipher: &Cipher,
signing_key: &Ed25519Keypair,
mut message: MegolmMessage,
) -> Self {
let mac = cipher.mac(&message.to_mac_bytes());
message.set_mac(mac);
let signature = signing_key.sign(&message.to_signature_bytes());
message.signature = signature;
message
}
pub(super) fn to_mac_bytes(&self) -> Vec<u8> {
self.encode_message()
}
pub(super) fn to_signature_bytes(&self) -> Vec<u8> {
let mut message = self.encode_message();
message.extend(self.mac.as_bytes());
message
}
}
impl Serialize for MegolmMessage {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let message = self.to_base64();
serializer.serialize_str(&message)
}
}
impl<'de> Deserialize<'de> for MegolmMessage {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let ciphertext = String::deserialize(d)?;
Self::from_base64(&ciphertext).map_err(serde::de::Error::custom)
}
}
impl TryFrom<&str> for MegolmMessage {
type Error = DecodeError;
fn try_from(message: &str) -> Result<Self, Self::Error> {
let decoded = base64_decode(message)?;
Self::try_from(decoded)
}
}
impl TryFrom<Vec<u8>> for MegolmMessage {
type Error = DecodeError;
fn try_from(message: Vec<u8>) -> Result<Self, Self::Error> {
Self::try_from(message.as_slice())
}
}
impl TryFrom<&[u8]> for MegolmMessage {
type Error = DecodeError;
fn try_from(message: &[u8]) -> Result<Self, Self::Error> {
let version = *message.first().ok_or(DecodeError::MissingVersion)?;
let suffix_length = match version {
VERSION => MESSAGE_SUFFIX_LENGTH,
MAC_TRUNCATED_VERSION => MESSAGE_TRUNCATED_SUFFIX_LENGTH,
_ => return Err(DecodeError::InvalidVersion(VERSION, version)),
};
if message.len() < suffix_length + 2 {
Err(DecodeError::MessageTooShort(message.len()))
} else {
let inner = ProtobufMegolmMessage::decode(
message
.get(1..message.len() - suffix_length)
.ok_or_else(|| DecodeError::MessageTooShort(message.len()))?,
)?;
let signature_location = message.len() - Ed25519Signature::LENGTH;
let signature_slice = &message[signature_location..];
let signature = Ed25519Signature::from_slice(signature_slice)?;
let mac_slice = &message[message.len() - suffix_length..];
let mac = extract_mac(mac_slice, version == MAC_TRUNCATED_VERSION);
Ok(MegolmMessage {
version,
ciphertext: inner.ciphertext,
message_index: inner.message_index,
mac,
signature,
})
}
}
}
impl Debug for MegolmMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { version, ciphertext: _, message_index, mac: _, signature: _ } = self;
f.debug_struct("MegolmMessage")
.field("version", version)
.field("message_index", message_index)
.finish_non_exhaustive()
}
}
#[derive(Clone, Message, PartialEq, Eq)]
struct ProtobufMegolmMessage {
#[prost(uint32, tag = "1")]
pub message_index: u32,
#[prost(bytes, tag = "2")]
pub ciphertext: Vec<u8>,
}
impl ProtobufMegolmMessage {
const INDEX_TAG: &'static [u8; 1] = b"\x08";
const CIPHER_TAG: &'static [u8; 1] = b"\x12";
fn encode_manual(&self, version: u8) -> Vec<u8> {
let index = self.message_index.to_var_int();
let ciphertext_len = self.ciphertext.len().to_var_int();
[
[version].as_ref(),
Self::INDEX_TAG.as_ref(),
&index,
Self::CIPHER_TAG.as_ref(),
&ciphertext_len,
&self.ciphertext,
]
.concat()
}
}
#[cfg(test)]
mod test {
#[cfg(feature = "low-level-api")]
use std::vec;
use crate::{
cipher::Mac,
megolm::{
message::{
MAC_TRUNCATED_VERSION, MESSAGE_SUFFIX_LENGTH, MESSAGE_TRUNCATED_SUFFIX_LENGTH,
VERSION,
},
MegolmMessage,
},
DecodeError, Ed25519Signature,
};
#[cfg(feature = "low-level-api")]
use crate::{Ed25519Keypair, Ed25519PublicKey};
#[test]
fn suffix_lengths() {
assert_eq!(MESSAGE_TRUNCATED_SUFFIX_LENGTH, Mac::TRUNCATED_LEN + Ed25519Signature::LENGTH);
assert_eq!(MESSAGE_SUFFIX_LENGTH, Mac::LENGTH + Ed25519Signature::LENGTH);
}
#[test]
fn message_to_short() {
let mut bytes = [1u8; 97];
bytes[0] = VERSION;
assert!(matches!(
MegolmMessage::try_from(bytes.as_ref()),
Err(DecodeError::MessageTooShort(_))
));
}
#[test]
fn truncated_message_to_short() {
let mut bytes = [1u8; 73];
bytes[0] = MAC_TRUNCATED_VERSION;
assert!(matches!(
MegolmMessage::try_from(bytes.as_ref()),
Err(DecodeError::MessageTooShort(_))
));
}
#[cfg(feature = "low-level-api")]
#[test]
fn add_valid_signature_succeeds() {
let mut message = MegolmMessage {
version: VERSION,
ciphertext: vec![],
message_index: 0,
mac: Mac([0u8; Mac::LENGTH]).into(),
signature: Ed25519Signature::from_slice(&[0; Ed25519Signature::LENGTH]).unwrap(),
};
let signing_key = Ed25519Keypair::new();
let signature = signing_key.sign(&message.to_signature_bytes());
message
.add_signature(signature, signing_key.public_key())
.expect("Should be able to add valid signature");
assert_eq!(message.signature, signature);
}
#[cfg(feature = "low-level-api")]
#[test]
fn add_invalid_signature_fails() {
let mut message = MegolmMessage {
version: VERSION,
ciphertext: vec![],
message_index: 0,
mac: Mac([0u8; Mac::LENGTH]).into(),
signature: Ed25519Signature::from_slice(&[0; Ed25519Signature::LENGTH]).unwrap(),
};
let public_key = Ed25519PublicKey::from_slice(&[0; 32]).unwrap();
let signature = Ed25519Signature::from_slice(&[1; Ed25519Signature::LENGTH]).unwrap();
message
.add_signature(signature, public_key)
.expect_err("Should not be able to add invalid signature");
assert_ne!(message.signature, signature);
}
}