diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index c5425dd19b..91a3aff05e 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -282,6 +282,7 @@ impl EstablishParams { statements: Statements::new(self.statement_cache_capacity), transaction_depth: 0, log_settings: self.log_settings.clone(), + progress_handler_callback: None, }) } } diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 903353be14..69879768ec 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -1,12 +1,14 @@ use futures_core::future::BoxFuture; use futures_intrusive::sync::MutexGuard; use futures_util::future; -use libsqlite3_sys::sqlite3; +use libsqlite3_sys::{sqlite3, sqlite3_progress_handler}; use sqlx_core::common::StatementCache; use sqlx_core::error::Error; use sqlx_core::transaction::Transaction; use std::cmp::Ordering; use std::fmt::{self, Debug, Formatter}; +use std::os::raw::{c_int, c_void}; +use std::panic::catch_unwind; use std::ptr::NonNull; use crate::connection::establish::EstablishParams; @@ -51,6 +53,10 @@ pub struct LockedSqliteHandle<'a> { pub(crate) guard: MutexGuard<'a, ConnectionState>, } +/// Represents a callback handler that will be shared with the underlying sqlite3 connection. +pub(crate) struct Handler(NonNull bool + Send + 'static>); +unsafe impl Send for Handler {} + pub(crate) struct ConnectionState { pub(crate) handle: ConnectionHandle, @@ -60,6 +66,22 @@ pub(crate) struct ConnectionState { pub(crate) statements: Statements, log_settings: LogSettings, + + /// Stores the progress handler set on the current connection. If the handler returns `false`, + /// the query is interrupted. + progress_handler_callback: Option, +} + +impl ConnectionState { + /// Drops the `progress_handler_callback` if it exists. + pub(crate) fn remove_progress_handler(&mut self) { + if let Some(mut handler) = self.progress_handler_callback.take() { + unsafe { + sqlite3_progress_handler(self.handle.as_ptr(), 0, None, 0 as *mut _); + let _ = { Box::from_raw(handler.0.as_mut()) }; + } + } + } } pub(crate) struct Statements { @@ -177,6 +199,21 @@ impl Connection for SqliteConnection { } } +/// Implements a C binding to a progress callback. The function returns `0` if the +/// user-provided callback returns `true`, and `1` otherwise to signal an interrupt. +extern "C" fn progress_callback(callback: *mut c_void) -> c_int +where + F: FnMut() -> bool, +{ + unsafe { + let r = catch_unwind(|| { + let callback: *mut F = callback.cast::(); + (*callback)() + }); + c_int::from(!r.unwrap_or_default()) + } +} + impl LockedSqliteHandle<'_> { /// Returns the underlying sqlite3* connection handle. /// @@ -206,12 +243,52 @@ impl LockedSqliteHandle<'_> { ) -> Result<(), Error> { collation::create_collation(&mut self.guard.handle, name, compare) } + + /// Sets a progress handler that is invoked periodically during long running calls. If the progress callback + /// returns `false`, then the operation is interrupted. + /// + /// `num_ops` is the approximate number of [virtual machine instructions](https://www.sqlite.org/opcode.html) + /// that are evaluated between successive invocations of the callback. If `num_ops` is less than one then the + /// progress handler is disabled. + /// + /// Only a single progress handler may be defined at one time per database connection; setting a new progress + /// handler cancels the old one. + /// + /// The progress handler callback must not do anything that will modify the database connection that invoked + /// the progress handler. Note that sqlite3_prepare_v2() and sqlite3_step() both modify their database connections + /// in this context. + pub fn set_progress_handler(&mut self, num_ops: i32, mut callback: F) + where + F: FnMut() -> bool + Send + 'static, + { + unsafe { + let callback_boxed = Box::new(callback); + // SAFETY: `Box::into_raw()` always returns a non-null pointer. + let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed)); + let handler = callback.as_ptr() as *mut _; + self.guard.remove_progress_handler(); + self.guard.progress_handler_callback = Some(Handler(callback)); + + sqlite3_progress_handler( + self.as_raw_handle().as_mut(), + num_ops, + Some(progress_callback::), + handler, + ); + } + } + + /// Removes the progress handler on a database connection. The method does nothing if no handler was set. + pub fn remove_progress_handler(&mut self) { + self.guard.remove_progress_handler(); + } } impl Drop for ConnectionState { fn drop(&mut self) { // explicitly drop statements before the connection handle is dropped self.statements.clear(); + self.remove_progress_handler(); } } diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 03f2013eb4..0c79bec1f3 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -7,6 +7,7 @@ use sqlx::{ SqliteConnection, SqlitePool, Statement, TypeInfo, }; use sqlx_test::new; +use std::sync::Arc; #[sqlx_macros::test] async fn it_connects() -> anyhow::Result<()> { @@ -725,3 +726,71 @@ async fn concurrent_read_and_write() { read.await; write.await; } + +#[sqlx_macros::test] +async fn test_query_with_progress_handler() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_progress_handler(1, move || { + assert_eq!(state, "test"); + false + }); + + match sqlx::query("SELECT 'hello' AS title") + .fetch_all(&mut conn) + .await + { + Err(sqlx::Error::Database(err)) => assert_eq!(err.message(), String::from("interrupted")), + _ => panic!("expected an interrupt"), + } + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow::Result<()> { + let ref_counted_object = Arc::new(0); + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + + { + let mut conn = new::().await?; + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_progress_handler(1, move || { + println!("{:?}", o); + false + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_progress_handler(1, move || { + println!("{:?}", o); + false + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_progress_handler(1, move || { + println!("{:?}", o); + false + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + match sqlx::query("SELECT 'hello' AS title") + .fetch_all(&mut conn) + .await + { + Err(sqlx::Error::Database(err)) => { + assert_eq!(err.message(), String::from("interrupted")) + } + _ => panic!("expected an interrupt"), + } + + conn.lock_handle().await?.remove_progress_handler(); + } + + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + Ok(()) +}