Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sqlite: fix a couple segfaults #1351

Merged
merged 6 commits into from
Aug 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlx-core/src/sqlite/connection/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>(
// fallback to [column_decltype]
if !stepped && stmt.read_only() {
stepped = true;
let _ = conn.worker.step(*stmt).await;
let _ = conn.worker.step(stmt).await;
}

let mut ty = stmt.column_type_info(col);
Expand Down
6 changes: 3 additions & 3 deletions sqlx-core/src/sqlite/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result<SqliteCo
// https://www.sqlite.org/c3ref/extended_result_codes.html
unsafe {
// NOTE: ignore the failure here
sqlite3_extended_result_codes(handle.0.as_ptr(), 1);
sqlite3_extended_result_codes(handle.as_ptr(), 1);
}

// Configure a busy timeout
Expand All @@ -99,7 +99,7 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result<SqliteCo
let ms =
i32::try_from(busy_timeout.as_millis()).expect("Given busy timeout value is too big.");

status = unsafe { sqlite3_busy_timeout(handle.0.as_ptr(), ms) };
status = unsafe { sqlite3_busy_timeout(handle.as_ptr(), ms) };

if status != SQLITE_OK {
return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr()))));
Expand All @@ -109,8 +109,8 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result<SqliteCo
})?;

Ok(SqliteConnection {
worker: StatementWorker::new(handle.to_ref()),
handle,
worker: StatementWorker::new(),
statements: StatementCache::new(options.statement_cache_capacity),
statement: None,
transaction_depth: 0,
Expand Down
46 changes: 27 additions & 19 deletions sqlx-core/src/sqlite/connection/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::logger::QueryLogger;
use crate::sqlite::connection::describe::describe;
use crate::sqlite::statement::{StatementHandle, VirtualStatement};
use crate::sqlite::statement::{StatementHandle, StatementWorker, VirtualStatement};
use crate::sqlite::{
Sqlite, SqliteArguments, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteStatement,
SqliteTypeInfo,
Expand All @@ -16,7 +16,8 @@ use libsqlite3_sys::sqlite3_last_insert_rowid;
use std::borrow::Cow;
use std::sync::Arc;

fn prepare<'a>(
async fn prepare<'a>(
worker: &mut StatementWorker,
statements: &'a mut StatementCache<VirtualStatement>,
statement: &'a mut Option<VirtualStatement>,
query: &str,
Expand All @@ -39,7 +40,7 @@ fn prepare<'a>(
if exists {
// as this statement has been executed before, we reset before continuing
// this also causes any rows that are from the statement to be inflated
statement.reset();
statement.reset(worker).await?;
}

Ok(statement)
Expand All @@ -61,19 +62,25 @@ fn bind(

/// A structure holding sqlite statement handle and resetting the
/// statement when it is dropped.
struct StatementResetter {
handle: StatementHandle,
struct StatementResetter<'a> {
handle: Arc<StatementHandle>,
worker: &'a mut StatementWorker,
}

impl StatementResetter {
fn new(handle: StatementHandle) -> Self {
Self { handle }
impl<'a> StatementResetter<'a> {
fn new(worker: &'a mut StatementWorker, handle: &Arc<StatementHandle>) -> Self {
Self {
worker,
handle: Arc::clone(handle),
}
}
}

impl Drop for StatementResetter {
impl Drop for StatementResetter<'_> {
fn drop(&mut self) {
self.handle.reset();
// this method is designed to eagerly send the reset command
// so we don't need to await or spawn it
let _ = self.worker.reset(&self.handle);
}
}

Expand Down Expand Up @@ -103,7 +110,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
} = self;

// prepare statement object (or checkout from cache)
let stmt = prepare(statements, statement, sql, persistent)?;
let stmt = prepare(worker, statements, statement, sql, persistent).await?;

// keep track of how many arguments we have bound
let mut num_arguments = 0;
Expand All @@ -113,7 +120,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
// is dropped. `StatementResetter` will reliably reset the
// statement even if the stream returned from `fetch_many`
// is dropped early.
let _resetter = StatementResetter::new(*stmt);
let resetter = StatementResetter::new(worker, stmt);

// bind values to the statement
num_arguments += bind(stmt, &arguments, num_arguments)?;
Expand All @@ -125,7 +132,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {

// invoke [sqlite3_step] on the dedicated worker thread
// this will move us forward one row or finish the statement
let s = worker.step(*stmt).await?;
let s = resetter.worker.step(stmt).await?;

match s {
Either::Left(changes) => {
Expand All @@ -145,7 +152,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {

Either::Right(()) => {
let (row, weak_values_ref) = SqliteRow::current(
*stmt,
&stmt,
columns,
column_names
);
Expand Down Expand Up @@ -188,7 +195,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
} = self;

// prepare statement object (or checkout from cache)
let virtual_stmt = prepare(statements, statement, sql, persistent)?;
let virtual_stmt = prepare(worker, statements, statement, sql, persistent).await?;

// keep track of how many arguments we have bound
let mut num_arguments = 0;
Expand All @@ -205,18 +212,18 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {

// invoke [sqlite3_step] on the dedicated worker thread
// this will move us forward one row or finish the statement
match worker.step(*stmt).await? {
match worker.step(stmt).await? {
Either::Left(_) => (),

Either::Right(()) => {
let (row, weak_values_ref) =
SqliteRow::current(*stmt, columns, column_names);
SqliteRow::current(stmt, columns, column_names);

*last_row_values = Some(weak_values_ref);

logger.increment_rows();

virtual_stmt.reset();
virtual_stmt.reset(worker).await?;
return Ok(Some(row));
}
}
Expand All @@ -238,11 +245,12 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
handle: ref mut conn,
ref mut statements,
ref mut statement,
ref mut worker,
..
} = self;

// prepare statement object (or checkout from cache)
let statement = prepare(statements, statement, sql, true)?;
let statement = prepare(worker, statements, statement, sql, true).await?;

let mut parameters = 0;
let mut columns = None;
Expand Down
35 changes: 30 additions & 5 deletions sqlx-core/src/sqlite/connection/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@ use std::ptr::NonNull;
use libsqlite3_sys::{sqlite3, sqlite3_close, SQLITE_OK};

use crate::sqlite::SqliteError;
use std::sync::Arc;

/// Managed handle to the raw SQLite3 database handle.
/// The database handle will be closed when this is dropped.
/// The database handle will be closed when this is dropped and no `ConnectionHandleRef`s exist.
#[derive(Debug)]
pub(crate) struct ConnectionHandle(pub(super) NonNull<sqlite3>);
pub(crate) struct ConnectionHandle(Arc<HandleInner>);

/// A wrapper around `ConnectionHandle` which only exists for a `StatementWorker` to own
/// which prevents the `sqlite3` handle from being finalized while it is running `sqlite3_step()`
/// or `sqlite3_reset()`.
///
/// Note that this does *not* actually give access to the database handle!
pub(crate) struct ConnectionHandleRef(Arc<HandleInner>);

// Wrapper for `*mut sqlite3` which finalizes the handle on-drop.
#[derive(Debug)]
struct HandleInner(NonNull<sqlite3>);

// A SQLite3 handle is safe to send between threads, provided not more than
// one is accessing it at the same time. This is upheld as long as [SQLITE_CONFIG_MULTITHREAD] is
Expand All @@ -20,19 +32,32 @@ pub(crate) struct ConnectionHandle(pub(super) NonNull<sqlite3>);

unsafe impl Send for ConnectionHandle {}

// SAFETY: `Arc<T>` normally only implements `Send` where `T: Sync` because it allows
// concurrent access.
//
// However, in this case we're only using `Arc` to prevent the database handle from being
// finalized while the worker still holds a statement handle; `ConnectionHandleRef` thus
// should *not* actually provide access to the database handle.
unsafe impl Send for ConnectionHandleRef {}

impl ConnectionHandle {
#[inline]
pub(super) unsafe fn new(ptr: *mut sqlite3) -> Self {
Self(NonNull::new_unchecked(ptr))
Self(Arc::new(HandleInner(NonNull::new_unchecked(ptr))))
}

#[inline]
pub(crate) fn as_ptr(&self) -> *mut sqlite3 {
self.0.as_ptr()
self.0 .0.as_ptr()
}

#[inline]
pub(crate) fn to_ref(&self) -> ConnectionHandleRef {
ConnectionHandleRef(Arc::clone(&self.0))
}
}

impl Drop for ConnectionHandle {
impl Drop for HandleInner {
fn drop(&mut self) {
unsafe {
// https://sqlite.org/c3ref/close.html
Expand Down
2 changes: 1 addition & 1 deletion sqlx-core/src/sqlite/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ mod executor;
mod explain;
mod handle;

pub(crate) use handle::ConnectionHandle;
pub(crate) use handle::{ConnectionHandle, ConnectionHandleRef};

/// A connection to a [Sqlite] database.
pub struct SqliteConnection {
Expand Down
6 changes: 3 additions & 3 deletions sqlx-core/src/sqlite/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub struct SqliteRow {
// IF the user drops the Row before iterating the stream (so
// nearly all of our internal stream iterators), the executor moves on; otherwise,
// it actually inflates this row with a list of owned sqlite3 values.
pub(crate) statement: StatementHandle,
pub(crate) statement: Arc<StatementHandle>,

pub(crate) values: Arc<AtomicPtr<SqliteValue>>,
pub(crate) num_values: usize,
Expand All @@ -48,7 +48,7 @@ impl SqliteRow {
// returns a weak reference to an atomic list where the executor should inflate if its going
// to increment the statement with [step]
pub(crate) fn current(
statement: StatementHandle,
statement: &Arc<StatementHandle>,
columns: &Arc<Vec<SqliteColumn>>,
column_names: &Arc<HashMap<UStr, usize>>,
) -> (Self, Weak<AtomicPtr<SqliteValue>>) {
Expand All @@ -57,7 +57,7 @@ impl SqliteRow {
let size = statement.column_count();

let row = Self {
statement,
statement: Arc::clone(statement),
values,
num_values: size,
columns: Arc::clone(columns),
Expand Down
50 changes: 39 additions & 11 deletions sqlx-core/src/sqlite/statement/handle.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::ffi::c_void;
use std::ffi::CStr;

use std::os::raw::{c_char, c_int};
use std::ptr;
use std::ptr::NonNull;
Expand All @@ -9,21 +10,22 @@ use std::str::{from_utf8, from_utf8_unchecked};
use libsqlite3_sys::{
sqlite3, sqlite3_bind_blob64, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64,
sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name,
sqlite3_bind_text64, sqlite3_changes, sqlite3_column_blob, sqlite3_column_bytes,
sqlite3_column_count, sqlite3_column_database_name, sqlite3_column_decltype,
sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_name,
sqlite3_column_origin_name, sqlite3_column_table_name, sqlite3_column_type,
sqlite3_column_value, sqlite3_db_handle, sqlite3_reset, sqlite3_sql, sqlite3_stmt,
sqlite3_stmt_readonly, sqlite3_table_column_metadata, sqlite3_value, SQLITE_OK,
SQLITE_TRANSIENT, SQLITE_UTF8,
sqlite3_bind_text64, sqlite3_changes, sqlite3_clear_bindings, sqlite3_column_blob,
sqlite3_column_bytes, sqlite3_column_count, sqlite3_column_database_name,
sqlite3_column_decltype, sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64,
sqlite3_column_name, sqlite3_column_origin_name, sqlite3_column_table_name,
sqlite3_column_type, sqlite3_column_value, sqlite3_db_handle, sqlite3_finalize, sqlite3_reset,
sqlite3_sql, sqlite3_step, sqlite3_stmt, sqlite3_stmt_readonly, sqlite3_table_column_metadata,
sqlite3_value, SQLITE_DONE, SQLITE_MISUSE, SQLITE_OK, SQLITE_ROW, SQLITE_TRANSIENT,
SQLITE_UTF8,
};

use crate::error::{BoxDynError, Error};
use crate::sqlite::type_info::DataType;
use crate::sqlite::{SqliteError, SqliteTypeInfo};

#[derive(Debug, Copy, Clone)]
pub(crate) struct StatementHandle(pub(super) NonNull<sqlite3_stmt>);
#[derive(Debug)]
pub(crate) struct StatementHandle(NonNull<sqlite3_stmt>);

// access to SQLite3 statement handles are safe to send and share between threads
// as long as the `sqlite3_step` call is serialized.
Expand All @@ -32,6 +34,14 @@ unsafe impl Send for StatementHandle {}
unsafe impl Sync for StatementHandle {}

impl StatementHandle {
pub(super) fn new(ptr: NonNull<sqlite3_stmt>) -> Self {
Self(ptr)
}

pub(crate) fn as_ptr(&self) -> *mut sqlite3_stmt {
self.0.as_ptr()
}

#[inline]
pub(super) unsafe fn db_handle(&self) -> *mut sqlite3 {
// O(c) access to the connection handle for this statement handle
Expand Down Expand Up @@ -280,7 +290,25 @@ impl StatementHandle {
Ok(from_utf8(self.column_blob(index))?)
}

pub(crate) fn reset(&self) {
unsafe { sqlite3_reset(self.0.as_ptr()) };
pub(crate) fn clear_bindings(&self) {
unsafe { sqlite3_clear_bindings(self.0.as_ptr()) };
}
}
impl Drop for StatementHandle {
fn drop(&mut self) {
// SAFETY: we have exclusive access to the `StatementHandle` here
unsafe {
// https://sqlite.org/c3ref/finalize.html
let status = sqlite3_finalize(self.0.as_ptr());
if status == SQLITE_MISUSE {
// Panic in case of detected misuse of SQLite API.
//
// sqlite3_finalize returns it at least in the
// case of detected double free, i.e. calling
// sqlite3_finalize on already finalized
// statement.
panic!("Detected sqlite3_finalize misuse.");
}
}
}
}
Loading