Skip to content

Commit

Permalink
Backport additional wasi-http changes to the 15.x release branch (#7540)
Browse files Browse the repository at this point in the history
* wasi-http: Implement http-error-code, and centralize error conversions (#7534)
* Filter out forbidden headers on incoming request and response resources (#7538)
  • Loading branch information
elliottt committed Nov 15, 2023
1 parent bb8eec8 commit de1f24e
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 69 deletions.
9 changes: 5 additions & 4 deletions crates/test-programs/src/bin/api_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ impl bindings::exports::wasi::http::incoming_handler::Guest for T {
let req_hdrs = request.headers();

assert!(
!req_hdrs.get(&header).is_empty(),
"missing `custom-forbidden-header` from request"
req_hdrs.get(&header).is_empty(),
"forbidden `custom-forbidden-header` found in request"
);

assert!(req_hdrs.delete(&header).is_err());
assert!(req_hdrs.append(&header, &b"no".to_vec()).is_err());

assert!(
!req_hdrs.get(&header).is_empty(),
"delete of forbidden header succeeded"
req_hdrs.get(&header).is_empty(),
"append of forbidden header succeeded"
);

let hdrs = bindings::wasi::http::types::Headers::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,19 @@ fn main() {

{
let request_body = outgoing_body.write().unwrap();
request_body
let e = request_body
.blocking_write_and_flush("more than 11 bytes".as_bytes())
.expect_err("write should fail");

// TODO: show how to use http-error-code to unwrap this error
let e = match e {
test_programs::wasi::io::streams::StreamError::LastOperationFailed(e) => e,
test_programs::wasi::io::streams::StreamError::Closed => panic!("request closed"),
};

assert!(matches!(
http_types::http_error_code(&e),
Some(http_types::ErrorCode::InternalError(Some(msg)))
if msg == "too much written to output stream"));
}

let e =
Expand Down
16 changes: 8 additions & 8 deletions crates/wasi-http/src/http_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
outgoing_handler,
types::{self, Scheme},
},
http_request_error, internal_error,
types::{HostFutureIncomingResponse, HostOutgoingRequest, OutgoingRequest},
WasiHttpView,
};
Expand Down Expand Up @@ -77,22 +78,21 @@ impl<T: WasiHttpView> outgoing_handler::Host for T {
uri = uri.path_and_query(path);
}

builder = builder.uri(
uri.build()
.map_err(|_| types::ErrorCode::HttpRequestUriInvalid)?,
);
builder = builder.uri(uri.build().map_err(http_request_error)?);

for (k, v) in req.headers.iter() {
builder = builder.header(k, v);
}

let body = req
.body
.unwrap_or_else(|| Empty::<Bytes>::new().map_err(|_| todo!("thing")).boxed());
let body = req.body.unwrap_or_else(|| {
Empty::<Bytes>::new()
.map_err(|_| unreachable!("Infallible error"))
.boxed()
});

let request = builder
.body(body)
.map_err(|err| types::ErrorCode::InternalError(Some(err.to_string())))?;
.map_err(|err| internal_error(err.to_string()))?;

Ok(Ok(self.send_request(OutgoingRequest {
use_tls,
Expand Down
52 changes: 52 additions & 0 deletions crates/wasi-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod bindings {
tracing: true,
async: false,
with: {
"wasi:io/error": wasmtime_wasi::preview2::bindings::io::error,
"wasi:io/streams": wasmtime_wasi::preview2::bindings::io::streams,
"wasi:io/poll": wasmtime_wasi::preview2::bindings::io::poll,

Expand Down Expand Up @@ -47,3 +48,54 @@ pub(crate) fn dns_error(rcode: String, info_code: u16) -> bindings::http::types:
pub(crate) fn internal_error(msg: String) -> bindings::http::types::ErrorCode {
bindings::http::types::ErrorCode::InternalError(Some(msg))
}

/// Translate a [`http::Error`] to a wasi-http `ErrorCode` in the context of a request.
pub fn http_request_error(err: http::Error) -> bindings::http::types::ErrorCode {
use bindings::http::types::ErrorCode;

if err.is::<http::uri::InvalidUri>() {
return ErrorCode::HttpRequestUriInvalid;
}

tracing::warn!("http request error: {err:?}");

ErrorCode::HttpProtocolError
}

/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
pub fn hyper_request_error(err: hyper::Error) -> bindings::http::types::ErrorCode {
use bindings::http::types::ErrorCode;
use std::error::Error;

// If there's a source, we might be able to extract a wasi-http error from it.
if let Some(cause) = err.source() {
if let Some(err) = cause.downcast_ref::<ErrorCode>() {
return err.clone();
}
}

tracing::warn!("hyper request error: {err:?}");

ErrorCode::HttpProtocolError
}

/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a response.
pub fn hyper_response_error(err: hyper::Error) -> bindings::http::types::ErrorCode {
use bindings::http::types::ErrorCode;
use std::error::Error;

if err.is_timeout() {
return ErrorCode::HttpResponseTimeout;
}

// If there's a source, we might be able to extract a wasi-http error from it.
if let Some(cause) = err.source() {
if let Some(err) = cause.downcast_ref::<ErrorCode>() {
return err.clone();
}
}

tracing::warn!("hyper response error: {err:?}");

ErrorCode::HttpProtocolError
}
88 changes: 67 additions & 21 deletions crates/wasi-http/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use crate::{
bindings::http::types::{self, Method, Scheme},
body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody},
dns_error,
dns_error, hyper_request_error,
};
use http_body_util::BodyExt;
use hyper::header::HeaderName;
Expand Down Expand Up @@ -35,17 +35,18 @@ pub trait WasiHttpView: Send {
fn new_incoming_request(
&mut self,
req: hyper::Request<HyperIncomingBody>,
) -> wasmtime::Result<Resource<HostIncomingRequest>> {
) -> wasmtime::Result<Resource<HostIncomingRequest>>
where
Self: Sized,
{
let (parts, body) = req.into_parts();
let body = HostIncomingBody::new(
body,
// TODO: this needs to be plumbed through
std::time::Duration::from_millis(600 * 1000),
);
Ok(self.table().push(HostIncomingRequest {
parts,
body: Some(body),
})?)
let incoming_req = HostIncomingRequest::new(self, parts, Some(body));
Ok(self.table().push(incoming_req)?)
}

fn new_response_outparam(
Expand Down Expand Up @@ -73,6 +74,41 @@ pub trait WasiHttpView: Send {
}
}

/// Returns `true` when the header is forbidden according to this [`WasiHttpView`] implementation.
pub(crate) fn is_forbidden_header(view: &mut dyn WasiHttpView, name: &HeaderName) -> bool {
static FORBIDDEN_HEADERS: [HeaderName; 9] = [
hyper::header::CONNECTION,
HeaderName::from_static("keep-alive"),
hyper::header::PROXY_AUTHENTICATE,
hyper::header::PROXY_AUTHORIZATION,
HeaderName::from_static("proxy-connection"),
hyper::header::TE,
hyper::header::TRANSFER_ENCODING,
hyper::header::UPGRADE,
HeaderName::from_static("http2-settings"),
];

FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
}

/// Removes forbidden headers from a [`hyper::HeaderMap`].
pub(crate) fn remove_forbidden_headers(
view: &mut dyn WasiHttpView,
headers: &mut hyper::HeaderMap,
) {
let forbidden_keys = Vec::from_iter(headers.keys().filter_map(|name| {
if is_forbidden_header(view, name) {
Some(name.clone())
} else {
None
}
}));

for name in forbidden_keys {
headers.remove(name);
}
}

pub fn default_send_request(
view: &mut dyn WasiHttpView,
OutgoingRequest {
Expand Down Expand Up @@ -156,20 +192,22 @@ async fn handler(
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let mut parts = authority.split(":");
let host = parts.next().unwrap_or(&authority);
let domain = rustls::ServerName::try_from(host)
.map_err(|_| dns_error("invalid dns name".to_string(), 0))?;
let stream = connector
.connect(domain, tcp_stream)
.await
.map_err(|_| types::ErrorCode::TlsProtocolError)?;
let domain = rustls::ServerName::try_from(host).map_err(|e| {
tracing::warn!("dns lookup error: {e:?}");
dns_error("invalid dns name".to_string(), 0)
})?;
let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
tracing::warn!("tls protocol error: {e:?}");
types::ErrorCode::TlsProtocolError
})?;

let (sender, conn) = timeout(
connect_timeout,
hyper::client::conn::http1::handshake(stream),
)
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(|_| types::ErrorCode::ConnectionTimeout)?;
.map_err(hyper_request_error)?;

let worker = preview2::spawn(async move {
match conn.await {
Expand All @@ -190,7 +228,7 @@ async fn handler(
)
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(|_| types::ErrorCode::HttpProtocolError)?;
.map_err(hyper_request_error)?;

let worker = preview2::spawn(async move {
match conn.await {
Expand All @@ -206,11 +244,8 @@ async fn handler(
let resp = timeout(first_byte_timeout, sender.send_request(request))
.await
.map_err(|_| types::ErrorCode::ConnectionReadTimeout)?
.map_err(|_| types::ErrorCode::HttpProtocolError)?
.map(|body| {
body.map_err(|_| types::ErrorCode::HttpProtocolError)
.boxed()
});
.map_err(hyper_request_error)?
.map(|body| body.map_err(hyper_request_error).boxed());

Ok(IncomingResponseInternal {
resp,
Expand Down Expand Up @@ -264,10 +299,21 @@ impl TryInto<http::Method> for types::Method {
}

pub struct HostIncomingRequest {
pub parts: http::request::Parts,
pub(crate) parts: http::request::Parts,
pub body: Option<HostIncomingBody>,
}

impl HostIncomingRequest {
pub fn new(
view: &mut dyn WasiHttpView,
mut parts: http::request::Parts,
body: Option<HostIncomingBody>,
) -> Self {
remove_forbidden_headers(view, &mut parts.headers);
Self { parts, body }
}
}

pub struct HostResponseOutparam {
pub result:
tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
Expand Down Expand Up @@ -318,7 +364,7 @@ impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
Some(body) => builder.body(body),
None => builder.body(
Empty::<bytes::Bytes>::new()
.map_err(|_| unreachable!())
.map_err(|_| unreachable!("Infallible error"))
.boxed(),
),
}
Expand Down
33 changes: 10 additions & 23 deletions crates/wasi-http/src/types_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use crate::{
bindings::http::types::{self, Headers, Method, Scheme, StatusCode, Trailers},
body::{HostFutureTrailers, HostIncomingBody, HostOutgoingBody},
types::{
FieldMap, HostFields, HostFutureIncomingResponse, HostIncomingRequest,
HostIncomingResponse, HostOutgoingRequest, HostOutgoingResponse, HostResponseOutparam,
is_forbidden_header, remove_forbidden_headers, FieldMap, HostFields,
HostFutureIncomingResponse, HostIncomingRequest, HostIncomingResponse, HostOutgoingRequest,
HostOutgoingResponse, HostResponseOutparam,
},
WasiHttpView,
};
use anyhow::Context;
use hyper::header::HeaderName;
use std::any::Any;
use std::str::FromStr;
use wasmtime::component::Resource;
Expand All @@ -20,9 +20,10 @@ use wasmtime_wasi::preview2::{
impl<T: WasiHttpView> crate::bindings::http::types::Host for T {
fn http_error_code(
&mut self,
_err: wasmtime::component::Resource<types::IoError>,
err: wasmtime::component::Resource<types::IoError>,
) -> wasmtime::Result<Option<types::ErrorCode>> {
todo!()
let e = self.table().get(&err)?;
Ok(e.downcast_ref::<types::ErrorCode>().cloned())
}
}

Expand Down Expand Up @@ -88,22 +89,6 @@ fn get_fields_mut<'a>(
}
}

fn is_forbidden_header<T: WasiHttpView>(view: &mut T, name: &HeaderName) -> bool {
static FORBIDDEN_HEADERS: [HeaderName; 9] = [
hyper::header::CONNECTION,
HeaderName::from_static("keep-alive"),
hyper::header::PROXY_AUTHENTICATE,
hyper::header::PROXY_AUTHORIZATION,
HeaderName::from_static("proxy-connection"),
hyper::header::TE,
hyper::header::TRANSFER_ENCODING,
hyper::header::UPGRADE,
HeaderName::from_static("http2-settings"),
];

FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
}

impl<T: WasiHttpView> crate::bindings::http::types::HostFields for T {
fn new(&mut self) -> wasmtime::Result<Resource<HostFields>> {
let id = self
Expand Down Expand Up @@ -833,11 +818,13 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse f
Ok(Err(e)) => return Ok(Some(Ok(Err(e)))),
};

let (parts, body) = resp.resp.into_parts();
let (mut parts, body) = resp.resp.into_parts();

remove_forbidden_headers(self, &mut parts.headers);

let resp = self.table().push(HostIncomingResponse {
status: parts.status.as_u16(),
headers: FieldMap::from(parts.headers),
headers: parts.headers,
body: Some({
let mut body = HostIncomingBody::new(body, resp.between_bytes_timeout);
body.retain_worker(&resp.worker);
Expand Down
Loading

0 comments on commit de1f24e

Please sign in to comment.