Skip to content

Commit

Permalink
add nullability info to Describe
Browse files Browse the repository at this point in the history
implement nullability check for Postgres as a query on pg_attribute

implement type name fetching for Postgres as part of `describe()`

add nullability for describe() to MySQL

improve errors with unknown result column type IDs in `query!()`

run cargo fmt and fix warnings

improve error when feature gates for chrono/uuid types is not turned on

workflows/rust: add step to UI-test missing optional features

improve error for unsupported/feature-gated input parameter types

fix `PgConnection::get_type_names()` for empty type IDs list

fix `tests::mysql::test_describe()` on MariaDB 10.4

copy-edit unsupported/feature-gated type errors in `query!()`

Postgres: fix SQL type of string array

Co-Authored-By: Anthony Dodd <Dodd.AnthonyJosiah@gmail.com>
  • Loading branch information
abonander and thedodd committed Mar 3, 2020
1 parent 56c0fc3 commit 30d45e8
Show file tree
Hide file tree
Showing 33 changed files with 651 additions and 62 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ jobs:
env:
DATABASE_URL: postgres://postgres:postgres@localhost:${{ job.services.postgres.ports[5432] }}/postgres

# UI feature gate tests: async-std
- run: cargo test --no-default-features --features 'runtime-async-std postgres macros tls'
env:
DATABASE_URL: postgres://postgres:postgres@localhost:${{ job.services.postgres.ports[5432] }}/postgres

# UI feature gate tests: tokio
- run: cargo test --no-default-features --features 'runtime-tokio postgres macros tls'
env:
DATABASE_URL: postgres://postgres:postgres@localhost:${{ job.services.postgres.ports[5432] }}/postgres

mysql:
needs: build
runs-on: ubuntu-latest
Expand Down Expand Up @@ -176,6 +186,21 @@ jobs:
# NOTE: Github Actions' YML parser doesn't handle multiline strings correctly
DATABASE_URL: mysql://root:password@localhost:${{ job.services.mysql.ports[3306] }}/sqlx?ssl-mode=VERIFY_CA&ssl-ca=%2Fdata%2Fmysql%2Fca.pem

# UI feature gate tests: async-std
- run: cargo test --no-default-features --features 'runtime-async-std mysql macros tls' --test ui-tests
env:
# pass the path to the CA that the MySQL service generated
# NOTE: Github Actions' YML parser doesn't handle multiline strings correctly
DATABASE_URL: mysql://root:password@localhost:${{ job.services.mysql.ports[3306] }}/sqlx?ssl-mode=VERIFY_CA&ssl-ca=%2Fdata%2Fmysql%2Fca.pem

# UI feature gate tests: tokio
- run: cargo test --no-default-features --features 'runtime-tokio mysql macros tls' --test ui-tests
env:
# pass the path to the CA that the MySQL service generated
# NOTE: Github Actions' YML parser doesn't handle multiline strings correctly
DATABASE_URL: mysql://root:password@localhost:${{ job.services.mysql.ports[3306] }}/sqlx?ssl-mode=VERIFY_CA&ssl-ca=%2Fdata%2Fmysql%2Fca.pem


mariadb:
needs: build
runs-on: ubuntu-latest
Expand Down Expand Up @@ -225,3 +250,13 @@ jobs:
- run: cargo test --no-default-features --features 'runtime-tokio mysql macros uuid chrono tls'
env:
DATABASE_URL: mariadb://root:password@localhost:${{ job.services.mariadb.ports[3306] }}/sqlx

# UI feature gate tests: async-std
- run: cargo test --no-default-features --features 'runtime-async-std mysql macros tls'
env:
DATABASE_URL: mariadb://root:password@localhost:${{ job.services.mariadb.ports[3306] }}/sqlx

# UI feature gate tests: tokio
- run: cargo test --no-default-features --features 'runtime-tokio mysql macros tls'
env:
DATABASE_URL: mariadb://root:password@localhost:${{ job.services.mariadb.ports[3306] }}/sqlx
3 changes: 3 additions & 0 deletions sqlx-core/src/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ where
pub name: Option<Box<str>>,
pub table_id: Option<DB::TableId>,
pub type_info: DB::TypeInfo,
/// Whether or not the column cannot be `NULL` (or if that is even knowable).
pub non_null: Option<bool>,
}

impl<DB> Debug for Column<DB>
Expand All @@ -53,6 +55,7 @@ where
.field("name", &self.name)
.field("table_id", &self.table_id)
.field("type_id", &self.type_info)
.field("nonnull", &self.non_null)
.finish()
}
}
5 changes: 3 additions & 2 deletions sqlx-core/src/mysql/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use std::sync::Arc;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;

