use core::cmp::min;
use core::num::NonZeroU64;
use ::bytes::{Buf, BufMut};
use crate::DecodeError;
#[inline]
pub fn encode_varint(mut value: u64, buf: &mut impl BufMut) {
for _ in 0..10 {
if value < 0x80 {
buf.put_u8(value as u8);
break;
} else {
buf.put_u8(((value & 0x7F) | 0x80) as u8);
value >>= 7;
}
}
}
#[inline]
pub fn encoded_len_varint(value: u64) -> usize {
let log2value = unsafe { NonZeroU64::new_unchecked(value | 1) }.ilog2();
((log2value * 9 + (64 + 9)) / 64) as usize
}
#[inline]
pub fn decode_varint(buf: &mut impl Buf) -> Result<u64, DecodeError> {
let bytes = buf.chunk();
let len = bytes.len();
if len == 0 {
return Err(DecodeError::new("invalid varint"));
}
let byte = bytes[0];
if byte < 0x80 {
buf.advance(1);
Ok(u64::from(byte))
} else if len > 10 || bytes[len - 1] < 0x80 {
let (value, advance) = decode_varint_slice(bytes)?;
buf.advance(advance);
Ok(value)
} else {
decode_varint_slow(buf)
}
}
#[inline]
fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> {
assert!(!bytes.is_empty());
assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80);
let mut b: u8 = unsafe { *bytes.get_unchecked(0) };
let mut part0: u32 = u32::from(b);
if b < 0x80 {
return Ok((u64::from(part0), 1));
};
part0 -= 0x80;
b = unsafe { *bytes.get_unchecked(1) };
part0 += u32::from(b) << 7;
if b < 0x80 {
return Ok((u64::from(part0), 2));
};
part0 -= 0x80 << 7;
b = unsafe { *bytes.get_unchecked(2) };
part0 += u32::from(b) << 14;
if b < 0x80 {
return Ok((u64::from(part0), 3));
};
part0 -= 0x80 << 14;
b = unsafe { *bytes.get_unchecked(3) };
part0 += u32::from(b) << 21;
if b < 0x80 {
return Ok((u64::from(part0), 4));
};
part0 -= 0x80 << 21;
let value = u64::from(part0);
b = unsafe { *bytes.get_unchecked(4) };
let mut part1: u32 = u32::from(b);
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 5));
};
part1 -= 0x80;
b = unsafe { *bytes.get_unchecked(5) };
part1 += u32::from(b) << 7;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 6));
};
part1 -= 0x80 << 7;
b = unsafe { *bytes.get_unchecked(6) };
part1 += u32::from(b) << 14;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 7));
};
part1 -= 0x80 << 14;
b = unsafe { *bytes.get_unchecked(7) };
part1 += u32::from(b) << 21;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 8));
};
part1 -= 0x80 << 21;
let value = value + ((u64::from(part1)) << 28);
b = unsafe { *bytes.get_unchecked(8) };
let mut part2: u32 = u32::from(b);
if b < 0x80 {
return Ok((value + (u64::from(part2) << 56), 9));
};
part2 -= 0x80;
b = unsafe { *bytes.get_unchecked(9) };
part2 += u32::from(b) << 7;
if b < 0x02 {
return Ok((value + (u64::from(part2) << 56), 10));
};
Err(DecodeError::new("invalid varint"))
}
#[inline(never)]
#[cold]
fn decode_varint_slow(buf: &mut impl Buf) -> Result<u64, DecodeError> {
let mut value = 0;
for count in 0..min(10, buf.remaining()) {
let byte = buf.get_u8();
value |= u64::from(byte & 0x7F) << (count * 7);
if byte <= 0x7F {
if count == 9 && byte >= 0x02 {
return Err(DecodeError::new("invalid varint"));
} else {
return Ok(value);
}
}
}
Err(DecodeError::new("invalid varint"))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn varint() {
fn check(value: u64, encoded: &[u8]) {
let mut buf = Vec::with_capacity(1);
encode_varint(value, &mut buf);
assert_eq!(buf, encoded);
let mut buf = Vec::with_capacity(100);
encode_varint(value, &mut buf);
assert_eq!(buf, encoded);
assert_eq!(encoded_len_varint(value), encoded.len());
let mut encoded_copy = encoded;
let roundtrip_value = decode_varint(&mut encoded_copy).expect("decoding failed");
assert_eq!(value, roundtrip_value);
let mut encoded_copy = encoded;
let roundtrip_value =
decode_varint_slow(&mut encoded_copy).expect("slow decoding failed");
assert_eq!(value, roundtrip_value);
}
check(2u64.pow(0) - 1, &[0x00]);
check(2u64.pow(0), &[0x01]);
check(2u64.pow(7) - 1, &[0x7F]);
check(2u64.pow(7), &[0x80, 0x01]);
check(300, &[0xAC, 0x02]);
check(2u64.pow(14) - 1, &[0xFF, 0x7F]);
check(2u64.pow(14), &[0x80, 0x80, 0x01]);
check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]);
check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]);
check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]);
check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]);
check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
check(
2u64.pow(49) - 1,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
check(
2u64.pow(49),
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
);
check(
2u64.pow(56) - 1,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
check(
2u64.pow(56),
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
);
check(
2u64.pow(63) - 1,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
check(
2u64.pow(63),
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
);
check(
u64::MAX,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01],
);
}
const U64_MAX_PLUS_ONE: &[u8] = &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02];
#[test]
fn varint_overflow() {
let mut copy = U64_MAX_PLUS_ONE;
decode_varint(&mut copy).expect_err("decoding u64::MAX + 1 succeeded");
}
#[test]
fn variant_slow_overflow() {
let mut copy = U64_MAX_PLUS_ONE;
decode_varint_slow(&mut copy).expect_err("slow decoding u64::MAX + 1 succeeded");
}
}