From 24cca65c7ab4677b399a25038721b64cb9cd4da7 Mon Sep 17 00:00:00 2001 From: Collin Styles Date: Sat, 21 Oct 2023 14:59:20 -0700 Subject: [PATCH] Add `TryAny` adapter --- futures-util/src/stream/try_stream/mod.rs | 33 +++++++ futures-util/src/stream/try_stream/try_any.rs | 98 +++++++++++++++++++ futures/tests/stream_try_stream.rs | 22 +++++ 3 files changed, 153 insertions(+) create mode 100644 futures-util/src/stream/try_stream/try_any.rs diff --git a/futures-util/src/stream/try_stream/mod.rs b/futures-util/src/stream/try_stream/mod.rs index 7cb9f3e1c..5d5702f36 100644 --- a/futures-util/src/stream/try_stream/mod.rs +++ b/futures-util/src/stream/try_stream/mod.rs @@ -170,6 +170,10 @@ mod try_all; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::try_all::TryAll; +mod try_any; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::try_any::TryAny; + impl TryStreamExt for S {} /// Adapters specific to `Result`-returning streams @@ -1215,4 +1219,33 @@ pub trait TryStreamExt: TryStream { { assert_future::, _>(TryAll::new(self, f)) } + + /// Attempt to execute a predicate over an asynchronous stream and evaluate if any items + /// satisfy the predicate. Exits early if an `Err` is encountered or if an `Ok` item is found + /// that satisfies the predicate. + /// + /// # Examples + /// + /// ``` + /// # futures::executor::block_on(async { + /// use futures::stream::{self, StreamExt, TryStreamExt}; + /// use std::convert::Infallible; + /// + /// let number_stream = stream::iter(0..10).map(Ok::<_, Infallible>); + /// let contain_three = number_stream.try_any(|i| async move { i == 3 }); + /// assert_eq!(contain_three.await, Ok(true)); + /// + /// let stream_with_errors = stream::iter([Ok(1), Err("err"), Ok(3)]); + /// let contain_three = stream_with_errors.try_any(|i| async move { i == 3 }); + /// assert_eq!(contain_three.await, Err("err")); + /// # }); + /// ``` + fn try_any(self, f: F) -> TryAny + where + Self: Sized, + F: FnMut(Self::Ok) -> Fut, + Fut: Future, + { + assert_future::, _>(TryAny::new(self, f)) + } } diff --git a/futures-util/src/stream/try_stream/try_any.rs b/futures-util/src/stream/try_stream/try_any.rs new file mode 100644 index 000000000..15adb3097 --- /dev/null +++ b/futures-util/src/stream/try_stream/try_any.rs @@ -0,0 +1,98 @@ +use core::fmt; +use core::pin::Pin; +use futures_core::future::{FusedFuture, Future}; +use futures_core::ready; +use futures_core::stream::TryStream; +use futures_core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Future for the [`any`](super::StreamExt::any) method. + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct TryAny { + #[pin] + stream: St, + f: F, + done: bool, + #[pin] + future: Option, + } +} + +impl fmt::Debug for TryAny +where + St: fmt::Debug, + Fut: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TryAny") + .field("stream", &self.stream) + .field("done", &self.done) + .field("future", &self.future) + .finish() + } +} + +impl TryAny +where + St: TryStream, + F: FnMut(St::Ok) -> Fut, + Fut: Future, +{ + pub(super) fn new(stream: St, f: F) -> Self { + Self { stream, f, done: false, future: None } + } +} + +impl FusedFuture for TryAny +where + St: TryStream, + F: FnMut(St::Ok) -> Fut, + Fut: Future, +{ + fn is_terminated(&self) -> bool { + self.done && self.future.is_none() + } +} + +impl Future for TryAny +where + St: TryStream, + F: FnMut(St::Ok) -> Fut, + Fut: Future, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + Poll::Ready(loop { + if let Some(fut) = this.future.as_mut().as_pin_mut() { + // we're currently processing a future to produce a new value + let acc = ready!(fut.poll(cx)); + this.future.set(None); + if acc { + *this.done = true; + break Ok(true); + } // early exit + } else if !*this.done { + // we're waiting on a new item from the stream + match ready!(this.stream.as_mut().try_poll_next(cx)) { + Some(Ok(item)) => { + this.future.set(Some((this.f)(item))); + } + Some(Err(err)) => { + *this.done = true; + break Err(err); + } + None => { + *this.done = true; + break Ok(false); + } + } + } else { + panic!("TryAny polled after completion") + } + }) + } +} diff --git a/futures/tests/stream_try_stream.rs b/futures/tests/stream_try_stream.rs index 4095fa536..ef38c510b 100644 --- a/futures/tests/stream_try_stream.rs +++ b/futures/tests/stream_try_stream.rs @@ -159,3 +159,25 @@ fn try_all() { assert_eq!(Err("err"), all); }); } + +#[test] +fn try_any() { + block_on(async { + let empty: [Result; 0] = []; + let st = stream::iter(empty); + let any = st.try_any(is_even).await; + assert_eq!(Ok(false), any); + + let st = stream::iter([Ok::<_, Infallible>(1), Ok(2), Ok(3)]); + let any = st.try_any(is_even).await; + assert_eq!(Ok(true), any); + + let st = stream::iter([Ok::<_, Infallible>(1), Ok(3), Ok(5)]); + let any = st.try_any(is_even).await; + assert_eq!(Ok(false), any); + + let st = stream::iter([Ok(1), Ok(3), Err("err"), Ok(8)]); + let any = st.try_any(is_even).await; + assert_eq!(Err("err"), any); + }); +}