Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add client SSL authentication using key-file for Postgres, MySQL and MariaDB #1850

Merged
merged 22 commits into from
Feb 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions .github/workflows/sqlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,26 @@ jobs:
# but `PgLTree` should just fall back to text format
RUSTFLAGS: --cfg postgres_${{ matrix.postgres }}

# client SSL authentication

- run: |
docker stop postgres_${{ matrix.postgres }}
docker-compose -f tests/docker-compose.yml run -d -p 5432:5432 --name postgres_${{ matrix.postgres }}_client_ssl postgres_${{ matrix.postgres }}_client_ssl
docker exec postgres_${{ matrix.postgres }}_client_ssl bash -c "until pg_isready; do sleep 1; done"

- uses: actions-rs/cargo@v1
if: matrix.tls != 'none'
with:
command: test
args: >
--no-default-features
--features any,postgres,macros,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }}
env:
DATABASE_URL: postgres://postgres@localhost:5432/sqlx?sslmode=verify-ca&sslrootcert=.%2Ftests%2Fcerts%2Fca.crt&sslkey=.%2Ftests%2Fkeys%2Fclient.key&sslcert=.%2Ftests%2Fcerts%2Fclient.crt
# FIXME: needed to disable `ltree` tests in Postgres 9.6
# but `PgLTree` should just fall back to text format
RUSTFLAGS: --cfg postgres_${{ matrix.postgres }}_client_ssl
ThibsG marked this conversation as resolved.
Show resolved Hide resolved

mysql:
name: MySQL
runs-on: ubuntu-20.04
Expand Down Expand Up @@ -260,7 +280,7 @@ jobs:
args: >
--features mysql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }}

- run: docker-compose -f tests/docker-compose.yml run -d -p 3306:3306 mysql_${{ matrix.mysql }}
- run: docker-compose -f tests/docker-compose.yml run -d -p 3306:3306 --name mysql_${{ matrix.mysql }} mysql_${{ matrix.mysql }}
- run: sleep 60

- uses: actions-rs/cargo@v1
Expand All @@ -285,6 +305,25 @@ jobs:
DATABASE_URL: mysql://root:password@localhost:3306/sqlx
RUSTFLAGS: --cfg mysql_${{ matrix.mysql }}

# client SSL authentication

- run: |
docker stop mysql_${{ matrix.mysql }}
docker-compose -f tests/docker-compose.yml run -d -p 3306:3306 --name mysql_${{ matrix.mysql }}_client_ssl mysql_${{ matrix.mysql }}_client_ssl
sleep 60

# MySQL 5.7 supports TLS but not TLSv1.3 as required by RusTLS.
- uses: actions-rs/cargo@v1
if: ${{ !(matrix.mysql == '5_7' && matrix.tls == 'rustls') && matrix.tls != 'none' }}
with:
command: test
args: >
--no-default-features
--features any,mysql,macros,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }}
env:
DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt
RUSTFLAGS: --cfg mysql_${{ matrix.mysql }}

mariadb:
name: MariaDB
runs-on: ubuntu-20.04
Expand Down Expand Up @@ -313,7 +352,7 @@ jobs:
args: >
--features mysql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }}

- run: docker-compose -f tests/docker-compose.yml run -d -p 3306:3306 mariadb_${{ matrix.mariadb }}
- run: docker-compose -f tests/docker-compose.yml run -d -p 3306:3306 --name mariadb_${{ matrix.mariadb }} mariadb_${{ matrix.mariadb }}
- run: sleep 30

- uses: actions-rs/cargo@v1
Expand All @@ -325,3 +364,21 @@ jobs:
env:
DATABASE_URL: mysql://root:password@localhost:3306/sqlx
RUSTFLAGS: --cfg mariadb_${{ matrix.mariadb }}

# client SSL authentication

- run: |
docker stop mariadb_${{ matrix.mariadb }}
docker-compose -f tests/docker-compose.yml run -d -p 3306:3306 --name mariadb_${{ matrix.mariadb }}_client_ssl mariadb_${{ matrix.mariadb }}_client_ssl
sleep 60

- uses: actions-rs/cargo@v1
if: matrix.tls != 'none'
with:
command: test
args: >
--no-default-features
--features any,mysql,macros,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }}
env:
DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt
RUSTFLAGS: --cfg mariadb_${{ matrix.mariadb }}
2 changes: 2 additions & 0 deletions sqlx-core/src/net/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ pub struct TlsConfig<'a> {
pub accept_invalid_hostnames: bool,
pub hostname: &'a str,
pub root_cert_path: Option<&'a CertificateInput>,
pub client_cert_path: Option<&'a CertificateInput>,
pub client_key_path: Option<&'a CertificateInput>,
}

pub async fn handshake<S, Ws>(
Expand Down
10 changes: 9 additions & 1 deletion sqlx-core/src/net/tls/tls_native_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::net::tls::TlsConfig;
use crate::net::Socket;
use crate::Error;

use native_tls::HandshakeError;
use native_tls::{HandshakeError, Identity};
use std::task::{Context, Poll};

pub struct NativeTlsSocket<S: Socket> {
Expand Down Expand Up @@ -53,6 +53,14 @@ pub async fn handshake<S: Socket>(
builder.add_root_certificate(native_tls::Certificate::from_pem(&data).map_err(Error::tls)?);
}

// authentication using user's key-file and its associated certificate
if let (Some(cert_path), Some(key_path)) = (config.client_cert_path, config.client_key_path) {
let cert_path = cert_path.data().await?;
let key_path = key_path.data().await?;
let identity = Identity::from_pkcs8(&cert_path, &key_path).map_err(Error::tls)?;
builder.identity(identity);
}

let connector = builder.build().map_err(Error::tls)?;

let mut mid_handshake = match connector.connect(config.hostname, StdSocket::new(socket)) {
Expand Down
82 changes: 72 additions & 10 deletions sqlx-core/src/net/tls/tls_rustls.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use futures_util::future;
use std::io;
use std::io::{Cursor, Read, Write};
use rustls::{Certificate, PrivateKey};
use std::io::{self, BufReader, Cursor, Read, Write};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::SystemTime;
Expand All @@ -13,7 +13,7 @@ use rustls::{
use crate::error::Error;
use crate::io::ReadBuf;
use crate::net::tls::util::StdSocket;
use crate::net::tls::TlsConfig;
use crate::net::tls::{CertificateInput, TlsConfig};
use crate::net::Socket;

pub struct RustlsSocket<S: Socket> {
Expand Down Expand Up @@ -48,7 +48,7 @@ impl<S: Socket> Socket for RustlsSocket<S> {
match self.state.writer().write(buf) {
// Returns a zero-length write when the buffer is full.
Ok(0) => Err(io::ErrorKind::WouldBlock.into()),
other => return other,
other => other,
}
}

Expand Down Expand Up @@ -81,10 +81,32 @@ where
{
let config = ClientConfig::builder().with_safe_defaults();

// authentication using user's key and its associated certificate
let user_auth = match (tls_config.client_cert_path, tls_config.client_key_path) {
(Some(cert_path), Some(key_path)) => {
let cert_chain = certs_from_pem(cert_path.data().await?)?;
let key_der = private_key_from_pem(key_path.data().await?)?;
Some((cert_chain, key_der))
}
(None, None) => None,
(_, _) => {
return Err(Error::Configuration(
"user auth key and certs must be given together".into(),
))
}
};

let config = if tls_config.accept_invalid_certs {
config
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
.with_no_client_auth()
if let Some(user_auth) = user_auth {
config
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
.with_single_cert(user_auth.0, user_auth.1)
.map_err(Error::tls)?
} else {
config
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
.with_no_client_auth()
}
} else {
let mut cert_store = RootCertStore::empty();
cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
Expand All @@ -100,7 +122,7 @@ where
let mut cursor = Cursor::new(data);

for cert in rustls_pemfile::certs(&mut cursor)
.map_err(|_| Error::Tls(format!("Invalid certificate {}", ca).into()))?
.map_err(|_| Error::Tls(format!("Invalid certificate {ca}").into()))?
{
cert_store
.add(&rustls::Certificate(cert))
Expand All @@ -111,9 +133,21 @@ where
if tls_config.accept_invalid_hostnames {
let verifier = WebPkiVerifier::new(cert_store, None);

if let Some(user_auth) = user_auth {
config
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
.with_single_cert(user_auth.0, user_auth.1)
.map_err(Error::tls)?
} else {
config
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
.with_no_client_auth()
}
} else if let Some(user_auth) = user_auth {
config
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
.with_no_client_auth()
.with_root_certificates(cert_store)
.with_single_cert(user_auth.0, user_auth.1)
.map_err(Error::tls)?
} else {
config
.with_root_certificates(cert_store)
Expand All @@ -135,6 +169,34 @@ where
Ok(socket)
}

fn certs_from_pem(pem: Vec<u8>) -> Result<Vec<rustls::Certificate>, Error> {
let cur = Cursor::new(pem);
let mut reader = BufReader::new(cur);
rustls_pemfile::certs(&mut reader)?
.into_iter()
.map(|v| Ok(rustls::Certificate(v)))
.collect()
}

fn private_key_from_pem(pem: Vec<u8>) -> Result<rustls::PrivateKey, Error> {
let cur = Cursor::new(pem);
let mut reader = BufReader::new(cur);

loop {
match rustls_pemfile::read_one(&mut reader)? {
Some(
rustls_pemfile::Item::RSAKey(key)
| rustls_pemfile::Item::PKCS8Key(key)
| rustls_pemfile::Item::ECKey(key),
) => return Ok(rustls::PrivateKey(key)),
None => break,
_ => {}
}
}

Err(Error::Configuration("no keys found pem file".into()))
}

struct DummyTlsVerifier;

impl ServerCertVerifier for DummyTlsVerifier {
Expand Down
2 changes: 2 additions & 0 deletions sqlx-mysql/src/connection/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ pub(super) async fn maybe_upgrade<S: Socket>(
accept_invalid_hostnames: !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity),
hostname: &options.host,
root_cert_path: options.ssl_ca.as_ref(),
client_cert_path: options.ssl_client_cert.as_ref(),
client_key_path: options.ssl_client_key.as_ref(),
};

// Request TLS upgrade
Expand Down
34 changes: 34 additions & 0 deletions sqlx-mysql/src/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ pub struct MySqlConnectOptions {
pub(crate) database: Option<String>,
pub(crate) ssl_mode: MySqlSslMode,
pub(crate) ssl_ca: Option<CertificateInput>,
pub(crate) ssl_client_cert: Option<CertificateInput>,
pub(crate) ssl_client_key: Option<CertificateInput>,
pub(crate) statement_cache_capacity: usize,
pub(crate) charset: String,
pub(crate) collation: Option<String>,
Expand Down Expand Up @@ -88,6 +90,8 @@ impl MySqlConnectOptions {
collation: None,
ssl_mode: MySqlSslMode::Preferred,
ssl_ca: None,
ssl_client_cert: None,
ssl_client_key: None,
statement_cache_capacity: 100,
log_settings: Default::default(),
pipes_as_concat: true,
Expand Down Expand Up @@ -186,6 +190,36 @@ impl MySqlConnectOptions {
self
}

/// Sets the name of a file containing SSL client certificate.
///
/// # Example
///
/// ```rust
/// # use sqlx_core::mysql::{MySqlSslMode, MySqlConnectOptions};
/// let options = MySqlConnectOptions::new()
/// .ssl_mode(MySqlSslMode::VerifyCa)
/// .ssl_client_cert("path/to/client.crt");
/// ```
pub fn ssl_client_cert(mut self, cert: impl AsRef<Path>) -> Self {
self.ssl_client_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf()));
self
}

/// Sets the name of a file containing SSL client key.
///
/// # Example
///
/// ```rust
/// # use sqlx_core::mysql::{MySqlSslMode, MySqlConnectOptions};
/// let options = MySqlConnectOptions::new()
/// .ssl_mode(MySqlSslMode::VerifyCa)
/// .ssl_client_key("path/to/client.key");
/// ```
pub fn ssl_client_key(mut self, key: impl AsRef<Path>) -> Self {
self.ssl_client_key = Some(CertificateInput::File(key.as_ref().to_path_buf()));
self
}

/// Sets the capacity of the connection's statement cache in a number of stored
/// distinct statements. Caching is handled using LRU, meaning when the
/// amount of queries hits the defined limit, the oldest statement will get
Expand Down
8 changes: 6 additions & 2 deletions sqlx-mysql/src/options/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ impl MySqlConnectOptions {

for (key, value) in url.query_pairs().into_iter() {
match &*key {
"ssl-mode" => {
"sslmode" | "ssl-mode" => {
options = options.ssl_mode(value.parse().map_err(Error::config)?);
}

"ssl-ca" => {
"sslca" | "ssl-ca" => {
options = options.ssl_ca(&*value);
}

Expand All @@ -59,6 +59,10 @@ impl MySqlConnectOptions {
options = options.collation(&*value);
}

"sslcert" | "ssl-cert" => options = options.ssl_client_cert(&*value),

"sslkey" | "ssl-key" => options = options.ssl_client_key(&*value),

"statement-cache-capacity" => {
options =
options.statement_cache_capacity(value.parse().map_err(Error::config)?);
Expand Down
2 changes: 2 additions & 0 deletions sqlx-postgres/src/connection/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ async fn maybe_upgrade<S: Socket>(
accept_invalid_hostnames,
hostname: &options.host,
root_cert_path: options.ssl_root_cert.as_ref(),
client_cert_path: options.ssl_client_cert.as_ref(),
client_key_path: options.ssl_client_key.as_ref(),
};

tls::handshake(socket, config, SocketIntoBox).await
Expand Down
Loading