Skip to content

Commit

Permalink
add nullability for describe() to MySQL
Browse files Browse the repository at this point in the history
improve errors with unknown result column type IDs in `query!()`

run cargo fmt
  • Loading branch information
abonander committed Feb 6, 2020
1 parent 093d496 commit b14c45b
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 33 deletions.
2 changes: 1 addition & 1 deletion sqlx-core/src/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ where
pub enum Nullability {
NonNull,
Nullable,
Unknown
Unknown,
}
9 changes: 7 additions & 2 deletions sqlx-core/src/mysql/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use std::sync::Arc;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;

use crate::describe::{Column, Describe};
use crate::describe::{Column, Describe, Nullability};
use crate::executor::Executor;
use crate::mysql::protocol::{
Capabilities, ColumnCount, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare,
ComStmtPrepareOk, Cursor, Decode, EofPacket, OkPacket, Row, TypeId,
ComStmtPrepareOk, Cursor, Decode, EofPacket, FieldFlags, OkPacket, Row, TypeId,
};
use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo};

Expand Down Expand Up @@ -257,6 +257,11 @@ impl MySqlConnection {
type_info: MySqlTypeInfo::from_column_def(&column),
name: column.column_alias.or(column.column),
table_id: column.table_alias.or(column.table),
nullability: if column.flags.contains(FieldFlags::NOT_NULL) {
Nullability::NonNull
} else {
Nullability::Nullable
},
});
}

Expand Down
29 changes: 28 additions & 1 deletion sqlx-core/src/mysql/protocol/type.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
use std::fmt::{self, Debug, Display, Formatter};

// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/binary__log__types_8h.html
// https://mariadb.com/kb/en/library/resultset/#field-types
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct TypeId(pub u8);

macro_rules! type_id_consts {
($(
pub const $name:ident: TypeId = TypeId($id:literal);
)*) => (
impl TypeId {
$(pub const $name: TypeId = TypeId($id);)*

#[doc(hidden)]
pub fn type_name(&self) -> &'static str {
match self.0 {
$($id => stringify!($name),)*
_ => "<unknown>"
}
}
}
)
}

impl Display for TypeId {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{} ({:#x})", self.type_name(), self.0)
}
}

// https://github.com/google/mysql/blob/c01fc2134d439282a21a2ddf687566e198ddee28/include/mysql_com.h#L429
impl TypeId {
type_id_consts! {
pub const NULL: TypeId = TypeId(6);

// String: CHAR, VARCHAR, TEXT
Expand All @@ -23,6 +49,7 @@ impl TypeId {
pub const SMALL_INT: TypeId = TypeId(2);
pub const INT: TypeId = TypeId(3);
pub const BIG_INT: TypeId = TypeId(8);
pub const MEDIUM_INT: TypeId = TypeId(9);

// Numeric: FLOAT, DOUBLE
pub const FLOAT: TypeId = TypeId(4);
Expand Down
7 changes: 6 additions & 1 deletion sqlx-core/src/mysql/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,17 @@ impl MySqlTypeInfo {
char_set: def.char_set,
}
}

#[doc(hidden)]
pub fn type_name(&self) -> &'static str {
self.id.type_name()
}
}

impl Display for MySqlTypeInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// TODO: Should we attempt to render the type *name* here?
write!(f, "{}", self.id.0)
write!(f, "ID {:#x}", self.id.0)
}
}

Expand Down
57 changes: 36 additions & 21 deletions sqlx-core/src/postgres/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use futures_util::{stream, FutureExt, StreamExt, TryStreamExt};
use crate::arguments::Arguments;
use crate::describe::{Column, Describe, Nullability};
use crate::encode::IsNull::No;
use crate::postgres::{PgArguments, PgRow, PgTypeInfo, Postgres};
use crate::postgres::protocol::{self, Encode, Field, Message, StatementId, TypeFormat, TypeId};
use crate::row::Row;
use crate::postgres::types::SharedStr;
use crate::postgres::{PgArguments, PgRow, PgTypeInfo, Postgres};
use crate::row::Row;

