use std::fmt::Debug;
use prost::Message as ProstMessage;
use serde::{Deserialize, Serialize};
use crate::{
cipher::{Mac, MessageMac},
utilities::{base64_decode, base64_encode, extract_mac, VarInt},
Curve25519PublicKey, DecodeError,
};
const MAC_TRUNCATED_VERSION: u8 = 3;
const VERSION: u8 = 4;
#[derive(Clone, PartialEq, Eq)]
pub struct Message {
pub(crate) version: u8,
pub(crate) ratchet_key: Curve25519PublicKey,
pub(crate) chain_index: u64,
pub(crate) ciphertext: Vec<u8>,
pub(crate) mac: MessageMac,
}
impl Message {
pub const fn ratchet_key(&self) -> Curve25519PublicKey {
self.ratchet_key
}
pub const fn chain_index(&self) -> u64 {
self.chain_index
}
pub fn ciphertext(&self) -> &[u8] {
&self.ciphertext
}
pub const fn version(&self) -> u8 {
self.version
}
pub const fn mac_truncated(&self) -> bool {
self.version == MAC_TRUNCATED_VERSION
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
Self::try_from(bytes)
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut message = self.encode();
message.extend(self.mac.as_bytes());
message
}
pub fn from_base64(message: &str) -> Result<Self, DecodeError> {
Self::try_from(message)
}
pub fn to_base64(&self) -> String {
base64_encode(self.to_bytes())
}
pub(crate) fn new(
ratchet_key: Curve25519PublicKey,
chain_index: u64,
ciphertext: Vec<u8>,
) -> Self {
Self {
version: VERSION,
ratchet_key,
chain_index,
ciphertext,
mac: Mac([0u8; Mac::LENGTH]).into(),
}
}
pub(crate) fn new_truncated_mac(
ratchet_key: Curve25519PublicKey,
chain_index: u64,
ciphertext: Vec<u8>,
) -> Self {
Self {
version: MAC_TRUNCATED_VERSION,
ratchet_key,
chain_index,
ciphertext,
mac: [0u8; Mac::TRUNCATED_LEN].into(),
}
}
fn encode(&self) -> Vec<u8> {
ProtoBufMessage {
ratchet_key: self.ratchet_key.to_bytes().to_vec(),
chain_index: self.chain_index,
ciphertext: self.ciphertext.clone(),
}
.encode_manual(self.version)
}
pub(crate) fn to_mac_bytes(&self) -> Vec<u8> {
self.encode()
}
pub(crate) fn set_mac(&mut self, mac: Mac) {
match self.mac {
MessageMac::Truncated(_) => self.mac = mac.truncate().into(),
MessageMac::Full(_) => self.mac = mac.into(),
}
}
}
impl Serialize for Message {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let message = self.to_base64();
serializer.serialize_str(&message)
}
}
impl<'de> Deserialize<'de> for Message {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let ciphertext = String::deserialize(d)?;
Message::from_base64(&ciphertext).map_err(serde::de::Error::custom)
}
}
impl TryFrom<&str> for Message {
type Error = DecodeError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let decoded = base64_decode(value)?;
Self::try_from(decoded)
}
}
impl TryFrom<Vec<u8>> for Message {
type Error = DecodeError;
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
Self::try_from(value.as_slice())
}
}
impl TryFrom<&[u8]> for Message {
type Error = DecodeError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
let version = *value.first().ok_or(DecodeError::MissingVersion)?;
let mac_length = match version {
VERSION => Mac::LENGTH,
MAC_TRUNCATED_VERSION => Mac::TRUNCATED_LEN,
_ => return Err(DecodeError::InvalidVersion(VERSION, version)),
};
if value.len() < mac_length + 2 {
Err(DecodeError::MessageTooShort(value.len()))
} else {
let inner = ProtoBufMessage::decode(
value
.get(1..value.len() - mac_length)
.ok_or_else(|| DecodeError::MessageTooShort(value.len()))?,
)?;
let mac_slice = &value[value.len() - mac_length..];
if mac_slice.len() != mac_length {
Err(DecodeError::InvalidMacLength(mac_length, mac_slice.len()))
} else {
let mac = extract_mac(mac_slice, version == MAC_TRUNCATED_VERSION);
let chain_index = inner.chain_index;
let ciphertext = inner.ciphertext;
let ratchet_key = Curve25519PublicKey::from_slice(&inner.ratchet_key)?;
let message = Message { version, ratchet_key, chain_index, ciphertext, mac };
Ok(message)
}
}
}
}
impl Debug for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { version, ratchet_key, chain_index, ciphertext: _, mac: _ } = self;
f.debug_struct("Message")
.field("version", version)
.field("ratchet_key", ratchet_key)
.field("chain_index", chain_index)
.finish_non_exhaustive()
}
}
#[derive(ProstMessage, PartialEq, Eq)]
struct ProtoBufMessage {
#[prost(bytes, tag = "1")]
ratchet_key: Vec<u8>,
#[prost(uint64, tag = "2")]
chain_index: u64,
#[prost(bytes, tag = "4")]
ciphertext: Vec<u8>,
}
impl ProtoBufMessage {
const RATCHET_TAG: &'static [u8; 1] = b"\x0A";
const INDEX_TAG: &'static [u8; 1] = b"\x10";
const CIPHER_TAG: &'static [u8; 1] = b"\x22";
fn encode_manual(&self, version: u8) -> Vec<u8> {
let index = self.chain_index.to_var_int();
let ratchet_len = self.ratchet_key.len().to_var_int();
let ciphertext_len = self.ciphertext.len().to_var_int();
[
[version].as_ref(),
Self::RATCHET_TAG.as_ref(),
&ratchet_len,
&self.ratchet_key,
Self::INDEX_TAG.as_ref(),
&index,
Self::CIPHER_TAG.as_ref(),
&ciphertext_len,
&self.ciphertext,
]
.concat()
}
}
#[cfg(test)]
mod test {
use assert_matches::assert_matches;
use super::Message;
use crate::{olm::messages::message::MAC_TRUNCATED_VERSION, Curve25519PublicKey, DecodeError};
#[test]
fn encode() {
let message = b"\x03\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertext";
let message_mac =
b"\x03\n\x20ratchetkeyhereprettyplease123456\x10\x02\"\nciphertextMACHEREE";
let ratchet_key = Curve25519PublicKey::from(*b"ratchetkeyhereprettyplease123456");
let ciphertext = b"ciphertext";
let chain_index = 2;
let mut encoded = Message::new_truncated_mac(ratchet_key, chain_index, ciphertext.to_vec());
encoded.mac = (*b"MACHEREE").into();
assert_eq!(encoded.to_mac_bytes(), message.as_ref());
assert_eq!(encoded.to_bytes(), message_mac.as_ref());
assert_eq!(encoded.ciphertext(), ciphertext.to_vec());
assert_eq!(encoded.chain_index(), chain_index);
assert_eq!(encoded.version(), MAC_TRUNCATED_VERSION);
}
#[test]
fn from_bytes_too_short() {
let bytes = vec![MAC_TRUNCATED_VERSION, 0, 0, 0, 0, 0, 0, 0, 0];
let result = Message::try_from(bytes);
assert_matches!(result, Err(DecodeError::MessageTooShort(9)));
}
}