Skip to content

Instantly share code, notes, and snippets.

@vmolsa
Created July 14, 2024 13:36
Show Gist options
  • Select an option

  • Save vmolsa/9b4ada38fd45aa5e86786082c79bb36e to your computer and use it in GitHub Desktop.

Select an option

Save vmolsa/9b4ada38fd45aa5e86786082c79bb36e to your computer and use it in GitHub Desktop.
Unified Mutex
// LICENSE: Apache-2.0
use std::{
future::poll_fn,
task::{Context, Poll},
ops::{Deref, DerefMut},
};
#[cfg(not(feature = "no_mutex"))]
use std::{
sync::{Mutex, TryLockError},
task::Waker,
};
#[cfg(feature = "no_mutex")]
use crossbeam::atomic::AtomicCell;
#[cfg(not(feature = "no_mutex"))]
use crossbeam::queue::SegQueue;
/// A unified mutex that supports blocking, non-blocking (poll), and asynchronous locking.
/// Can act as a mutable reference without mutex if the `no_mutex` feature is enabled.
///
/// The `Um` struct provides a wrapper around a `std::sync::Mutex` along with a queue of wakers
/// to handle asynchronous wake-ups efficiently using a lock-free `crossbeam::queue::SegQueue`.
///
/// # Examples
///
/// ```
/// use um::Um;
/// use std::sync::Arc;
/// use futures::executor::block_on;
///
/// let um = Arc::new(Um::new(42));
///
/// // Blocking lock
/// {
/// let guard = um.lock("test");
/// assert_eq!(*guard, 42);
/// }
///
/// // Non-blocking lock
/// {
/// let waker = futures::task::noop_waker();
/// let mut cx = std::task::Context::from_waker(&waker);
///
/// if let std::task::Poll::Ready(guard) = um.poll_lock(&mut cx, "test") {
/// assert_eq!(*guard, 42);
/// }
/// }
///
/// // Asynchronous lock
/// {
/// block_on(async {
/// let guard = um.wait_lock("test").await;
/// assert_eq!(*guard, 42);
/// });
/// }
/// ```
pub struct Um<T> {
#[cfg(not(feature = "no_mutex"))]
mutex: Mutex<T>,
#[cfg(not(feature = "no_mutex"))]
wakers: SegQueue<Waker>,
#[cfg(feature = "no_mutex")]
value: AtomicCell<T>,
}
impl<T> Um<T> {
/// Creates a new `Um` instance containing the given value.
///
/// # Arguments
///
/// * `value` - The initial value to be stored in the mutex.
///
/// # Examples
///
/// ```
/// let um = um::Um::new(42);
/// ```
pub fn new(value: T) -> Self {
Self {
#[cfg(not(feature = "no_mutex"))]
mutex: Mutex::new(value),
#[cfg(not(feature = "no_mutex"))]
wakers: SegQueue::new(),
#[cfg(feature = "no_mutex")]
value: AtomicCell::new(value),
}
}
/// Acquires a blocking lock on the mutex.
///
/// This method will block the current thread until the lock can be acquired.
///
/// # Arguments
///
/// * `_reason` - A static string providing a reason for acquiring the lock (used for logging).
///
/// # Panics
///
/// Panics if the mutex is poisoned and poisoning is not detected in debug mode.
///
/// # Examples
///
/// ```
/// let um = um::Um::new(42);
/// let guard = um.lock("test");
/// assert_eq!(*guard, 42);
/// ```
#[cfg(not(feature = "no_mutex"))]
pub fn lock(&self, _reason: &'static str) -> Guard<T> {
#[cfg(debug_assertions)]
match self.mutex.lock() {
Ok(guard) => Guard {
guard,
um: self,
},
Err(poisoned) => {
#[cfg(debug_assertions)]
log::error!("Lock was poisoned for {}", _reason);
Guard {
guard: poisoned.into_inner(),
um: self,
}
}
}
#[cfg(not(debug_assertions))]
Guard {
guard: self.mutex.lock().unwrap(),
um: self,
}
}
#[cfg(feature = "no_mutex")]
pub fn lock(&self, _reason: &'static str) -> Guard<T> {
Guard {
um: self,
}
}
/// Tries to acquire a non-blocking lock on the mutex.
///
/// This method attempts to acquire the lock without blocking. If the lock cannot be acquired
/// immediately, it registers the current task's waker and returns `Poll::Pending`.
///
/// # Arguments
///
/// * `cx` - The current task's context.
/// * `_reason` - A static string providing a reason for acquiring the lock (used for logging).
///
/// # Returns
///
/// Returns `Poll::Ready` with a `Guard` if the lock is acquired, or `Poll::Pending` if
/// the lock is not available.
///
/// # Examples
///
/// ```
/// let um = um::Um::new(42);
/// let waker = futures::task::noop_waker();
/// let mut cx = std::task::Context::from_waker(&waker);
/// if let std::task::Poll::Ready(guard) = um.poll_lock(&mut cx, "test") {
/// assert_eq!(*guard, 42);
/// };
/// ```
#[cfg(not(feature = "no_mutex"))]
pub fn poll_lock(&self, cx: &mut Context<'_>, _reason: &'static str) -> Poll<Guard<T>> {
match self.mutex.try_lock() {
Ok(guard) => Poll::Ready(Guard {
guard,
um: self,
}),
Err(error) => match error {
TryLockError::Poisoned(poisoned) => {
#[cfg(debug_assertions)]
log::error!("Lock was poisoned for {}", _reason);
Poll::Ready(Guard {
guard: poisoned.into_inner(),
um: self,
})
}
TryLockError::WouldBlock => {
self.wakers.push(cx.waker().clone());
Poll::Pending
}
},
}
}
#[cfg(feature = "no_mutex")]
pub fn poll_lock(&self, _cx: &mut Context<'_>, _reason: &'static str) -> Poll<Guard<T>> {
Poll::Ready(Guard {
um: self,
})
}
/// Acquires an asynchronous lock on the mutex.
///
/// This method returns a future that resolves to a `Guard` once the lock is acquired.
///
/// # Arguments
///
/// * `_reason` - A static string providing a reason for acquiring the lock (used for logging).
///
/// # Examples
///
/// ```
/// use std::sync::Arc;
///
/// let um = Arc::new(um::Um::new(42));
///
/// futures::executor::block_on(async {
/// let guard = um.wait_lock("test").await;
/// assert_eq!(*guard, 42);
/// });
/// ```
pub async fn wait_lock(&self, _reason: &'static str) -> Guard<T> {
poll_fn(|cx| self.poll_lock(cx, _reason)).await
}
}
/// A guard that releases the lock when dropped.
///
/// The `Guard` struct wraps a `std::sync::Guard` and ensures that the next waker
/// in the queue is notified when the guard is dropped.
#[cfg(not(feature = "no_mutex"))]
pub struct Guard<'a, T> {
guard: std::sync::MutexGuard<'a, T>,
um: &'a Um<T>,
}
#[cfg(not(feature = "no_mutex"))]
impl<'a, T> Deref for Guard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&*self.guard
}
}
#[cfg(not(feature = "no_mutex"))]
impl<'a, T> DerefMut for Guard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut *self.guard
}
}
#[cfg(not(feature = "no_mutex"))]
impl<'a, T> Drop for Guard<'a, T> {
fn drop(&mut self) {
if let Some(waker) = self.um.wakers.pop() {
waker.wake();
}
}
}
#[cfg(feature = "no_mutex")]
pub struct Guard<'a, T> {
um: &'a Um<T>,
}
#[cfg(feature = "no_mutex")]
impl<'a, T> Deref for Guard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*(self.um.value.as_ptr() as *const T) }
}
}
#[cfg(feature = "no_mutex")]
impl<'a, T> DerefMut for Guard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *(self.um.value.as_ptr() as *mut T) }
}
}
#[cfg(feature = "no_mutex")]
impl<'a, T> Drop for Guard<'a, T> {
fn drop(&mut self) {
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment