diff --git a/Cargo.toml b/Cargo.toml index 8b8060ee..b01f4f30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,9 +37,9 @@ tracing = "0.1" # optional dependencies ## rustls arc-swap = { version = "1", optional = true } -rustls = { version = "0.21", features = ["dangerous_configuration"], optional = true } +rustls = { version = "0.22.1", optional = true } rustls-pemfile = { version = "2.0.0", optional = true } -tokio-rustls = { version = "0.24", optional = true } +tokio-rustls = { version = "0.25.0", optional = true } ## openssl openssl = { version = "0.10", optional = true } @@ -47,7 +47,7 @@ tokio-openssl = { version = "0.6", optional = true } [dev-dependencies] serial_test = "2.0" -axum = "0.7" +axum = "0.7.1" hyper = { version = "1.0.1", features = ["full"] } tokio = { version = "1", features = ["full"] } tower = { version = "0.4", features = ["util"] } diff --git a/src/tls_rustls/mod.rs b/src/tls_rustls/mod.rs index d4422f0a..4235a312 100644 --- a/src/tls_rustls/mod.rs +++ b/src/tls_rustls/mod.rs @@ -33,8 +33,10 @@ use crate::{ server::{io_other, Server}, }; use arc_swap::ArcSwap; -use rustls::{Certificate, PrivateKey, ServerConfig}; -use rustls_pemfile::Item; +use rustls::{ + pki_types::{CertificateDer, PrivateKeyDer}, + ServerConfig, +}; use std::time::Duration; use std::{fmt, io, net::SocketAddr, path::Path, sync::Arc}; use tokio::{ @@ -172,10 +174,8 @@ impl RustlsConfig { /// The certificate must be DER-encoded X.509. /// /// The private key must be DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format. - pub async fn from_der(cert: Vec>, key: Vec) -> io::Result { - let server_config = spawn_blocking(|| config_from_der(cert, key)) - .await - .unwrap()?; + pub async fn from_der(cert: Vec>, key: PrivateKeyDer<'static>) -> io::Result { + let server_config = config_from_der(cert, key)?; let inner = Arc::new(ArcSwap::from_pointee(server_config)); Ok(Self { inner }) @@ -218,10 +218,12 @@ impl RustlsConfig { /// The certificate must be DER-encoded X.509. /// /// The private key must be DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format. - pub async fn reload_from_der(&self, cert: Vec>, key: Vec) -> io::Result<()> { - let server_config = spawn_blocking(|| config_from_der(cert, key)) - .await - .unwrap()?; + pub async fn reload_from_der( + &self, + cert: Vec>, + key: PrivateKeyDer<'static>, + ) -> io::Result<()> { + let server_config = config_from_der(cert, key)?; let inner = Arc::new(server_config); self.inner.store(inner); @@ -278,12 +280,10 @@ impl fmt::Debug for RustlsConfig { } } -fn config_from_der(cert: Vec>, key: Vec) -> io::Result { - let cert = cert.into_iter().map(Certificate).collect(); - let key = PrivateKey(key); +fn config_from_der(cert: Vec>, key: PrivateKeyDer<'static>) -> io::Result { + let cert = cert.into_iter().map(CertificateDer::from).collect(); let mut config = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() .with_single_cert(cert, key) .map_err(io_other)?; @@ -295,24 +295,13 @@ fn config_from_der(cert: Vec>, key: Vec) -> io::Result fn config_from_pem(cert: Vec, key: Vec) -> io::Result { let cert = rustls_pemfile::certs(&mut cert.as_ref()) - .map(|it| it.map(|it| it.to_vec())) + .map(|cert| cert.map(|cert| cert.as_ref().to_vec())) .collect::, _>>()?; - // Check the entire PEM file for the key in case it is not first section - let mut key_vec: Vec> = rustls_pemfile::read_all(&mut key.as_ref()) - .filter_map(|i| match i.ok()? { - Item::Sec1Key(key) => Some(key.secret_sec1_der().to_vec()), - Item::Pkcs1Key(key) => Some(key.secret_pkcs1_der().to_vec().into()), - Item::Pkcs8Key(key) => Some(key.secret_pkcs8_der().to_vec().into()), - _ => None, - }) - .collect(); - - // Make sure file contains only one key - if key_vec.len() != 1 { - return Err(io_other("private key format not supported")); - } + // Use the first private key found. + let key = rustls_pemfile::private_key(&mut key.as_ref())? + .ok_or(io_other("private key format not found"))?; - config_from_der(cert, key_vec.pop().unwrap()) + config_from_der(cert, key) } async fn config_from_pem_file( @@ -330,21 +319,12 @@ async fn config_from_pem_chain_file( chain: impl AsRef, ) -> io::Result { let cert = tokio::fs::read(cert.as_ref()).await?; - let cert = rustls_pemfile::certs(&mut cert.as_ref()) - .map(|it| it.map(|it| rustls::Certificate(it.to_vec()))) - .collect::, _>>()?; + let cert = rustls_pemfile::certs(&mut cert.as_ref()).collect::, _>>()?; let key = tokio::fs::read(chain.as_ref()).await?; - let key_cert: rustls::PrivateKey = match rustls_pemfile::read_one(&mut key.as_ref())? - .ok_or_else(|| io_other("could not parse pem file"))? - { - Item::Pkcs8Key(key) => Ok(rustls::PrivateKey(key.secret_pkcs8_der().to_vec().into())), - x => Err(io_other(format!( - "invalid certificate format, received: {x:?}" - ))), - }?; + let key_cert = rustls_pemfile::private_key(&mut key.as_ref())? + .ok_or_else(|| io_other("could not parse pem file"))?; ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() .with_single_cert(cert, key_cert) .map_err(|_| io_other("invalid certificate")) @@ -362,17 +342,10 @@ mod tests { use http_body_util::BodyExt; use hyper::client::conn::http1::{handshake, SendRequest}; use hyper_util::rt::TokioIo; - use rustls::{ - client::{ServerCertVerified, ServerCertVerifier}, - Certificate, ClientConfig, ServerName, - }; - use std::{ - convert::TryFrom, - io, - net::SocketAddr, - sync::Arc, - time::{Duration, SystemTime}, - }; + use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; + use rustls::pki_types::{CertificateDer, ServerName}; + use rustls::{ClientConfig, SignatureScheme}; + use std::{io, net::SocketAddr, sync::Arc, time::Duration}; use tokio::time::sleep; use tokio::{net::TcpStream, task::JoinHandle, time::timeout}; use tokio_rustls::TlsConnector; @@ -552,13 +525,15 @@ mod tests { (handle, server_task, addr) } - async fn get_first_cert(addr: SocketAddr) -> Certificate { + async fn get_first_cert(addr: SocketAddr) -> CertificateDer<'static> { let stream = TcpStream::connect(addr).await.unwrap(); let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap(); let (_io, client_connection) = tls_stream.into_inner(); - client_connection.peer_certificates().unwrap()[0].clone() + client_connection.peer_certificates().unwrap()[0] + .clone() + .into_owned() } async fn connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>) { @@ -586,24 +561,50 @@ mod tests { } fn tls_connector() -> TlsConnector { + #[derive(Debug)] struct NoVerify; impl ServerCertVerifier for NoVerify { fn verify_server_cert( &self, - _end_entity: &Certificate, - _intermediates: &[Certificate], - _server_name: &ServerName, - _scts: &mut dyn Iterator, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, _ocsp_response: &[u8], - _now: SystemTime, + _now: rustls::pki_types::UnixTime, ) -> Result { Ok(ServerCertVerified::assertion()) } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + vec![ + SignatureScheme::RSA_PKCS1_SHA256, + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::ECDSA_NISTP256_SHA256, + ] + } } let mut client_config = ClientConfig::builder() - .with_safe_defaults() + .dangerous() .with_custom_certificate_verifier(Arc::new(NoVerify)) .with_no_client_auth(); @@ -612,7 +613,7 @@ mod tests { TlsConnector::from(Arc::new(client_config)) } - fn dns_name() -> ServerName { + fn dns_name() -> ServerName<'static> { ServerName::try_from("localhost").unwrap() } }