Skip to content

Commit

Permalink
feat(tonic): add h2::Error as a source for Status (#612)
Browse files Browse the repository at this point in the history
## Motivation

A gRPC server may send a HTTP/2 GOAWAY frame with NO_ERROR status to gracefully shutdown a connection. This appears to Tonic users as a `tonic::Status` with `Code::Internal` and the message set to `h2 protocol error: protocol error: not a result of an error`.

The only way to currently detect this case and differentiate it from other internal errors (e.g., an application-level internal error) is to match on the message. A client may want to differentiate these cases because it may only want to alert on the application-level internal error and not on the transient transport-level issue. (Indeed, this is the use case for which I'm envisioning using this change.)

Matching on a message is not as robust, however, as matching on an `h2::Error` and its reason code. (The message could change for example if a future version of Tonic decided to vary the message. This would break any users that matched on the previous version of the message.)

## Solution

Store the `h2::Error` used when creating a `tonic::Status` from a `h2::Error` and provide it as the `source` for purposes of `std::error::Error`. This will allow users to downcast it and match on the original `h2::Reason`.
  • Loading branch information
Tom Dyas committed Jun 23, 2021
1 parent 12815d0 commit b90bb7b
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 46 deletions.
2 changes: 1 addition & 1 deletion tonic/src/client/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl<T> Grpc<T> {
.inner
.call(request)
.await
.map_err(|err| Status::from_error(&*(err.into())))?;
.map_err(|err| Status::from_error(err.into()))?;

let status_code = response.status();
let trailers_only_status = Status::from_header_map(response.headers());
Expand Down
8 changes: 4 additions & 4 deletions tonic/src/codec/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ impl<T> Streaming<T> {
// them manually.
let map = future::poll_fn(|cx| Pin::new(&mut self.body).poll_trailers(cx))
.await
.map_err(|e| Status::from_error(&e))?;
.map_err(|e| Status::from_error(Box::new(e)));

Ok(map.map(MetadataMap::from_headers))
map.map(|x| x.map(MetadataMap::from_headers))
}

fn decode_chunk(&mut self) -> Result<Option<T>, Status> {
Expand Down Expand Up @@ -232,7 +232,7 @@ impl<T> Stream for Streaming<T> {
Some(Err(e)) => {
let err: crate::Error = e.into();
debug!("decoder inner stream error: {:?}", err);
let status = Status::from_error(&*err);
let status = Status::from_error(err);
return Poll::Ready(Some(Err(status)));
}
None => None,
Expand Down Expand Up @@ -266,7 +266,7 @@ impl<T> Stream for Streaming<T> {
Err(e) => {
let err: crate::Error = e.into();
debug!("decoder inner trailers error: {:?}", err);
let status = Status::from_error(&*err);
let status = Status::from_error(err);
return Some(Err(status)).into();
}
}
Expand Down
2 changes: 1 addition & 1 deletion tonic/src/codec/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ mod tests {

let msg = Vec::from(&[0u8; 1024][..]);

let messages = std::iter::repeat(Ok::<_, Status>(msg)).take(10000);
let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
let source = futures_util::stream::iter(messages);

let body = encode_server(encoder, source);
Expand Down
118 changes: 83 additions & 35 deletions tonic/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ const GRPC_STATUS_DETAILS_HEADER: &str = "grpc-status-details-bin";
/// assert_eq!(status1.code(), Code::InvalidArgument);
/// assert_eq!(status1.code(), status2.code());
/// ```
#[derive(Clone)]
pub struct Status {
/// The gRPC status code, found in the `grpc-status` header.
code: Code,
Expand All @@ -45,6 +44,8 @@ pub struct Status {
/// If the metadata contains any headers with names reserved either by the gRPC spec
/// or by `Status` fields above, they will be ignored.
metadata: MetadataMap,
/// Optional underlying error.
source: Option<Box<dyn Error + Send + Sync + 'static>>,
}

/// gRPC status codes used by [`Status`].
Expand Down Expand Up @@ -162,6 +163,7 @@ impl Status {
message: message.into(),
details: Bytes::new(),
metadata: MetadataMap::new(),
source: None,
}
}

Expand Down Expand Up @@ -302,38 +304,34 @@ impl Status {
}

#[cfg_attr(not(feature = "transport"), allow(dead_code))]
pub(crate) fn from_error(err: &(dyn Error + 'static)) -> Status {
Status::try_from_error(err).unwrap_or_else(|| Status::new(Code::Unknown, err.to_string()))
pub(crate) fn from_error(err: Box<dyn Error + Send + Sync + 'static>) -> Status {
Status::try_from_error(err)
.unwrap_or_else(|err| Status::new(Code::Unknown, err.to_string()))
}

pub(crate) fn try_from_error(err: &(dyn Error + 'static)) -> Option<Status> {
let mut cause = Some(err);

while let Some(err) = cause {
if let Some(status) = err.downcast_ref::<Status>() {
return Some(Status {
code: status.code,
message: status.message.clone(),
details: status.details.clone(),
metadata: status.metadata.clone(),
});
pub(crate) fn try_from_error(
err: Box<dyn Error + Send + Sync + 'static>,
) -> Result<Status, Box<dyn Error + Send + Sync + 'static>> {
let err = match err.downcast::<Status>() {
Ok(status) => {
return Ok(*status);
}
Err(err) => err,
};

#[cfg(feature = "transport")]
{
if let Some(h2) = err.downcast_ref::<h2::Error>() {
return Some(Status::from_h2_error(h2));
}

if let Some(timeout) = err.downcast_ref::<crate::transport::TimeoutExpired>() {
return Some(Status::cancelled(timeout.to_string()));
}
#[cfg(feature = "transport")]
let err = match err.downcast::<h2::Error>() {
Ok(h2) => {
return Ok(Status::from_h2_error(&*h2));
}
Err(err) => err,
};

cause = err.source();
if let Some(status) = find_status_in_source_chain(&*err) {
return Ok(status);
}

None
Err(err)
}

// FIXME: bubble this into `transport` and expose generic http2 reasons.
Expand All @@ -356,7 +354,13 @@ impl Status {
_ => Code::Unknown,
};

Status::new(code, format!("h2 protocol error: {}", err))
let mut status = Self::new(code, format!("h2 protocol error: {}", err));
let error = err
.reason()
.map(h2::Error::from)
.map(|err| Box::new(err) as Box<dyn Error + Send + Sync + 'static>);
status.source = error;
status
}

#[cfg(feature = "transport")]
Expand All @@ -374,7 +378,8 @@ impl Status {
where
E: Into<Box<dyn Error + Send + Sync>>,
{
Status::from_error(&*err.into())
let err: Box<dyn Error + Send + Sync> = err.into();
Status::from_error(err)
}

/// Extract a `Status` from a hyper `HeaderMap`.
Expand Down Expand Up @@ -410,6 +415,7 @@ impl Status {
message,
details,
metadata: MetadataMap::from_headers(other_headers),
source: None,
},
Err(err) => {
warn!("Error deserializing status message header: {}", err);
Expand All @@ -418,6 +424,7 @@ impl Status {
message: format!("Error deserializing status message header: {}", err),
details,
metadata: MetadataMap::from_headers(other_headers),
source: None,
}
}
}
Expand Down Expand Up @@ -505,6 +512,7 @@ impl Status {
message: message.into(),
details,
metadata,
source: None,
}
}

Expand All @@ -524,6 +532,32 @@ impl Status {
}
}

fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option<Status> {
let mut source = Some(err);

while let Some(err) = source {
if let Some(status) = err.downcast_ref::<Status>() {
return Some(Status {
code: status.code,
message: status.message.clone(),
details: status.details.clone(),
metadata: status.metadata.clone(),
// Since `Status` is not `Clone`, any `source` on the original Status
// cannot be cloned so must remain with the original `Status`.
source: None,
});
}

#[cfg(feature = "transport")]
if let Some(timeout) = err.downcast_ref::<crate::transport::TimeoutExpired>() {
return Some(Status::cancelled(timeout.to_string()));
}

source = err.source();
}

None
}
impl fmt::Debug for Status {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// A manual impl to reduce the noise of frequently empty fields.
Expand All @@ -543,6 +577,8 @@ impl fmt::Debug for Status {
builder.field("metadata", &self.metadata);
}

builder.field("source", &self.source);

builder.finish()
}
}
Expand Down Expand Up @@ -609,7 +645,11 @@ impl fmt::Display for Status {
}
}

impl Error for Status {}
impl Error for Status {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source.as_ref().map(|err| (&**err) as _)
}
}

///
/// Take the `Status` value from `trailers` if it is available, else from `status_code`.
Expand Down Expand Up @@ -775,25 +815,25 @@ mod tests {
#[test]
fn from_error_status() {
let orig = Status::new(Code::OutOfRange, "weeaboo");
let found = Status::from_error(&orig);
let found = Status::from_error(Box::new(orig));

assert_eq!(orig.code(), found.code());
assert_eq!(orig.message(), found.message());
assert_eq!(found.code(), Code::OutOfRange);
assert_eq!(found.message(), "weeaboo");
}

#[test]
fn from_error_unknown() {
let orig: Error = "peek-a-boo".into();
let found = Status::from_error(&*orig);
let found = Status::from_error(orig);

assert_eq!(found.code(), Code::Unknown);
assert_eq!(found.message(), orig.to_string());
assert_eq!(found.message(), "peek-a-boo".to_string());
}

#[test]
fn from_error_nested() {
let orig = Nested(Box::new(Status::new(Code::OutOfRange, "weeaboo")));
let found = Status::from_error(&orig);
let found = Status::from_error(Box::new(orig));

assert_eq!(found.code(), Code::OutOfRange);
assert_eq!(found.message(), "weeaboo");
Expand All @@ -802,10 +842,18 @@ mod tests {
#[test]
#[cfg(feature = "transport")]
fn from_error_h2() {
use std::error::Error as _;

let orig = h2::Error::from(h2::Reason::CANCEL);
let found = Status::from_error(&orig);
let found = Status::from_error(Box::new(orig));

assert_eq!(found.code(), Code::Cancelled);

let source = found
.source()
.and_then(|err| err.downcast_ref::<h2::Error>())
.unwrap();
assert_eq!(source.reason(), Some(h2::Reason::CANCEL));
}

#[test]
Expand Down
9 changes: 4 additions & 5 deletions tonic/src/transport/server/recover_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,14 @@ where
let response = response.map(MaybeEmptyBody::full);
Poll::Ready(Ok(response))
}
Err(err) => {
if let Some(status) = Status::try_from_error(&*err) {
Err(err) => match Status::try_from_error(err) {
Ok(status) => {
let mut res = Response::new(MaybeEmptyBody::empty());
status.add_header(res.headers_mut()).unwrap();
Poll::Ready(Ok(res))
} else {
Poll::Ready(Err(err))
}
}
Err(err) => Poll::Ready(Err(err)),
},
}
}
}
Expand Down

0 comments on commit b90bb7b

Please sign in to comment.