diff --git a/session/nontransactional.go b/session/nontransactional.go index 8358a5253cee9..f834965cafe54 100644 --- a/session/nontransactional.go +++ b/session/nontransactional.go @@ -17,10 +17,12 @@ package session import ( "context" "fmt" + "math" "strings" "github.com/pingcap/errors" "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/format" @@ -65,9 +67,8 @@ func HandleNonTransactionalDelete(ctx context.Context, stmt *ast.NonTransactiona if err != nil { return nil, err } - if !(se.GetSessionVars().IsAutocommit() && !se.GetSessionVars().InTxn()) { - return nil, errors.Errorf("non-transactional statement can only run in auto-commit mode. auto-commit:%v, inTxn:%v", - se.GetSessionVars().IsAutocommit(), se.GetSessionVars().InTxn()) + if err := checkConstraint(ctx, stmt, se); err != nil { + return nil, err } tableName, selectSQL, shardColumnInfo, err := buildSelectSQL(stmt, se) if err != nil { @@ -91,6 +92,32 @@ func HandleNonTransactionalDelete(ctx context.Context, stmt *ast.NonTransactiona return buildExecuteResults(jobs, se.GetSessionVars().BatchSize.MaxChunkSize) } +func checkConstraint(ctx context.Context, stmt *ast.NonTransactionalDeleteStmt, se Session) error { + sessVars := se.GetSessionVars() + if !(sessVars.IsAutocommit() && !sessVars.InTxn()) { + return errors.Errorf("non-transactional statement can only run in auto-commit mode. auto-commit:%v, inTxn:%v", + se.GetSessionVars().IsAutocommit(), se.GetSessionVars().InTxn()) + } + if config.GetGlobalConfig().EnableBatchDML && sessVars.DMLBatchSize > 0 && (sessVars.BatchDelete || sessVars.BatchInsert) { + return errors.Errorf("can't run non-transactional statement with batch dml") + } + + if sessVars.ReadConsistency.IsWeak() { + return errors.New("can't run non-transactional under weak read consistency") + } + if sessVars.SnapshotTS != 0 { + return errors.New("can't do non-transactional DML when tidb_snapshot is set") + } + // TODO: return error if there are multiple tables + if stmt.DeleteStmt.TableRefs == nil || stmt.DeleteStmt.TableRefs.TableRefs == nil { + return errors.New("table reference is nil") + } + if stmt.DeleteStmt.TableRefs.TableRefs.Right != nil { + return errors.New("Non-transactional delete doesn't support multiple tables") + } + return nil +} + // single-threaded worker. work on the key range [start, end] func splitDeleteWorker(ctx context.Context, jobs []job, stmt *ast.NonTransactionalDeleteStmt, tableName *ast.TableName, se Session, originalCondition ast.ExprNode) ([]string, error) { @@ -270,7 +297,16 @@ func buildShardJobs(ctx context.Context, stmt *ast.NonTransactionalDeleteStmt, s shardColumnCollate = "" } + // A NT-DML is not a SELECT. We ignore the SelectLimit for selectSQL so that it can read all values. + originalSelectLimit := se.GetSessionVars().SelectLimit + se.GetSessionVars().SelectLimit = math.MaxUint64 + // NT-DML is a write operation, and should not be affected by read_staleness that is supposed to affect only SELECT. + originalReadStaleness := se.GetSessionVars().ReadStaleness + se.GetSessionVars().ReadStaleness = 0 rss, err := se.Execute(ctx, selectSQL) + se.GetSessionVars().SelectLimit = originalSelectLimit + se.GetSessionVars().ReadStaleness = originalReadStaleness + if err != nil { return nil, err } @@ -344,13 +380,6 @@ func buildShardJobs(ctx context.Context, stmt *ast.NonTransactionalDeleteStmt, s func buildSelectSQL(stmt *ast.NonTransactionalDeleteStmt, se Session) (*ast.TableName, string, *model.ColumnInfo, error) { // only use the first table - // TODO: return error if there are multiple tables - if stmt.DeleteStmt.TableRefs == nil || stmt.DeleteStmt.TableRefs.TableRefs == nil { - return nil, "", nil, errors.New("table reference is nil") - } - if stmt.DeleteStmt.TableRefs.TableRefs.Right != nil { - return nil, "", nil, errors.New("Non-transactional delete doesn't support multiple tables") - } tableSource, ok := stmt.DeleteStmt.TableRefs.TableRefs.Left.(*ast.TableSource) if !ok { return nil, "", nil, errors.New("Non-transactional delete, table source not found") diff --git a/session/nontransactional_test.go b/session/nontransactional_test.go index 5ea4f731f3253..cc83d85594e2a 100644 --- a/session/nontransactional_test.go +++ b/session/nontransactional_test.go @@ -18,10 +18,13 @@ import ( "fmt" "strings" "testing" + "time" "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/testkit" "github.com/stretchr/testify/require" + tikvutil "github.com/tikv/client-go/v2/util" ) func TestNonTransactionalDeleteSharding(t *testing.T) { @@ -226,3 +229,111 @@ func TestNonTransactionalDeleteInvisibleIndex(t *testing.T) { tk.MustExec("split on a limit 10 delete from t") tk.MustQuery("select count(*) from t").Check(testkit.Rows("0")) } + +func TestNonTransactionalDeleteIgnoreSelectLimit(t *testing.T) { + store, clean := createStorage(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@tidb_max_chunk_size=35") + tk.MustExec("set @@sql_select_limit=3") + tk.MustExec("use test") + tk.MustExec("create table t(a int, b int, key(a))") + for i := 0; i < 100; i++ { + tk.MustExec(fmt.Sprintf("insert into t values (%d, %d)", i, i*2)) + } + tk.MustExec("split on a limit 10 delete from t") + tk.MustQuery("select count(*) from t").Check(testkit.Rows("0")) +} + +func TestNonTransactionalDeleteReadStaleness(t *testing.T) { + store, clean := createStorage(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@tidb_max_chunk_size=35") + tk.MustExec("set @@tidb_read_staleness=-100") + tk.MustExec("use test") + tk.MustExec("create table t(a int, b int, key(a))") + for i := 0; i < 100; i++ { + tk.MustExec(fmt.Sprintf("insert into t values (%d, %d)", i, i*2)) + } + tk.MustExec("split on a limit 10 delete from t") + tk.MustExec("set @@tidb_read_staleness=0") + tk.MustQuery("select count(*) from t").Check(testkit.Rows("0")) +} + +func TestNonTransactionalDeleteCheckConstraint(t *testing.T) { + store, clean := createStorage(t) + defer clean() + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk.MustExec("create table t(a int, b int, key(a))") + + // For mocked tikv, safe point is not initialized, we manually insert it for snapshot to use. + safePointName := "tikv_gc_safe_point" + now := time.Now() + safePointValue := now.Format(tikvutil.GCTimeFormat) + safePointComment := "All versions after safe point can be accessed. (DO NOT EDIT)" + updateSafePoint := fmt.Sprintf("INSERT INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s') ON DUPLICATE KEY UPDATE variable_value = '%[2]s', comment = '%[3]s'", safePointName, safePointValue, safePointComment) + tk.MustExec(updateSafePoint) + + tk.MustExec("set @@tidb_max_chunk_size=35") + tk.MustExec("set @a=now(6)") + + for i := 0; i < 100; i++ { + tk.MustExec(fmt.Sprintf("insert into t values (%d, %d)", i, i*2)) + } + tk.MustExec("set @@tidb_snapshot=@a") + err := tk.ExecToErr("split on a limit 10 delete from t") + require.Error(t, err) + tk.MustExec("set @@tidb_snapshot=''") + tk.MustQuery("select count(*) from t").Check(testkit.Rows("100")) + + tk.MustExec("set @@tidb_read_consistency=weak") + err = tk.ExecToErr("split on a limit 10 delete from t") + require.Error(t, err) + tk.MustQuery("select count(*) from t").Check(testkit.Rows("100")) + tk.MustExec("set @@tidb_read_consistency=strict") + + tk.MustExec("set autocommit=0") + err = tk.ExecToErr("split on a limit 10 delete from t") + require.Error(t, err) + tk.MustQuery("select count(*) from t").Check(testkit.Rows("100")) + tk.MustExec("set autocommit=1") + + tk.MustExec("begin") + err = tk.ExecToErr("split on a limit 10 delete from t") + require.Error(t, err) + tk.MustQuery("select count(*) from t").Check(testkit.Rows("100")) + tk.MustExec("commit") + + config.GetGlobalConfig().EnableBatchDML = true + tk.Session().GetSessionVars().BatchInsert = true + tk.Session().GetSessionVars().DMLBatchSize = 1 + err = tk.ExecToErr("split on a limit 10 delete from t") + require.Error(t, err) + tk.MustQuery("select count(*) from t").Check(testkit.Rows("100")) + config.GetGlobalConfig().EnableBatchDML = false + tk.Session().GetSessionVars().BatchInsert = false + tk.Session().GetSessionVars().DMLBatchSize = 0 + + tk.MustExec("create table t1(a int, b int, key(a))") + tk.MustExec("insert into t1 values (1, 1)") + err = tk.ExecToErr("split limit 1 delete t, t1 from t, t1 where t.a = t1.a") + require.Error(t, err) + tk.MustQuery("select count(*) from t").Check(testkit.Rows("100")) + tk.MustQuery("select count(*) from t1").Check(testkit.Rows("1")) +} + +func TestNonTransactionalDeleteOptimizerHints(t *testing.T) { + store, clean := createStorage(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(a int, b int, key(a))") + for i := 0; i < 10; i++ { + tk.MustExec(fmt.Sprintf("insert into t values ('%d', %d)", i, i*2)) + } + result := tk.MustQuery("split on a limit 10 dry run delete /*+ USE_INDEX(t) */ from t").Rows()[0][0].(string) + require.Equal(t, result, "DELETE /*+ USE_INDEX(`t` )*/ FROM `test`.`t` WHERE `a` BETWEEN 0 AND 9") +}