#[derive(Debug)]
enum Step {
Expand All @@ -37,7 +37,7 @@ impl super::PgConnection {
query,
param_types: &*args.types,
}
.encode(self.stream.buffer_mut());
.encode(self.stream.buffer_mut());

self.statement_cache.put(query.to_owned(), id);

Expand All @@ -59,7 +59,7 @@ impl super::PgConnection {
values: &*args.values,
result_formats: &[TypeFormat::Binary],
}
.encode(self.stream.buffer_mut());
.encode(self.stream.buffer_mut());
}

fn write_execute(&mut self, portal: &str, limit: i32) {
Expand Down Expand Up @@ -291,13 +291,14 @@ impl super::PgConnection {
let result_fields = result.map_or_else(Default::default, |r| r.fields);

// TODO: cache this result
let type_names = self.get_type_names(
params
.ids
.iter()
.cloned()
.chain(result_fields.iter().map(|field| field.type_id))
)
let type_names = self
.get_type_names(
params
.ids
.iter()
.cloned()
.chain(result_fields.iter().map(|field| field.type_id)),
)
.await?;

Ok(Describe {
Expand All @@ -307,20 +308,25 @@ impl super::PgConnection {
.map(|id| PgTypeInfo::new(*id, &type_names[&id.0]))
.collect::<Vec<_>>()
.into_boxed_slice(),
result_columns: self.map_result_columns(result_fields, type_names).await?
result_columns: self
.map_result_columns(result_fields, type_names)
.await?
.into_boxed_slice(),
})
}

async fn get_type_names(&mut self, ids: impl IntoIterator<Item = TypeId>) -> crate::Result<HashMap<u32, SharedStr>> {
async fn get_type_names(
&mut self,
ids: impl IntoIterator<Item = TypeId>,
) -> crate::Result<HashMap<u32, SharedStr>> {
let type_ids: HashSet<u32> = ids.into_iter().map(|id| id.0).collect::<HashSet<u32>>();

let mut query = "select types.type_id, pg_type.typname from (VALUES ".to_string();
let mut args = PgArguments::default();
let mut pushed = false;

// TODO: dedup this with the one below, ideally as an API we can export
for (i, (&type_id, bind)) in type_ids.iter().zip((1 .. ).step_by(2)).enumerate() {
for (i, (&type_id, bind)) in type_ids.iter().zip((1..).step_by(2)).enumerate() {
if pushed {
query += ", ";
}
Expand All @@ -334,8 +340,8 @@ impl super::PgConnection {
}

query += ") as types(idx, type_id) \
inner join pg_catalog.pg_type on pg_type.oid = type_id \
order by types.idx";
inner join pg_catalog.pg_type on pg_type.oid = type_id \
order by types.idx";

self.fetch(&query, args)
.map_ok(|row: PgRow| -> (u32, SharedStr) {
Expand All @@ -345,16 +351,22 @@ impl super::PgConnection {
.await
}

async fn map_result_columns(&mut self, fields: Box<[Field]>, type_names: HashMap<u32, SharedStr>) -> crate::Result<Vec<Column<Postgres>>> {
async fn map_result_columns(
&mut self,
fields: Box<[Field]>,
type_names: HashMap<u32, SharedStr>,
) -> crate::Result<Vec<Column<Postgres>>> {
use crate::describe::Nullability::*;

if fields.is_empty() { return Ok(vec![]); }
if fields.is_empty() {
return Ok(vec![]);
}

let mut query = "select col.idx, pg_attribute.attnotnull from (VALUES ".to_string();
let mut pushed = false;
let mut args = PgArguments::default();

for (i, (field, bind)) in fields.iter().zip((1 ..).step_by(3)).enumerate() {
for (i, (field, bind)) in fields.iter().zip((1..).step_by(3)).enumerate() {
if pushed {
query += ", ";
}
Expand All @@ -381,14 +393,17 @@ impl super::PgConnection {
let nonnull = row.get::<Option<bool>, _>(1);

if idx != fidx as i32 {
return Err(protocol_err!("missing field from query, field: {:?}", field).into());
return Err(
protocol_err!("missing field from query, field: {:?}", field).into(),
);
}

Ok(Column {
name: field.name,
table_id: field.table_id,
type_info: PgTypeInfo::new(field.type_id, &type_names[&field.type_id.0]),
nullability: nonnull.map(|nonnull| if nonnull { NonNull } else { Nullable })
nullability: nonnull
.map(|nonnull| if nonnull { NonNull } else { Nullable })
.unwrap_or(Unknown),
})
})
Expand Down
12 changes: 9 additions & 3 deletions sqlx-core/src/postgres/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,21 @@ pub struct PgTypeInfo {

impl PgTypeInfo {
pub(crate) fn new(id: TypeId, name: impl Into<SharedStr>) -> Self {
Self { id, name: Some(name.into()) }
Self {
id,
name: Some(name.into()),
}
}

/// Create a `PgTypeInfo` from a type's object identifier.
///
/// The object identifier of a type can be queried with
/// `SELECT oid FROM pg_type WHERE typname = <name>;`
pub fn with_oid(oid: u32) -> Self {
Self { id: TypeId(oid), name: None }
Self {
id: TypeId(oid),
name: None,
}
}

#[doc(hidden)]
Expand Down Expand Up @@ -101,4 +107,4 @@ impl fmt::Display for SharedStr {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.pad(self)
}
}
}
7 changes: 6 additions & 1 deletion sqlx-macros/src/query_macros/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ pub fn columns_to_rust<DB: DatabaseExt>(describe: &Describe<DB>) -> crate::Resul
let ident = parse_ident(name)?;

let type_ = <DB as DatabaseExt>::return_type_for_id(&column.type_info)
.ok_or_else(|| format!("unknown type: {}", &column.type_info))?
.ok_or_else(|| {
format!(
"unknown output type {} for column at position {} (name: {:?})",
column.type_info, i, column.name
)
})?
.parse::<TokenStream>()
.unwrap();

Expand Down
35 changes: 35 additions & 0 deletions tests/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,41 @@ async fn it_selects_null() -> anyhow::Result<()> {
Ok(())
}

#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn test_describe() -> anyhow::Result<()> {
use sqlx::describe::Nullability::*;

let mut conn = connect().await?;

let _ = conn
.send(
r#"
CREATE TEMPORARY TABLE describe_test (
id int primary key auto_increment,
name text not null,
hash blob
)
"#,
)
.await?;

let describe = conn
.describe("select nt.*, false from describe_test nt")
.await?;

assert_eq!(describe.result_columns[0].nullability, NonNull);
assert_eq!(describe.result_columns[0].type_info.type_name(), "INT");
assert_eq!(describe.result_columns[1].nullability, NonNull);
assert_eq!(describe.result_columns[1].type_info.type_name(), "TEXT");
assert_eq!(describe.result_columns[2].nullability, Nullable);
assert_eq!(describe.result_columns[2].type_info.type_name(), "TEXT");
assert_eq!(describe.result_columns[3].nullability, NonNull);
assert_eq!(describe.result_columns[3].type_info.type_name(), "BIG_INT");

Ok(())
}

#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn pool_immediately_fails_with_db_error() -> anyhow::Result<()> {
Expand Down
11 changes: 8 additions & 3 deletions tests/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,20 @@ async fn test_describe() -> anyhow::Result<()> {

let mut conn = connect().await?;

let _ = conn.send(r#"
let _ = conn
.send(
r#"
CREATE TEMP TABLE describe_test (
id SERIAL primary key,
name text not null,
hash bytea
)
"#).await?;
"#,
)
.await?;

let describe = conn.describe("select nt.*, false from describe_test nt")
let describe = conn
.describe("select nt.*, false from describe_test nt")
.await?;

assert_eq!(describe.result_columns[0].nullability, NonNull);
Expand Down

0 comments on commit b14c45b

Please sign in to comment.