use crate::describe::{Column, Describe};
use crate::describe::{Column, Describe, Nullability};
use crate::executor::Executor;
use crate::mysql::protocol::{
Capabilities, ColumnCount, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare,
ComStmtPrepareOk, Cursor, Decode, EofPacket, OkPacket, Row, TypeId,
ComStmtPrepareOk, Cursor, Decode, EofPacket, FieldFlags, OkPacket, Row, TypeId,
};
use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo};

Expand Down Expand Up @@ -257,6 +257,7 @@ impl MySqlConnection {
type_info: MySqlTypeInfo::from_column_def(&column),
name: column.column_alias.or(column.column),
table_id: column.table_alias.or(column.table),
non_null: Some(!column.flags.contains(FieldFlags::NOT_NULL)),
});
}

Expand Down
29 changes: 28 additions & 1 deletion sqlx-core/src/mysql/protocol/type.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
use std::fmt::{self, Debug, Display, Formatter};

// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/binary__log__types_8h.html
// https://mariadb.com/kb/en/library/resultset/#field-types
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct TypeId(pub u8);

macro_rules! type_id_consts {
($(
pub const $name:ident: TypeId = TypeId($id:literal);
)*) => (
impl TypeId {
$(pub const $name: TypeId = TypeId($id);)*

#[doc(hidden)]
pub fn type_name(&self) -> &'static str {
match self.0 {
$($id => stringify!($name),)*
_ => "<unknown>"
}
}
}
)
}

impl Display for TypeId {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{} ({:#x})", self.type_name(), self.0)
}
}

// https://github.com/google/mysql/blob/c01fc2134d439282a21a2ddf687566e198ddee28/include/mysql_com.h#L429
impl TypeId {
type_id_consts! {
pub const NULL: TypeId = TypeId(6);

// String: CHAR, VARCHAR, TEXT
Expand All @@ -23,6 +49,7 @@ impl TypeId {
pub const SMALL_INT: TypeId = TypeId(2);
pub const INT: TypeId = TypeId(3);
pub const BIG_INT: TypeId = TypeId(8);
pub const MEDIUM_INT: TypeId = TypeId(9);

// Numeric: FLOAT, DOUBLE
pub const FLOAT: TypeId = TypeId(4);
Expand Down
20 changes: 18 additions & 2 deletions sqlx-core/src/mysql/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,28 @@ impl MySqlTypeInfo {
char_set: def.char_set,
}
}

#[doc(hidden)]
pub fn type_name(&self) -> &'static str {
self.id.type_name()
}

#[doc(hidden)]
pub fn type_feature_gate(&self) -> Option<&'static str> {
match self.id {
TypeId::DATE | TypeId::TIME | TypeId::DATETIME | TypeId::TIMESTAMP => Some("chrono"),
_ => None,
}
}
}

