Skip to content

Instantly share code, notes, and snippets.

@MarinPostma
Created September 5, 2024 10:05
Show Gist options
  • Select an option

  • Save MarinPostma/268743d5180048cd4ef3b0c4ae7c1867 to your computer and use it in GitHub Desktop.

Select an option

Save MarinPostma/268743d5180048cd4ef3b0c4ae7c1867 to your computer and use it in GitHub Desktop.
tokio sqlite
use std::{future::Future, io, task::{Poll, Waker}};
use std::cell::RefCell;
use corosensei::{stack::DefaultStack, ScopedCoroutine, Yielder};
use rand::Rng;
use rusqlite::OpenFlags;
use sqlite_vfs::{register, DatabaseHandle, LockKind, Vfs, WalDisabled};
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt as _};
thread_local! {
static CTX: RefCell<Option<Ctx>> = RefCell::new(None);
}
struct Ctx {
waker: Waker,
yielder: *const Yielder<Waker, ()>,
}
fn run_fut<F>(f: F) -> F::Output
where F: Future,
{
tokio::pin!(f);
loop {
let waker = unsafe { &*CTX.with_borrow(|x| &(x.as_ref().unwrap().waker) as *const _) };
let mut cx = std::task::Context::from_waker(waker);
match f.as_mut().poll(&mut cx) {
Poll::Pending => {
let yielder = unsafe { &*CTX.with_borrow(|x| x.as_ref().unwrap().yielder) };
let waker = yielder.suspend(());
CTX.with_borrow_mut(|c| c.as_mut().unwrap().waker = waker);
}
Poll::Ready(x) => {
return x
}
}
}
}
struct CoroVfs;
struct CoroFile {
file: tokio::fs::File,
lock: LockKind,
}
impl DatabaseHandle for CoroFile {
type WalIndex = WalDisabled;
fn size(&self) -> Result<u64, std::io::Error> {
let fut = async {
Ok(self.file.metadata().await?.len())
};
run_fut(fut)
}
fn read_exact_at(&mut self, buf: &mut [u8], offset: u64) -> Result<(), std::io::Error> {
let fut = async {
self.file.seek(io::SeekFrom::Start(offset as _)).await?;
self.file.read_exact(buf).await?;
Ok(())
};
run_fut(fut)
}
fn write_all_at(&mut self, buf: &[u8], offset: u64) -> Result<(), std::io::Error> {
let fut = async {
self.file.seek(io::SeekFrom::Start(offset as _)).await?;
self.file.write_all(buf).await?;
Ok(())
};
run_fut(fut)
}
fn sync(&mut self, _data_only: bool) -> Result<(), std::io::Error> {
let fut = async {
self.file.sync_all().await
};
run_fut(fut)
}
fn set_len(&mut self, size: u64) -> Result<(), std::io::Error> {
let fut = async {
self.file.set_len(size).await
};
run_fut(fut)
}
fn lock(&mut self, lock: sqlite_vfs::LockKind) -> Result<bool, std::io::Error> {
self.lock = lock;
Ok(true)
}
fn reserved(&mut self) -> Result<bool, std::io::Error> {
Ok(matches!(self.lock, LockKind::Reserved | LockKind::Exclusive | LockKind::Pending))
}
fn current_lock(&self) -> Result<sqlite_vfs::LockKind, std::io::Error> {
Ok(self.lock)
}
fn wal_index(&self, _readonly: bool) -> Result<Self::WalIndex, std::io::Error> {
Ok(WalDisabled::default())
}
}
impl Vfs for CoroVfs {
type Handle = CoroFile;
fn open(&self, db: &str, opts: sqlite_vfs::OpenOptions) -> Result<Self::Handle, std::io::Error> {
let fut = async {
let file = tokio::fs::File::options().read(true).write(true).create(true).open(db).await?;
Ok(CoroFile {
file, lock: LockKind::None
})
};
run_fut(fut)
}
fn delete(&self, db: &str) -> Result<(), std::io::Error> {
let fut = async {
tokio::fs::remove_file(db).await?;
Ok(())
};
run_fut(fut)
}
fn exists(&self, db: &str) -> Result<bool, std::io::Error> {
let fut = async {
tokio::fs::try_exists(db).await
};
run_fut(fut)
}
fn temporary_name(&self) -> String {
uuid::Uuid::new_v4().to_string()
}
fn random(&self, buffer: &mut [i8]) {
rand::thread_rng().fill(buffer)
}
fn sleep(&self, duration: std::time::Duration) -> std::time::Duration {
let fut = async {
tokio::time::sleep(duration).await;
duration
};
run_fut(fut)
}
}
struct CoroTask<'a, R> {
cr: ScopedCoroutine<'a, Waker, (), R, DefaultStack>,
ctx: Option<Ctx>
}
impl<'a, R> CoroTask<'a, R> {
fn new(f: impl FnOnce() -> R + 'a) -> Self {
Self {
cr: ScopedCoroutine::new(|y, w| {
CTX.with(|ctx| {
ctx.borrow_mut().replace(Ctx { waker: w, yielder: y as *const _ });
});
f()
}),
ctx: None,
}
}
}
impl<'a, R> Future for CoroTask<'a, R> {
type Output = R;
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
// install ctx
if let Some(ctx) = self.ctx.take() {
CTX.with(|c| {
c.borrow_mut().replace(ctx);
});
}
let ret = match self.cr.resume(cx.waker().clone()) {
corosensei::CoroutineResult::Yield(()) => Poll::Pending,
corosensei::CoroutineResult::Return(r) => Poll::Ready(r),
};
CTX.with(|c| {
self.ctx = c.borrow_mut().take();
});
ret
}
}
macro_rules! a {
($($tts:tt)*) => {
CoroTask::new(|| {
$($tts)*
})
};
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
register("corovfs", CoroVfs, true).unwrap();
let conn = a!(rusqlite::Connection::open_with_flags_and_vfs("testdb", OpenFlags::SQLITE_OPEN_CREATE | OpenFlags::SQLITE_OPEN_READ_WRITE, "corovfs")).await.unwrap();
a!(conn.execute("create table if not exists test (x)", ())).await.unwrap();
a!(conn.execute("insert into test values (1234)", ())).await.unwrap();
a!(conn.execute("insert into test values (1234)", ())).await.unwrap();
a!(conn.execute("insert into test values (1234)", ())).await.unwrap();
let mut stmt = a!(conn.prepare("select * from test")).await.unwrap();
let mut rows = a!(stmt.query(())).await.unwrap();
while let Some(r) = a!(rows.next()).await.unwrap() {
dbg!(r);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment