diff --git a/ngrok/Cargo.toml b/ngrok/Cargo.toml index 03789a5..08dc093 100644 --- a/ngrok/Cargo.toml +++ b/ngrok/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ngrok" -version = "0.14.0-pre.6" +version = "0.14.0-pre.7" edition = "2021" license = "MIT OR Apache-2.0" description = "The ngrok agent SDK" @@ -33,6 +33,8 @@ tokio-socks = "0.5.1" hyper-proxy = { version = "0.9.1", default-features = false, features = ["rustls"] } url = "2.4.0" rustls-native-certs = "0.6.3" +proxy-protocol = "0.5.0" +pin-project = "1.1.3" [target.'cfg(windows)'.dependencies] windows-sys = { version = "0.45.0", features = ["Win32_Foundation"] } diff --git a/ngrok/src/conn.rs b/ngrok/src/conn.rs index 9a1a299..6f7f4a3 100644 --- a/ngrok/src/conn.rs +++ b/ngrok/src/conn.rs @@ -16,9 +16,12 @@ use tokio::io::{ AsyncWrite, }; -use crate::internals::proto::{ - EdgeType, - ProxyHeader, +use crate::{ + config::ProxyProto, + internals::proto::{ + EdgeType, + ProxyHeader, + }, }; /// A connection from an ngrok tunnel. /// @@ -33,6 +36,7 @@ pub(crate) struct ConnInner { pub(crate) struct Info { pub(crate) header: ProxyHeader, pub(crate) remote_addr: SocketAddr, + pub(crate) proxy_proto: ProxyProto, } impl ConnInfo for Info { diff --git a/ngrok/src/internals/proto.rs b/ngrok/src/internals/proto.rs index fc2320c..04a1b77 100644 --- a/ngrok/src/internals/proto.rs +++ b/ngrok/src/internals/proto.rs @@ -505,7 +505,7 @@ rpc_req!(SrvInfo, SrvInfoResp, SRV_INFO_REQ); /// to use with this tunnel. /// /// [ProxyProto::None] disables PROXY protocol support. -#[derive(Debug, Copy, Clone, Default)] +#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)] pub enum ProxyProto { /// No PROXY protocol #[default] @@ -772,11 +772,11 @@ pub struct TlsEndpoint { #[derive(Serialize, Deserialize, Debug, Clone, Default)] pub struct TlsTermination { - #[serde(with = "base64bytes", skip_serializing_if = "is_default")] + #[serde(default, with = "base64bytes", skip_serializing_if = "is_default")] pub cert: Vec, #[serde(skip_serializing_if = "is_default", default)] pub key: SecretBytes, - #[serde(with = "base64bytes", skip_serializing_if = "is_default")] + #[serde(default, with = "base64bytes", skip_serializing_if = "is_default")] pub sealed_key: Vec, } diff --git a/ngrok/src/lib.rs b/ngrok/src/lib.rs index 322f174..6009836 100644 --- a/ngrok/src/lib.rs +++ b/ngrok/src/lib.rs @@ -31,6 +31,8 @@ pub mod config { mod webhook_verification; } +mod proxy_proto; + /// Types for working with the ngrok session. pub mod session; /// Types for working with ngrok tunnels. diff --git a/ngrok/src/online_tests.rs b/ngrok/src/online_tests.rs index 7e09946..33c19e7 100644 --- a/ngrok/src/online_tests.rs +++ b/ngrok/src/online_tests.rs @@ -1,4 +1,5 @@ use std::{ + io, io::prelude::*, net::SocketAddr, str::FromStr, @@ -9,17 +10,26 @@ use std::{ }, Arc, }, + time::Duration, }; use anyhow::{ anyhow, Error, }; +use async_rustls::{ + rustls, + rustls::{ + ClientConfig, + RootCertStore, + }, +}; use axum::{ extract::connect_info::Connected, routing::get, Router, }; +use bytes::Bytes; use flate2::read::GzDecoder; use futures::{ channel::oneshot, @@ -32,7 +42,9 @@ use hyper::{ StatusCode, Uri, }; +use once_cell::sync::Lazy; use paste::paste; +use proxy_protocol::ProxyHeader; use rand::{ distributions::Alphanumeric, thread_rng, @@ -43,6 +55,7 @@ use tokio::{ AsyncReadExt, AsyncWriteExt, }, + net::TcpStream, sync::mpsc, test, }; @@ -50,7 +63,9 @@ use tokio_tungstenite::{ connect_async, tungstenite::Message, }; +use tokio_util::compat::*; use tracing_test::traced_test; +use url::Url; use crate::{ config::{ @@ -685,3 +700,81 @@ async fn session_tls_config() -> Result<(), Error> { Ok(()) } + +fn tls_client_config() -> Result, &'static io::Error> { + static CONFIG: Lazy, io::Error>> = Lazy::new(|| { + let der_certs = rustls_native_certs::load_native_certs()? + .into_iter() + .map(|c| c.0) + .collect::>(); + let der_certs = der_certs.as_slice(); + let mut root_store = RootCertStore::empty(); + root_store.add_parsable_certificates(der_certs); + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + Ok(Arc::new(config)) + }); + + Ok(CONFIG.as_ref()?.clone()) +} + +#[traced_test] +#[cfg_attr(not(feature = "paid-tests"), ignore)] +#[test] +async fn forward_proxy_protocol_tls() -> Result<(), Error> { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let sess = Session::builder().authtoken_from_env().connect().await?; + let forwarder = sess + .tls_endpoint() + .proxy_proto(ProxyProto::V2) + .termination(Bytes::default(), Bytes::default()) + .listen_and_forward(format!("tls://{}", addr).parse()?) + .await?; + + let tunnel_url: Url = forwarder.url().to_string().parse()?; + + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(500)).await; + let tunnel_conn = TcpStream::connect(format!( + "{}:{}", + tunnel_url.host_str().unwrap(), + tunnel_url.port().unwrap_or(443) + )) + .await?; + + let domain = rustls::ServerName::try_from(tunnel_url.host_str().unwrap()) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + + let mut tls_conn = async_rustls::TlsConnector::from( + tls_client_config().map_err(|e| io::Error::from(e.kind()))?, + ) + .connect(domain, tunnel_conn.compat()) + .await? + .compat(); + + tls_conn.write_all(b"Hello, world!").await + }); + + let (conn, _) = listener.accept().await?; + + let mut proxy_conn = crate::proxy_proto::Stream::incoming(conn); + let proxy_header = proxy_conn + .proxy_header() + .await? + .unwrap() + .map(Clone::clone) + .unwrap(); + + match proxy_header { + ProxyHeader::Version2 { .. } => {} + _ => unreachable!("we configured v2"), + } + + // TODO: actually accept the tls connection from the server side + + Ok(()) +} diff --git a/ngrok/src/proxy_proto.rs b/ngrok/src/proxy_proto.rs new file mode 100644 index 0000000..7095893 --- /dev/null +++ b/ngrok/src/proxy_proto.rs @@ -0,0 +1,564 @@ +use std::{ + io, + mem, + pin::{ + pin, + Pin, + }, + task::{ + ready, + Context, + Poll, + }, +}; + +use bytes::{ + Buf, + BytesMut, +}; +use proxy_protocol::{ + ParseError, + ProxyHeader, +}; +use tokio::io::{ + AsyncRead, + AsyncWrite, + ReadBuf, +}; +use tracing::instrument; + +// 536 is the smallest possible TCP segment, which both v1 and v2 are guaranteed +// to fit into. +const MAX_HEADER_LEN: usize = 536; +// v2 headers start with at least 16 bytes +const MIN_HEADER_LEN: usize = 16; + +#[derive(Debug)] +enum ReadState { + Reading(Option, BytesMut), + Error(proxy_protocol::ParseError, BytesMut), + Header(Option, BytesMut), + None, +} + +impl ReadState { + fn new() -> ReadState { + ReadState::Reading(None, BytesMut::with_capacity(MAX_HEADER_LEN)) + } + + fn header(&self) -> Result, &ParseError> { + match self { + ReadState::Error(err, _) | ReadState::Reading(Some(err), _) => Err(err), + ReadState::None | ReadState::Reading(None, _) => Ok(None), + ReadState::Header(hdr, _) => Ok(hdr.as_ref()), + } + } + + /// Read the header from the stream *once*. Once a header has been read, or + /// it's been determined that no header is coming, this will be a no-op. + #[instrument(level = "trace", skip(reader))] + fn poll_read_header_once( + &mut self, + cx: &mut Context, + mut reader: Pin<&mut impl AsyncRead>, + ) -> Poll> { + loop { + let read_state = mem::replace(self, ReadState::None); + let (last_err, mut hdr_buf) = match read_state { + // End states + ReadState::None | ReadState::Header(_, _) | ReadState::Error(_, _) => { + *self = read_state; + return Poll::Ready(Ok(())); + } + ReadState::Reading(err, hdr_buf) => (err, hdr_buf), + }; + + if hdr_buf.len() < MAX_HEADER_LEN { + let mut tmp_buf = ReadBuf::uninit(hdr_buf.spare_capacity_mut()); + let read_res = reader.as_mut().poll_read(cx, &mut tmp_buf); + // Regardless of error, make sure we track the read bytes + let filled = tmp_buf.filled().len(); + if filled > 0 { + let len = hdr_buf.len(); + // Safety: the tmp_buf is backed by the uninitialized + // portion of hdr_buf. Advancing the len to len + filled is + // guaranteed to only cover the bytes initialized by the + // read. + unsafe { hdr_buf.set_len(len + filled) } + } + match read_res { + // If we hit the end of the stream due to either an EOF or + // an error, set the state to a terminal one and return the + // result. + Poll::Ready(ref res) if res.is_err() || filled == 0 => { + *self = match last_err { + Some(err) => ReadState::Error(err, hdr_buf), + None => ReadState::Header(None, hdr_buf), + }; + return read_res; + } + // Pending leaves the last error and buffer unchanged. + Poll::Pending => { + *self = ReadState::Reading(last_err, hdr_buf); + return read_res; + } + _ => {} + } + } + + // Create a view into the header buffer so that failed parse + // attempts don't consume it. + let mut hdr_view = &*hdr_buf; + + // Don't try to parse unless we have a minimum number of bytes to + // avoid spurious "NotProxyHeader" errors. + // Also hack around a bug in the proxy_protocol crate that results + // in panics when the input ends in \r without the \n. + if hdr_view.len() < MIN_HEADER_LEN || matches!(hdr_view.last(), Some(b'\r')) { + *self = ReadState::Reading(last_err, hdr_buf); + continue; + } + + match proxy_protocol::parse(&mut hdr_view) { + Ok(hdr) => { + hdr_buf.advance(hdr_buf.len() - hdr_view.len()); + *self = ReadState::Header(Some(hdr), hdr_buf); + return Poll::Ready(Ok(())); + } + Err(ParseError::NotProxyHeader) => { + *self = ReadState::Header(None, hdr_buf); + return Poll::Ready(Ok(())); + } + + // Keep track of the last error - it might not be fatal if we + // simply haven't read enough + Err(err) => { + // If we've read too much, consider the error fatal. + if hdr_buf.len() >= MAX_HEADER_LEN { + *self = ReadState::Error(err, hdr_buf); + } else { + *self = ReadState::Reading(Some(err), hdr_buf); + } + continue; + } + } + } + } +} + +#[derive(Debug)] +enum WriteState { + Writing(BytesMut), + None, +} + +impl WriteState { + fn new(hdr: proxy_protocol::ProxyHeader) -> Result { + proxy_protocol::encode(hdr).map(WriteState::Writing) + } + + /// Write the header *once*. After its written to the stream, this will be a + /// no-op. + #[instrument(level = "trace", skip(writer))] + fn poll_write_header_once( + &mut self, + cx: &mut Context, + mut writer: Pin<&mut impl AsyncWrite>, + ) -> Poll> { + loop { + let state = mem::replace(self, WriteState::None); + match state { + WriteState::None => return Poll::Ready(Ok(())), + WriteState::Writing(mut buf) => { + let write_res = writer.as_mut().poll_write(cx, &buf); + match write_res { + Poll::Pending | Poll::Ready(Err(_)) => { + *self = WriteState::Writing(buf); + ready!(write_res)?; + unreachable!( + "ready! will return for us on either Pending or Ready(Err)" + ); + } + Poll::Ready(Ok(written)) => { + buf.advance(written); + if !buf.is_empty() { + *self = WriteState::Writing(buf); + continue; + } else { + return Ok(()).into(); + } + } + } + } + } + } + } +} + +#[derive(Debug)] +#[pin_project::pin_project] +pub struct Stream { + read_state: ReadState, + write_state: WriteState, + #[pin] + inner: S, +} + +impl Stream { + pub fn outgoing(stream: S, header: ProxyHeader) -> Result { + Ok(Stream { + inner: stream, + write_state: WriteState::new(header)?, + read_state: ReadState::None, + }) + } + + pub fn incoming(stream: S) -> Self { + Stream { + inner: stream, + read_state: ReadState::new(), + write_state: WriteState::None, + } + } + + pub fn disabled(stream: S) -> Self { + Stream { + inner: stream, + read_state: ReadState::None, + write_state: WriteState::None, + } + } +} + +impl Stream +where + S: AsyncRead, +{ + #[instrument(level = "trace", skip(self), fields(read_state = ?self.read_state))] + pub fn poll_proxy_header( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, &ParseError>>> { + let this = self.project(); + + ready!(this.read_state.poll_read_header_once(cx, this.inner))?; + + Ok(this.read_state.header()).into() + } + + #[instrument(level = "debug", skip(self))] + pub async fn proxy_header(&mut self) -> io::Result, &ParseError>> + where + Self: Unpin, + { + let mut this = Pin::new(self); + + futures::future::poll_fn(|cx| { + let this = this.as_mut().project(); + this.read_state.poll_read_header_once(cx, this.inner) + }) + .await?; + + Ok(this.get_mut().read_state.header()) + } +} + +impl AsyncRead for Stream +where + S: AsyncRead, +{ + #[instrument(level = "trace", skip(self), fields(read_state = ?self.read_state))] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let mut this = self.project(); + + ready!(this + .read_state + .poll_read_header_once(cx, this.inner.as_mut()))?; + + match this.read_state { + ReadState::Error(_, remainder) | ReadState::Header(_, remainder) => { + if !remainder.is_empty() { + let available = std::cmp::min(remainder.len(), buf.remaining()); + buf.put_slice(&remainder.split_to(available)); + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + } + } + ReadState::None => {} + _ => unreachable!(), + } + + this.inner.poll_read(cx, buf) + } +} + +impl AsyncWrite for Stream +where + S: AsyncWrite, +{ + #[instrument(level = "trace", skip(self), fields(write_state = ?self.write_state))] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut this = self.project(); + + ready!(this + .write_state + .poll_write_header_once(cx, this.inner.as_mut()))?; + + this.inner.poll_write(cx, buf) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_shutdown(cx) + } +} + +#[cfg(test)] +mod test { + use std::{ + cmp, + io, + pin::Pin, + task::{ + ready, + Context, + Poll, + }, + time::Duration, + }; + + use bytes::{ + BufMut, + BytesMut, + }; + use proxy_protocol::{ + version2::{ + self, + ProxyCommand, + }, + ProxyHeader, + }; + use tokio::io::{ + AsyncRead, + AsyncReadExt, + AsyncWriteExt, + ReadBuf, + }; + + use super::Stream; + + #[pin_project::pin_project] + struct ShortReader { + #[pin] + inner: S, + min: usize, + max: usize, + } + + impl AsyncRead for ShortReader + where + S: AsyncRead, + { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let mut this = self.project(); + let max_bytes = + *this.min + cmp::max(1, rand::random::() % (*this.max - *this.min)); + let mut tmp = vec![0; max_bytes]; + let mut tmp_buf = ReadBuf::new(&mut tmp); + let res = ready!(this.inner.as_mut().poll_read(cx, &mut tmp_buf)); + + buf.put_slice(tmp_buf.filled()); + + res?; + + Poll::Ready(Ok(())) + } + } + + impl ShortReader { + fn new(inner: S, min: usize, max: usize) -> Self { + ShortReader { inner, min, max } + } + } + + const INPUT: &str = "PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n"; + const PARTIAL_INPUT: &str = "PROXY TCP4 192.168.0.1"; + const FINAL_INPUT: &str = " 192.168.0.11 56324 443\r\n"; + + // Smoke test to ensure that the proxy protocol parser works as expected. + // Not actually testing our code. + #[test] + fn test_proxy_protocol() { + let mut buf = BytesMut::from(INPUT); + + assert!(proxy_protocol::parse(&mut buf).is_ok()); + + buf = BytesMut::from(PARTIAL_INPUT); + + assert!(proxy_protocol::parse(&mut &*buf).is_err()); + + buf.put_slice(FINAL_INPUT.as_bytes()); + + assert!(proxy_protocol::parse(&mut &*buf).is_ok()); + } + + #[tokio::test] + #[tracing_test::traced_test] + async fn test_header_stream_v2() { + let (left, mut right) = tokio::io::duplex(1024); + + let header = ProxyHeader::Version2 { + command: ProxyCommand::Proxy, + transport_protocol: version2::ProxyTransportProtocol::Stream, + addresses: version2::ProxyAddresses::Ipv4 { + source: "127.0.0.1:1".parse().unwrap(), + destination: "127.0.0.2:2".parse().unwrap(), + }, + }; + + let input = proxy_protocol::encode(header).unwrap(); + + let mut proxy_stream = Stream::incoming(ShortReader::new(left, 2, 5)); + + // Chunk our writes to ensure that our reader is resilient across split inputs. + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(50)).await; + + right.write_all(&input).await.expect("write header"); + + right + .write_all(b"Hello, world!") + .await + .expect("write hello"); + + right.shutdown().await.expect("shutdown"); + }); + + let hdr = proxy_stream + .proxy_header() + .await + .expect("read header") + .expect("decode header") + .expect("header exists"); + + assert!(matches!(hdr, ProxyHeader::Version2 { .. })); + + let mut buf = String::new(); + + proxy_stream + .read_to_string(&mut buf) + .await + .expect("read rest"); + + assert_eq!(buf, "Hello, world!"); + + // Get the header again - should be the same. + let hdr = proxy_stream + .proxy_header() + .await + .expect("read header") + .expect("decode header") + .expect("header exists"); + + assert!(matches!(hdr, ProxyHeader::Version2 { .. })); + } + + #[tokio::test] + #[tracing_test::traced_test] + async fn test_header_stream() { + let (left, mut right) = tokio::io::duplex(1024); + + let mut proxy_stream = Stream::incoming(ShortReader::new(left, 2, 5)); + + // Chunk our writes to ensure that our reader is resilient across split inputs. + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(50)).await; + + right + .write_all(INPUT.as_bytes()) + .await + .expect("write header"); + + right + .write_all(b"Hello, world!") + .await + .expect("write hello"); + + right.shutdown().await.expect("shutdown"); + }); + + let hdr = proxy_stream + .proxy_header() + .await + .expect("read header") + .expect("decode header") + .expect("header exists"); + + assert!(matches!(hdr, ProxyHeader::Version1 { .. })); + + let mut buf = String::new(); + + proxy_stream + .read_to_string(&mut buf) + .await + .expect("read rest"); + + assert_eq!(buf, "Hello, world!"); + + // Get the header again - should be the same. + let hdr = proxy_stream + .proxy_header() + .await + .expect("read header") + .expect("decode header") + .expect("header exists"); + + assert!(matches!(hdr, ProxyHeader::Version1 { .. })); + } + + #[tokio::test] + #[tracing_test::traced_test] + async fn test_noheader() { + let (left, mut right) = tokio::io::duplex(1024); + + let mut proxy_stream = Stream::incoming(left); + + right + .write_all(b"Hello, world!") + .await + .expect("write stream"); + + right.shutdown().await.expect("shutdown"); + drop(right); + + assert!(proxy_stream + .proxy_header() + .await + .unwrap() + .unwrap() + .is_none()); + + let mut buf = String::new(); + + proxy_stream + .read_to_string(&mut buf) + .await + .expect("read stream"); + + assert_eq!(buf, "Hello, world!"); + } +} diff --git a/ngrok/src/session.rs b/ngrok/src/session.rs index acf0e1b..923be3b 100644 --- a/ngrok/src/session.rs +++ b/ngrok/src/session.rs @@ -85,6 +85,7 @@ use crate::{ config::{ HttpTunnelBuilder, LabeledTunnelBuilder, + ProxyProto, TcpTunnelBuilder, TlsTunnelBuilder, TunnelConfig, @@ -96,7 +97,10 @@ use crate::{ BindExtra, BindOpts, Error, + HttpEndpoint, SecretString, + TcpEndpoint, + TlsEndpoint, }, raw_session::{ AcceptError as RawAcceptError, @@ -999,11 +1003,22 @@ async fn accept_one( if let Some(BindOpts::Tls(opts)) = &tun.opts { header.passthrough_tls = opts.tls_termination.is_none(); } + let proxy_proto = if let Some( + BindOpts::Tls(TlsEndpoint { proxy_proto, .. }) + | BindOpts::Http(HttpEndpoint { proxy_proto, .. }) + | BindOpts::Tcp(TcpEndpoint { proxy_proto, .. }), + ) = tun.opts + { + proxy_proto + } else { + ProxyProto::None + }; tun.tx .send(Ok(ConnInner { info: crate::conn::Info { remote_addr, header, + proxy_proto, }, stream: conn.stream, })) diff --git a/ngrok/src/tunnel_ext.rs b/ngrok/src/tunnel_ext.rs index 4e1191a..e1b5a1b 100644 --- a/ngrok/src/tunnel_ext.rs +++ b/ngrok/src/tunnel_ext.rs @@ -28,6 +28,7 @@ use hyper::{ StatusCode, }; use once_cell::sync::Lazy; +use proxy_protocol::ProxyHeader; #[cfg(target_os = "windows")] use tokio::net::windows::named_pipe::ClientOptions; #[cfg(not(target_os = "windows"))] @@ -50,6 +51,7 @@ use tokio_util::compat::{ use tracing::{ debug, field, + warn, Instrument, Span, }; @@ -58,7 +60,9 @@ use url::Url; use windows_sys::Win32::Foundation::ERROR_PIPE_BUSY; use crate::{ + config::ProxyProto, prelude::*, + proxy_proto, session::IoStream, EdgeConn, EndpointConn, @@ -97,9 +101,8 @@ pub trait TunnelExt: Tunnel + Send { async fn forward(&mut self, url: Url) -> Result<(), io::Error>; } -#[async_trait] pub(crate) trait ConnExt { - async fn forward_to(mut self, url: &Url) -> Result, io::Error>; + fn forward_to(self, url: &Url) -> JoinHandle>; } #[tracing::instrument(skip_all, fields(tunnel_id = tun.id(), url = %url))] @@ -119,58 +122,81 @@ where return Ok(()); }; - tunnel_conn.forward_to(&url).await?; + tunnel_conn.forward_to(&url); } } -#[async_trait] impl ConnExt for EdgeConn { - async fn forward_to(self, url: &Url) -> Result, io::Error> { - let upstream = match connect( - self.edge_type() == EdgeType::Tls && self.passthrough_tls(), - false, - url, - ) - .await - { - Ok(conn) => conn, - Err(e) => { - #[cfg(feature = "hyper")] - if self.edge_type() == EdgeType::Https { - serve_gateway_error(format!("{e}"), self); + fn forward_to(mut self, url: &Url) -> JoinHandle> { + let url = url.clone(); + tokio::spawn(async move { + let mut upstream = match connect( + self.edge_type() == EdgeType::Tls && self.passthrough_tls(), + None, // Edges don't support proxyproto (afaik) + &url, + ) + .await + { + Ok(conn) => conn, + Err(error) => { + #[cfg(feature = "hyper")] + if self.edge_type() == EdgeType::Https { + serve_gateway_error(format!("{error}"), self); + } + warn!(%error, "error connecting to upstream"); + return Err(error); } - return Err(e); - } - }; + }; - Ok(join_streams(self, upstream)) + copy_bidirectional(&mut self, &mut upstream).await?; + Ok(()) + }) } } -#[async_trait] impl ConnExt for EndpointConn { - async fn forward_to(self, url: &Url) -> Result, io::Error> { - let upstream = match connect( - self.proto() == "tls" && self.inner.info.header.passthrough_tls, - false, - url, - ) - .await - { - Ok(conn) => conn, - Err(e) => { - #[cfg(feature = "hyper")] - match self.proto() { - "http" | "https" => { - serve_gateway_error(format!("{e}"), self); + fn forward_to(self, url: &Url) -> JoinHandle> { + let url = url.clone(); + tokio::spawn(async move { + let proxy_proto = self.inner.info.proxy_proto; + let proto_tls = self.proto() == "tls"; + let proto_http = matches!(self.proto(), "http" | "https"); + let passthrough_tls = self.inner.info.passthrough_tls(); + + let (mut stream, proxy_header) = match proxy_proto { + ProxyProto::None => (crate::proxy_proto::Stream::disabled(self), None), + _ => { + let mut stream = crate::proxy_proto::Stream::incoming(self); + let header = stream + .proxy_header() + .await? + .map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("invalid proxy-protocol header: {}", e), + ) + })? + .cloned(); + (stream, header) + } + }; + + let mut upstream = match connect(proto_tls && passthrough_tls, proxy_header, &url).await + { + Ok(conn) => conn, + Err(error) => { + #[cfg(feature = "hyper")] + if proto_http { + serve_gateway_error(format!("{error}"), stream); } - _ => {} + warn!(%error, "error connecting to upstream"); + return Err(error); } - return Err(e); - } - }; + }; - Ok(join_streams(self, upstream)) + copy_bidirectional(&mut stream, &mut upstream).await?; + Ok(()) + }) } } @@ -199,7 +225,7 @@ fn tls_config() -> Result, &'static io::Error> { // Note: this additional wrapping logic currently unimplemented. async fn connect( tunnel_tls: bool, - _tunnel_proxyproto: bool, + proxy_proto_header: Option, url: &Url, ) -> Result, io::Error> { let host = url.host_str().unwrap_or("localhost"); @@ -282,6 +308,14 @@ async fn connect( } }; + // We have to write the proxy header _before_ tls termination + if let Some(header) = proxy_proto_header { + conn = Box::new( + proxy_proto::Stream::outgoing(conn, header) + .expect("re-serializing proxy header should always succeed"), + ) + }; + if backend_tls && !tunnel_tls { let domain = rustls::ServerName::try_from(host) .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; @@ -293,7 +327,7 @@ async fn connect( ) } - // TODO: proxyproto, header rewrites? + // TODO: header rewrites? Ok(conn) } @@ -306,21 +340,6 @@ async fn connect_tcp(host: &str, port: u16) -> Result { Ok(conn) } -fn join_streams( - mut left: impl AsyncRead + AsyncWrite + Unpin + Send + 'static, - mut right: impl AsyncRead + AsyncWrite + Unpin + Send + 'static, -) -> JoinHandle<()> { - tokio::spawn( - async move { - match copy_bidirectional(&mut left, &mut right).await { - Ok((l_bytes, r_bytes)) => debug!("joined streams closed, bytes from tunnel: {l_bytes}, bytes from local: {r_bytes}"), - Err(e) => debug!("joined streams error: {e}"), - }; - } - .in_current_span(), - ) -} - #[cfg(feature = "hyper")] fn serve_gateway_error( err: impl fmt::Display + Send + 'static,