From 95d39be76d0b3c813e7f3a64ab4cba743b7e5dc8 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 20 Apr 2021 19:10:50 -0700 Subject: [PATCH] WIP feat: reintroduce `Pool` Signed-off-by: Austin Bonander --- Cargo.lock | 27 ++ script/enforce-new-mod-style.sh | 30 +++ sqlx-core/src/runtime.rs | 11 + sqlx-core/src/runtime/tokio.rs | 10 + sqlx/Cargo.toml | 10 + sqlx/src/lib.rs | 3 + sqlx/src/pool.rs | 65 +++++ sqlx/src/pool/connection.rs | 156 +++++++++++ sqlx/src/pool/options.rs | 268 +++++++++++++++++++ sqlx/src/pool/shared.rs | 245 +++++++++++++++++ sqlx/src/pool/wait_list.rs | 453 ++++++++++++++++++++++++++++++++ 11 files changed, 1278 insertions(+) create mode 100755 script/enforce-new-mod-style.sh create mode 100644 sqlx/src/pool.rs create mode 100644 sqlx/src/pool/connection.rs create mode 100644 sqlx/src/pool/options.rs create mode 100644 sqlx/src/pool/shared.rs create mode 100644 sqlx/src/pool/wait_list.rs diff --git a/Cargo.lock b/Cargo.lock index dab07ae59c..98d3b3aa4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -429,6 +429,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f55667319111d593ba876406af7c409c0ebb44dc4be6132a783ccf163ea14c1" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.13" @@ -436,6 +451,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c2dd2df839b57db9ab69c2c9d8f3e8c81984781937fe2807dc6dcf3b2ad2939" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -488,6 +504,12 @@ dependencies = [ "syn", ] +[[package]] +name = "futures-sink" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c5629433c555de3d82861a7a4e3794a4c40040390907cfbfd7143a92a426c23" + [[package]] name = "futures-task" version = "0.3.13" @@ -500,9 +522,11 @@ version = "0.3.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1812c7ab8aedf8d6f2701a43e1243acdbcc2b36ab26e2ad421eb99ac963d96d1" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", + "futures-sink", "futures-task", "memchr", "pin-project-lite", @@ -1088,7 +1112,10 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" name = "sqlx" version = "0.6.0-pre" dependencies = [ + "crossbeam-queue", + "futures", "futures-util", + "parking_lot", "sqlx-core", "sqlx-mysql", "sqlx-postgres", diff --git a/script/enforce-new-mod-style.sh b/script/enforce-new-mod-style.sh new file mode 100755 index 0000000000..a448d3fd41 --- /dev/null +++ b/script/enforce-new-mod-style.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +# This script scans the project for `mod.rs` files and exits with a nonzero code if it finds any. +# +# You can also call it with `--fix` to replace any `mod.rs` files with their 2018 edition equivalents. +# The new files will be staged for commit for convenience. + +FILES=$(find ./ -name mod.rs -print) + +if [[ -z $FILES ]]; then + exit 0 +fi + +if [ "$1" != "--fix" ]; then + echo 'This project uses the Rust 2018 module style. mod.rs files are forbidden.' + echo "Execute \`$0 --fix\` to replace these with their 2018 equivalents and stage for commit." + echo 'Found mod.rs files:' + echo "$FILES" + exit 1 +fi + +echo 'Fixing Rust 2018 Module Style' + +while read -r file; do + dest="$(dirname $file).rs" + echo "$file -> $dest" + mv $file $dest + git add $dest +done <<< $FILES + diff --git a/sqlx-core/src/runtime.rs b/sqlx-core/src/runtime.rs index 9f2e61469f..99b4c47004 100644 --- a/sqlx-core/src/runtime.rs +++ b/sqlx-core/src/runtime.rs @@ -26,6 +26,8 @@ mod tokio_; pub use actix_::Actix; #[cfg(feature = "async-std")] pub use async_std_::AsyncStd; +use std::future::Future; +use std::time::Instant; #[cfg(feature = "tokio")] pub use tokio_::Tokio; @@ -82,6 +84,15 @@ pub trait Runtime: 'static + Send + Sync + Sized + Debug { fn connect_unix_async(path: &Path) -> BoxFuture<'_, io::Result> where Self: Async; + + #[doc(hidden)] + #[cfg(all(unix, feature = "async"))] + fn timeout_at_async<'a, F: Future + Send + 'a>( + fut: F, + deadline: Instant, + ) -> BoxFuture<'a, Option> + where + Self: Async; } /// Marks a [`Runtime`] as being capable of handling asynchronous execution. diff --git a/sqlx-core/src/runtime/tokio.rs b/sqlx-core/src/runtime/tokio.rs index 8422e15e3e..943632ddbe 100644 --- a/sqlx-core/src/runtime/tokio.rs +++ b/sqlx-core/src/runtime/tokio.rs @@ -12,6 +12,8 @@ use futures_util::{AsyncReadExt, AsyncWriteExt, FutureExt, TryFutureExt}; use crate::io::Stream; use crate::{Async, Runtime}; +use std::future::Future; +use std::time::Instant; /// Provides [`Runtime`] for [**Tokio**](https://tokio.rs). Supports only non-blocking operation. /// @@ -55,6 +57,14 @@ impl Runtime for Tokio { fn connect_unix_async(path: &Path) -> BoxFuture<'_, io::Result> { UnixStream::connect(path).map_ok(Compat::new).boxed() } + + #[doc(hidden)] + fn timeout_at_async<'a, F: Future + Send + 'a>( + fut: F, + deadline: Instant, + ) -> BoxFuture<'a, Option> { + Box::pin(_tokio::time::timeout_at(deadline.into(), fut).map(Result::ok)) + } } impl Async for Tokio {} diff --git a/sqlx/Cargo.toml b/sqlx/Cargo.toml index ef2da2ce34..b633c9636c 100644 --- a/sqlx/Cargo.toml +++ b/sqlx/Cargo.toml @@ -28,6 +28,9 @@ async-std = ["async", "sqlx-core/async-std"] actix = ["async", "sqlx-core/actix"] tokio = ["async", "sqlx-core/tokio"] +# Connection Pool +pool = ["crossbeam-queue", "parking_lot"] + # MySQL mysql = ["sqlx-mysql"] mysql-async = ["async", "mysql", "sqlx-mysql/async"] @@ -38,8 +41,15 @@ postgres = ["sqlx-postgres"] postgres-async = ["async", "postgres", "sqlx-postgres/async"] postgres-blocking = ["blocking", "postgres", "sqlx-postgres/blocking"] + [dependencies] sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core" } sqlx-mysql = { version = "0.6.0-pre", path = "../sqlx-mysql", optional = true } sqlx-postgres = { version = "0.6.0-pre", path = "../sqlx-postgres", optional = true } futures-util = { version = "0.3", optional = true, features = ["io"] } + +crossbeam-queue = { version = "0.3.1", optional = true } +parking_lot = { version = "0.11", optional = true } + +[dev-dependencies] +futures = "0.3.5" diff --git a/sqlx/src/lib.rs b/sqlx/src/lib.rs index 530d33649f..57a04a2aa8 100644 --- a/sqlx/src/lib.rs +++ b/sqlx/src/lib.rs @@ -46,6 +46,9 @@ #[cfg(feature = "blocking")] pub mod blocking; +#[cfg(feature = "pool")] +pub mod pool; + mod query; mod query_as; mod runtime; diff --git a/sqlx/src/pool.rs b/sqlx/src/pool.rs new file mode 100644 index 0000000000..16035152c6 --- /dev/null +++ b/sqlx/src/pool.rs @@ -0,0 +1,65 @@ +use crate::pool::connection::{Idle, Pooled}; +use crate::pool::options::PoolOptions; +use crate::pool::shared::{SharedPool, TryAcquireResult}; +use crate::pool::wait_list::WaitList; +use crate::{Connect, Connection, DefaultRuntime, Runtime}; +use crossbeam_queue::ArrayQueue; +use std::sync::atomic::AtomicU32; +use std::sync::Arc; +use std::time::Instant; + +mod connection; +mod options; +mod shared; +mod wait_list; + +pub struct Pool> { + shared: Arc>, +} + +impl> Pool { + pub fn new(uri: &str) -> crate::Result { + Self::builder().build(uri) + } + + pub fn new_with(connect_options: >::Options) -> Self { + Self::builder().build_with(connect_options) + } + + pub fn builder() -> PoolOptions { + PoolOptions::new() + } +} + +#[cfg(feature = "async")] +impl> Pool { + pub async fn connect(uri: &str) -> crate::Result { + Self::builder().connect(uri).await + } + + pub async fn connect_with(connect_options: >::Options) -> crate::Result { + Self::builder().connect_with(connect_options).await + } + + pub async fn acquire(&self) -> crate::Result> {} + + async fn acquire_inner(&self, deadline: Option) -> crate::Result> { + let mut acquire_permit = None; + + loop { + match self.shared.try_acquire(acquire_permit.take()) { + TryAcquireResult::Acquired(mut conn) => { + match self.shared.on_acquire_async(&mut conn) { + Ok(()) => return Ok(conn.attach(&self.shared)), + Err(e) => { + log::info!("error from before_acquire: {:?}", e); + } + } + } + TryAcquireResult::Connect(permit) => self.shared.connect_async(permit).await, + TryAcquireResult::Wait => {} + TryAcquireResult::PoolClosed => Err(todo!("crate::Error::PoolClosed")), + } + } + } +} diff --git a/sqlx/src/pool/connection.rs b/sqlx/src/pool/connection.rs new file mode 100644 index 0000000000..3d1e27c816 --- /dev/null +++ b/sqlx/src/pool/connection.rs @@ -0,0 +1,156 @@ +use super::shared::{DecrementSizeGuard, SharedPool}; +use crate::{Connection, Runtime}; +use std::fmt::{self, Debug, Formatter}; +use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; +use std::time::Instant; + +/// A connection managed by a [`Pool`][crate::pool::Pool]. +/// +/// Will be returned to the pool on-drop. +pub struct Pooled> { + live: Option, + pub(crate) pool: Arc>, +} + +pub(super) struct Live> { + pub(super) raw: C, + pub(super) created: Instant, + _rt: PhantomData, +} + +pub(super) struct Idle> { + pub(super) live: Live, + pub(super) since: Instant, +} + +/// RAII wrapper for connections being handled by functions that may drop them +pub(super) struct Floating<'pool, C> { + inner: C, + guard: DecrementSizeGuard<'pool>, +} + +const DEREF_ERR: &str = "(bug) connection already released to pool"; + +impl> Debug for Pooled { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // TODO: Show the type name of the connection ? + f.debug_struct("PoolConnection").finish() + } +} + +impl> Deref for Pooled { + type Target = C; + + fn deref(&self) -> &Self::Target { + &self.live.as_ref().expect(DEREF_ERR).raw + } +} + +impl> DerefMut for Pooled { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.live.as_mut().expect(DEREF_ERR).raw + } +} + +impl> Pooled { + /// Explicitly release a connection from the pool + pub fn release(mut self) -> C { + self.live.take().expect("PoolConnection double-dropped").float(&self.pool).detach() + } +} + +/// Returns the connection to the [`Pool`][crate::pool::Pool] it was checked-out from. +impl> Drop for Pooled { + fn drop(&mut self) { + if let Some(live) = self.live.take() { + self.pool.release(live); + } + } +} + +impl> Live { + pub fn float(self, guard: DecrementSizeGuard<'_>) -> Floating<'_, Self> { + Floating { inner: self, guard } + } + + pub fn into_idle(self) -> Idle { + Idle { live: self, since: Instant::now() } + } +} + +impl> Idle { + pub fn float(self, guard: DecrementSizeGuard<'_>) -> Floating<'_, Self> { + Floating { inner: self, guard } + } +} + +impl> Deref for Idle { + type Target = Live; + + fn deref(&self) -> &Self::Target { + &self.live + } +} + +impl> DerefMut for Idle { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.live + } +} + +impl<'s, C> Floating<'s, C> { + pub fn into_leakable(self) -> C { + self.guard.cancel(); + self.inner + } + + pub fn same_pool(&self, other: &SharedPool) -> bool { + self.guard.same_pool(other) + } +} + +impl<'s, Rt: Runtime, C: Connection> Floating<'s, Live> { + pub fn attach(self, pool: &Arc>) -> Pooled { + let Floating { inner, guard } = self; + + debug_assert!(guard.same_pool(pool), "BUG: attaching connection to different pool"); + + guard.cancel(); + Pooled { live: Some(inner), pool: Arc::clone(pool) } + } + + pub fn detach(self) -> C { + self.inner.raw + } + + pub fn into_idle(self) -> Floating<'s, Idle> { + Floating { inner: self.inner.into_idle(), guard: self.guard } + } +} + +impl<'s, Rt: Runtime, C: Connection> Floating<'s, Idle> { + pub fn into_live(self) -> Floating<'s, Live> { + Floating { inner: self.inner.live, guard: self.guard } + } + + pub async fn close(self) -> crate::Result<()> { + // `guard` is dropped as intended + self.inner.live.raw.close().await + } +} + +impl Deref for Floating<'_, C> { + type Target = C; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for Floating<'_, C> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} diff --git a/sqlx/src/pool/options.rs b/sqlx/src/pool/options.rs new file mode 100644 index 0000000000..f0da65514c --- /dev/null +++ b/sqlx/src/pool/options.rs @@ -0,0 +1,268 @@ +use crate::pool::shared::SharedPool; +use crate::pool::Pool; +use crate::{Connect, ConnectOptions, Connection, Runtime}; +use std::cmp; +use std::fmt::{self, Debug, Formatter}; +use std::marker::PhantomData; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +pub struct PoolOptions> { + // to satisfy the orphan type params check + _rt: PhantomData, + + // general options + pub(crate) max_connections: u32, + pub(crate) connect_timeout: Duration, + pub(crate) min_connections: u32, + pub(crate) max_lifetime: Option, + pub(crate) idle_timeout: Option, + + // callback functions (any runtime) + pub(crate) after_release: Option bool + 'static + Send + Sync>>, + + // callback functions (async) + #[cfg(feature = "async")] + pub(crate) after_connect_async: Option< + Box< + dyn Fn(&mut C) -> futures_util::BoxFuture<'_, crate::Result<()>> + + Send + + Sync + + 'static, + >, + >, + + #[cfg(feature = "async")] + pub(crate) before_acquire_async: Option< + Box< + dyn Fn(&mut C) -> futures_util::BoxFuture<'_, crate::Result<()>> + + Send + + Sync + + 'static, + >, + >, + + //callback functions (blocking) + #[cfg(feature = "blocking")] + pub(crate) after_connect_blocking: + Option crate::Result<()> + Send + Sync + 'static>>, + #[cfg(feature = "blocking")] + pub(crate) before_acquire_blocking: + Option crate::Result<()> + Send + Sync + 'static>>, +} + +impl> Default for PoolOptions { + fn default() -> Self { + Self::new() + } +} + +impl> PoolOptions { + /// Create a new `PoolOptions` with some arbitrary, but sane, default values. + /// + /// See the source of this method for the current values. + pub fn new() -> Self { + Self { + _rt: PhantomData, + min_connections: 0, + max_connections: 10, + connect_timeout: Duration::from_secs(30), + idle_timeout: Some(Duration::from_secs(10 * 60)), + max_lifetime: Some(Duration::from_secs(30 * 60)), + after_release: None, + #[cfg(feature = "async")] + after_connect_async: None, + #[cfg(feature = "async")] + before_acquire_async: None, + #[cfg(feature = "blocking")] + after_connect_blocking: None, + #[cfg(feature = "blocking")] + before_acquire_blocking: None, + } + } + + /// Set the minimum number of connections that this pool should maintain at all times. + /// + /// When the pool size drops below this amount, new connections are established automatically + /// in the background. + pub fn min_connections(mut self, min: u32) -> Self { + self.min_connections = min; + self + } + + /// Set the maximum number of connections that this pool should maintain. + pub fn max_connections(mut self, max: u32) -> Self { + self.max_connections = max; + self + } + + /// Set the amount of time to attempt connecting to the database. + /// + /// If this timeout elapses, [`Pool::acquire`] will return an error. + pub fn connect_timeout(mut self, timeout: Duration) -> Self { + self.connect_timeout = timeout; + self + } + + /// Set the maximum lifetime of individual connections. + /// + /// Any connection with a lifetime greater than this will be closed. + /// + /// When set to `None`, all connections live until either reaped by [`idle_timeout`] + /// or explicitly disconnected. + /// + /// Long-lived connections are not recommended due to the unfortunate reality of memory/resource + /// leaks on the database-side. It is better to retire connections periodically + /// (even if only once daily) to allow the database the opportunity to clean up data structures + /// (parse trees, query metadata caches, thread-local storage, etc.) that are associated with a + /// session. + /// + /// [`idle_timeout`]: Self::idle_timeout + pub fn max_lifetime(mut self, lifetime: impl Into>) -> Self { + self.max_lifetime = lifetime.into(); + self + } + + /// Set a maximum idle duration for individual connections. + /// + /// Any connection with an idle duration longer than this will be closed. + /// + /// For usage-based database server billing, this can be a cost saver. + pub fn idle_timeout(mut self, timeout: impl Into>) -> Self { + self.idle_timeout = timeout.into(); + self + } + + /// If set, the health of a connection will be verified by a call to [`Connection::ping`] + /// before returning the connection. + /// + /// This overrides a previous callback set to [Self::before_acquire] and is also overridden by + /// `before_acquire`. + pub fn test_before_acquire(mut self) -> Self { + #[cfg(feature = "async")] + self.before_acquire_async = Some(Box::new(Connection::ping)); + #[cfg(feature = "blocking")] + todo!("Connection doesn't have a ping_blocking()"); + + self + } + + pub fn after_release(mut self, callback: F) -> Self + where + F: Fn(&mut C) -> bool + 'static + Send + Sync, + { + self.after_release = Some(Box::new(callback)); + self + } + + /// Creates a new pool from this configuration. + /// + /// Note that **this does not immediately connect to the database**; + /// this call will only error if the URI fails to parse. + /// + /// A connection will first be established either on the first call to + /// [`Pool::acquire()`][super::Pool::acquire()] or, + /// if [`self.min_connections`][Self::min_connections] is nonzero, + /// when the background monitor task (async runtime) or thread (blocking runtime) is spawned. + /// + /// If you prefer to establish a minimum number of connections on startup to ensure a valid + /// configuration, use [`.connect()`][Self::connect()] instead. + /// + /// See [`Self::build_with()`] for a version that lets you pass a [`ConnectOptions`]. + pub fn build(self, uri: &str) -> crate::Result> { + Ok(self.build_with(uri.parse()?)) + } + + /// Creates a new pool from this configuration. + /// + /// Note that **this does not immediately connect to the database**; + /// this method call is infallible. + /// + /// A connection will first be established either on the first call to + /// [`Pool::acquire()`][super::Pool::acquire()] or, + /// if [`self.min_connections`][Self::min_connections] is nonzero, + /// when the background monitor task (async runtime) or thread (blocking runtime) is spawned. + /// + /// If you prefer to establish at least one connections on startup to ensure a valid + /// configuration, use [`.connect_with()`][Self::connect_with()] instead. + pub fn build_with(self, options: >::Options) -> Pool { + Pool { shared: SharedPool::new(self, options).into() } + } +} + +#[cfg(feature = "async")] +impl> PoolOptions { + /// Perform an action after connecting to the database. + pub fn after_connect(mut self, callback: F) -> Self + where + for<'c> F: + Fn(&'c mut C) -> futures_util::BoxFuture<'c, crate::Result<()>> + Send + Sync + 'static, + { + self.after_connect = Some(Box::new(callback)); + self + } + + /// If set, this callback is executed with a connection that has been acquired from the idle + /// queue. + /// + /// If the callback returns `Ok`, the acquired connection is returned to the caller. If + /// it returns `Err`, the error is logged and the caller attempts to acquire another connection. + /// + /// This overrides [`Self::test_before_acquire()`]. + pub fn before_acquire(mut self, callback: F) -> Self + where + for<'c> F: Fn(&'c mut C) -> futures_util::BoxFuture<'c, crate::Result> + + Send + + Sync + + 'static, + { + self.before_acquire = Some(Box::new(callback)); + self + } + + /// Creates a new pool from this configuration and immediately establishes + /// [`self.min_connections`][Self::min_connections()], + /// or just one connection if `min_connections == 0`. + /// + /// Returns an error if the URI fails to parse or an error occurs while establishing a connection. + /// + /// See [`Self::connect_with()`] for a version that lets you pass a [`ConnectOptions`]. + /// + /// If you do not want to connect immediately on startup, + /// use [`.build()`][Self::build()] instead. + pub async fn connect(self, uri: &str) -> crate::Result> { + self.connect_with(uri.parse()?).await + } + + /// Creates a new pool from this configuration and immediately establishes + /// [`self.min_connections`][Self::min_connections()], + /// or just one connection if `min_connections == 0`. + /// + /// Returns an error if an error occurs while establishing a connection. + /// + /// If you do not want to connect immediately on startup, + /// use [`.build_with()`][Self::build_with()] instead. + pub async fn connect_with( + self, + options: >::Options, + ) -> crate::Result> { + let mut shared = SharedPool::new(self, options); + + shared.init_min_connections().await?; + + Ok(Pool { shared: shared.into() }) + } +} + +impl> Debug for PoolOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("PoolOptions") + .field("max_connections", &self.max_connections) + .field("min_connections", &self.min_connections) + .field("connect_timeout", &self.connect_timeout) + .field("max_lifetime", &self.max_lifetime) + .field("idle_timeout", &self.idle_timeout) + .field("test_before_acquire", &self.test_before_acquire) + .finish() + } +} diff --git a/sqlx/src/pool/shared.rs b/sqlx/src/pool/shared.rs new file mode 100644 index 0000000000..16d538c04f --- /dev/null +++ b/sqlx/src/pool/shared.rs @@ -0,0 +1,245 @@ +use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use std::sync::Arc; +use std::time::Instant; +use std::{cmp, mem, ptr}; + +use crossbeam_queue::ArrayQueue; + +use crate::pool::connection::{Floating, Idle, Pooled}; +use crate::pool::options::PoolOptions; +use crate::pool::wait_list::WaitList; +use crate::{Acquire, Connect, Connection, Runtime}; + +pub struct SharedPool> { + idle: ArrayQueue>, + wait_list: WaitList, + size: AtomicU32, + is_closed: AtomicBool, + pool_options: PoolOptions, + connect_options: >::Options, +} + +/// RAII guard returned by `Pool::try_increment_size()` and others. +/// +/// Will decrement the pool size if dropped, to avoid semantically "leaking" connections +/// (where the pool thinks it has more connections than it does). +pub struct DecrementSizeGuard<'pool> { + size: &'pool AtomicU32, + wait_list: &'pool WaitList, + dropped: bool, +} + +// NOTE: neither of these may be `Copy` or `Clone`! +pub struct ConnectPermit<'pool>(DecrementSizeGuard<'pool>); +pub struct AcquirePermit<'pool>(&'pool AtomicU32); // just need a pointer to compare for sanity check + +/// Returned by `SharedPool::try_acquire()`. +/// +/// Compared to SQLx <= 0.5, the process of acquiring a connection is broken into distinct steps +/// in order to facilitate both blocking and nonblocking versions. +pub enum TryAcquireResult<'pool, Rt: Runtime, C: Connection> { + /// A connection has been acquired from the idle queue. + /// + /// Depending on the pool settings, it may still need to be tested for liveness before being + /// returned to the user. + Acquired(Floating<'pool, Idle>), + /// The pool's current size dropped below its maximum and a new connection may be opened. + /// + /// Call `.connect_async()` or `.connect_blocking()` with the given permit. + Connect(ConnectPermit<'pool>), + /// The task or thread should wait and call `.try_acquire()` again. + /// + /// The inner value is the same `AcquirePermit` that was passed to `.try_acquire()`. + Wait, + /// The pool is closed; the attempt to acquire the connection should return an error. + PoolClosed, +} + +impl> SharedPool { + pub fn new( + pool_options: PoolOptions, + connect_options: >::Options, + ) -> Self { + Self { + idle: ArrayQueue::new(pool_options.max_connections as usize), + wait_list: WaitList::new(), + size: AtomicU32::new(0), + is_closed: AtomicBool::new(false), + pool_options, + connect_options, + } + } + + #[inline] + pub fn is_closed(&self) -> bool { + self.is_closed.load(Ordering::Acquire) + } + + /// Attempt to acquire a connection. + /// + /// If `permit` is `Some`, + pub fn try_acquire(&self, permit: Option>) -> TryAcquireResult<'_, C> { + use TryAcquireResult::*; + + assert!( + permit.map_or(true, |permit| ptr::eq(&self.size, permit.0)), + "BUG: given AcquirePermit is from a different pool" + ); + + if self.is_closed() { + return PoolClosed; + } + + // if the user has an `AcquirePermit`, then they've already waited at least once + // and we should try to get them a connection immediately if possible; + // + // otherwise, we can immediately return a connection or `ConnectPermit` if no one is waiting + if permit.is_some() || self.wait_list.is_empty() { + // try to pull a connection from the idle queue + if let Some(idle) = self.idle.pop() { + return Acquired(idle.float(self)); + } + + // try to bump `self.size` + if let Some(guard) = self.try_increment_size() { + return Connect(ConnectPermit(guard)); + } + } + + // check again after the others to make sure + if self.is_closed() { + return PoolClosed; + } + + Wait + } + + /// Attempt to increment the current size, failing if it would exceed the maximum size. + fn try_increment_size(&self) -> Option> { + if self.is_closed() { + return None; + } + + self.size + .fetch_update(Ordering::AcqRel, Ordering::Acquire, |size| { + (size < todo!("self.options.max_connections")).then(|| size + 1) + }) + .ok() + .map(|_| DecrementSizeGuard::new(self)) + } +} + +#[cfg(feature = "async")] +impl> SharedPool { + pub async fn wait_async(&self, deadline: Option) -> Option { + if let Some(deadline) = deadline { + // returns `None` if deadline elapses + Rt::timeout_at(self.wait_list.wait(), deadline).await?; + } else { + self.wait_list.wait().await; + } + + Some(AcquirePermit(&self.size)) + } + + pub async fn connect_async( + self: &Arc, + permit: ConnectPermit, + ) -> crate::Result> + where + C: crate::Connect, + { + assert!(permit.0.same_pool(self), "BUG: ConnectPermit is from a different pool!"); + + let mut conn = crate::Connect::connect_with(&self.connect_options) + .await + .map(|c| Floating::new_live(c, permit.0))?; + + if let Some(ref after_connect) = self.pool_options.after_connect_async { + after_connect(&mut conn).await?; + } + + Ok(conn.attach(self)) + } + + pub async fn on_acquire_async( + self: &Arc, + conn: &mut Floating<'_, C>, + ) -> crate::Result<()> { + assert!(conn.same_pool(self), "BUG: connection is from a different pool"); + + if let Some(ref before_acquire) = self.pool_options.before_acquire_async { + before_acquire(conn).await?; + } + + Ok(()) + } + + pub async fn init_min_connections>( + &mut self, + ) -> crate::Result<()> { + for _ in 0..cmp::max(pool.options.min_connections, 1) { + let deadline = Instant::now() + pool.options.connect_timeout; + + // this guard will prevent us from exceeding `max_size` + if let Some(guard) = pool.try_increment_size() { + // [connect] will raise an error when past deadline + let conn = pool.connection(deadline, guard).await?; + let is_ok = pool.idle_conns.push(conn.into_idle().into_leakable()).is_ok(); + + if !is_ok { + panic!("BUG: connection queue overflow in init_min_connections"); + } + } + } + + Ok(()) + } +} + +#[cfg(feature = "blocking")] +impl> SharedPool { + pub fn wait_blocking(&self, deadline: Option) -> Option> { + self.wait_list.wait().block_on(deadline).then(|| AcquirePermit(&self.size)) + } + + pub fn connect_blocking( + self: &Arc, + permit: ConnectPermit<'_>, + ) -> crate::Result> + where + C: crate::blocking::Connect, + { + assert!(permit.0.same_pool(self), "BUG: ConnectPermit is from a different pool!"); + + crate::blocking::Connect::connect_with(&self.connect_options) + .map(|c| Floating::new_live(c, permit.0).attach(self)) + } +} + +impl<'pool> DecrementSizeGuard<'pool> { + fn new>(pool: &'pool SharedPool) -> Self { + Self { size: &pool.size, wait_list: &pool.wait_list, dropped: false } + } + + /// Return `true` if the internal references point to the same fields in `SharedPool`. + pub fn same_pool>( + &self, + pool: &'pool SharedPool, + ) -> bool { + ptr::eq(self.size, &pool.size) && ptr::eq(self.wait_list, &pool.wait_list) + } + + pub fn cancel(self) { + mem::forget(self); + } +} + +impl Drop for DecrementSizeGuard<'_> { + fn drop(&mut self) { + assert!(!self.dropped, "double-dropped!"); + self.dropped = true; + self.size.fetch_sub(1, Ordering::SeqCst); + self.wait_list.wake_one(); + } +} diff --git a/sqlx/src/pool/wait_list.rs b/sqlx/src/pool/wait_list.rs new file mode 100644 index 0000000000..52559b2ac5 --- /dev/null +++ b/sqlx/src/pool/wait_list.rs @@ -0,0 +1,453 @@ +// see `SAFETY:` annotations +#![allow(unsafe_code)] + +use parking_lot::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::{Context, Poll, Waker}; +use std::thread::{self, Thread}; +use std::time::Instant; + +/// An intrusive list of waiting tasks. +/// +/// Tasks wait by calling `.wait().await` for async code or `.wait().block_on(deadline)` +/// for blocking code where `deadline` is `Option` +pub struct WaitList(RwLock); + +struct ListInner { + // NOTE: these must either both be null or both be pointing to a node + /// The head of the list; if NULL then the list is empty. + head: *mut Node, + /// The tail of the list; if NULL then the list is empty. + tail: *mut Node, +} + +// SAFETY: access to `Node` pointers must be protected by a lock +// this could potentially be made lock-free but the critical sections are short +// so using a lightweight RwLock like from `parking_lot` seemed reasonable +unsafe impl Send for ListInner {} +unsafe impl Sync for ListInner {} + +impl WaitList { + pub fn new() -> Self { + WaitList(RwLock::new(ListInner { head: ptr::null_mut(), tail: ptr::null_mut() })) + } + + pub fn is_empty(&self) -> bool { + let inner = self.0.read(); + inner.head.is_null() && inner.tail.is_null() + } + + pub fn wake_one(&self) { + self.0.read().wake(false) + } + + pub fn wake_all(&self) { + self.0.read().wake(true) + } + + /// Wait in this waitlist for a call to either `.wake_one()` or `.wake_all()`. + /// + /// The returned handle may either be `.await`ed for async code, or you can call + /// `.block_on(deadline)` for blocking code, where `deadline` is the optional `Instant` + /// at which to stop waiting. + pub fn wait(&self) -> Wait<'_> { + Wait { list: &self.0, node: None, actually_woken: bool, _not_unpin: PhantomPinned } + } +} + +impl ListInner { + /// Wake either one or all nodes in the list. + fn wake(&self, all: bool) { + let mut node_p: *const Node = inner.head; + + // SAFETY: `node_p` is not dangling as long as we have a shared lock + // (implied by having `&self`) + while let Some(node) = unsafe { node_p.as_ref() } { + // `.wake()` only returns `true` if the node was not already woken + if node.wake() && !all { + break; + } + + node_p = node.next; + } + } +} + +pub struct Wait<'a> { + list: &'a RwLock, + /// SAFETY: `Node` must not be modified without a lock + /// SAFETY: `Node` may not be moved once it's entered in the list + node: Option, + actually_woken: bool, + _not_unpin: PhantomPinned, +} + +/// cancel-safe +impl<'a> Future for Wait<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let node = self.get_node(|| Wake::Waker(cx.waker().clone())); + + if node.woken.load(Ordering::Acquire) { + // SAFETY: not moving out of `self` here + unsafe { self.get_unchecked_mut().actually_woken = true } + Poll::Ready(()) + } else { + let wake = RwLock::upgradable_read(&node.wake); + + // make sure our `Waker` is up to date; + // the waker may change if the task moves between threads + if !wake.waker_eq(cx.waker()) { + *RwLockUpgradableReadGuard::upgrade(wake) = Wake::Waker(cx.waker().clone()); + } + + Poll::Pending + } + } +} + +impl<'a> Wait<'a> { + /// Insert a node into the parent `WaitList` referred to by `self` and return it. + /// + /// The provided closure should return the appropriate `Wake` variant for waking the calling + /// task. + fn get_node(self: Pin<&mut Self>, get_wake: impl FnOnce() -> Wake) -> &Node { + // SAFETY: `this.node` must not be moved once it's entered in the list + let this = unsafe { self.get_unchecked_mut() }; + + if let Some(ref node) = this.node { + node + } else { + // FIXME: use `Option::insert()` when stable + let node = this.node.get_or_insert_with(|| Node::new(get_wake())); + + // SAFETY: we need an exclusive lock to modify the list + let mut list = this.list.write(); + + if list.head.is_null() { + // sanity check; see `ListInner` definition + assert!(list.tail.is_null()); + + // the list is empty so insert this node as both the head and tail + list.head = node; + list.tail = node; + } else { + // sanity check; see `ListInner` definition + assert!(!list.tail.is_null()); + + // the list is nonempty so insert this node as the tail + + // SAFETY: `list.tail` is not null because of the above assert and + // not dangling as long as we have an exclusive lock for modifying the list + // (or any nodes in it) + unsafe { + // set the `next` pointer of the previous tail to this node + (*list.tail).next = node; + } + node.prev = list.tail; + list.tail = node; + } + + node + } + } + + /// Block until woken. + /// + /// Returns `true` if we were woken without the deadline elapsing, `false` if the deadline elapsed. + /// If no deadline is set then this always returns `true`. + #[cfg(feature = "blocking")] + pub fn block_on(mut self, deadline: Option) -> bool { + // SAFETY: + // * `self` may not escape this scope + // * `this` may not be moved out of + // * we must remove `self.node` from the list before returning (covered by `impl Drop for Self`) + let mut this = unsafe { Pin::new_unchecked(&mut self) }; + let node = this.as_mut().get_node(|| Wake::Thread(thread::current())); + + while !node.woken.load(Ordering::Acquire) { + if let Some(deadline) = deadline { + let now = Instant::now(); + + if deadline < now { + return false; + } else { + // N.B. may wake spuriously + thread::park_timeout(deadline - now); + } + } else { + // N.B. may return spuriously + thread::park(); + } + } + + // SAFETY: we're not moving out of `this` here + unsafe { + this.get_unchecked_mut().actually_woken = true; + } + + true + } +} + +// SAFETY: since futures must be pinned to be polled we can be sure that `Drop::drop()` is called +// because there's no way to leak a future without the memory location remaining valid for the +// life of the program: +// * can't be moved into `mem::forget()` or an Rc-cycle because it's pinned +// * leaking `Pin>` or via Rc-cycle keeps it around forever, perfectly fine +// * aborting the program means it's not our problem anymore +// +// The only way this could cause memory issues is if the *thread* is aborted without unwinding +// or aborting the process, which doesn't have a safe API in Rust and the C APIs for canceling +// threads don't recommend doing it either for similar reasons. +// * https://man7.org/linux/man-pages/man3/pthread_exit.3.html#DESCRIPTION +// * https://docs.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-exitthread#remarks +// +// However, if Rust were to gain a safe API for instantly exiting a thread it would completely break +// the assumptions that the `Pin` API are built on so it's not something for us to worry about +// specifically. +impl<'a> Drop for Wait<'a> { + fn drop(&mut self) { + // SAFETY: we must have an exclusive lock while we're futzing with the list + let mut list = self.list.write(); + + // remove the node from the list, + // linking the previous node (if applicable) to the next node (if applicable) + if let Some(node) = &mut self.node { + // SAFETY: `prev` cannot be dangling while we have an exclusive lock + if let Some(prev) = unsafe { node.prev.as_mut() } { + // set the `next` pointer of the previous node to this node's `next` pointer + // note: `node.next` may be null which means we're the tail of the list + prev.next = node.next; + } else { + // we were the head of the list so we set the head to the next node + list.head = node.next; + } + + // SAFETY: `next` cannot be dangling while we have an exclusive lock + if let Some(next) = unsafe { node.next.as_mut() } { + // set the `prev` pointer of the next node to this node's `prev` pointer + // note: `node.prev` may be null which means we're the head of the list + next.prev = node.prev; + } else { + // we were the tail of the list so we set the tail to the previous node + list.tail = node.prev; + } + + // sanity check; see `ListInner` definition + assert_eq!(list.head.is_null(), list.tail.is_null()); + + // if this node was marked woken but we didn't actually wake, + // then we need to wake the next node in the list + if node.woken.load(Ordering::Acquire) && !self.actually_woken { + RwLockWriteGuard::downgrade(list).wake(false); + } + } + } +} + +struct Node { + /// The previous node in the list. If NULL, then this node is the head of the list. + prev: *mut Node, + /// The next node in the list. If NULL, then this node is the tail of the list. + next: *mut Node, + woken: AtomicBool, + wake: RwLock, +} + +// SAFETY: access to `Node` pointers must be protected by a lock +unsafe impl Send for Node {} +unsafe impl Sync for Node {} + +impl Node { + fn new(wake: Wake) -> Self { + Node { + prev: ptr::null_mut(), + next: ptr::null_mut(), + woken: AtomicBool::new(false), + wake: RwLock::new(wake), + } + } + + /// Returns `true` if this node was woken by this call, `false` otherwise. + fn wake(&self) -> bool { + let do_wake = + self.woken.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire).is_ok(); + + if do_wake { + match &*self.wake.read() { + Wake::Waker(waker) => waker.wake_by_ref(), + #[cfg(feature = "blocking")] + Wake::Thread(thread) => thread.unpark(), + } + } + + do_wake + } +} + +enum Wake { + Waker(Waker), + #[cfg(feature = "blocking")] + Thread(Thread), +} + +impl Wake { + fn waker_eq(&self, waker: &Waker) -> bool { + match self { + Self::Waker(waker_) => waker_.will_wake(waker), + #[cfg(feature = "blocking")] + _ => false, + } + } +} + +// note: this test should take about 2 minutes to run! +#[test] +#[cfg(feature = "blocking")] +fn test_wait_list_blocking() { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + use std::thread; + use std::time::{Duration, Instant}; + + const NUM_THREADS: u64 = 200; + + let list = Arc::new(WaitList::new()); + let mut threads = Vec::new(); + + // create an arbitrary pattern of deadlines; some of these may elapse, others may not + // the ultimate goal of this test is to make sure that no threads _deadlock_ or segfault + for i in 1..NUM_THREADS { + let ms = i + i * 25 % 100; + + let deadline = (i < 100).then(|| Instant::now() + Duration::from_millis(ms)); + + let list = Arc::new(list.clone()); + let thread = Arc::new(AtomicBool::new(false)); + + threads.push((thread.clone(), deadline)); + + thread::spawn(move || { + list.wait().block_on(deadline); + thread.store(true, Ordering::Release); + }); + } + + // + for _ in 1..NUM_THREADS { + thread::sleep(Duration::from_millis(5)); + list.wake_one(); + } + + // wait enough time for all timeouts to elapse + thread::sleep(Duration::from_secs(60)); + + for (i, (thread, deadline)) in threads.iter().enumerate() { + assert!( + thread.load(Ordering::Acquire), + "thread {} did not exit; deadline: {:?}", + i, + deadline + ); + } +} + +// #[cfg(all(test, feature = "async"))] +// mod test_async { +// use super::WaitList; +// +// #[cfg(feature = "tokio")] +// +// async fn test_waiter_list() { +// use futures::future::{join_all, FutureExt}; +// use futures::pin_mut; +// use std::sync::Arc; +// use std::time::Duration; +// +// let list = Arc::new(WaitList::new()); +// let mut tasks = Vec::new(); +// +// for _ in 0..1000 { +// let list = list.clone(); +// +// tasks.push(spawn(async move { +// list.wait().await; +// +// list.wait().await; +// })); +// } +// +// let waker = async { +// loop { +// list.wake_one(); +// yield_now().await; +// } +// } +// .fuse(); +// +// let timeout = timeout(Duration::from_secs(10), join_all(tasks)).fuse(); +// +// pin_mut!(waker); +// pin_mut!(timeout); +// +// futures::select_biased!( +// res = timeout => res.expect("all tasks should have exited by now"), +// _ = waker => unreachable!("waker shouldn't have quit"), +// ); +// } +// } +// +// // N.B. test will run forever +// #[test] +// #[ignore] +// fn test_waiter_list_forever() { +// use async_std::{ +// future::{timeout, Future}, +// task, +// }; +// use futures::future::poll_fn; +// use futures::pin_mut; +// use futures::stream::{FuturesUnordered, StreamExt}; +// use std::sync::Arc; +// use std::thread; +// use std::time::Duration; +// +// let list = Arc::new(WaitList::new()); +// +// let list_ = list.clone(); +// task::spawn(async move { +// let mut unordered = FuturesUnordered::new(); +// +// loop { +// unordered.push(WaitList::wait(&list_)); +// let _ = timeout(Duration::from_millis(50), unordered.next()).await; +// } +// }); +// +// let list_ = list.clone(); +// task::spawn(poll_fn::<(), _>(move |cx| { +// let yielder = task::yield_now(); +// pin_mut!(yielder); +// let _ = yielder.poll(cx); +// +// let park = WaitList::wait(&list_); +// pin_mut!(park); +// let _ = park.poll(cx); +// +// Poll::Pending +// })); +// +// for num in (0..5).cycle() { +// for _ in 0..num { +// list.wake_one(); +// } +// +// thread::sleep(Duration::from_millis(50)); +// } +// }