Skip to content

Commit

Permalink
add progress handler support to sqlite (#2256)
Browse files Browse the repository at this point in the history
* rebase main

* fmt

* use NonNull to fix UB

* apply code suggestions

* add test for multiple handler drops

* remove nightly features for test
  • Loading branch information
nbaztec authored Mar 24, 2023
1 parent 14d70fe commit 4f1ac1d
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 1 deletion.
1 change: 1 addition & 0 deletions sqlx-sqlite/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ impl EstablishParams {
statements: Statements::new(self.statement_cache_capacity),
transaction_depth: 0,
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
})
}
}
79 changes: 78 additions & 1 deletion sqlx-sqlite/src/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use futures_core::future::BoxFuture;
use futures_intrusive::sync::MutexGuard;
use futures_util::future;
use libsqlite3_sys::sqlite3;
use libsqlite3_sys::{sqlite3, sqlite3_progress_handler};
use sqlx_core::common::StatementCache;
use sqlx_core::error::Error;
use sqlx_core::transaction::Transaction;
use std::cmp::Ordering;
use std::fmt::{self, Debug, Formatter};
use std::os::raw::{c_int, c_void};
use std::panic::catch_unwind;
use std::ptr::NonNull;

use crate::connection::establish::EstablishParams;
Expand Down Expand Up @@ -51,6 +53,10 @@ pub struct LockedSqliteHandle<'a> {
pub(crate) guard: MutexGuard<'a, ConnectionState>,
}

/// Represents a callback handler that will be shared with the underlying sqlite3 connection.
pub(crate) struct Handler(NonNull<dyn FnMut() -> bool + Send + 'static>);
unsafe impl Send for Handler {}

pub(crate) struct ConnectionState {
pub(crate) handle: ConnectionHandle,

Expand All @@ -60,6 +66,22 @@ pub(crate) struct ConnectionState {
pub(crate) statements: Statements,

log_settings: LogSettings,

/// Stores the progress handler set on the current connection. If the handler returns `false`,
/// the query is interrupted.
progress_handler_callback: Option<Handler>,
}

impl ConnectionState {
/// Drops the `progress_handler_callback` if it exists.
pub(crate) fn remove_progress_handler(&mut self) {
if let Some(mut handler) = self.progress_handler_callback.take() {
unsafe {
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, 0 as *mut _);
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
}
}

pub(crate) struct Statements {
Expand Down Expand Up @@ -177,6 +199,21 @@ impl Connection for SqliteConnection {
}
}

/// Implements a C binding to a progress callback. The function returns `0` if the
/// user-provided callback returns `true`, and `1` otherwise to signal an interrupt.
extern "C" fn progress_callback<F>(callback: *mut c_void) -> c_int
where
F: FnMut() -> bool,
{
unsafe {
let r = catch_unwind(|| {
let callback: *mut F = callback.cast::<F>();
(*callback)()
});
c_int::from(!r.unwrap_or_default())
}
}

impl LockedSqliteHandle<'_> {
/// Returns the underlying sqlite3* connection handle.
///
Expand Down Expand Up @@ -206,12 +243,52 @@ impl LockedSqliteHandle<'_> {
) -> Result<(), Error> {
collation::create_collation(&mut self.guard.handle, name, compare)
}

/// Sets a progress handler that is invoked periodically during long running calls. If the progress callback
/// returns `false`, then the operation is interrupted.
///
/// `num_ops` is the approximate number of [virtual machine instructions](https://www.sqlite.org/opcode.html)
/// that are evaluated between successive invocations of the callback. If `num_ops` is less than one then the
/// progress handler is disabled.
///
/// Only a single progress handler may be defined at one time per database connection; setting a new progress
/// handler cancels the old one.
///
/// The progress handler callback must not do anything that will modify the database connection that invoked
/// the progress handler. Note that sqlite3_prepare_v2() and sqlite3_step() both modify their database connections
/// in this context.
pub fn set_progress_handler<F>(&mut self, num_ops: i32, mut callback: F)
where
F: FnMut() -> bool + Send + 'static,
{
unsafe {
let callback_boxed = Box::new(callback);
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
let handler = callback.as_ptr() as *mut _;
self.guard.remove_progress_handler();
self.guard.progress_handler_callback = Some(Handler(callback));

sqlite3_progress_handler(
self.as_raw_handle().as_mut(),
num_ops,
Some(progress_callback::<F>),
handler,
);
}
}

/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
pub fn remove_progress_handler(&mut self) {
self.guard.remove_progress_handler();
}
}

impl Drop for ConnectionState {
fn drop(&mut self) {
// explicitly drop statements before the connection handle is dropped
self.statements.clear();
self.remove_progress_handler();
}
}

Expand Down
69 changes: 69 additions & 0 deletions tests/sqlite/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use sqlx::{
SqliteConnection, SqlitePool, Statement, TypeInfo,
};
use sqlx_test::new;
use std::sync::Arc;

#[sqlx_macros::test]
async fn it_connects() -> anyhow::Result<()> {
Expand Down Expand Up @@ -725,3 +726,71 @@ async fn concurrent_read_and_write() {
read.await;
write.await;
}

#[sqlx_macros::test]
async fn test_query_with_progress_handler() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;

// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
let state = format!("test");
conn.lock_handle().await?.set_progress_handler(1, move || {
assert_eq!(state, "test");
false
});

match sqlx::query("SELECT 'hello' AS title")
.fetch_all(&mut conn)
.await
{
Err(sqlx::Error::Database(err)) => assert_eq!(err.message(), String::from("interrupted")),
_ => panic!("expected an interrupt"),
}

Ok(())
}

#[sqlx_macros::test]
async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow::Result<()> {
let ref_counted_object = Arc::new(0);
assert_eq!(1, Arc::strong_count(&ref_counted_object));

{
let mut conn = new::<Sqlite>().await?;

let o = ref_counted_object.clone();
conn.lock_handle().await?.set_progress_handler(1, move || {
println!("{:?}", o);
false
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

let o = ref_counted_object.clone();
conn.lock_handle().await?.set_progress_handler(1, move || {
println!("{:?}", o);
false
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

let o = ref_counted_object.clone();
conn.lock_handle().await?.set_progress_handler(1, move || {
println!("{:?}", o);
false
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

match sqlx::query("SELECT 'hello' AS title")
.fetch_all(&mut conn)
.await
{
Err(sqlx::Error::Database(err)) => {
assert_eq!(err.message(), String::from("interrupted"))
}
_ => panic!("expected an interrupt"),
}

conn.lock_handle().await?.remove_progress_handler();
}

assert_eq!(1, Arc::strong_count(&ref_counted_object));
Ok(())
}

0 comments on commit 4f1ac1d

Please sign in to comment.