use std::{
hash::{Hash, Hasher},
mem,
sync::RwLock,
task::{Context, Poll, Waker},
};
#[derive(Debug)]
pub struct ObservableState<T> {
value: T,
metadata: RwLock<ObservableStateMetadata>,
}
#[derive(Debug)]
struct ObservableStateMetadata {
version: u64,
wakers: Vec<Waker>,
}
impl Default for ObservableStateMetadata {
fn default() -> Self {
Self { version: 1, wakers: Vec::new() }
}
}
impl<T> ObservableState<T> {
pub(crate) fn new(value: T) -> Self {
Self { value, metadata: Default::default() }
}
pub(crate) fn get(&self) -> &T {
&self.value
}
pub(crate) fn version(&self) -> u64 {
self.metadata.read().unwrap().version
}
pub(crate) fn poll_update(
&self,
observed_version: &mut u64,
cx: &Context<'_>,
) -> Poll<Option<()>> {
let mut metadata = self.metadata.write().unwrap();
if metadata.version == 0 {
Poll::Ready(None)
} else if *observed_version < metadata.version {
*observed_version = metadata.version;
Poll::Ready(Some(()))
} else {
metadata.wakers.push(cx.waker().clone());
Poll::Pending
}
}
pub(crate) fn set(&mut self, value: T) -> T {
let result = mem::replace(&mut self.value, value);
self.incr_version_and_wake();
result
}
pub(crate) fn set_if_not_eq(&mut self, value: T) -> Option<T>
where
T: PartialEq,
{
if self.value != value {
Some(self.set(value))
} else {
None
}
}
pub(crate) fn set_if_hash_not_eq(&mut self, value: T) -> Option<T>
where
T: Hash,
{
if hash(&self.value) != hash(&value) {
Some(self.set(value))
} else {
None
}
}
pub(crate) fn update(&mut self, f: impl FnOnce(&mut T)) {
f(&mut self.value);
self.incr_version_and_wake();
}
pub(crate) fn update_if(&mut self, f: impl FnOnce(&mut T) -> bool) {
if f(&mut self.value) {
self.incr_version_and_wake();
}
}
pub(crate) fn close(&self) {
let mut metadata = self.metadata.write().unwrap();
metadata.version = 0;
wake(mem::take(&mut metadata.wakers));
}
fn incr_version_and_wake(&mut self) {
let metadata = self.metadata.get_mut().unwrap();
metadata.version += 1;
wake(metadata.wakers.drain(..));
}
}
fn hash<T: Hash>(value: &T) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
value.hash(&mut hasher);
hasher.finish()
}
fn wake<I>(wakers: I)
where
I: IntoIterator<Item = Waker>,
I::IntoIter: ExactSizeIterator,
{
let iter = wakers.into_iter();
#[cfg(feature = "tracing")]
{
let num_wakers = iter.len();
if num_wakers > 0 {
tracing::debug!("Waking up {num_wakers} waiting subscribers");
} else {
tracing::debug!("No wakers");
}
}
for waker in iter {
waker.wake();
}
}