Skip to content

Commit

Permalink
Merge pull request #4054 from weiznich/fix/auto_type_lifetimes
Browse files Browse the repository at this point in the history
Fix lifetimes with `#[auto_type]`
  • Loading branch information
weiznich committed Jun 12, 2024
1 parent 1062c27 commit 4d5fb4f
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 4 deletions.
23 changes: 23 additions & 0 deletions diesel_compile_tests/tests/fail/auto_type_life_times.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use diesel::dsl::*;
use diesel::prelude::*;

diesel::table! {
users {
id -> Integer,
name -> Text,
}
}

#[auto_type]
fn with_lifetime(name: &'_ str) -> _ {
users::table.filter(users::name.eq(name))
}

#[auto_type]
fn with_lifetime2(name: &str) -> _ {
users::table.filter(users::name.eq(name))
}

fn main() {
println!("Hello, world!");
}
49 changes: 49 additions & 0 deletions diesel_compile_tests/tests/fail/auto_type_life_times.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
error: `#[auto_type]` requires named lifetimes
--> tests/fail/auto_type_life_times.rs:12:25
|
12 | fn with_lifetime(name: &'_ str) -> _ {
| ^^

error: `#[auto_type]` requires named lifetimes
--> tests/fail/auto_type_life_times.rs:17:25
|
17 | fn with_lifetime2(name: &str) -> _ {
| ^^^^

error[E0106]: missing lifetime specifier
--> tests/fail/auto_type_life_times.rs:12:25
|
12 | fn with_lifetime(name: &'_ str) -> _ {
| ^^ expected named lifetime parameter
|
help: consider introducing a named lifetime parameter
|
12 | fn with_lifetime<'a>(name: &'a str) -> _ {
| ++++ ~~

error[E0106]: missing lifetime specifier
--> tests/fail/auto_type_life_times.rs:17:25
|
17 | fn with_lifetime2(name: &str) -> _ {
| ^ expected named lifetime parameter
|
help: consider introducing a named lifetime parameter
|
17 | fn with_lifetime2<'a>(name: &'a str) -> _ {
| ++++ ++

error: lifetime may not live long enough
--> tests/fail/auto_type_life_times.rs:13:5
|
12 | fn with_lifetime(name: &'_ str) -> _ {
| - let's call the lifetime of this reference `'1`
13 | users::table.filter(users::name.eq(name))
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'static`

error: lifetime may not live long enough
--> tests/fail/auto_type_life_times.rs:18:5
|
17 | fn with_lifetime2(name: &str) -> _ {
| - let's call the lifetime of this reference `'1`
18 | users::table.filter(users::name.eq(name))
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'static`
18 changes: 18 additions & 0 deletions diesel_derives/tests/auto_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,24 @@ fn test_normal_functions() -> _ {
))
}

#[auto_type]
fn with_lifetime<'a>(name: &'a str) -> _ {
users::table.filter(users::name.eq(name))
}

#[auto_type]
fn with_type_generics<'a, T>(name: &'a T) -> _
where
&'a T: diesel::expression::AsExpression<diesel::sql_types::Text>,
{
users::name.eq(name)
}

#[auto_type]
fn with_const_generics<const N: i32>() -> _ {
users::id.eq(N)
}

// #[auto_type]
// fn test_sql_fragment() -> _ {
// sql("foo")
Expand Down
2 changes: 1 addition & 1 deletion dsl_auto_type/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ either = "1"
heck = "0.5"
proc-macro2 = "1"
quote = "1"
syn = { version = "2", features = ["extra-traits", "full", "derive", "parsing"] }
syn = { version = "2", features = ["extra-traits", "full", "derive", "parsing", "visit"] }

[dev-dependencies]
diesel = { path = "../diesel" }
2 changes: 2 additions & 0 deletions dsl_auto_type/src/auto_type/expression_type_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ impl TypeInferrer<'_> {
Err(e) => self.register_error(e, expr.span()),
}
}

fn register_error(&self, error: syn::Error, infer_type_span: Span) -> syn::Type {
self.errors.borrow_mut().push(Rc::new(error));
parse_quote_spanned!(infer_type_span=> _)
}

