Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward proxy-protocol correctly to tls backends #112

Merged
merged 9 commits into from
Oct 3, 2023
4 changes: 3 additions & 1 deletion ngrok/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ngrok"
version = "0.14.0-pre.6"
version = "0.14.0-pre.7"
edition = "2021"
license = "MIT OR Apache-2.0"
description = "The ngrok agent SDK"
Expand Down Expand Up @@ -33,6 +33,8 @@ tokio-socks = "0.5.1"
hyper-proxy = { version = "0.9.1", default-features = false, features = ["rustls"] }
url = "2.4.0"
rustls-native-certs = "0.6.3"
proxy-protocol = "0.5.0"
pin-project = "1.1.3"

[target.'cfg(windows)'.dependencies]
windows-sys = { version = "0.45.0", features = ["Win32_Foundation"] }
Expand Down
10 changes: 7 additions & 3 deletions ngrok/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ use tokio::io::{
AsyncWrite,
};

use crate::internals::proto::{
EdgeType,
ProxyHeader,
use crate::{
config::ProxyProto,
internals::proto::{
EdgeType,
ProxyHeader,
},
};
/// A connection from an ngrok tunnel.
///
Expand All @@ -33,6 +36,7 @@ pub(crate) struct ConnInner {
pub(crate) struct Info {
pub(crate) header: ProxyHeader,
pub(crate) remote_addr: SocketAddr,
pub(crate) proxy_proto: ProxyProto,
}

impl ConnInfo for Info {
Expand Down
6 changes: 3 additions & 3 deletions ngrok/src/internals/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ rpc_req!(SrvInfo, SrvInfoResp, SRV_INFO_REQ);
/// to use with this tunnel.
///
/// [ProxyProto::None] disables PROXY protocol support.
#[derive(Debug, Copy, Clone, Default)]
#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
pub enum ProxyProto {
/// No PROXY protocol
#[default]
Expand Down Expand Up @@ -772,11 +772,11 @@ pub struct TlsEndpoint {

#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct TlsTermination {
#[serde(with = "base64bytes", skip_serializing_if = "is_default")]
#[serde(default, with = "base64bytes", skip_serializing_if = "is_default")]
pub cert: Vec<u8>,
#[serde(skip_serializing_if = "is_default", default)]
pub key: SecretBytes,
#[serde(with = "base64bytes", skip_serializing_if = "is_default")]
#[serde(default, with = "base64bytes", skip_serializing_if = "is_default")]
pub sealed_key: Vec<u8>,
}

Expand Down
2 changes: 2 additions & 0 deletions ngrok/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub mod config {
mod webhook_verification;
}

mod proxy_proto;

/// Types for working with the ngrok session.
pub mod session;
/// Types for working with ngrok tunnels.
Expand Down
93 changes: 93 additions & 0 deletions ngrok/src/online_tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
io,
io::prelude::*,
net::SocketAddr,
str::FromStr,
Expand All @@ -9,17 +10,26 @@ use std::{
},
Arc,
},
time::Duration,
};

use anyhow::{
anyhow,
Error,
};
use async_rustls::{
rustls,
rustls::{
ClientConfig,
RootCertStore,
},
};
use axum::{
extract::connect_info::Connected,
routing::get,
Router,
};
use bytes::Bytes;
use flate2::read::GzDecoder;
use futures::{
channel::oneshot,
Expand All @@ -32,7 +42,9 @@ use hyper::{
StatusCode,
Uri,
};
use once_cell::sync::Lazy;
use paste::paste;
use proxy_protocol::ProxyHeader;
use rand::{
distributions::Alphanumeric,
thread_rng,
Expand All @@ -43,14 +55,17 @@ use tokio::{
AsyncReadExt,
AsyncWriteExt,
},
net::TcpStream,
sync::mpsc,
test,
};
use tokio_tungstenite::{
connect_async,
tungstenite::Message,
};
use tokio_util::compat::*;
use tracing_test::traced_test;
use url::Url;

use crate::{
config::{
Expand Down Expand Up @@ -685,3 +700,81 @@ async fn session_tls_config() -> Result<(), Error> {

Ok(())
}

fn tls_client_config() -> Result<Arc<ClientConfig>, &'static io::Error> {
static CONFIG: Lazy<Result<Arc<ClientConfig>, io::Error>> = Lazy::new(|| {
let der_certs = rustls_native_certs::load_native_certs()?
.into_iter()
.map(|c| c.0)
.collect::<Vec<_>>();
let der_certs = der_certs.as_slice();
let mut root_store = RootCertStore::empty();
root_store.add_parsable_certificates(der_certs);
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(Arc::new(config))
});

Ok(CONFIG.as_ref()?.clone())
}

#[traced_test]
#[cfg_attr(not(feature = "paid-tests"), ignore)]
#[test]
async fn forward_proxy_protocol_tls() -> Result<(), Error> {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;

let sess = Session::builder().authtoken_from_env().connect().await?;
let forwarder = sess
.tls_endpoint()
.proxy_proto(ProxyProto::V2)
.termination(Bytes::default(), Bytes::default())
.listen_and_forward(format!("tls://{}", addr).parse()?)
.await?;

let tunnel_url: Url = forwarder.url().to_string().parse()?;

tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(500)).await;
let tunnel_conn = TcpStream::connect(format!(
"{}:{}",
tunnel_url.host_str().unwrap(),
tunnel_url.port().unwrap_or(443)
))
.await?;

let domain = rustls::ServerName::try_from(tunnel_url.host_str().unwrap())
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;

let mut tls_conn = async_rustls::TlsConnector::from(
tls_client_config().map_err(|e| io::Error::from(e.kind()))?,
)
.connect(domain, tunnel_conn.compat())
.await?
.compat();

tls_conn.write_all(b"Hello, world!").await
});

let (conn, _) = listener.accept().await?;

let mut proxy_conn = crate::proxy_proto::Stream::incoming(conn);
let proxy_header = proxy_conn
.proxy_header()
.await?
.unwrap()
.map(Clone::clone)
.unwrap();

match proxy_header {
ProxyHeader::Version2 { .. } => {}
_ => unreachable!("we configured v2"),
}

// TODO: actually accept the tls connection from the server side

Ok(())
}
Loading
Loading