Skip to content

Commit

Permalink
Support async gen fn
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Dec 5, 2023
1 parent 8d98feb commit 7cc2ec8
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 83 deletions.
11 changes: 8 additions & 3 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2415,10 +2415,12 @@ pub enum Unsafe {
/// Iterator`.
#[derive(Copy, Clone, Encodable, Decodable, Debug)]
pub enum CoroutineKind {
/// `async`, which evaluates to `impl Future`
/// `async`, which returns an `impl Future`
Async { span: Span, closure_id: NodeId, return_impl_trait_id: NodeId },
/// `gen`, which evaluates to `impl Iterator`
/// `gen`, which returns an `impl Iterator`
Gen { span: Span, closure_id: NodeId, return_impl_trait_id: NodeId },
/// `async gen`, which returns an `impl AsyncIterator`
AsyncGen { span: Span, closure_id: NodeId, return_impl_trait_id: NodeId },
}

impl CoroutineKind {
Expand All @@ -2435,7 +2437,10 @@ impl CoroutineKind {
pub fn return_id(self) -> (NodeId, Span) {
match self {
CoroutineKind::Async { return_impl_trait_id, span, .. }
| CoroutineKind::Gen { return_impl_trait_id, span, .. } => (return_impl_trait_id, span),
| CoroutineKind::Gen { return_impl_trait_id, span, .. }
| CoroutineKind::AsyncGen { return_impl_trait_id, span, .. } => {
(return_impl_trait_id, span)
}
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_ast/src/mut_visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,8 @@ pub fn noop_visit_closure_binder<T: MutVisitor>(binder: &mut ClosureBinder, vis:
pub fn noop_visit_coroutine_kind<T: MutVisitor>(coroutine_kind: &mut CoroutineKind, vis: &mut T) {
match coroutine_kind {
CoroutineKind::Async { span, closure_id, return_impl_trait_id }
| CoroutineKind::Gen { span, closure_id, return_impl_trait_id } => {
| CoroutineKind::Gen { span, closure_id, return_impl_trait_id }
| CoroutineKind::AsyncGen { span, closure_id, return_impl_trait_id } => {
vis.visit_span(span);
vis.visit_id(closure_id);
vis.visit_id(return_impl_trait_id);
Expand Down
16 changes: 9 additions & 7 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use rustc_ast::*;
use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_hir as hir;
use rustc_hir::def::{DefKind, Res};
use rustc_middle::span_bug;
use rustc_session::errors::report_lit_error;
use rustc_span::source_map::{respan, Spanned};
use rustc_span::symbol::{kw, sym, Ident, Symbol};
Expand Down Expand Up @@ -202,15 +203,12 @@ impl<'hir> LoweringContext<'_, 'hir> {
fn_decl_span,
fn_arg_span,
}) => match coroutine_kind {
Some(
CoroutineKind::Async { closure_id, .. }
| CoroutineKind::Gen { closure_id, .. },
) => self.lower_expr_async_closure(
Some(coroutine_kind) => self.lower_expr_coroutine_closure(
binder,
*capture_clause,
e.id,
hir_id,
*closure_id,
*coroutine_kind,
fn_decl,
body,
*fn_decl_span,
Expand Down Expand Up @@ -1098,18 +1096,22 @@ impl<'hir> LoweringContext<'_, 'hir> {
(binder, params)
}

fn lower_expr_async_closure(
fn lower_expr_coroutine_closure(
&mut self,
binder: &ClosureBinder,
capture_clause: CaptureBy,
closure_id: NodeId,
closure_hir_id: hir::HirId,
inner_closure_id: NodeId,
coroutine_kind: CoroutineKind,
decl: &FnDecl,
body: &Expr,
fn_decl_span: Span,
fn_arg_span: Span,
) -> hir::ExprKind<'hir> {
let CoroutineKind::Async { closure_id: inner_closure_id, .. } = coroutine_kind else {
span_bug!(fn_decl_span, "`async gen` and `gen` closures are not supported, yet");
};

if let &ClosureBinder::For { span, .. } = binder {
self.tcx.sess.emit_err(NotSupportedForLifetimeBinderAsyncClosure { span });
}
Expand Down
16 changes: 11 additions & 5 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1036,11 +1036,9 @@ impl<'hir> LoweringContext<'_, 'hir> {
let (Some(coroutine_kind), Some(body)) = (coroutine_kind, body) else {
return self.lower_fn_body_block(span, decl, body);
};
let closure_id = match coroutine_kind {
CoroutineKind::Async { closure_id, .. } | CoroutineKind::Gen { closure_id, .. } => {
closure_id
}
};
let (CoroutineKind::Async { closure_id, .. }
| CoroutineKind::Gen { closure_id, .. }
| CoroutineKind::AsyncGen { closure_id, .. }) = coroutine_kind;

self.lower_body(|this| {
let mut parameters: Vec<hir::Param<'_>> = Vec::new();
Expand Down Expand Up @@ -1224,6 +1222,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::CoroutineSource::Fn,
mkbody,
),
CoroutineKind::AsyncGen { .. } => this.make_async_gen_expr(
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
closure_id,
None,
body.span,
hir::CoroutineSource::Fn,
mkbody,
),
};

let hir_id = this.lower_node_id(closure_id);
Expand Down
8 changes: 5 additions & 3 deletions compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1904,7 +1904,8 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {

let opaque_ty_node_id = match coro {
CoroutineKind::Async { return_impl_trait_id, .. }
| CoroutineKind::Gen { return_impl_trait_id, .. } => return_impl_trait_id,
| CoroutineKind::Gen { return_impl_trait_id, .. }
| CoroutineKind::AsyncGen { return_impl_trait_id, .. } => return_impl_trait_id,
};

let captured_lifetimes: Vec<_> = self
Expand Down Expand Up @@ -1960,8 +1961,9 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {

// "<$assoc_ty_name = T>"
let (assoc_ty_name, trait_lang_item) = match coro {
CoroutineKind::Async { .. } => (hir::FN_OUTPUT_NAME, hir::LangItem::Future),
CoroutineKind::Gen { .. } => (hir::ITERATOR_ITEM_NAME, hir::LangItem::Iterator),
CoroutineKind::Async { .. } => (sym::Output, hir::LangItem::Future),
CoroutineKind::Gen { .. } => (sym::Item, hir::LangItem::Iterator),
CoroutineKind::AsyncGen { .. } => (sym::Item, hir::LangItem::AsyncIterator),
};

let future_args = self.arena.alloc(hir::GenericArgs {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_ast_lowering/src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
FnRetTy::Default(_) => self.arena.alloc(self.ty_tup(*span, &[])),
};
let args = smallvec![GenericArg::Type(self.arena.alloc(self.ty_tup(*inputs_span, inputs)))];
let binding = self.assoc_ty_binding(hir::FN_OUTPUT_NAME, output_ty.span, output_ty);
let binding = self.assoc_ty_binding(sym::Output, output_ty.span, output_ty);
(
GenericArgsCtor {
args,
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_ast_pretty/src/pprust/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,10 @@ impl<'a> State<'a> {
ast::CoroutineKind::Async { .. } => {
self.word_nbsp("async");
}
ast::CoroutineKind::AsyncGen { .. } => {
self.word_nbsp("async");
self.word_nbsp("gen");
}
}
}

Expand Down
5 changes: 0 additions & 5 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2254,11 +2254,6 @@ pub enum ImplItemKind<'hir> {
Type(&'hir Ty<'hir>),
}

/// The name of the associated type for `Fn` return types.
pub const FN_OUTPUT_NAME: Symbol = sym::Output;
/// The name of the associated type for `Iterator` item types.
pub const ITERATOR_ITEM_NAME: Symbol = sym::Item;

/// Bind a type to an associated type (i.e., `A = Foo`).
///
/// Bindings like `A: Debug` are represented as a special type `A =
Expand Down
2 changes: 0 additions & 2 deletions compiler/rustc_parse/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ parse_async_block_in_2015 = `async` blocks are only allowed in Rust 2018 or late
parse_async_fn_in_2015 = `async fn` is not permitted in Rust 2015
.label = to use `async fn`, switch to Rust 2018 or later
parse_async_gen_fn = `async gen` functions are not supported
parse_async_move_block_in_2015 = `async move` blocks are only allowed in Rust 2018 or later
parse_async_move_order_incorrect = the order of `move` and `async` is incorrect
Expand Down
7 changes: 0 additions & 7 deletions compiler/rustc_parse/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,13 +562,6 @@ pub(crate) struct GenFn {
pub span: Span,
}

#[derive(Diagnostic)]
#[diag(parse_async_gen_fn)]
pub(crate) struct AsyncGenFn {
#[primary_span]
pub span: Span,
}

#[derive(Diagnostic)]
#[diag(parse_comma_after_base_struct)]
#[note]
Expand Down
19 changes: 13 additions & 6 deletions compiler/rustc_parse/src/parser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2235,8 +2235,8 @@ impl<'a> Parser<'a> {
let movability =
if self.eat_keyword(kw::Static) { Movability::Static } else { Movability::Movable };

let asyncness = if self.token.uninterpolated_span().at_least_rust_2018() {
self.parse_asyncness(Case::Sensitive)
let coroutine_kind = if self.token.uninterpolated_span().at_least_rust_2018() {
self.parse_coroutine_kind(Case::Sensitive)
} else {
None
};
Expand All @@ -2262,9 +2262,16 @@ impl<'a> Parser<'a> {
}
};

if let Some(CoroutineKind::Async { span, .. }) = asyncness {
// Feature-gate `async ||` closures.
self.sess.gated_spans.gate(sym::async_closure, span);
match coroutine_kind {
Some(CoroutineKind::Async { span, .. }) => {
// Feature-gate `async ||` closures.
self.sess.gated_spans.gate(sym::async_closure, span);
}
Some(CoroutineKind::AsyncGen { span, .. }) | Some(CoroutineKind::Gen { span, .. }) => {
// Feature-gate `async ||` closures.
self.sess.gated_spans.gate(sym::gen_blocks, span);
}
None => {}
}

if self.token.kind == TokenKind::Semi
Expand All @@ -2285,7 +2292,7 @@ impl<'a> Parser<'a> {
binder,
capture_clause,
constness,
coroutine_kind: asyncness,
coroutine_kind,
movability,
fn_decl,
body,
Expand Down
47 changes: 21 additions & 26 deletions compiler/rustc_parse/src/parser/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2392,18 +2392,15 @@ impl<'a> Parser<'a> {
let constness = self.parse_constness(case);

let async_start_sp = self.token.span;
let asyncness = self.parse_asyncness(case);

let _gen_start_sp = self.token.span;
let genness = self.parse_genness(case);
let coroutine_kind = self.parse_coroutine_kind(case);

let unsafe_start_sp = self.token.span;
let unsafety = self.parse_unsafety(case);

let ext_start_sp = self.token.span;
let ext = self.parse_extern(case);

if let Some(CoroutineKind::Async { span, .. }) = asyncness {
if let Some(CoroutineKind::Async { span, .. }) = coroutine_kind {
if span.is_rust_2015() {
self.sess.emit_err(errors::AsyncFnIn2015 {
span,
Expand All @@ -2412,16 +2409,11 @@ impl<'a> Parser<'a> {
}
}

if let Some(CoroutineKind::Gen { span, .. }) = genness {
self.sess.gated_spans.gate(sym::gen_blocks, span);
}

if let (
Some(CoroutineKind::Async { span: async_span, .. }),
Some(CoroutineKind::Gen { span: gen_span, .. }),
) = (asyncness, genness)
{
self.sess.emit_err(errors::AsyncGenFn { span: async_span.to(gen_span) });
match coroutine_kind {
Some(CoroutineKind::Gen { span, .. }) | Some(CoroutineKind::AsyncGen { span, .. }) => {
self.sess.gated_spans.gate(sym::gen_blocks, span);
}
Some(CoroutineKind::Async { .. }) | None => {}
}

if !self.eat_keyword_case(kw::Fn, case) {
Expand All @@ -2440,7 +2432,7 @@ impl<'a> Parser<'a> {

// We may be able to recover
let mut recover_constness = constness;
let mut recover_asyncness = asyncness;
let mut recover_coroutine_kind = coroutine_kind;
let mut recover_unsafety = unsafety;
// This will allow the machine fix to directly place the keyword in the correct place or to indicate
// that the keyword is already present and the second instance should be removed.
Expand All @@ -2453,15 +2445,24 @@ impl<'a> Parser<'a> {
}
}
} else if self.check_keyword(kw::Async) {
match asyncness {
match coroutine_kind {
Some(CoroutineKind::Async { span, .. }) => {
Some(WrongKw::Duplicated(span))
}
Some(CoroutineKind::AsyncGen { span, .. }) => {
Some(WrongKw::Duplicated(span))
}
Some(CoroutineKind::Gen { .. }) => {
panic!("not sure how to recover here")
recover_coroutine_kind = Some(CoroutineKind::AsyncGen {
span: self.token.span,
closure_id: DUMMY_NODE_ID,
return_impl_trait_id: DUMMY_NODE_ID,
});
// FIXME(gen_blocks): This span is wrong, didn't want to think about it.
Some(WrongKw::Misplaced(unsafe_start_sp))
}
None => {
recover_asyncness = Some(CoroutineKind::Async {
recover_coroutine_kind = Some(CoroutineKind::Async {
span: self.token.span,
closure_id: DUMMY_NODE_ID,
return_impl_trait_id: DUMMY_NODE_ID,
Expand Down Expand Up @@ -2559,7 +2560,7 @@ impl<'a> Parser<'a> {
return Ok(FnHeader {
constness: recover_constness,
unsafety: recover_unsafety,
coroutine_kind: recover_asyncness,
coroutine_kind: recover_coroutine_kind,
ext,
});
}
Expand All @@ -2569,12 +2570,6 @@ impl<'a> Parser<'a> {
}
}

let coroutine_kind = match asyncness {
Some(CoroutineKind::Async { .. }) => asyncness,
Some(CoroutineKind::Gen { .. }) => unreachable!("asycness cannot be Gen"),
None => genness,
};

Ok(FnHeader { constness, unsafety, coroutine_kind, ext })
}

Expand Down
33 changes: 17 additions & 16 deletions compiler/rustc_parse/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1125,23 +1125,24 @@ impl<'a> Parser<'a> {
}

/// Parses asyncness: `async` or nothing.
fn parse_asyncness(&mut self, case: Case) -> Option<CoroutineKind> {
fn parse_coroutine_kind(&mut self, case: Case) -> Option<CoroutineKind> {
let span = self.token.uninterpolated_span();
if self.eat_keyword_case(kw::Async, case) {
let span = self.prev_token.uninterpolated_span();
Some(CoroutineKind::Async {
span,
closure_id: DUMMY_NODE_ID,
return_impl_trait_id: DUMMY_NODE_ID,
})
} else {
None
}
}

/// Parses genness: `gen` or nothing.
fn parse_genness(&mut self, case: Case) -> Option<CoroutineKind> {
if self.token.span.at_least_rust_2024() && self.eat_keyword_case(kw::Gen, case) {
let span = self.prev_token.uninterpolated_span();
if self.eat_keyword_case(kw::Gen, case) {
let gen_span = self.prev_token.uninterpolated_span();
Some(CoroutineKind::AsyncGen {
span: span.to(gen_span),
closure_id: DUMMY_NODE_ID,
return_impl_trait_id: DUMMY_NODE_ID,
})
} else {
Some(CoroutineKind::Async {
span,
closure_id: DUMMY_NODE_ID,
return_impl_trait_id: DUMMY_NODE_ID,
})
}
} else if self.eat_keyword_case(kw::Gen, case) {
Some(CoroutineKind::Gen {
span,
closure_id: DUMMY_NODE_ID,
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_resolve/src/def_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ impl<'a, 'b, 'tcx> visit::Visitor<'a> for DefCollector<'a, 'b, 'tcx> {
match closure.coroutine_kind {
Some(
CoroutineKind::Async { closure_id, .. }
| CoroutineKind::Gen { closure_id, .. },
| CoroutineKind::Gen { closure_id, .. }
| CoroutineKind::AsyncGen { closure_id, .. },
) => self.create_def(closure_id, kw::Empty, DefKind::Closure, expr.span),
None => closure_def,
}
Expand Down

0 comments on commit 7cc2ec8

Please sign in to comment.