Skip to content

Commit

Permalink
fix(mysql): Close prepared statement if persistence is disabled (#2905)
Browse files Browse the repository at this point in the history
* close prepared statement if persistence or statement cache are disabled

* add tests
  • Loading branch information
larsschumacher authored Jan 21, 2024
1 parent 31e541a commit 29dcd44
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 29 deletions.
90 changes: 61 additions & 29 deletions sqlx-mysql/src/connection/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,10 @@ use futures_util::{pin_mut, TryStreamExt};
use std::{borrow::Cow, sync::Arc};

impl MySqlConnection {
async fn get_or_prepare<'c>(
async fn prepare_statement<'c>(
&mut self,
sql: &str,
persistent: bool,
) -> Result<(u32, MySqlStatementMetadata), Error> {
if let Some(statement) = self.cache_statement.get_mut(sql) {
// <MySqlStatementMetadata> is internally reference-counted
return Ok((*statement).clone());
}

// https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK

Expand Down Expand Up @@ -72,11 +66,23 @@ impl MySqlConnection {
column_names: Arc::new(column_names),
};

if persistent && self.cache_statement.is_enabled() {
// in case of the cache being full, close the least recently used statement
if let Some((id, _)) = self.cache_statement.insert(sql, (id, metadata.clone())) {
self.stream.send_packet(StmtClose { statement: id }).await?;
}
Ok((id, metadata))
}

async fn get_or_prepare_statement<'c>(
&mut self,
sql: &str,
) -> Result<(u32, MySqlStatementMetadata), Error> {
if let Some(statement) = self.cache_statement.get_mut(sql) {
// <MySqlStatementMetadata> is internally reference-counted
return Ok((*statement).clone());
}

let (id, metadata) = self.prepare_statement(sql).await?;

// in case of the cache being full, close the least recently used statement
if let Some((id, _)) = self.cache_statement.insert(sql, (id, metadata.clone())) {
self.stream.send_packet(StmtClose { statement: id }).await?;
}

Ok((id, metadata))
Expand All @@ -102,21 +108,37 @@ impl MySqlConnection {
let mut columns = Arc::new(Vec::new());

let (mut column_names, format, mut needs_metadata) = if let Some(arguments) = arguments {
let (id, metadata) = self.get_or_prepare(
sql,
persistent,
)
.await?;

// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
self.stream
.send_packet(StatementExecute {
statement: id,
arguments: &arguments,
})
.await?;

(metadata.column_names, MySqlValueFormat::Binary, false)
if persistent && self.cache_statement.is_enabled() {
let (id, metadata) = self
.get_or_prepare_statement(sql)
.await?;

// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
self.stream
.send_packet(StatementExecute {
statement: id,
arguments: &arguments,
})
.await?;

(metadata.column_names, MySqlValueFormat::Binary, false)
} else {
let (id, metadata) = self
.prepare_statement(sql)
.await?;

// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
self.stream
.send_packet(StatementExecute {
statement: id,
arguments: &arguments,
})
.await?;

self.stream.send_packet(StmtClose { statement: id }).await?;

(metadata.column_names, MySqlValueFormat::Binary, false)
}
} else {
// https://dev.mysql.com/doc/internals/en/com-query.html
self.stream.send_packet(Query(sql)).await?;
Expand Down Expand Up @@ -269,7 +291,15 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection {
Box::pin(async move {
self.stream.wait_until_ready().await?;

let (_, metadata) = self.get_or_prepare(sql, true).await?;
let metadata = if self.cache_statement.is_enabled() {
self.get_or_prepare_statement(sql).await?.1
} else {
let (id, metadata) = self.prepare_statement(sql).await?;

self.stream.send_packet(StmtClose { statement: id }).await?;

metadata
};

Ok(MySqlStatement {
sql: Cow::Borrowed(sql),
Expand All @@ -287,7 +317,9 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection {
Box::pin(async move {
self.stream.wait_until_ready().await?;

let (_, metadata) = self.get_or_prepare(sql, false).await?;
let (id, metadata) = self.prepare_statement(sql).await?;

self.stream.send_packet(StmtClose { statement: id }).await?;

let columns = (&*metadata.columns).clone();

Expand Down
66 changes: 66 additions & 0 deletions tests/mysql/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,57 @@ async fn it_caches_statements() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn it_closes_statements_with_persistent_disabled() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;

let old_statement_count = select_statement_count(&mut conn).await.unwrap_or_default();

for i in 0..2 {
let row = sqlx::query("SELECT ? AS val")
.bind(i)
.persistent(false)
.fetch_one(&mut conn)
.await?;

let val: i32 = row.get("val");

assert_eq!(i, val);
}

let new_statement_count = select_statement_count(&mut conn).await.unwrap_or_default();

assert_eq!(old_statement_count, new_statement_count);

Ok(())
}

#[sqlx_macros::test]
async fn it_closes_statements_with_cache_disabled() -> anyhow::Result<()> {
setup_if_needed();

let mut url = url::Url::parse(&env::var("DATABASE_URL")?)?;
url.query_pairs_mut()
.append_pair("statement-cache-capacity", "0");

let mut conn = MySqlConnection::connect(url.as_ref()).await?;

let old_statement_count = select_statement_count(&mut conn).await.unwrap_or_default();

for index in 1..=10_i32 {
let _ = sqlx::query("SELECT ?")
.bind(index)
.execute(&mut conn)
.await?;
}

let new_statement_count = select_statement_count(&mut conn).await.unwrap_or_default();

assert_eq!(old_statement_count, new_statement_count);

Ok(())
}

#[sqlx_macros::test]
async fn it_can_bind_null_and_non_null_issue_540() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
Expand Down Expand Up @@ -510,3 +561,18 @@ async fn test_shrink_buffers() -> anyhow::Result<()> {

Ok(())
}

async fn select_statement_count(conn: &mut MySqlConnection) -> Result<i64, sqlx::Error> {
// Fails if performance schema does not exist
sqlx::query_scalar(
r#"
SELECT COUNT(*)
FROM performance_schema.threads AS t
INNER JOIN performance_schema.prepared_statements_instances AS psi
ON psi.OWNER_THREAD_ID = t.THREAD_ID
WHERE t.processlist_id = CONNECTION_ID()
"#,
)
.fetch_one(conn)
.await
}

0 comments on commit 29dcd44

Please sign in to comment.