From b14c45bf9f79a3b18cf683de4e4729cb33b1dedc Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 5 Feb 2020 18:55:25 -0800 Subject: [PATCH] add nullability for describe() to MySQL improve errors with unknown result column type IDs in `query!()` run cargo fmt --- sqlx-core/src/describe.rs | 2 +- sqlx-core/src/mysql/executor.rs | 9 +++- sqlx-core/src/mysql/protocol/type.rs | 29 ++++++++++++- sqlx-core/src/mysql/types/mod.rs | 7 +++- sqlx-core/src/postgres/executor.rs | 57 ++++++++++++++++---------- sqlx-core/src/postgres/types/mod.rs | 12 ++++-- sqlx-macros/src/query_macros/output.rs | 7 +++- tests/mysql.rs | 35 ++++++++++++++++ tests/postgres.rs | 11 +++-- 9 files changed, 136 insertions(+), 33 deletions(-) diff --git a/sqlx-core/src/describe.rs b/sqlx-core/src/describe.rs index 3133dd6f68..61eb31d1aa 100644 --- a/sqlx-core/src/describe.rs +++ b/sqlx-core/src/describe.rs @@ -62,5 +62,5 @@ where pub enum Nullability { NonNull, Nullable, - Unknown + Unknown, } diff --git a/sqlx-core/src/mysql/executor.rs b/sqlx-core/src/mysql/executor.rs index 8333e1a11c..eae01a113b 100644 --- a/sqlx-core/src/mysql/executor.rs +++ b/sqlx-core/src/mysql/executor.rs @@ -4,11 +4,11 @@ use std::sync::Arc; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; -use crate::describe::{Column, Describe}; +use crate::describe::{Column, Describe, Nullability}; use crate::executor::Executor; use crate::mysql::protocol::{ Capabilities, ColumnCount, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare, - ComStmtPrepareOk, Cursor, Decode, EofPacket, OkPacket, Row, TypeId, + ComStmtPrepareOk, Cursor, Decode, EofPacket, FieldFlags, OkPacket, Row, TypeId, }; use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo}; @@ -257,6 +257,11 @@ impl MySqlConnection { type_info: MySqlTypeInfo::from_column_def(&column), name: column.column_alias.or(column.column), table_id: column.table_alias.or(column.table), + nullability: if column.flags.contains(FieldFlags::NOT_NULL) { + Nullability::NonNull + } else { + Nullability::Nullable + }, }); } diff --git a/sqlx-core/src/mysql/protocol/type.rs b/sqlx-core/src/mysql/protocol/type.rs index 6284401b08..b242ed02b0 100644 --- a/sqlx-core/src/mysql/protocol/type.rs +++ b/sqlx-core/src/mysql/protocol/type.rs @@ -1,10 +1,36 @@ +use std::fmt::{self, Debug, Display, Formatter}; + // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/binary__log__types_8h.html // https://mariadb.com/kb/en/library/resultset/#field-types #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct TypeId(pub u8); +macro_rules! type_id_consts { + ($( + pub const $name:ident: TypeId = TypeId($id:literal); + )*) => ( + impl TypeId { + $(pub const $name: TypeId = TypeId($id);)* + + #[doc(hidden)] + pub fn type_name(&self) -> &'static str { + match self.0 { + $($id => stringify!($name),)* + _ => "" + } + } + } + ) +} + +impl Display for TypeId { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{} ({:#x})", self.type_name(), self.0) + } +} + // https://github.com/google/mysql/blob/c01fc2134d439282a21a2ddf687566e198ddee28/include/mysql_com.h#L429 -impl TypeId { +type_id_consts! { pub const NULL: TypeId = TypeId(6); // String: CHAR, VARCHAR, TEXT @@ -23,6 +49,7 @@ impl TypeId { pub const SMALL_INT: TypeId = TypeId(2); pub const INT: TypeId = TypeId(3); pub const BIG_INT: TypeId = TypeId(8); + pub const MEDIUM_INT: TypeId = TypeId(9); // Numeric: FLOAT, DOUBLE pub const FLOAT: TypeId = TypeId(4); diff --git a/sqlx-core/src/mysql/types/mod.rs b/sqlx-core/src/mysql/types/mod.rs index 39e8252b1d..730bba8728 100644 --- a/sqlx-core/src/mysql/types/mod.rs +++ b/sqlx-core/src/mysql/types/mod.rs @@ -49,12 +49,17 @@ impl MySqlTypeInfo { char_set: def.char_set, } } + + #[doc(hidden)] + pub fn type_name(&self) -> &'static str { + self.id.type_name() + } } impl Display for MySqlTypeInfo { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // TODO: Should we attempt to render the type *name* here? - write!(f, "{}", self.id.0) + write!(f, "ID {:#x}", self.id.0) } } diff --git a/sqlx-core/src/postgres/executor.rs b/sqlx-core/src/postgres/executor.rs index 5e100a40c2..213e37776e 100644 --- a/sqlx-core/src/postgres/executor.rs +++ b/sqlx-core/src/postgres/executor.rs @@ -10,10 +10,10 @@ use futures_util::{stream, FutureExt, StreamExt, TryStreamExt}; use crate::arguments::Arguments; use crate::describe::{Column, Describe, Nullability}; use crate::encode::IsNull::No; -use crate::postgres::{PgArguments, PgRow, PgTypeInfo, Postgres}; use crate::postgres::protocol::{self, Encode, Field, Message, StatementId, TypeFormat, TypeId}; -use crate::row::Row; use crate::postgres::types::SharedStr; +use crate::postgres::{PgArguments, PgRow, PgTypeInfo, Postgres}; +use crate::row::Row; #[derive(Debug)] enum Step { @@ -37,7 +37,7 @@ impl super::PgConnection { query, param_types: &*args.types, } - .encode(self.stream.buffer_mut()); + .encode(self.stream.buffer_mut()); self.statement_cache.put(query.to_owned(), id); @@ -59,7 +59,7 @@ impl super::PgConnection { values: &*args.values, result_formats: &[TypeFormat::Binary], } - .encode(self.stream.buffer_mut()); + .encode(self.stream.buffer_mut()); } fn write_execute(&mut self, portal: &str, limit: i32) { @@ -291,13 +291,14 @@ impl super::PgConnection { let result_fields = result.map_or_else(Default::default, |r| r.fields); // TODO: cache this result - let type_names = self.get_type_names( - params - .ids - .iter() - .cloned() - .chain(result_fields.iter().map(|field| field.type_id)) - ) + let type_names = self + .get_type_names( + params + .ids + .iter() + .cloned() + .chain(result_fields.iter().map(|field| field.type_id)), + ) .await?; Ok(Describe { @@ -307,12 +308,17 @@ impl super::PgConnection { .map(|id| PgTypeInfo::new(*id, &type_names[&id.0])) .collect::>() .into_boxed_slice(), - result_columns: self.map_result_columns(result_fields, type_names).await? + result_columns: self + .map_result_columns(result_fields, type_names) + .await? .into_boxed_slice(), }) } - async fn get_type_names(&mut self, ids: impl IntoIterator) -> crate::Result> { + async fn get_type_names( + &mut self, + ids: impl IntoIterator, + ) -> crate::Result> { let type_ids: HashSet = ids.into_iter().map(|id| id.0).collect::>(); let mut query = "select types.type_id, pg_type.typname from (VALUES ".to_string(); @@ -320,7 +326,7 @@ impl super::PgConnection { let mut pushed = false; // TODO: dedup this with the one below, ideally as an API we can export - for (i, (&type_id, bind)) in type_ids.iter().zip((1 .. ).step_by(2)).enumerate() { + for (i, (&type_id, bind)) in type_ids.iter().zip((1..).step_by(2)).enumerate() { if pushed { query += ", "; } @@ -334,8 +340,8 @@ impl super::PgConnection { } query += ") as types(idx, type_id) \ - inner join pg_catalog.pg_type on pg_type.oid = type_id \ - order by types.idx"; + inner join pg_catalog.pg_type on pg_type.oid = type_id \ + order by types.idx"; self.fetch(&query, args) .map_ok(|row: PgRow| -> (u32, SharedStr) { @@ -345,16 +351,22 @@ impl super::PgConnection { .await } - async fn map_result_columns(&mut self, fields: Box<[Field]>, type_names: HashMap) -> crate::Result>> { + async fn map_result_columns( + &mut self, + fields: Box<[Field]>, + type_names: HashMap, + ) -> crate::Result>> { use crate::describe::Nullability::*; - if fields.is_empty() { return Ok(vec![]); } + if fields.is_empty() { + return Ok(vec![]); + } let mut query = "select col.idx, pg_attribute.attnotnull from (VALUES ".to_string(); let mut pushed = false; let mut args = PgArguments::default(); - for (i, (field, bind)) in fields.iter().zip((1 ..).step_by(3)).enumerate() { + for (i, (field, bind)) in fields.iter().zip((1..).step_by(3)).enumerate() { if pushed { query += ", "; } @@ -381,14 +393,17 @@ impl super::PgConnection { let nonnull = row.get::, _>(1); if idx != fidx as i32 { - return Err(protocol_err!("missing field from query, field: {:?}", field).into()); + return Err( + protocol_err!("missing field from query, field: {:?}", field).into(), + ); } Ok(Column { name: field.name, table_id: field.table_id, type_info: PgTypeInfo::new(field.type_id, &type_names[&field.type_id.0]), - nullability: nonnull.map(|nonnull| if nonnull { NonNull } else { Nullable }) + nullability: nonnull + .map(|nonnull| if nonnull { NonNull } else { Nullable }) .unwrap_or(Unknown), }) }) diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index 2b50bb582e..8bd04fdc38 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -26,7 +26,10 @@ pub struct PgTypeInfo { impl PgTypeInfo { pub(crate) fn new(id: TypeId, name: impl Into) -> Self { - Self { id, name: Some(name.into()) } + Self { + id, + name: Some(name.into()), + } } /// Create a `PgTypeInfo` from a type's object identifier. @@ -34,7 +37,10 @@ impl PgTypeInfo { /// The object identifier of a type can be queried with /// `SELECT oid FROM pg_type WHERE typname = ;` pub fn with_oid(oid: u32) -> Self { - Self { id: TypeId(oid), name: None } + Self { + id: TypeId(oid), + name: None, + } } #[doc(hidden)] @@ -101,4 +107,4 @@ impl fmt::Display for SharedStr { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fmt.pad(self) } -} \ No newline at end of file +} diff --git a/sqlx-macros/src/query_macros/output.rs b/sqlx-macros/src/query_macros/output.rs index 089223517e..867bc914e5 100644 --- a/sqlx-macros/src/query_macros/output.rs +++ b/sqlx-macros/src/query_macros/output.rs @@ -25,7 +25,12 @@ pub fn columns_to_rust(describe: &Describe) -> crate::Resul let ident = parse_ident(name)?; let type_ = ::return_type_for_id(&column.type_info) - .ok_or_else(|| format!("unknown type: {}", &column.type_info))? + .ok_or_else(|| { + format!( + "unknown output type {} for column at position {} (name: {:?})", + column.type_info, i, column.name + ) + })? .parse::() .unwrap(); diff --git a/tests/mysql.rs b/tests/mysql.rs index b67ddb9dd2..6124035aca 100644 --- a/tests/mysql.rs +++ b/tests/mysql.rs @@ -65,6 +65,41 @@ async fn it_selects_null() -> anyhow::Result<()> { Ok(()) } +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_describe() -> anyhow::Result<()> { + use sqlx::describe::Nullability::*; + + let mut conn = connect().await?; + + let _ = conn + .send( + r#" + CREATE TEMPORARY TABLE describe_test ( + id int primary key auto_increment, + name text not null, + hash blob + ) + "#, + ) + .await?; + + let describe = conn + .describe("select nt.*, false from describe_test nt") + .await?; + + assert_eq!(describe.result_columns[0].nullability, NonNull); + assert_eq!(describe.result_columns[0].type_info.type_name(), "INT"); + assert_eq!(describe.result_columns[1].nullability, NonNull); + assert_eq!(describe.result_columns[1].type_info.type_name(), "TEXT"); + assert_eq!(describe.result_columns[2].nullability, Nullable); + assert_eq!(describe.result_columns[2].type_info.type_name(), "TEXT"); + assert_eq!(describe.result_columns[3].nullability, NonNull); + assert_eq!(describe.result_columns[3].type_info.type_name(), "BIG_INT"); + + Ok(()) +} + #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn pool_immediately_fails_with_db_error() -> anyhow::Result<()> { diff --git a/tests/postgres.rs b/tests/postgres.rs index b592e27b25..99b7c616ca 100644 --- a/tests/postgres.rs +++ b/tests/postgres.rs @@ -163,15 +163,20 @@ async fn test_describe() -> anyhow::Result<()> { let mut conn = connect().await?; - let _ = conn.send(r#" + let _ = conn + .send( + r#" CREATE TEMP TABLE describe_test ( id SERIAL primary key, name text not null, hash bytea ) - "#).await?; + "#, + ) + .await?; - let describe = conn.describe("select nt.*, false from describe_test nt") + let describe = conn + .describe("select nt.*, false from describe_test nt") .await?; assert_eq!(describe.result_columns[0].nullability, NonNull);