Skip to content

Commit

Permalink
stream: add next_many and poll_next_many to StreamMap (#6409)
Browse files Browse the repository at this point in the history
  • Loading branch information
maminrayej authored Mar 26, 2024
1 parent deff252 commit 4601c84
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 3 deletions.
3 changes: 3 additions & 0 deletions tokio-stream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@
#[macro_use]
mod macros;

mod poll_fn;
pub(crate) use poll_fn::poll_fn;

pub mod wrappers;

mod stream_ext;
Expand Down
35 changes: 35 additions & 0 deletions tokio-stream/src/poll_fn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

pub(crate) struct PollFn<F> {
f: F,
}

pub(crate) fn poll_fn<T, F>(f: F) -> PollFn<F>
where
F: FnMut(&mut Context<'_>) -> Poll<T>,
{
PollFn { f }
}

impl<T, F> Future for PollFn<F>
where
F: FnMut(&mut Context<'_>) -> Poll<T>,
{
type Output = T;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
// Safety: We never construct a `Pin<&mut F>` anywhere, so accessing `f`
// mutably in an unpinned way is sound.
//
// This use of unsafe cannot be replaced with the pin-project macro
// because:
// * If we put `#[pin]` on the field, then it gives us a `Pin<&mut F>`,
// which we can't use to call the closure.
// * If we don't put `#[pin]` on the field, then it makes `PollFn` be
// unconditionally `Unpin`, which we also don't want.
let me = unsafe { Pin::into_inner_unchecked(self) };
(me.f)(cx)
}
}
106 changes: 105 additions & 1 deletion tokio-stream/src/stream_map.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::Stream;
use crate::{poll_fn, Stream};

use std::borrow::Borrow;
use std::hash::Hash;
Expand Down Expand Up @@ -561,6 +561,110 @@ impl<K, V> Default for StreamMap<K, V> {
}
}

impl<K, V> StreamMap<K, V>
where
K: Clone + Unpin,
V: Stream + Unpin,
{
/// Receives multiple items on this [`StreamMap`], extending the provided `buffer`.
///
/// This method returns the number of items that is appended to the `buffer`.
///
/// Note that this method does not guarantee that exactly `limit` items
/// are received. Rather, if at least one item is available, it returns
/// as many items as it can up to the given limit. This method returns
/// zero only if the `StreamMap` is empty (or if `limit` is zero).
///
/// # Cancel safety
///
/// This method is cancel safe. If `next_many` is used as the event in a
/// [`tokio::select!`](tokio::select) statement and some other branch
/// completes first, it is guaranteed that no items were received on any of
/// the underlying streams.
pub async fn next_many(&mut self, buffer: &mut Vec<(K, V::Item)>, limit: usize) -> usize {
poll_fn(|cx| self.poll_next_many(cx, buffer, limit)).await
}

/// Polls to receive multiple items on this `StreamMap`, extending the provided `buffer`.
///
/// This method returns:
/// * `Poll::Pending` if no items are available but the `StreamMap` is not empty.
/// * `Poll::Ready(count)` where `count` is the number of items successfully received and
/// stored in `buffer`. This can be less than, or equal to, `limit`.
/// * `Poll::Ready(0)` if `limit` is set to zero or when the `StreamMap` is empty.
///
/// Note that this method does not guarantee that exactly `limit` items
/// are received. Rather, if at least one item is available, it returns
/// as many items as it can up to the given limit. This method returns
/// zero only if the `StreamMap` is empty (or if `limit` is zero).
pub fn poll_next_many(
&mut self,
cx: &mut Context<'_>,
buffer: &mut Vec<(K, V::Item)>,
limit: usize,
) -> Poll<usize> {
if limit == 0 || self.entries.is_empty() {
return Poll::Ready(0);
}

let mut added = 0;

let start = self::rand::thread_rng_n(self.entries.len() as u32) as usize;
let mut idx = start;

while added < limit {
// Indicates whether at least one stream returned a value when polled or not
let mut should_loop = false;

for _ in 0..self.entries.len() {
let (_, stream) = &mut self.entries[idx];

match Pin::new(stream).poll_next(cx) {
Poll::Ready(Some(val)) => {
added += 1;

let key = self.entries[idx].0.clone();
buffer.push((key, val));

should_loop = true;

idx = idx.wrapping_add(1) % self.entries.len();
}
Poll::Ready(None) => {
// Remove the entry
self.entries.swap_remove(idx);

// Check if this was the last entry, if so the cursor needs
// to wrap
if idx == self.entries.len() {
idx = 0;
} else if idx < start && start <= self.entries.len() {
// The stream being swapped into the current index has
// already been polled, so skip it.
idx = idx.wrapping_add(1) % self.entries.len();
}
}
Poll::Pending => {
idx = idx.wrapping_add(1) % self.entries.len();
}
}
}

if !should_loop {
break;
}
}

if added > 0 {
Poll::Ready(added)
} else if self.entries.is_empty() {
Poll::Ready(0)
} else {
Poll::Pending
}
}
}

impl<K, V> Stream for StreamMap<K, V>
where
K: Clone + Unpin,
Expand Down
Loading

0 comments on commit 4601c84

Please sign in to comment.