Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport additional wasi-http changes to the 15.x release branch #7540

Merged
merged 2 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions crates/test-programs/src/bin/api_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ impl bindings::exports::wasi::http::incoming_handler::Guest for T {
let req_hdrs = request.headers();

assert!(
!req_hdrs.get(&header).is_empty(),
"missing `custom-forbidden-header` from request"
req_hdrs.get(&header).is_empty(),
"forbidden `custom-forbidden-header` found in request"
);

assert!(req_hdrs.delete(&header).is_err());
assert!(req_hdrs.append(&header, &b"no".to_vec()).is_err());

assert!(
!req_hdrs.get(&header).is_empty(),
"delete of forbidden header succeeded"
req_hdrs.get(&header).is_empty(),
"append of forbidden header succeeded"
);

let hdrs = bindings::wasi::http::types::Headers::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,19 @@ fn main() {

{
let request_body = outgoing_body.write().unwrap();
request_body
let e = request_body
.blocking_write_and_flush("more than 11 bytes".as_bytes())
.expect_err("write should fail");

// TODO: show how to use http-error-code to unwrap this error
let e = match e {
test_programs::wasi::io::streams::StreamError::LastOperationFailed(e) => e,
test_programs::wasi::io::streams::StreamError::Closed => panic!("request closed"),
};

assert!(matches!(
http_types::http_error_code(&e),
Some(http_types::ErrorCode::InternalError(Some(msg)))
if msg == "too much written to output stream"));
}

let e =
Expand Down
16 changes: 8 additions & 8 deletions crates/wasi-http/src/http_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
outgoing_handler,
types::{self, Scheme},
},
http_request_error, internal_error,
types::{HostFutureIncomingResponse, HostOutgoingRequest, OutgoingRequest},
WasiHttpView,
};
Expand Down Expand Up @@ -77,22 +78,21 @@ impl<T: WasiHttpView> outgoing_handler::Host for T {
uri = uri.path_and_query(path);
}

builder = builder.uri(
uri.build()
.map_err(|_| types::ErrorCode::HttpRequestUriInvalid)?,
);
builder = builder.uri(uri.build().map_err(http_request_error)?);

for (k, v) in req.headers.iter() {
builder = builder.header(k, v);
}

let body = req
.body
.unwrap_or_else(|| Empty::<Bytes>::new().map_err(|_| todo!("thing")).boxed());
let body = req.body.unwrap_or_else(|| {
Empty::<Bytes>::new()
.map_err(|_| unreachable!("Infallible error"))
.boxed()
});

let request = builder
.body(body)
.map_err(|err| types::ErrorCode::InternalError(Some(err.to_string())))?;
.map_err(|err| internal_error(err.to_string()))?;

