From d20f3feb18b034c02c9b920988db7ea92b650c76 Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Thu, 22 Aug 2019 20:49:37 -0400 Subject: [PATCH] Make generated 'project' reference take an '&mut Pin<&mut Self>' Based on https://github.com/rust-lang/unsafe-code-guidelines/issues/148#issuecomment-523711778 by @CAD97 Currently, the generated 'project' method takes a 'Pin<&mut Self>', consuming it. This makes it impossible to use the original Pin<&mut Self> after calling project(), since the 'Pin<&mut Self>' has been moved into the the 'Project' method. This makes it impossible to implement useful pattern when working with enums: ```rust enum Foo { Variant1(#[pin] SomeFuture), Variant2(OtherType) } fn process(foo: Pin<&mut Foo>) { match foo.project() { __FooProjection(fut) => { fut.poll(); let new_foo: Foo = ...; foo.set(new_foo); }, _ => {} } } ``` This pattern is common when implementing a Future combinator - an inner future is polled, and then the containing enum is changed to a new variant. However, as soon as 'project()' is called, it becoms imposible to call 'set' on the original 'Pin<&mut Self>'. To support this pattern, this commit changes the 'project' method to take a '&mut Pin<&mut Self>'. The projection types works exactly as before - however, creating it no longer requires consuming the original 'Pin<&mut Self>' Unfortunately, current limitations of Rust prevent us from simply modifiying the signature of the 'project' method in the inherent impl of the projection type. While using 'Pin<&mut Self>' as a receiver is supported on stable rust, using '&mut Pin<&mut Self>' as a receiver requires the unstable `#![feature(arbitrary_self_types)]` For compatibility with stable Rust, we instead dynamically define a new trait, '__{Type}ProjectionTrait', where {Type} is the name of the type with the `#[pin_project]` attribute. This trait looks like this: ```rust trait __FooProjectionTrait { fn project(&'a mut self) -> __FooProjection<'a>; } ``` It is then implemented for `Pin<&mut {Type}>`. This allows the `project` method to be invoked on `&mut Pin<&mut {Type}>`, which is what we want. If Generic Associated Types (https://github.com/rust-lang/rust/issues/44265) were implemented and stablized, we could use a single trait for all pin projections: ```rust trait Projectable { type Projection<'a>; fn project(&'a mut self) -> Self::Projection<'a>; } ``` However, Generic Associated Types are not even implemented on nightly yet, so we need for generate a new trait per type for the forseeable future. --- pin-project-internal/src/pin_project/enums.rs | 7 +-- pin-project-internal/src/pin_project/mod.rs | 43 +++++++++++++---- .../src/pin_project/structs.rs | 7 +-- pin-project-internal/src/utils.rs | 4 ++ src/lib.rs | 2 +- tests/pin_project.rs | 46 +++++++++++++++++-- tests/pinned_drop.rs | 2 +- 7 files changed, 90 insertions(+), 21 deletions(-) diff --git a/pin-project-internal/src/pin_project/enums.rs b/pin-project-internal/src/pin_project/enums.rs index 3262c01e..2b7e2c67 100644 --- a/pin-project-internal/src/pin_project/enums.rs +++ b/pin-project-internal/src/pin_project/enums.rs @@ -29,16 +29,17 @@ pub(super) fn parse(mut cx: Context, mut item: ItemEnum) -> Result let Context { original, projected, lifetime, impl_unpin, .. } = cx; let proj_generics = proj_generics(&item.generics, &lifetime); let proj_ty_generics = proj_generics.split_for_impl().1; + let proj_trait = &cx.projected_trait; let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); let mut proj_items = quote! { enum #projected #proj_generics #where_clause { #(#proj_variants,)* } }; let proj_method = quote! { - impl #impl_generics #original #ty_generics #where_clause { - fn project<#lifetime>(self: ::core::pin::Pin<&#lifetime mut Self>) -> #projected #proj_ty_generics { + impl #impl_generics #proj_trait #ty_generics for ::core::pin::Pin<&mut #original #ty_generics> #where_clause { + fn project<#lifetime>(&#lifetime mut self) -> #projected #proj_ty_generics #where_clause { unsafe { - match ::core::pin::Pin::get_unchecked_mut(self) { + match self.as_mut().get_unchecked_mut() { #(#proj_arms,)* } } diff --git a/pin-project-internal/src/pin_project/mod.rs b/pin-project-internal/src/pin_project/mod.rs index 64f75944..474b638e 100644 --- a/pin-project-internal/src/pin_project/mod.rs +++ b/pin-project-internal/src/pin_project/mod.rs @@ -4,11 +4,10 @@ use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, token::Comma, - Fields, FieldsNamed, FieldsUnnamed, GenericParam, Generics, Index, Item, ItemStruct, Lifetime, - LifetimeDef, Meta, NestedMeta, Result, Type, + * }; -use crate::utils::{crate_path, proj_ident}; +use crate::utils::{crate_path, proj_ident, proj_trait_ident}; mod enums; mod structs; @@ -51,6 +50,10 @@ struct Context { original: Ident, /// Name of the projected type. projected: Ident, + /// Name of the trait generated + /// to provide a 'project' method + projected_trait: Ident, + generics: Generics, lifetime: Lifetime, impl_unpin: ImplUnpin, @@ -63,7 +66,8 @@ impl Context { let projected = proj_ident(&original); let lifetime = proj_lifetime(&generics.params); let impl_unpin = ImplUnpin::new(generics, unsafe_unpin); - Ok(Self { original, projected, lifetime, impl_unpin, pinned_drop }) + let projected_trait = proj_trait_ident(&original); + Ok(Self { original, projected, projected_trait, lifetime, impl_unpin, pinned_drop, generics: generics.clone() }) } fn impl_drop<'a>(&self, generics: &'a Generics) -> ImplDrop<'a> { @@ -72,24 +76,47 @@ impl Context { } fn parse(args: TokenStream, input: TokenStream) -> Result { + match syn::parse2(input)? { Item::Struct(item) => { - let cx = Context::new(args, item.ident.clone(), &item.generics)?; + let mut cx = Context::new(args, item.ident.clone(), &item.generics)?; + let mut res = make_proj_trait(&mut cx)?; + let packed_check = ensure_not_packed(&item)?; - let mut res = structs::parse(cx, item)?; + res.extend(structs::parse(cx, item)?); res.extend(packed_check); Ok(res) } Item::Enum(item) => { - let cx = Context::new(args, item.ident.clone(), &item.generics)?; + let mut cx = Context::new(args, item.ident.clone(), &item.generics)?; + let mut res = make_proj_trait(&mut cx)?; + // We don't need to check for '#[repr(packed)]', // since it does not apply to enums - enums::parse(cx, item) + res.extend(enums::parse(cx, item)); + Ok(res) } item => Err(error!(item, "may only be used on structs or enums")), } } +fn make_proj_trait(cx: &mut Context) -> Result { + let proj_trait = &cx.projected_trait; + let lifetime = &cx.lifetime; + let proj_ident = &cx.projected; + let proj_generics = proj_generics(&cx.generics, &cx.lifetime); + let proj_ty_generics = proj_generics.split_for_impl().1; + + let (orig_generics, _orig_ty_generics, orig_where_clause) = cx.generics.split_for_impl(); + + Ok(quote! { + trait #proj_trait #orig_generics { + fn project<#lifetime>(&#lifetime mut self) -> #proj_ident #proj_ty_generics #orig_where_clause; + } + }) + +} + fn ensure_not_packed(item: &ItemStruct) -> Result { for meta in item.attrs.iter().filter_map(|attr| attr.parse_meta().ok()) { if let Meta::List(l) = meta { diff --git a/pin-project-internal/src/pin_project/structs.rs b/pin-project-internal/src/pin_project/structs.rs index 4ed64bc1..83b887a6 100644 --- a/pin-project-internal/src/pin_project/structs.rs +++ b/pin-project-internal/src/pin_project/structs.rs @@ -26,16 +26,17 @@ pub(super) fn parse(mut cx: Context, mut item: ItemStruct) -> Result(self: ::core::pin::Pin<&#lifetime mut Self>) -> #proj_ident #proj_ty_generics { + impl #impl_generics #proj_trait #ty_generics for ::core::pin::Pin<&mut #orig_ident #ty_generics> #where_clause { + fn project<#lifetime>(&#lifetime mut self) -> #proj_ident #proj_ty_generics #where_clause { unsafe { - let this = ::core::pin::Pin::get_unchecked_mut(self); + let this = self.as_mut().get_unchecked_mut(); #proj_ident #proj_init } } diff --git a/pin-project-internal/src/utils.rs b/pin-project-internal/src/utils.rs index c411177a..8c482415 100644 --- a/pin-project-internal/src/utils.rs +++ b/pin-project-internal/src/utils.rs @@ -7,6 +7,10 @@ pub(crate) fn proj_ident(ident: &Ident) -> Ident { format_ident!("__{}Projection", ident) } +pub(crate) fn proj_trait_ident(ident: &Ident) -> Ident { + format_ident!("__{}ProjectionTrait", ident) +} + pub(crate) trait VecExt { fn find_remove(&mut self, ident: &str) -> Option; } diff --git a/src/lib.rs b/src/lib.rs index 2aebc6f9..60357506 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,7 +22,7 @@ //! } //! //! impl Foo { -//! fn baz(self: Pin<&mut Self>) { +//! fn baz(mut self: Pin<&mut Self>) { //! let this = self.project(); //! let _: Pin<&mut T> = this.future; // Pinned reference to the field //! let _: &mut U = this.field; // Normal reference to the field diff --git a/tests/pin_project.rs b/tests/pin_project.rs index 2d5d7243..5cfde74d 100644 --- a/tests/pin_project.rs +++ b/tests/pin_project.rs @@ -19,7 +19,8 @@ fn test_pin_project() { let mut foo = Foo { field1: 1, field2: 2 }; - let foo = Pin::new(&mut foo).project(); + let mut foo_orig = Pin::new(&mut foo); + let foo = foo_orig.project(); let x: Pin<&mut i32> = foo.field1; assert_eq!(*x, 1); @@ -27,9 +28,13 @@ fn test_pin_project() { let y: &mut i32 = foo.field2; assert_eq!(*y, 2); + assert_eq!(foo_orig.as_ref().field1, 1); + assert_eq!(foo_orig.as_ref().field2, 2); + let mut foo = Foo { field1: 1, field2: 2 }; - let foo = Pin::new(&mut foo).project(); + let mut foo = Pin::new(&mut foo); + let foo = foo.project(); let __FooProjection { field1, field2 } = foo; let _: Pin<&mut i32> = field1; @@ -42,7 +47,8 @@ fn test_pin_project() { let mut bar = Bar(1, 2); - let bar = Pin::new(&mut bar).project(); + let mut bar = Pin::new(&mut bar); + let bar = bar.project(); let x: Pin<&mut i32> = bar.0; assert_eq!(*x, 1); @@ -53,6 +59,7 @@ fn test_pin_project() { // enum #[pin_project] + #[derive(Eq, PartialEq, Debug)] enum Baz { Variant1(#[pin] A, B), Variant2 { @@ -65,7 +72,8 @@ fn test_pin_project() { let mut baz = Baz::Variant1(1, 2); - let baz = Pin::new(&mut baz).project(); + let mut baz_orig = Pin::new(&mut baz); + let baz = baz_orig.project(); match baz { __BazProjection::Variant1(x, y) => { @@ -82,9 +90,12 @@ fn test_pin_project() { __BazProjection::None => {} } + assert_eq!(Pin::into_ref(baz_orig).get_ref(), &Baz::Variant1(1, 2)); + let mut baz = Baz::Variant2 { field1: 3, field2: 4 }; - let mut baz = Pin::new(&mut baz).project(); + let mut baz = Pin::new(&mut baz); + let mut baz = baz.project(); match &mut baz { __BazProjection::Variant1(x, y) => { @@ -110,6 +121,31 @@ fn test_pin_project() { } } +#[test] +fn enum_project_set() { + + #[pin_project] + #[derive(Eq, PartialEq, Debug)] + enum Bar { + Variant1(#[pin] u8), + Variant2(bool) + } + + let mut bar = Bar::Variant1(25); + let mut bar_orig = Pin::new(&mut bar); + let bar_proj = bar_orig.project(); + + match bar_proj { + __BarProjection::Variant1(val) => { + let new_bar = Bar::Variant2(val.as_ref().get_ref() == &25); + bar_orig.set(new_bar); + }, + _ => unreachable!() + } + + assert_eq!(bar, Bar::Variant2(true)); +} + #[test] fn where_clause_and_associated_type_fields() { // struct diff --git a/tests/pinned_drop.rs b/tests/pinned_drop.rs index 6f331da0..f58a2cd4 100644 --- a/tests/pinned_drop.rs +++ b/tests/pinned_drop.rs @@ -14,7 +14,7 @@ pub struct Foo<'a> { } #[pinned_drop] -fn do_drop(foo: Pin<&mut Foo<'_>>) { +fn do_drop(mut foo: Pin<&mut Foo<'_>>) { **foo.project().was_dropped = true; }