Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

h2-support,h2-tests: add tools to ensure wake #794

Merged
merged 1 commit into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 137 additions & 4 deletions tests/h2-support/src/future_ext.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use futures::FutureExt;
use futures::{FutureExt, TryFuture};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::task::{Context, Poll, Wake, Waker};

/// Future extension helpers that are useful for tests
pub trait TestFuture: Future {
Expand All @@ -15,9 +17,140 @@ pub trait TestFuture: Future {
{
Drive {
driver: self,
future: Box::pin(other),
future: other.wakened(),
}
}

fn wakened(self) -> Wakened<Self>
where
Self: Sized,
{
Wakened {
future: Box::pin(self),
woken: Arc::new(AtomicBool::new(true)),
}
}
}

/// Wraps futures::future::join to ensure that the futures are only polled if they are woken.
pub fn join<Fut1, Fut2>(
future1: Fut1,
future2: Fut2,
) -> futures::future::Join<Wakened<Fut1>, Wakened<Fut2>>
where
Fut1: Future,
Fut2: Future,
{
futures::future::join(future1.wakened(), future2.wakened())
}

/// Wraps futures::future::join3 to ensure that the futures are only polled if they are woken.
pub fn join3<Fut1, Fut2, Fut3>(
future1: Fut1,
future2: Fut2,
future3: Fut3,
) -> futures::future::Join3<Wakened<Fut1>, Wakened<Fut2>, Wakened<Fut3>>
where
Fut1: Future,
Fut2: Future,
Fut3: Future,
{
futures::future::join3(future1.wakened(), future2.wakened(), future3.wakened())
}

/// Wraps futures::future::join4 to ensure that the futures are only polled if they are woken.
pub fn join4<Fut1, Fut2, Fut3, Fut4>(
future1: Fut1,
future2: Fut2,
future3: Fut3,
future4: Fut4,
) -> futures::future::Join4<Wakened<Fut1>, Wakened<Fut2>, Wakened<Fut3>, Wakened<Fut4>>
where
Fut1: Future,
Fut2: Future,
Fut3: Future,
Fut4: Future,
{
futures::future::join4(
future1.wakened(),
future2.wakened(),
future3.wakened(),
future4.wakened(),
)
}

/// Wraps futures::future::try_join to ensure that the futures are only polled if they are woken.
pub fn try_join<Fut1, Fut2>(
future1: Fut1,
future2: Fut2,
) -> futures::future::TryJoin<Wakened<Fut1>, Wakened<Fut2>>
where
Fut1: futures::future::TryFuture + Future,
Fut2: Future,
Wakened<Fut1>: futures::future::TryFuture,
Wakened<Fut2>: futures::future::TryFuture<Error = <Wakened<Fut1> as TryFuture>::Error>,
{
futures::future::try_join(future1.wakened(), future2.wakened())
}

/// Wraps futures::future::select to ensure that the futures are only polled if they are woken.
pub fn select<A, B>(future1: A, future2: B) -> futures::future::Select<Wakened<A>, Wakened<B>>
where
A: Future + Unpin,
B: Future + Unpin,
{
futures::future::select(future1.wakened(), future2.wakened())
}

/// Wraps futures::future::join_all to ensure that the futures are only polled if they are woken.
pub fn join_all<I>(iter: I) -> futures::future::JoinAll<Wakened<I::Item>>
where
I: IntoIterator,
I::Item: Future,
{
futures::future::join_all(iter.into_iter().map(|f| f.wakened()))
}

/// A future that only polls the inner future if it has been woken (after the initial poll).
pub struct Wakened<T> {
future: Pin<Box<T>>,
woken: Arc<AtomicBool>,
}

/// A future that only polls the inner future if it has been woken (after the initial poll).
impl<T> Future for Wakened<T>
where
T: Future,
{
type Output = T::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if !this.woken.load(std::sync::atomic::Ordering::SeqCst) {
return Poll::Pending;
}
this.woken.store(false, std::sync::atomic::Ordering::SeqCst);
let my_waker = IfWokenWaker {
inner: cx.waker().clone(),
wakened: this.woken.clone(),
};
let my_waker = Arc::new(my_waker).into();
let mut cx = Context::from_waker(&my_waker);
this.future.as_mut().poll(&mut cx)
}
}

impl Wake for IfWokenWaker {
fn wake(self: Arc<Self>) {
self.wakened
.store(true, std::sync::atomic::Ordering::SeqCst);
self.inner.wake_by_ref();
}
}

struct IfWokenWaker {
inner: Waker,
wakened: Arc<AtomicBool>,
}

impl<T: Future> TestFuture for T {}
Expand All @@ -29,7 +162,7 @@ impl<T: Future> TestFuture for T {}
/// This is useful for H2 futures that also require the connection to be polled.
pub struct Drive<'a, T, U> {
driver: &'a mut T,
future: Pin<Box<U>>,
future: Wakened<U>,
}