fn try_infer_expression_type(
&self,
expr: &syn::Expr,
Expand Down
10 changes: 10 additions & 0 deletions dsl_auto_type/src/auto_type/local_variables_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ impl<'a, 'p> LocalVariablesMap<'a, 'p> {
Ok(())
}

pub(crate) fn process_const_generic(&mut self, const_generic: &'a syn::ConstParam) {
self.inner.map.insert(
&const_generic.ident,
LetStatementInferredType {
type_: const_generic.ty.clone(),
errors: Vec::new(),
},
);
}

/// Finishes a block inference for this map.
/// It may be initialized with `pat`s before (such as function parameters),
/// then this function is used to infer the type of the last expression in the block.
Expand Down
19 changes: 16 additions & 3 deletions dsl_auto_type/src/auto_type/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod case;
pub mod expression_type_inference;
mod local_variables_map;
mod referenced_generics;
mod settings_builder;

use {
Expand Down Expand Up @@ -134,6 +135,9 @@ pub(crate) fn auto_type_impl(
parent: None,
},
};
for const_generic in input_function.sig.generics.const_params() {
local_variables_map.process_const_generic(const_generic);
}
for function_param in &input_function.sig.inputs {
if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = function_param {
match local_variables_map.process_pat(pat, Some(ty), None) {
Expand Down Expand Up @@ -165,11 +169,19 @@ pub(crate) fn auto_type_impl(

let type_alias = match type_alias {
Some(type_alias) => {
// We're generating a type alias so we need to extract the necessary lifetimes and
// generic type parameters for that type alias
let type_alias_generics = referenced_generics::extract_referenced_generics(
&return_type,
&input_function.sig.generics,
&mut errors,
);

let vis = &input_function.vis;
input_function.sig.output = parse_quote!(-> #type_alias);
input_function.sig.output = parse_quote!(-> #type_alias #type_alias_generics);
quote! {
#[allow(non_camel_case_types)]
#vis type #type_alias = #return_type;
#vis type #type_alias #type_alias_generics = #return_type;
}
}
None => {
Expand All @@ -180,12 +192,13 @@ pub(crate) fn auto_type_impl(

let mut res = quote! {
#type_alias
#[allow(clippy::needless_lifetimes)]
#input_function
};

for error in errors {
// Extracting from the `Rc` only if it's the last reference is an elegant way to
// deduplicate errors For this to work it is necessary that the rest of
// deduplicate errors. For this to work it is necessary that the rest of
// the errors (those from the local variables map that weren't used) are
// dropped before, which is the case here, and that we are iterating on the
// errors in an owned manner.
Expand Down
109 changes: 109 additions & 0 deletions dsl_auto_type/src/auto_type/referenced_generics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use std::rc::Rc;
use syn::parse_quote;
use syn::visit::{self, Visit};
use syn::{Ident, Lifetime};

pub(crate) fn extract_referenced_generics(
ty: &syn::Type,
generics: &syn::Generics,
errors: &mut Vec<Rc<syn::Error>>,
) -> syn::Generics {
struct Visitor<'g, 'errs> {
lifetimes: Vec<(&'g Lifetime, bool)>,
type_parameters: Vec<(&'g Ident, bool)>,
errors: &'errs mut Vec<Rc<syn::Error>>,
}

let mut visitor = Visitor {
lifetimes: generics
.lifetimes()
.map(|lt| (&lt.lifetime, false))
.collect(),
type_parameters: generics
.type_params()
.map(|tp| (&tp.ident, false))
.collect(),
errors,
};
visitor.lifetimes.sort_unstable();
visitor.type_parameters.sort_unstable();

impl<'ast> Visit<'ast> for Visitor<'_, '_> {
fn visit_lifetime(&mut self, lifetime: &'ast Lifetime) {
if lifetime.ident == "_" {
self.errors.push(Rc::new(syn::Error::new_spanned(
lifetime,
"`#[auto_type]` requires named lifetimes",
)));
} else if lifetime.ident != "static" {
if let Ok(lifetime_idx) = self
.lifetimes
.binary_search_by_key(&lifetime, |(lt, _)| *lt)
{
self.lifetimes[lifetime_idx].1 = true;
}
}
visit::visit_lifetime(self, lifetime)
}

fn visit_type_reference(&mut self, reference: &'ast syn::TypeReference) {
if reference.lifetime.is_none() {
self.errors.push(Rc::new(syn::Error::new_spanned(
reference,
"`#[auto_type]` requires named lifetimes",
)));
}
visit::visit_type_reference(self, reference)
}

fn visit_type_path(&mut self, type_path: &'ast syn::TypePath) {
if let Some(path_ident) = type_path.path.get_ident() {
if let Ok(type_param_idx) = self
.type_parameters
.binary_search_by_key(&path_ident, |tp| tp.0)
{
self.type_parameters[type_param_idx].1 = true;
}
}
visit::visit_type_path(self, type_path)
}
}

visitor.visit_type(ty);

let generic_params: syn::punctuated::Punctuated<syn::GenericParam, _> = generics
.params
.iter()
.filter_map(|param| match param {
syn::GenericParam::Lifetime(lt)
if visitor
.lifetimes
.binary_search(&(&lt.lifetime, true))
.is_ok() =>
{
let lt = &lt.lifetime;
Some(parse_quote!(#lt))
}
syn::GenericParam::Type(tp)
if visitor
.type_parameters
.binary_search(&(&tp.ident, true))
.is_ok() =>
{
let ident = &tp.ident;
Some(parse_quote!(#ident))
}
_ => None::<syn::GenericParam>,
})
.collect();

// We need to not set the lt_token and gt_token if `params` is empty to get
// a reasonable error message for the case that there is no lifetime specifier
// but we need one
syn::Generics {
lt_token: (!generic_params.is_empty()).then(Default::default),
gt_token: (!generic_params.is_empty()).then(Default::default),
params: generic_params,
where_clause: None,
}
}

0 comments on commit 4d5fb4f

Please sign in to comment.