Skip to content

Commit

Permalink
ngrok: push server addr validation to configure time
Browse files Browse the repository at this point in the history
  • Loading branch information
jrobsonchase committed Aug 23, 2023
1 parent 1e3a5ad commit 1bcb220
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 32 deletions.
1 change: 1 addition & 0 deletions ngrok/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ parking_lot = "0.12.1"
once_cell = "1.17.1"
hostname = "0.3.1"
regex = "1.7.3"
http = "0.2.9"

[target.'cfg(windows)'.dependencies]
windows-sys = { version = "0.45.0", features = ["Win32_Foundation"] }
Expand Down
2 changes: 1 addition & 1 deletion ngrok/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub mod config {

mod headers;
mod http;
pub use http::*;
pub use self::http::*;
mod labeled;
pub use labeled::*;
mod oauth;
Expand Down
74 changes: 43 additions & 31 deletions ngrok/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::{
env,
future::Future,
io,
num::ParseIntError,
sync::{
atomic::{
AtomicBool,
Expand All @@ -18,17 +17,15 @@ use std::{
};

use arc_swap::ArcSwap;
use async_rustls::rustls::{
self,
client::InvalidDnsNameError,
};
use async_rustls::rustls::{self,};
use async_trait::async_trait;
use bytes::Bytes;
use futures::{
future,
prelude::*,
FutureExt,
};
use http::Uri;
use muxado::heartbeat::HeartbeatConfig;
pub use muxado::heartbeat::HeartbeatHandler;
use once_cell::sync::OnceCell;
Expand Down Expand Up @@ -168,7 +165,8 @@ pub trait Connector: Sync + Send + 'static {
/// returned from the [default_connect] function.
async fn connect(
&self,
addr: String,
host: String,
port: u16,
tls_config: Arc<rustls::ClientConfig>,
err: Option<AcceptError>,
) -> Result<Box<dyn IoStream>, ConnectError>;
Expand All @@ -177,16 +175,17 @@ pub trait Connector: Sync + Send + 'static {
#[async_trait]
impl<F, U> Connector for F
where
F: Fn(String, Arc<rustls::ClientConfig>, Option<AcceptError>) -> U + Send + Sync + 'static,
F: Fn(String, u16, Arc<rustls::ClientConfig>, Option<AcceptError>) -> U + Send + Sync + 'static,
U: Future<Output = Result<Box<dyn IoStream>, ConnectError>> + Send,
{
async fn connect(
&self,
addr: String,
host: String,
port: u16,
tls_config: Arc<rustls::ClientConfig>,
err: Option<AcceptError>,
) -> Result<Box<dyn IoStream>, ConnectError> {
self(addr, tls_config, err).await
self(host, port, tls_config, err).await
}
}

Expand All @@ -198,24 +197,21 @@ where
/// Discards any errors during reconnect, allowing attempts to recur
/// indefinitely.
pub async fn default_connect(
addr: String,
host: String,
port: u16,
tls_config: Arc<rustls::ClientConfig>,
_: Option<AcceptError>,
) -> Result<Box<dyn IoStream>, ConnectError> {
let mut split = addr.split(':');
let host = split.next().unwrap();
let port = split
.next()
.map(str::parse::<u16>)
.transpose()?
.unwrap_or(443);
let conn = tokio::net::TcpStream::connect(&(host, port))
let stream = tokio::net::TcpStream::connect(&(host.as_str(), port))
.await
.map_err(ConnectError::Tcp)?
.compat();

let domain = rustls::ServerName::try_from(host.as_str())
.expect("host should have been validated by SessionBuilder::server_addr");

let tls_conn = async_rustls::TlsConnector::from(tls_config)
.connect(rustls::ServerName::try_from(host)?, conn)
.connect(domain, stream)
.await
.map_err(ConnectError::Tls)?;
Ok(Box::new(tls_conn.compat()) as Box<dyn IoStream>)
Expand All @@ -232,7 +228,8 @@ pub struct SessionBuilder {
heartbeat_interval: Option<i64>,
heartbeat_tolerance: Option<i64>,
heartbeat_handler: Option<Arc<dyn HeartbeatHandler>>,
server_addr: String,
server_host: String,
server_port: u16,
ca_cert: Option<bytes::Bytes>,
tls_config: Option<rustls::ClientConfig>,
connector: Arc<dyn Connector>,
Expand All @@ -245,12 +242,6 @@ pub struct SessionBuilder {
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ConnectError {
/// The builder specified an invalid port.
#[error("invalid server port")]
InvalidServerPort(#[from] ParseIntError),
/// The builder specified an invalid server name.
#[error("invalid server name")]
InvalidServerName(#[from] InvalidDnsNameError),
/// An error occurred when establishing a TCP connection to the ngrok
/// server.
#[error("failed to establish tcp connection")]
Expand Down Expand Up @@ -311,6 +302,11 @@ pub struct InvalidHeartbeatInterval(u128);
#[error("invalid heartbeat tolerance: {0}")]
pub struct InvalidHeartbeatTolerance(u128);

/// The builder provided an invalid server address
#[derive(Error, Debug, Clone)]
#[error("invalid server address: {0}")]
pub struct InvalidServerAddr(String);

impl Default for SessionBuilder {
fn default() -> Self {
SessionBuilder {
Expand All @@ -322,7 +318,8 @@ impl Default for SessionBuilder {
heartbeat_interval: None,
heartbeat_tolerance: None,
heartbeat_handler: None,
server_addr: "tunnel.ngrok.com:443".into(),
server_host: "connect.ngrok-agent.com".into(),
server_port: 443,
ca_cert: None,
tls_config: None,
connector: Arc::new(default_connect),
Expand Down Expand Up @@ -415,9 +412,23 @@ impl SessionBuilder {
/// See the [server_addr parameter in the ngrok docs] for additional details.
///
/// [server_addr parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#server_addr
pub fn server_addr(&mut self, addr: impl Into<String>) -> &mut Self {
self.server_addr = addr.into();
self
pub fn server_addr(&mut self, addr: impl Into<String>) -> Result<&mut Self, InvalidServerAddr> {
let addr = addr.into();
let server_uri: Uri = format!("http://{addr}")
.parse()
.map_err(|_| InvalidServerAddr(addr.clone()))?;

self.server_host = server_uri
.host()
.map(String::from)
.ok_or_else(|| InvalidServerAddr(addr.clone()))?;

rustls::ServerName::try_from(self.server_host.as_str())
.map_err(|_| InvalidServerAddr(addr.clone()))?;

self.server_port = server_uri.port_u16().unwrap_or(443);

Ok(self)
}

/// Sets the default certificate in PEM format to validate ngrok Session TLS connections.
Expand Down Expand Up @@ -586,7 +597,8 @@ impl SessionBuilder {
let conn = self
.connector
.connect(
self.server_addr.clone(),
self.server_host.clone(),
self.server_port,
Arc::new(self.get_or_create_tls_config()),
err.into(),
)
Expand Down

0 comments on commit 1bcb220

Please sign in to comment.