Ok(Ok(self.send_request(OutgoingRequest {
use_tls,
Expand Down
52 changes: 52 additions & 0 deletions crates/wasi-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod bindings {
tracing: true,
async: false,
with: {
"wasi:io/error": wasmtime_wasi::preview2::bindings::io::error,
"wasi:io/streams": wasmtime_wasi::preview2::bindings::io::streams,
"wasi:io/poll": wasmtime_wasi::preview2::bindings::io::poll,

Expand Down Expand Up @@ -47,3 +48,54 @@ pub(crate) fn dns_error(rcode: String, info_code: u16) -> bindings::http::types:
pub(crate) fn internal_error(msg: String) -> bindings::http::types::ErrorCode {
bindings::http::types::ErrorCode::InternalError(Some(msg))
}

/// Translate a [`http::Error`] to a wasi-http `ErrorCode` in the context of a request.
pub fn http_request_error(err: http::Error) -> bindings::http::types::ErrorCode {
use bindings::http::types::ErrorCode;

if err.is::<http::uri::InvalidUri>() {
return ErrorCode::HttpRequestUriInvalid;
}

tracing::warn!("http request error: {err:?}");

ErrorCode::HttpProtocolError
}

/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
pub fn hyper_request_error(err: hyper::Error) -> bindings::http::types::ErrorCode {
use bindings::http::types::ErrorCode;
use std::error::Error;

// If there's a source, we might be able to extract a wasi-http error from it.
if let Some(cause) = err.source() {
if let Some(err) = cause.downcast_ref::<ErrorCode>() {
return err.clone();
}
}

tracing::warn!("hyper request error: {err:?}");

ErrorCode::HttpProtocolError
}

/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a response.
pub fn hyper_response_error(err: hyper::Error) -> bindings::http::types::ErrorCode {
use bindings::http::types::ErrorCode;
use std::error::Error;

if err.is_timeout() {
return ErrorCode::HttpResponseTimeout;
}

// If there's a source, we might be able to extract a wasi-http error from it.
if let Some(cause) = err.source() {
if let Some(err) = cause.downcast_ref::<ErrorCode>() {
return err.clone();
}
}

tracing::warn!("hyper response error: {err:?}");

ErrorCode::HttpProtocolError
}
88 changes: 67 additions & 21 deletions crates/wasi-http/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use crate::{
bindings::http::types::{self, Method, Scheme},
body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody},
dns_error,
dns_error, hyper_request_error,
};
use http_body_util::BodyExt;
use hyper::header::HeaderName;
Expand Down Expand Up @@ -35,17 +35,18 @@ pub trait WasiHttpView: Send {
fn new_incoming_request(
&mut self,
req: hyper::Request<HyperIncomingBody>,
) -> wasmtime::Result<Resource<HostIncomingRequest>> {
) -> wasmtime::Result<Resource<HostIncomingRequest>>
where
Self: Sized,
{
let (parts, body) = req.into_parts();
let body = HostIncomingBody::new(
body,
// TODO: this needs to be plumbed through
std::time::Duration::from_millis(600 * 1000),
);
Ok(self.table().push(HostIncomingRequest {
parts,
body: Some(body),
})?)
let incoming_req = HostIncomingRequest::new(self, parts, Some(body));
Ok(self.table().push(incoming_req)?)
}

fn new_response_outparam(
Expand Down Expand Up @@ -73,6 +74,41 @@ pub trait WasiHttpView: Send {
}
}

/// Returns `true` when the header is forbidden according to this [`WasiHttpView`] implementation.
pub(crate) fn is_forbidden_header(view: &mut dyn WasiHttpView, name: &HeaderName) -> bool {
static FORBIDDEN_HEADERS: [HeaderName; 9] = [
hyper::header::CONNECTION,
HeaderName::from_static("keep-alive"),
hyper::header::PROXY_AUTHENTICATE,
hyper::header::PROXY_AUTHORIZATION,
HeaderName::from_static("proxy-connection"),
hyper::header::TE,
hyper::header::TRANSFER_ENCODING,
hyper::header::UPGRADE,
HeaderName::from_static("http2-settings"),
];

FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
}

/// Removes forbidden headers from a [`hyper::HeaderMap`].
pub(crate) fn remove_forbidden_headers(
view: &mut dyn WasiHttpView,
headers: &mut hyper::HeaderMap,
) {
let forbidden_keys = Vec::from_iter(headers.keys().filter_map(|name| {
if is_forbidden_header(view, name) {
Some(name.clone())
} else {
None
}
}));

for name in forbidden_keys {
headers.remove(name);
}
}

pub fn default_send_request(
view: &mut dyn WasiHttpView,
OutgoingRequest {
Expand Down Expand Up @@ -156,20 +192,22 @@ async fn handler(
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let mut parts = authority.split(":");
let host = parts.next().unwrap_or(&authority);
let domain = rustls::ServerName::try_from(host)
.map_err(|_| dns_error("invalid dns name".to_string(), 0))?;
let stream = connector
.connect(domain, tcp_stream)
.await
.map_err(|_| types::ErrorCode::TlsProtocolError)?;
let domain = rustls::ServerName::try_from(host).map_err(|e| {
tracing::warn!("dns lookup error: {e:?}");
dns_error("invalid dns name".to_string(), 0)
})?;
let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
tracing::warn!("tls protocol error: {e:?}");
types::ErrorCode::TlsProtocolError
})?;

let (sender, conn) = timeout(
connect_timeout,
hyper::client::conn::http1::handshake(stream),
)
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(|_| types::ErrorCode::ConnectionTimeout)?;
.map_err(hyper_request_error)?;

let worker = preview2::spawn(async move {
match conn.await {
Expand All @@ -190,7 +228,7 @@ async fn handler(
)
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(|_| types::ErrorCode::HttpProtocolError)?;
.map_err(hyper_request_error)?;

let worker = preview2::spawn(async move {
match conn.await {
Expand All @@ -206,11 +244,8 @@ async fn handler(
let resp = timeout(first_byte_timeout, sender.send_request(request))
.await
.map_err(|_| types::ErrorCode::ConnectionReadTimeout)?
.map_err(|_| types::ErrorCode::HttpProtocolError)?
.map(|body| {
body.map_err(|_| types::ErrorCode::HttpProtocolError)
.boxed()
});
.map_err(hyper_request_error)?
.map(|body| body.map_err(hyper_request_error).boxed());

Ok(IncomingResponseInternal {
resp,
Expand Down Expand Up @@ -264,10 +299,21 @@ impl TryInto<http::Method> for types::Method {
}

pub struct HostIncomingRequest {
pub parts: http::request::Parts,
pub(crate) parts: http::request::Parts,
pub body: Option<HostIncomingBody>,
}

impl HostIncomingRequest {
pub fn new(
view: &mut dyn WasiHttpView,
mut parts: http::request::Parts,
body: Option<HostIncomingBody>,
) -> Self {
remove_forbidden_headers(view, &mut parts.headers);
Self { parts, body }
}
}

pub struct HostResponseOutparam {
pub result:
tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
Expand Down Expand Up @@ -318,7 +364,7 @@ impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
Some(body) => builder.body(body),
None => builder.body(
Empty::<bytes::Bytes>::new()
.map_err(|_| unreachable!())
.map_err(|_| unreachable!("Infallible error"))
.boxed(),
),
}
Expand Down
33 changes: 10 additions & 23 deletions crates/wasi-http/src/types_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use crate::{
bindings::http::types::{self, Headers, Method, Scheme, StatusCode, Trailers},
body::{HostFutureTrailers, HostIncomingBody, HostOutgoingBody},
types::{
FieldMap, HostFields, HostFutureIncomingResponse, HostIncomingRequest,
HostIncomingResponse, HostOutgoingRequest, HostOutgoingResponse, HostResponseOutparam,
is_forbidden_header, remove_forbidden_headers, FieldMap, HostFields,
HostFutureIncomingResponse, HostIncomingRequest, HostIncomingResponse, HostOutgoingRequest,
HostOutgoingResponse, HostResponseOutparam,
},
WasiHttpView,
};
use anyhow::Context;
use hyper::header::HeaderName;
use std::any::Any;
use std::str::FromStr;
use wasmtime::component::Resource;
Expand All @@ -20,9 +20,10 @@ use wasmtime_wasi::preview2::{
impl<T: WasiHttpView> crate::bindings::http::types::Host for T {
fn http_error_code(
&mut self,
_err: wasmtime::component::Resource<types::IoError>,
err: wasmtime::component::Resource<types::IoError>,
) -> wasmtime::Result<Option<types::ErrorCode>> {
todo!()
let e = self.table().get(&err)?;
Ok(e.downcast_ref::<types::ErrorCode>().cloned())
}
}

Expand Down Expand Up @@ -88,22 +89,6 @@ fn get_fields_mut<'a>(
}
}

fn is_forbidden_header<T: WasiHttpView>(view: &mut T, name: &HeaderName) -> bool {
static FORBIDDEN_HEADERS: [HeaderName; 9] = [
hyper::header::CONNECTION,
HeaderName::from_static("keep-alive"),
hyper::header::PROXY_AUTHENTICATE,
hyper::header::PROXY_AUTHORIZATION,
HeaderName::from_static("proxy-connection"),
hyper::header::TE,
hyper::header::TRANSFER_ENCODING,
hyper::header::UPGRADE,
HeaderName::from_static("http2-settings"),
];

FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
}

impl<T: WasiHttpView> crate::bindings::http::types::HostFields for T {
fn new(&mut self) -> wasmtime::Result<Resource<HostFields>> {
let id = self
Expand Down Expand Up @@ -833,11 +818,13 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse f
Ok(Err(e)) => return Ok(Some(Ok(Err(e)))),
};

let (parts, body) = resp.resp.into_parts();
let (mut parts, body) = resp.resp.into_parts();

remove_forbidden_headers(self, &mut parts.headers);

let resp = self.table().push(HostIncomingResponse {
status: parts.status.as_u16(),
headers: FieldMap::from(parts.headers),
headers: parts.headers,
body: Some({
let mut body = HostIncomingBody::new(body, resp.between_bytes_timeout);
body.retain_worker(&resp.worker);
Expand Down
Loading