Skip to content

Commit

Permalink
Add JSON support to FromRow derive (#2121)
Browse files Browse the repository at this point in the history
* Add `json` attribute to `FromRow` derive

* Add docs and fix formatting
  • Loading branch information
95ulisse authored Jul 31, 2023
1 parent 84f21e9 commit 9463b75
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 30 deletions.
45 changes: 43 additions & 2 deletions sqlx-core/src/from_row.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::error::Error;
use crate::row::Row;
use crate::{error::Error, row::Row};

/// A record that can be built from a row returned by the database.
///
Expand Down Expand Up @@ -210,6 +209,48 @@ use crate::row::Row;
///
/// In MySql, `BigInt` type matches `i64`, but you can convert it to `u64` by `try_from`.
///
/// #### `json`
///
/// If your database supports a JSON type, you can leverage `#[sqlx(json)]`
/// to automatically integrate JSON deserialization in your [`FromRow`] implementation using [`serde`](https://docs.rs/serde/latest/serde/).
///
/// ```rust,ignore
/// #[derive(serde::Deserialize)]
/// struct Data {
/// field1: String,
/// field2: u64
/// }
///
/// #[derive(sqlx::FromRow)]
/// struct User {
/// id: i32,
/// name: String,
/// #[sqlx(json)]
/// metadata: Data
/// }
/// ```
///
/// Given a query like the following:
///
/// ```sql
/// SELECT
/// 1 AS id,
/// 'Name' AS name,
/// JSON_OBJECT('field1', 'value1', 'field2', 42) AS metadata
/// ```
///
/// The `metadata` field will be deserialized used its `serde::Deserialize` implementation:
///
/// ```rust,ignore
/// User {
/// id: 1,
/// name: "Name",
/// metadata: Data {
/// field1: "value1",
/// field2: 42
/// }
/// }
/// ```
pub trait FromRow<'r, R: Row>: Sized {
fn from_row(row: &'r R) -> Result<Self, Error>;
}
Expand Down
19 changes: 15 additions & 4 deletions sqlx-macros-core/src/derives/attributes.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::Comma;
use syn::{Attribute, DeriveInput, Field, Lit, Meta, MetaNameValue, NestedMeta, Type, Variant};
use syn::{
punctuated::Punctuated, spanned::Spanned, token::Comma, Attribute, DeriveInput, Field, Lit,
Meta, MetaNameValue, NestedMeta, Type, Variant,
};

macro_rules! assert_attribute {
($e:expr, $err:expr, $input:expr) => {
Expand Down Expand Up @@ -65,6 +65,7 @@ pub struct SqlxChildAttributes {
pub flatten: bool,
pub try_from: Option<Type>,
pub skip: bool,
pub json: bool,
}

pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContainerAttributes> {
Expand Down Expand Up @@ -164,6 +165,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
let mut try_from = None;
let mut flatten = false;
let mut skip: bool = false;
let mut json = false;

for attr in input.iter().filter(|a| a.path.is_ident("sqlx")) {
let meta = attr
Expand All @@ -187,12 +189,20 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
Meta::Path(path) if path.is_ident("default") => default = true,
Meta::Path(path) if path.is_ident("flatten") => flatten = true,
Meta::Path(path) if path.is_ident("skip") => skip = true,
Meta::Path(path) if path.is_ident("json") => json = true,
u => fail!(u, "unexpected attribute"),
},
u => fail!(u, "unexpected attribute"),
}
}
}

if json && flatten {
fail!(
attr,
"Cannot use `json` and `flatten` together on the same field"
);
}
}

Ok(SqlxChildAttributes {
Expand All @@ -201,6 +211,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
flatten,
try_from,
skip,
json,
})
}

Expand Down
70 changes: 46 additions & 24 deletions sqlx-macros-core/src/derives/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,45 +78,67 @@ fn expand_derive_from_row_struct(
));
}

let expr: Expr = match (attributes.flatten, attributes.try_from) {
(true, None) => {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
}
(false, None) => {
let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();

let expr: Expr = match (attributes.flatten, attributes.try_from, attributes.json) {
// <No attributes>
(false, None, false) => {
predicates
.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s))
}
(true,Some(try_from)) => {
// Flatten
(true, None, false) => {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
}
// Flatten + Try from
(true, Some(try_from), false) => {
predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(row).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
(false,Some(try_from)) => {
// Flatten + Json
(true, _, true) => {
panic!("Cannot use both flatten and json")
}
// Try from
(false, Some(try_from), false) => {
predicates
.push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#try_from: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
// Try from + Json
(false, Some(try_from), true) => {
predicates
.push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::types::Type<R::Database>));

parse_quote!(
row.try_get::<::sqlx::types::Json<_>, _>(#id_s).and_then(|v|
<#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v.0)
.map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))
)
)
},
// Json
(false, None, true) => {
predicates
.push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::types::Type<R::Database>));

parse_quote!(row.try_get::<::sqlx::types::Json<_>, _>(#id_s).map(|x| x.0))
},
};

if attributes.default {
Expand Down
62 changes: 62 additions & 0 deletions tests/mysql/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,4 +468,66 @@ async fn test_try_from_attr_with_complex_type() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn test_from_row_json_attr() -> anyhow::Result<()> {
#[derive(serde::Deserialize)]
struct J {
a: u32,
b: u32,
}

#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(json)]
j: J,
}

let mut conn = new::<MySql>().await?;

let record = sqlx::query_as::<_, Record>("select json_object('a', 1, 'b', 2) as j")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.j.a, 1);
assert_eq!(record.j.b, 2);

Ok(())
}

#[sqlx_macros::test]
async fn test_from_row_json_try_from_attr() -> anyhow::Result<()> {
#[derive(serde::Deserialize)]
struct J {
a: u32,
b: u32,
}

// Non-deserializable
struct J2 {
sum: u32,
}

impl std::convert::From<J> for J2 {
fn from(j: J) -> Self {
Self { sum: j.a + j.b }
}
}

#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(json, try_from = "J")]
j: J2,
}

let mut conn = new::<MySql>().await?;

let record = sqlx::query_as::<_, Record>("select json_object('a', 1, 'b', 2) as j")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.j.sum, 3);

Ok(())
}

// we don't emit bind parameter type-checks for MySQL so testing the overrides is redundant

0 comments on commit 9463b75

Please sign in to comment.