impl Display for MySqlTypeInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// TODO: Should we attempt to render the type *name* here?
write!(f, "{}", self.id.0)
if self.id.type_name() != "<unknown>" {
write!(f, "{}", self.id.type_name())
} else {
write!(f, "ID {:#x}", self.id.0)
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions sqlx-core/src/postgres/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ fn parse_row_description(rd: RowDescription) -> (HashMap<Box<str>, usize>, Vec<T

// Used to describe the incoming results
// We store the column map in an Arc and share it among all rows
async fn describe(
async fn expect_desc(
conn: &mut PgConnection,
) -> crate::Result<(HashMap<Box<str>, usize>, Vec<TypeFormat>)> {
let description: Option<_> = loop {
Expand Down Expand Up @@ -108,7 +108,7 @@ async fn get_or_describe(
if !conn.cache_statement_columns.contains_key(&statement)
|| !conn.cache_statement_formats.contains_key(&statement)
{
let (columns, formats) = describe(conn).await?;
let (columns, formats) = expect_desc(conn).await?;

conn.cache_statement_columns
.insert(statement, Arc::new(columns));
Expand Down
159 changes: 144 additions & 15 deletions sqlx-core/src/postgres/executor.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
use std::collections::{HashMap, HashSet};
use std::fmt::Write;

use futures_core::future::BoxFuture;
use futures_util::{stream, StreamExt, TryStreamExt};

use crate::arguments::Arguments;
use crate::cursor::Cursor;
use crate::describe::{Column, Describe};
use crate::executor::{Execute, Executor, RefExecutor};
use crate::postgres::protocol::{
self, CommandComplete, Message, ParameterDescription, RowDescription, StatementId, TypeFormat,
self, CommandComplete, Field, Message, ParameterDescription, RowDescription, StatementId,
TypeFormat, TypeId,
};
use crate::postgres::{PgArguments, PgConnection, PgCursor, PgTypeInfo, Postgres};
use crate::postgres::types::SharedStr;
use crate::postgres::{PgArguments, PgConnection, PgCursor, PgRow, PgTypeInfo, Postgres};
use crate::row::Row;

impl PgConnection {
pub(crate) fn write_simple_query(&mut self, query: &str) {
Expand Down Expand Up @@ -132,13 +140,14 @@ impl PgConnection {
&'e mut self,
query: &'q str,
) -> crate::Result<Describe<Postgres>> {
self.is_ready = false;

let statement = self.write_prepare(query, &Default::default());

self.write_describe(protocol::Describe::Statement(statement));
self.write_sync();

self.stream.flush().await?;
self.wait_until_ready().await?;

let params = loop {
match self.stream.read().await? {
Expand Down Expand Up @@ -171,27 +180,147 @@ impl PgConnection {
}
};

self.wait_until_ready().await?;

let result_fields = result.map_or_else(Default::default, |r| r.fields);

// TODO: cache this result
let type_names = self
.get_type_names(
params
.ids
.iter()
.cloned()
.chain(result_fields.iter().map(|field| field.type_id)),
)
.await?;

Ok(Describe {
param_types: params
.ids
.iter()
.map(|id| PgTypeInfo::new(*id))
.map(|id| PgTypeInfo::new(*id, &type_names[&id.0]))
.collect::<Vec<_>>()
.into_boxed_slice(),
result_columns: result
.map(|r| r.fields)
.unwrap_or_default()
.into_vec()
.into_iter()
// TODO: Should [Column] just wrap [protocol::Field] ?
.map(|field| Column {
result_columns: self
.map_result_columns(result_fields, type_names)
.await?
.into_boxed_slice(),
})
}

async fn get_type_names(
&mut self,
ids: impl IntoIterator<Item = TypeId>,
) -> crate::Result<HashMap<u32, SharedStr>> {
let type_ids: HashSet<u32> = ids.into_iter().map(|id| id.0).collect::<HashSet<u32>>();

if type_ids.is_empty() {
return Ok(HashMap::new());
}

// uppercase type names are easier to visually identify
let mut query = "select types.type_id, UPPER(pg_type.typname) from (VALUES ".to_string();
let mut args = PgArguments::default();
let mut pushed = false;

// TODO: dedup this with the one below, ideally as an API we can export
for (i, (&type_id, bind)) in type_ids.iter().zip((1..).step_by(2)).enumerate() {
if pushed {
query += ", ";
}

pushed = true;
let _ = write!(query, "(${}, ${})", bind, bind + 1);

// not used in the output but ensures are values are sorted correctly
args.add(i as i32);
args.add(type_id as i32);
}

query += ") as types(idx, type_id) \
inner join pg_catalog.pg_type on pg_type.oid = type_id \
order by types.idx";

crate::query::query(&query)
.bind_all(args)
.map(|row: PgRow| -> crate::Result<(u32, SharedStr)> {
Ok((
row.get::<i32, _>(0)? as u32,
row.get::<String, _>(1)?.into(),
))
})
.fetch(self)
.try_collect()
.await
}

async fn map_result_columns(
&mut self,
fields: Box<[Field]>,
type_names: HashMap<u32, SharedStr>,
) -> crate::Result<Vec<Column<Postgres>>> {
if fields.is_empty() {
return Ok(vec![]);
}

let mut query = "select col.idx, pg_attribute.attnotnull from (VALUES ".to_string();
let mut pushed = false;
let mut args = PgArguments::default();

for (i, (field, bind)) in fields.iter().zip((1..).step_by(3)).enumerate() {
if pushed {
query += ", ";
}

pushed = true;
let _ = write!(
query,
"(${}::int4, ${}::int4, ${}::int2)",
bind,
bind + 1,
bind + 2
);

args.add(i as i32);
args.add(field.table_id.map(|id| id as i32));
args.add(field.column_id);
}

query += ") as col(idx, table_id, col_idx) \
left join pg_catalog.pg_attribute on table_id is not null and attrelid = table_id and attnum = col_idx \
order by col.idx;";

log::trace!("describe pg_attribute query: {:#?}", query);

crate::query::query(&query)
.bind_all(args)
.map(|row: PgRow| {
let idx = row.get::<i32, _>(0)?;
let non_null = row.get::<Option<bool>, _>(1)?;

Ok((idx, non_null))
})
.fetch(self)
.zip(stream::iter(fields.into_vec().into_iter().enumerate()))
.map(|(row, (fidx, field))| -> crate::Result<Column<_>> {
let (idx, non_null) = row?;

if idx != fidx as i32 {
return Err(
protocol_err!("missing field from query, field: {:?}", field).into(),
);
}

Ok(Column {
name: field.name,
table_id: field.table_id,
type_info: PgTypeInfo::new(field.type_id),
type_info: PgTypeInfo::new(field.type_id, &type_names[&field.type_id.0]),
non_null,
})
.collect::<Vec<_>>()
.into_boxed_slice(),
})
})
.try_collect()
.await
}

// Poll messages from Postgres, counting the rows affected, until we finish the query
Expand Down
Loading

0 comments on commit 30d45e8

Please sign in to comment.