Skip to content

Commit

Permalink
Do not require db url for prepare (#2882)
Browse files Browse the repository at this point in the history
* feat(cli): do not require db url

* chore: remove unused import

* fix(cli): do not always pass DATABASE_URL

* fix(cli): check db when DATABASE_URL is provided
  • Loading branch information
tamasfe authored Nov 21, 2023
1 parent 979a55d commit d3a28d4
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 25 deletions.
14 changes: 6 additions & 8 deletions sqlx-cli/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ pub async fn create(connect_opts: &ConnectOpts) -> anyhow::Result<()> {
std::sync::atomic::Ordering::Release,
);

Any::create_database(&connect_opts.database_url).await?;
Any::create_database(connect_opts.required_db_url()?).await?;
}

Ok(())
}

pub async fn drop(connect_opts: &ConnectOpts, confirm: bool) -> anyhow::Result<()> {
if confirm && !ask_to_continue(connect_opts) {
if confirm && !ask_to_continue_drop(connect_opts.required_db_url()?) {
return Ok(());
}

Expand All @@ -33,7 +33,7 @@ pub async fn drop(connect_opts: &ConnectOpts, confirm: bool) -> anyhow::Result<(
let exists = crate::retry_connect_errors(connect_opts, Any::database_exists).await?;

if exists {
Any::drop_database(&connect_opts.database_url).await?;
Any::drop_database(connect_opts.required_db_url()?).await?;
}

Ok(())
Expand All @@ -53,12 +53,10 @@ pub async fn setup(migration_source: &str, connect_opts: &ConnectOpts) -> anyhow
migrate::run(migration_source, connect_opts, false, false, None).await
}

fn ask_to_continue(connect_opts: &ConnectOpts) -> bool {
fn ask_to_continue_drop(db_url: &str) -> bool {
loop {
let r: Result<String, ReadlineError> = prompt(format!(
"Drop database at {}? (y/n)",
style(&connect_opts.database_url).cyan()
));
let r: Result<String, ReadlineError> =
prompt(format!("Drop database at {}? (y/n)", style(db_url).cyan()));
match r {
Ok(response) => {
if response == "n" || response == "N" {
Expand Down
12 changes: 7 additions & 5 deletions sqlx-cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub async fn run(opt: Opt) -> Result<()> {
}

/// Attempt to connect to the database server, retrying up to `ops.connect_timeout`.
async fn connect(opts: &ConnectOpts) -> sqlx::Result<AnyConnection> {
async fn connect(opts: &ConnectOpts) -> anyhow::Result<AnyConnection> {
retry_connect_errors(opts, AnyConnection::connect).await
}

Expand All @@ -112,32 +112,34 @@ async fn connect(opts: &ConnectOpts) -> sqlx::Result<AnyConnection> {
async fn retry_connect_errors<'a, F, Fut, T>(
opts: &'a ConnectOpts,
mut connect: F,
) -> sqlx::Result<T>
) -> anyhow::Result<T>
where
F: FnMut(&'a str) -> Fut,
Fut: Future<Output = sqlx::Result<T>> + 'a,
{
sqlx::any::install_default_drivers();

let db_url = opts.required_db_url()?;

backoff::future::retry(
backoff::ExponentialBackoffBuilder::new()
.with_max_elapsed_time(Some(Duration::from_secs(opts.connect_timeout)))
.build(),
|| {
connect(&opts.database_url).map_err(|e| -> backoff::Error<sqlx::Error> {
connect(db_url).map_err(|e| -> backoff::Error<anyhow::Error> {
match e {
sqlx::Error::Io(ref ioe) => match ioe.kind() {
io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionReset
| io::ErrorKind::ConnectionAborted => {
return backoff::Error::transient(e);
return backoff::Error::transient(e.into());
}
_ => (),
},
_ => (),
}

backoff::Error::permanent(e)
backoff::Error::permanent(e.into())
})
},
)
Expand Down
16 changes: 14 additions & 2 deletions sqlx-cli/src/opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,9 @@ impl Deref for Source {
/// Argument for the database URL.
#[derive(Args, Debug)]
pub struct ConnectOpts {
/// Location of the DB, by default will be read from the DATABASE_URL env var
/// Location of the DB, by default will be read from the DATABASE_URL env var or `.env` files.
#[clap(long, short = 'D', env)]
pub database_url: String,
pub database_url: Option<String>,

/// The maximum time, in seconds, to try connecting to the database server before
/// returning an error.
Expand All @@ -251,6 +251,18 @@ pub struct ConnectOpts {
pub sqlite_create_db_wal: bool,
}

impl ConnectOpts {
/// Require a database URL to be provided, otherwise
/// return an error.
pub fn required_db_url(&self) -> anyhow::Result<&str> {
self.database_url.as_deref().ok_or_else(
|| anyhow::anyhow!(
"the `--database-url` option the or `DATABASE_URL` environment variable must be provided"
)
)
}
}

/// Argument for automatic confirmation.
#[derive(Args, Copy, Clone, Debug)]
pub struct Confirmation {
Expand Down
26 changes: 16 additions & 10 deletions sqlx-cli/src/prepare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use std::process::Command;

use anyhow::{bail, Context};
use console::style;

use sqlx::Connection;

use crate::metadata::{manifest_dir, Metadata};
Expand Down Expand Up @@ -64,7 +63,9 @@ hint: This command only works in the manifest directory of a Cargo package or wo
}

async fn prepare(ctx: &PrepareCtx) -> anyhow::Result<()> {
check_backend(&ctx.connect_opts).await?;
if ctx.connect_opts.database_url.is_some() {
check_backend(&ctx.connect_opts).await?;
}

let prepare_dir = ctx.prepare_dir()?;
run_prepare_step(ctx, &prepare_dir)?;
Expand All @@ -90,7 +91,9 @@ async fn prepare(ctx: &PrepareCtx) -> anyhow::Result<()> {
}

async fn prepare_check(ctx: &PrepareCtx) -> anyhow::Result<()> {
let _ = check_backend(&ctx.connect_opts).await?;
if ctx.connect_opts.database_url.is_some() {
check_backend(&ctx.connect_opts).await?;
}

// Re-generate and store the queries in a separate directory from both the prepared
// queries and the ones generated by `cargo check`, to avoid conflicts.
Expand Down Expand Up @@ -171,10 +174,14 @@ fn run_prepare_step(ctx: &PrepareCtx, cache_dir: &Path) -> anyhow::Result<()> {
check_command
.arg("check")
.args(&ctx.cargo_args)
.env("DATABASE_URL", &ctx.connect_opts.database_url)
.env("SQLX_TMP", tmp_dir)
.env("SQLX_OFFLINE", "false")
.env("SQLX_OFFLINE_DIR", cache_dir);

if let Some(database_url) = &ctx.connect_opts.database_url {
check_command.env("DATABASE_URL", database_url);
}

// `cargo check` recompiles on changed rust flags which can be set either via the env var
// or through the `rustflags` field in `$CARGO_HOME/config` when the env var isn't set.
// Because of this we only pass in `$RUSTFLAGS` when present.
Expand Down Expand Up @@ -319,12 +326,6 @@ fn minimal_project_recompile_action(metadata: &Metadata) -> ProjectRecompileActi
}
}

/// Ensure the database server is available.
async fn check_backend(opts: &ConnectOpts) -> anyhow::Result<()> {
crate::connect(opts).await?.close().await?;
Ok(())
}

/// Find all `query-*.json` files in a directory.
fn glob_query_files(path: impl AsRef<Path>) -> anyhow::Result<Vec<PathBuf>> {
let path = path.as_ref();
Expand All @@ -347,6 +348,11 @@ fn load_json_file(path: impl AsRef<Path>) -> anyhow::Result<serde_json::Value> {
Ok(serde_json::from_slice(&file_bytes)?)
}

async fn check_backend(opts: &ConnectOpts) -> anyhow::Result<()> {
crate::connect(opts).await?.close().await?;
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit d3a28d4

Please sign in to comment.