From d5b8c66e24d66eb082ec76b2d1ab08048b481227 Mon Sep 17 00:00:00 2001 From: tyrelr <44035897+tyrelr@users.noreply.github.com> Date: Wed, 14 Sep 2022 19:12:21 -0600 Subject: [PATCH] Fix sqlite update return and order by type inference (#1960) * add failing test cases for update/delete return into * fix regression in null tracking by improving tracking of cursor empty/full state * add failing test case for order by column types * Add support for SorterOpen,SorterInsert,SorterData * add failing test case for unions * fix range copy/move implementation * fix wrong copy/move range * remove calls to dbg! --- sqlx-core/src/sqlite/connection/explain.rs | 153 ++++++++++++++++----- tests/sqlite/describe.rs | 143 +++++++++++++++---- 2 files changed, 237 insertions(+), 59 deletions(-) diff --git a/sqlx-core/src/sqlite/connection/explain.rs b/sqlx-core/src/sqlite/connection/explain.rs index ace42616f8..b470ce1ecf 100644 --- a/sqlx-core/src/sqlite/connection/explain.rs +++ b/sqlx-core/src/sqlite/connection/explain.rs @@ -67,7 +67,11 @@ const OP_SEEK_LT: &str = "SeekLT"; const OP_SEEK_ROW_ID: &str = "SeekRowId"; const OP_SEEK_SCAN: &str = "SeekScan"; const OP_SEQUENCE_TEST: &str = "SequenceTest"; +const OP_SORT: &str = "Sort"; +const OP_SORTER_DATA: &str = "SorterData"; +const OP_SORTER_INSERT: &str = "SorterInsert"; const OP_SORTER_NEXT: &str = "SorterNext"; +const OP_SORTER_OPEN: &str = "SorterOpen"; const OP_SORTER_SORT: &str = "SorterSort"; const OP_V_FILTER: &str = "VFilter"; const OP_V_NEXT: &str = "VNext"; @@ -180,29 +184,36 @@ impl RegDataType { #[derive(Debug, Clone, Eq, PartialEq)] enum CursorDataType { - Normal(HashMap), + Normal { + cols: HashMap, + is_empty: Option, + }, Pseudo(i64), } impl CursorDataType { - fn from_sparse_record(record: &HashMap) -> Self { - Self::Normal( - record + fn from_sparse_record(record: &HashMap, is_empty: Option) -> Self { + Self::Normal { + cols: record .iter() .map(|(&colnum, &datatype)| (colnum, datatype)) .collect(), - ) + is_empty, + } } - fn from_dense_record(record: &Vec) -> Self { - Self::Normal((0..).zip(record.iter().copied()).collect()) + fn from_dense_record(record: &Vec, is_empty: Option) -> Self { + Self::Normal { + cols: (0..).zip(record.iter().copied()).collect(), + is_empty, + } } fn map_to_dense_record(&self, registers: &HashMap) -> Vec { match self { - Self::Normal(record) => { - let mut rowdata = vec![ColumnType::default(); record.len()]; - for (idx, col) in record.iter() { + Self::Normal { cols, .. } => { + let mut rowdata = vec![ColumnType::default(); cols.len()]; + for (idx, col) in cols.iter() { rowdata[*idx as usize] = col.clone(); } rowdata @@ -219,13 +230,20 @@ impl CursorDataType { registers: &HashMap, ) -> HashMap { match self { - Self::Normal(c) => c.clone(), + Self::Normal { cols, .. } => cols.clone(), Self::Pseudo(i) => match registers.get(i) { Some(RegDataType::Record(r)) => (0..).zip(r.iter().copied()).collect(), _ => HashMap::new(), }, } } + + fn is_empty(&self) -> Option { + match self { + Self::Normal { is_empty, .. } => *is_empty, + Self::Pseudo(_) => Some(false), //pseudo cursors have exactly one row + } + } } #[allow(clippy::wildcard_in_or_patterns)] @@ -379,11 +397,11 @@ pub(super) fn explain( | OP_GE | OP_GO_SUB | OP_GT | OP_IDX_GE | OP_IDX_GT | OP_IDX_LE | OP_IDX_LT | OP_IF | OP_IF_NO_HOPE | OP_IF_NOT | OP_IF_NOT_OPEN | OP_IF_NOT_ZERO | OP_IF_NULL_ROW | OP_IF_POS | OP_IF_SMALLER | OP_INCR_VACUUM | OP_IS_NULL - | OP_IS_NULL_OR_TYPE | OP_LE | OP_LAST | OP_LT | OP_MUST_BE_INT | OP_NE - | OP_NEXT | OP_NO_CONFLICT | OP_NOT_EXISTS | OP_NOT_NULL | OP_ONCE | OP_PREV - | OP_PROGRAM | OP_ROW_SET_READ | OP_ROW_SET_TEST | OP_SEEK_GE | OP_SEEK_GT - | OP_SEEK_LE | OP_SEEK_LT | OP_SEEK_ROW_ID | OP_SEEK_SCAN | OP_SEQUENCE_TEST - | OP_SORTER_NEXT | OP_SORTER_SORT | OP_V_FILTER | OP_V_NEXT | OP_REWIND => { + | OP_IS_NULL_OR_TYPE | OP_LE | OP_LT | OP_MUST_BE_INT | OP_NE | OP_NEXT + | OP_NO_CONFLICT | OP_NOT_EXISTS | OP_NOT_NULL | OP_ONCE | OP_PREV | OP_PROGRAM + | OP_ROW_SET_READ | OP_ROW_SET_TEST | OP_SEEK_GE | OP_SEEK_GT | OP_SEEK_LE + | OP_SEEK_LT | OP_SEEK_ROW_ID | OP_SEEK_SCAN | OP_SEQUENCE_TEST + | OP_SORTER_NEXT | OP_V_FILTER | OP_V_NEXT => { // goto or next instruction (depending on actual values) state.visited[state.program_i] = true; @@ -395,6 +413,35 @@ pub(super) fn explain( continue; } + OP_REWIND | OP_LAST | OP_SORT | OP_SORTER_SORT => { + // goto if cursor p1 is empty, else next instruction + state.visited[state.program_i] = true; + + if let Some(cursor) = state.p.get(&p1) { + if matches!(cursor.is_empty(), None | Some(true)) { + //only take this branch if the cursor is empty + + let mut branch_state = state.clone(); + branch_state.program_i = p2 as usize; + + if let Some(CursorDataType::Normal { is_empty, .. }) = + branch_state.p.get_mut(&p1) + { + *is_empty = Some(true); + } + states.push(branch_state); + } + + if matches!(cursor.is_empty(), None | Some(false)) { + //only take this branch if the cursor is non-empty + state.program_i += 1; + continue; + } + } + + break; + } + OP_INIT_COROUTINE => { // goto or next instruction (depending on actual values) state.visited[state.program_i] = true; @@ -503,7 +550,7 @@ pub(super) fn explain( } } - OP_ROW_DATA => { + OP_ROW_DATA | OP_SORTER_DATA => { //Get entire row from cursor p1, store it into register p2 if let Some(record) = state.p.get(&p1) { let rowdata = record.map_to_dense_record(&state.r); @@ -528,11 +575,14 @@ pub(super) fn explain( state.r.insert(p3, RegDataType::Record(record)); } - OP_INSERT | OP_IDX_INSERT => { + OP_INSERT | OP_IDX_INSERT | OP_SORTER_INSERT => { if let Some(RegDataType::Record(record)) = state.r.get(&p2) { - if let Some(CursorDataType::Normal(row)) = state.p.get_mut(&p1) { + if let Some(CursorDataType::Normal { cols, is_empty }) = + state.p.get_mut(&p1) + { // Insert the record into wherever pointer p1 is - *row = (0..).zip(record.iter().copied()).collect(); + *cols = (0..).zip(record.iter().copied()).collect(); + *is_empty = Some(false); } } //Noop if the register p2 isn't a record, or if pointer p1 does not exist @@ -548,24 +598,35 @@ pub(super) fn explain( if let Some(columns) = root_block_cols.get(&p2) { state .p - .insert(p1, CursorDataType::from_sparse_record(columns)); + .insert(p1, CursorDataType::from_sparse_record(columns, None)); } else { - state - .p - .insert(p1, CursorDataType::Normal(HashMap::with_capacity(6))); + state.p.insert( + p1, + CursorDataType::Normal { + cols: HashMap::with_capacity(6), + is_empty: None, + }, + ); } } else { - state - .p - .insert(p1, CursorDataType::Normal(HashMap::with_capacity(6))); + state.p.insert( + p1, + CursorDataType::Normal { + cols: HashMap::with_capacity(6), + is_empty: None, + }, + ); } } - OP_OPEN_EPHEMERAL | OP_OPEN_AUTOINDEX => { + OP_OPEN_EPHEMERAL | OP_OPEN_AUTOINDEX | OP_SORTER_OPEN => { //Create a new pointer which is referenced by p1 state.p.insert( p1, - CursorDataType::from_dense_record(&vec![ColumnType::null(); p2 as usize]), + CursorDataType::from_dense_record( + &vec![ColumnType::null(); p2 as usize], + Some(true), + ), ); } @@ -594,8 +655,9 @@ pub(super) fn explain( OP_NULL_ROW => { // all columns in cursor X are potentially nullable - if let Some(CursorDataType::Normal(ref mut cursor)) = state.p.get_mut(&p1) { - for ref mut col in cursor.values_mut() { + if let Some(CursorDataType::Normal { ref mut cols, .. }) = state.p.get_mut(&p1) + { + for ref mut col in cols.values_mut() { col.nullable = Some(true); } } @@ -649,13 +711,40 @@ pub(super) fn explain( } } - OP_COPY | OP_MOVE | OP_SCOPY | OP_INT_COPY => { + OP_SCOPY | OP_INT_COPY => { // r[p2] = r[p1] if let Some(v) = state.r.get(&p1).cloned() { state.r.insert(p2, v); } } + OP_COPY => { + // r[p2..=p2+p3] = r[p1..=p1+p3] + if p3 >= 0 { + for i in 0..=p3 { + let src = p1 + i; + let dst = p2 + i; + if let Some(v) = state.r.get(&src).cloned() { + state.r.insert(dst, v); + } + } + } + } + + OP_MOVE => { + // r[p2..p2+p3] = r[p1..p1+p3]; r[p1..p1+p3] = null + if p3 >= 1 { + for i in 0..p3 { + let src = p1 + i; + let dst = p2 + i; + if let Some(v) = state.r.get(&src).cloned() { + state.r.insert(dst, v); + state.r.insert(src, RegDataType::Single(ColumnType::null())); + } + } + } + } + OP_INTEGER => { // r[p2] = p1 state.r.insert(p2, RegDataType::Int(p1)); diff --git a/tests/sqlite/describe.rs b/tests/sqlite/describe.rs index fe42220663..ffd088badc 100644 --- a/tests/sqlite/describe.rs +++ b/tests/sqlite/describe.rs @@ -185,7 +185,47 @@ async fn it_describes_insert_with_returning() -> anyhow::Result<()> { assert_eq!(d.columns().len(), 4); assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); assert_eq!(d.column(1).type_info().name(), "TEXT"); + assert_eq!(d.nullable(1), Some(false)); + + let d = conn + .describe("INSERT INTO accounts (name, is_active) VALUES ('a', true) RETURNING id") + .await?; + + assert_eq!(d.columns().len(), 1); + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_describes_update_with_returning() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let d = conn + .describe("UPDATE accounts SET is_active=true WHERE name=?1 RETURNING id") + .await?; + + assert_eq!(d.columns().len(), 1); + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_describes_delete_with_returning() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let d = conn + .describe("DELETE FROM accounts WHERE name=?1 RETURNING id") + .await?; + + assert_eq!(d.columns().len(), 1); + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); Ok(()) } @@ -299,38 +339,38 @@ async fn it_describes_literal_subquery() -> anyhow::Result<()> { Ok(()) } -#[sqlx_macros::test] -async fn it_describes_table_subquery() -> anyhow::Result<()> { - async fn assert_tweet_described( - conn: &mut sqlx::SqliteConnection, - query: &str, - ) -> anyhow::Result<()> { - let info = conn.describe(query).await?; - let columns = info.columns(); +async fn assert_tweet_described( + conn: &mut sqlx::SqliteConnection, + query: &str, +) -> anyhow::Result<()> { + let info = conn.describe(query).await?; + let columns = info.columns(); - assert_eq!(columns[0].name(), "id", "{}", query); - assert_eq!(columns[1].name(), "text", "{}", query); - assert_eq!(columns[2].name(), "is_sent", "{}", query); - assert_eq!(columns[3].name(), "owner_id", "{}", query); + assert_eq!(columns[0].name(), "id", "{}", query); + assert_eq!(columns[1].name(), "text", "{}", query); + assert_eq!(columns[2].name(), "is_sent", "{}", query); + assert_eq!(columns[3].name(), "owner_id", "{}", query); - assert_eq!(columns[0].ordinal(), 0, "{}", query); - assert_eq!(columns[1].ordinal(), 1, "{}", query); - assert_eq!(columns[2].ordinal(), 2, "{}", query); - assert_eq!(columns[3].ordinal(), 3, "{}", query); + assert_eq!(columns[0].ordinal(), 0, "{}", query); + assert_eq!(columns[1].ordinal(), 1, "{}", query); + assert_eq!(columns[2].ordinal(), 2, "{}", query); + assert_eq!(columns[3].ordinal(), 3, "{}", query); - assert_eq!(info.nullable(0), Some(false), "{}", query); - assert_eq!(info.nullable(1), Some(false), "{}", query); - assert_eq!(info.nullable(2), Some(false), "{}", query); - assert_eq!(info.nullable(3), Some(true), "{}", query); + assert_eq!(info.nullable(0), Some(false), "{}", query); + assert_eq!(info.nullable(1), Some(false), "{}", query); + assert_eq!(info.nullable(2), Some(false), "{}", query); + assert_eq!(info.nullable(3), Some(true), "{}", query); - assert_eq!(columns[0].type_info().name(), "INTEGER", "{}", query); - assert_eq!(columns[1].type_info().name(), "TEXT", "{}", query); - assert_eq!(columns[2].type_info().name(), "BOOLEAN", "{}", query); - assert_eq!(columns[3].type_info().name(), "INTEGER", "{}", query); + assert_eq!(columns[0].type_info().name(), "INTEGER", "{}", query); + assert_eq!(columns[1].type_info().name(), "TEXT", "{}", query); + assert_eq!(columns[2].type_info().name(), "BOOLEAN", "{}", query); + assert_eq!(columns[3].type_info().name(), "INTEGER", "{}", query); - Ok(()) - } + Ok(()) +} +#[sqlx_macros::test] +async fn it_describes_table_subquery() -> anyhow::Result<()> { let mut conn = new::().await?; assert_tweet_described(&mut conn, "SELECT * FROM tweet").await?; assert_tweet_described(&mut conn, "SELECT * FROM (SELECT * FROM tweet)").await?; @@ -348,6 +388,43 @@ async fn it_describes_table_subquery() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn it_describes_table_order_by() -> anyhow::Result<()> { + let mut conn = new::().await?; + assert_tweet_described(&mut conn, "SELECT * FROM tweet ORDER BY id").await?; + assert_tweet_described(&mut conn, "SELECT * FROM tweet ORDER BY id NULLS LAST").await?; + assert_tweet_described( + &mut conn, + "SELECT * FROM tweet ORDER BY owner_id DESC, text ASC", + ) + .await?; + + async fn assert_literal_order_by_described( + conn: &mut sqlx::SqliteConnection, + query: &str, + ) -> anyhow::Result<()> { + let info = conn.describe(query).await?; + + assert_eq!(info.column(0).type_info().name(), "TEXT", "{}", query); + assert_eq!(info.nullable(0), Some(false), "{}", query); + assert_eq!(info.column(1).type_info().name(), "TEXT", "{}", query); + assert_eq!(info.nullable(1), Some(false), "{}", query); + + Ok(()) + } + + assert_literal_order_by_described(&mut conn, "SELECT 'a', text FROM tweet ORDER BY id").await?; + assert_literal_order_by_described( + &mut conn, + "SELECT 'a', text FROM tweet ORDER BY id NULLS LAST", + ) + .await?; + assert_literal_order_by_described(&mut conn, "SELECT 'a', text FROM tweet ORDER BY text") + .await?; + + Ok(()) +} + #[sqlx_macros::test] async fn it_describes_union() -> anyhow::Result<()> { async fn assert_union_described( @@ -375,8 +452,20 @@ async fn it_describes_union() -> anyhow::Result<()> { "SELECT 'txt','a',null,'b' UNION ALL SELECT 'int',NULL,1,2 ", ) .await?; + //TODO: insert into temp-table not merging datatype/nullable of all operations - currently keeping last-writer - //assert_union_described(&mut conn, "SELECT 'txt','a',null,'b' UNION SELECT 'int',NULL,1,2 ").await?; + //assert_union_described(&mut conn, "SELECT 'txt','a',null,'b' UNION SELECT 'int',NULL,1,2 ").await?; + + assert_union_described( + &mut conn, + "SELECT 'tweet',text,owner_id id,null from tweet + UNION SELECT 'account',name,id,is_active from accounts + UNION SELECT 'account',name,id,is_active from accounts_view + UNION SELECT 'dummy',null,null,null + ORDER BY id + ", + ) + .await?; Ok(()) }