impl<'a, T, U> Future for Drive<'a, T, U>
Expand Down
2 changes: 1 addition & 1 deletion tests/h2-support/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub use {bytes, futures, http, tokio::io as tokio_io, tracing, tracing_subscribe
pub use futures::{Future, Sink, Stream};

// And our Future extensions
pub use super::future_ext::TestFuture;
pub use super::future_ext::{join, join3, join4, join_all, select, try_join, TestFuture};

// Our client_ext helpers
pub use super::client_ext::SendRequestExt;
Expand Down
21 changes: 7 additions & 14 deletions tests/h2-tests/tests/client_request.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use futures::future::{join, join_all, ready, select, Either};
use futures::future::{ready, Either};
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use h2_support::prelude::*;
Expand Down Expand Up @@ -849,7 +849,7 @@ async fn recv_too_big_headers() {
};

let client = async move {
let (mut client, conn) = client::Builder::new()
let (mut client, mut conn) = client::Builder::new()
.max_header_list_size(10)
.handshake::<_, Bytes>(io)
.await
Expand All @@ -863,30 +863,23 @@ async fn recv_too_big_headers() {
let req1 = client.send_request(request, true);
// Spawn tasks to ensure that the error wakes up tasks that are blocked
// waiting for a response.
let req1 = tokio::spawn(async move {
let req1 = async move {
let err = req1.expect("send_request").0.await.expect_err("response1");
assert_eq!(err.reason(), Some(Reason::PROTOCOL_ERROR));
});
};

let request = Request::builder()
.uri("https://http2.akamai.com/")
.body(())
.unwrap();

let req2 = client.send_request(request, true);
let req2 = tokio::spawn(async move {
let req2 = async move {
let err = req2.expect("send_request").0.await.expect_err("response2");
assert_eq!(err.reason(), Some(Reason::PROTOCOL_ERROR));
});
};

let conn = tokio::spawn(async move {
conn.await.expect("client");
});
for err in join_all([req1, req2, conn]).await {
if let Some(err) = err.err().and_then(|err| err.try_into_panic().ok()) {
std::panic::resume_unwind(err);
}
}
conn.drive(join(req1, req2)).await;
};

join(srv, client).await;
Expand Down
1 change: 0 additions & 1 deletion tests/h2-tests/tests/codec_read.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use futures::future::join;
use h2_support::prelude::*;

#[tokio::test]
Expand Down
1 change: 0 additions & 1 deletion tests/h2-tests/tests/codec_write.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use futures::future::join;
use h2_support::prelude::*;

#[tokio::test]
Expand Down
1 change: 0 additions & 1 deletion tests/h2-tests/tests/flow_control.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use futures::future::{join, join4};
use futures::{StreamExt, TryStreamExt};
use h2_support::prelude::*;
use h2_support::util::yield_once;
Expand Down
1 change: 0 additions & 1 deletion tests/h2-tests/tests/ping_pong.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use futures::channel::oneshot;
use futures::future::join;
use futures::StreamExt;
use h2_support::assert_ping;
use h2_support::prelude::*;
Expand Down
1 change: 0 additions & 1 deletion tests/h2-tests/tests/prioritization.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use futures::future::{join, select};
use futures::{pin_mut, FutureExt, StreamExt};

use h2_support::prelude::*;
Expand Down
16 changes: 4 additions & 12 deletions tests/h2-tests/tests/push_promise.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use std::iter::FromIterator;

use futures::{future::join, FutureExt as _, StreamExt, TryStreamExt};
use futures::{StreamExt, TryStreamExt};
use h2_support::prelude::*;

#[tokio::test]
Expand Down Expand Up @@ -52,15 +50,9 @@ async fn recv_push_works() {
let ps: Vec<_> = p.collect().await;
assert_eq!(1, ps.len())
};
// Use a FuturesUnordered to poll both tasks but only poll them
// if they have been notified.
let tasks = futures::stream::FuturesUnordered::from_iter([
check_resp_status.boxed(),
check_pushed_response.boxed(),
])
.collect::<()>();

h2.drive(tasks).await;

h2.drive(join(check_resp_status, check_pushed_response))
.await;
};

join(mock, h2).await;
Expand Down
1 change: 0 additions & 1 deletion tests/h2-tests/tests/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#![deny(warnings)]

use futures::future::join;
use futures::StreamExt;
use h2_support::prelude::*;
use tokio::io::AsyncWriteExt;
Expand Down
2 changes: 1 addition & 1 deletion tests/h2-tests/tests/stream_states.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![deny(warnings)]

use futures::future::{join, join3, lazy, try_join};
use futures::future::lazy;
use futures::{FutureExt, StreamExt, TryStreamExt};
use h2_support::prelude::*;
use h2_support::util::yield_once;
Expand Down