Created
September 5, 2024 10:05
-
-
Save MarinPostma/268743d5180048cd4ef3b0c4ae7c1867 to your computer and use it in GitHub Desktop.
tokio sqlite
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use std::{future::Future, io, task::{Poll, Waker}}; | |
| use std::cell::RefCell; | |
| use corosensei::{stack::DefaultStack, ScopedCoroutine, Yielder}; | |
| use rand::Rng; | |
| use rusqlite::OpenFlags; | |
| use sqlite_vfs::{register, DatabaseHandle, LockKind, Vfs, WalDisabled}; | |
| use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt as _}; | |
| thread_local! { | |
| static CTX: RefCell<Option<Ctx>> = RefCell::new(None); | |
| } | |
| struct Ctx { | |
| waker: Waker, | |
| yielder: *const Yielder<Waker, ()>, | |
| } | |
| fn run_fut<F>(f: F) -> F::Output | |
| where F: Future, | |
| { | |
| tokio::pin!(f); | |
| loop { | |
| let waker = unsafe { &*CTX.with_borrow(|x| &(x.as_ref().unwrap().waker) as *const _) }; | |
| let mut cx = std::task::Context::from_waker(waker); | |
| match f.as_mut().poll(&mut cx) { | |
| Poll::Pending => { | |
| let yielder = unsafe { &*CTX.with_borrow(|x| x.as_ref().unwrap().yielder) }; | |
| let waker = yielder.suspend(()); | |
| CTX.with_borrow_mut(|c| c.as_mut().unwrap().waker = waker); | |
| } | |
| Poll::Ready(x) => { | |
| return x | |
| } | |
| } | |
| } | |
| } | |
| struct CoroVfs; | |
| struct CoroFile { | |
| file: tokio::fs::File, | |
| lock: LockKind, | |
| } | |
| impl DatabaseHandle for CoroFile { | |
| type WalIndex = WalDisabled; | |
| fn size(&self) -> Result<u64, std::io::Error> { | |
| let fut = async { | |
| Ok(self.file.metadata().await?.len()) | |
| }; | |
| run_fut(fut) | |
| } | |
| fn read_exact_at(&mut self, buf: &mut [u8], offset: u64) -> Result<(), std::io::Error> { | |
| let fut = async { | |
| self.file.seek(io::SeekFrom::Start(offset as _)).await?; | |
| self.file.read_exact(buf).await?; | |
| Ok(()) | |
| }; | |
| run_fut(fut) | |
| } | |
| fn write_all_at(&mut self, buf: &[u8], offset: u64) -> Result<(), std::io::Error> { | |
| let fut = async { | |
| self.file.seek(io::SeekFrom::Start(offset as _)).await?; | |
| self.file.write_all(buf).await?; | |
| Ok(()) | |
| }; | |
| run_fut(fut) | |
| } | |
| fn sync(&mut self, _data_only: bool) -> Result<(), std::io::Error> { | |
| let fut = async { | |
| self.file.sync_all().await | |
| }; | |
| run_fut(fut) | |
| } | |
| fn set_len(&mut self, size: u64) -> Result<(), std::io::Error> { | |
| let fut = async { | |
| self.file.set_len(size).await | |
| }; | |
| run_fut(fut) | |
| } | |
| fn lock(&mut self, lock: sqlite_vfs::LockKind) -> Result<bool, std::io::Error> { | |
| self.lock = lock; | |
| Ok(true) | |
| } | |
| fn reserved(&mut self) -> Result<bool, std::io::Error> { | |
| Ok(matches!(self.lock, LockKind::Reserved | LockKind::Exclusive | LockKind::Pending)) | |
| } | |
| fn current_lock(&self) -> Result<sqlite_vfs::LockKind, std::io::Error> { | |
| Ok(self.lock) | |
| } | |
| fn wal_index(&self, _readonly: bool) -> Result<Self::WalIndex, std::io::Error> { | |
| Ok(WalDisabled::default()) | |
| } | |
| } | |
| impl Vfs for CoroVfs { | |
| type Handle = CoroFile; | |
| fn open(&self, db: &str, opts: sqlite_vfs::OpenOptions) -> Result<Self::Handle, std::io::Error> { | |
| let fut = async { | |
| let file = tokio::fs::File::options().read(true).write(true).create(true).open(db).await?; | |
| Ok(CoroFile { | |
| file, lock: LockKind::None | |
| }) | |
| }; | |
| run_fut(fut) | |
| } | |
| fn delete(&self, db: &str) -> Result<(), std::io::Error> { | |
| let fut = async { | |
| tokio::fs::remove_file(db).await?; | |
| Ok(()) | |
| }; | |
| run_fut(fut) | |
| } | |
| fn exists(&self, db: &str) -> Result<bool, std::io::Error> { | |
| let fut = async { | |
| tokio::fs::try_exists(db).await | |
| }; | |
| run_fut(fut) | |
| } | |
| fn temporary_name(&self) -> String { | |
| uuid::Uuid::new_v4().to_string() | |
| } | |
| fn random(&self, buffer: &mut [i8]) { | |
| rand::thread_rng().fill(buffer) | |
| } | |
| fn sleep(&self, duration: std::time::Duration) -> std::time::Duration { | |
| let fut = async { | |
| tokio::time::sleep(duration).await; | |
| duration | |
| }; | |
| run_fut(fut) | |
| } | |
| } | |
| struct CoroTask<'a, R> { | |
| cr: ScopedCoroutine<'a, Waker, (), R, DefaultStack>, | |
| ctx: Option<Ctx> | |
| } | |
| impl<'a, R> CoroTask<'a, R> { | |
| fn new(f: impl FnOnce() -> R + 'a) -> Self { | |
| Self { | |
| cr: ScopedCoroutine::new(|y, w| { | |
| CTX.with(|ctx| { | |
| ctx.borrow_mut().replace(Ctx { waker: w, yielder: y as *const _ }); | |
| }); | |
| f() | |
| }), | |
| ctx: None, | |
| } | |
| } | |
| } | |
| impl<'a, R> Future for CoroTask<'a, R> { | |
| type Output = R; | |
| fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> { | |
| // install ctx | |
| if let Some(ctx) = self.ctx.take() { | |
| CTX.with(|c| { | |
| c.borrow_mut().replace(ctx); | |
| }); | |
| } | |
| let ret = match self.cr.resume(cx.waker().clone()) { | |
| corosensei::CoroutineResult::Yield(()) => Poll::Pending, | |
| corosensei::CoroutineResult::Return(r) => Poll::Ready(r), | |
| }; | |
| CTX.with(|c| { | |
| self.ctx = c.borrow_mut().take(); | |
| }); | |
| ret | |
| } | |
| } | |
| macro_rules! a { | |
| ($($tts:tt)*) => { | |
| CoroTask::new(|| { | |
| $($tts)* | |
| }) | |
| }; | |
| } | |
| #[tokio::main(flavor = "current_thread")] | |
| async fn main() { | |
| register("corovfs", CoroVfs, true).unwrap(); | |
| let conn = a!(rusqlite::Connection::open_with_flags_and_vfs("testdb", OpenFlags::SQLITE_OPEN_CREATE | OpenFlags::SQLITE_OPEN_READ_WRITE, "corovfs")).await.unwrap(); | |
| a!(conn.execute("create table if not exists test (x)", ())).await.unwrap(); | |
| a!(conn.execute("insert into test values (1234)", ())).await.unwrap(); | |
| a!(conn.execute("insert into test values (1234)", ())).await.unwrap(); | |
| a!(conn.execute("insert into test values (1234)", ())).await.unwrap(); | |
| let mut stmt = a!(conn.prepare("select * from test")).await.unwrap(); | |
| let mut rows = a!(stmt.query(())).await.unwrap(); | |
| while let Some(r) = a!(rows.next()).await.unwrap() { | |
| dbg!(r); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment