diff --git a/executor/adapter.go b/executor/adapter.go index 0e9e18ed07776..cb773cea6d5aa 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -338,12 +338,11 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) { }() failpoint.Inject("assertStaleTSO", func(val failpoint.Value) { - if n, ok := val.(int); ok { + if n, ok := val.(int); ok && a.IsStaleness { startTS := oracle.ExtractPhysical(a.SnapshotTS) / 1000 if n != int(startTS) { panic(fmt.Sprintf("different tso %d != %d", n, startTS)) } - failpoint.Return() } }) sctx := a.Ctx diff --git a/executor/builder.go b/executor/builder.go index 2e40a3c088c5c..6e24478696f69 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -2608,6 +2608,7 @@ func (b *executorBuilder) buildIndexLookUpJoin(v *plannercore.PhysicalIndexJoin) for i, col := range v.OuterHashKeys { outerTypes[col.Index] = outerTypes[col.Index].Clone() outerTypes[col.Index].Collate = innerTypes[v.InnerHashKeys[i].Index].Collate + outerTypes[col.Index].Flag = col.RetType.Flag } var ( diff --git a/executor/index_lookup_join_test.go b/executor/index_lookup_join_test.go index 992f77e8dfb14..335b67843fb2c 100644 --- a/executor/index_lookup_join_test.go +++ b/executor/index_lookup_join_test.go @@ -349,6 +349,19 @@ func (s *testSuite5) TestIssue24547(c *C) { tk.MustExec("delete a from a inner join b on a.k1 = b.k1 and a.k2 = b.k2 where b.k2 <> '333'") } +func (s *testSuite5) TestIssue27893(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t1 (a enum('x','y'))") + tk.MustExec("create table t2 (a int, key(a))") + tk.MustExec("insert into t1 values('x')") + tk.MustExec("insert into t2 values(1)") + tk.MustQuery("select /*+ inl_join(t2) */ count(*) from t1 join t2 on t1.a = t2.a").Check(testkit.Rows("1")) + tk.MustQuery("select /*+ inl_hash_join(t2) */ count(*) from t1 join t2 on t1.a = t2.a").Check(testkit.Rows("1")) +} + func (s *testSuite5) TestPartitionTableIndexJoinAndIndexReader(c *C) { if israce.RaceEnabled { c.Skip("exhaustive types test, skip race test") diff --git a/executor/stale_txn_test.go b/executor/stale_txn_test.go index 622aac1905e9e..32ac8698fc879 100644 --- a/executor/stale_txn_test.go +++ b/executor/stale_txn_test.go @@ -1014,7 +1014,7 @@ func (s *testStaleTxnSerialSuite) TestStaleReadPrepare(c *C) { c.Assert("execute p1", NotNil) } -func (s *testStaleTxnSuite) TestStmtCtxStaleFlag(c *C) { +func (s *testStaleTxnSerialSuite) TestStmtCtxStaleFlag(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -1106,3 +1106,41 @@ func (s *testStaleTxnSuite) TestStmtCtxStaleFlag(c *C) { c.Assert(tk.Se.GetSessionVars().StmtCtx.IsStaleness, IsFalse) } } + +func (s *testStaleTxnSerialSuite) TestStaleSessionQuery(c *C) { + tk := testkit.NewTestKit(c, s.store) + // For mocktikv, safe point is not initialized, we manually insert it for snapshot to use. + safePointName := "tikv_gc_safe_point" + safePointValue := "20160102-15:04:05 -0700" + safePointComment := "All versions after safe point can be accessed. (DO NOT EDIT)" + updateSafePoint := fmt.Sprintf(`INSERT INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s') + ON DUPLICATE KEY + UPDATE variable_value = '%[2]s', comment = '%[3]s'`, safePointName, safePointValue, safePointComment) + tk.MustExec(updateSafePoint) + tk.MustExec("use test") + tk.MustExec("create table t10 (id int);") + tk.MustExec("insert into t10 (id) values (1)") + time.Sleep(2 * time.Second) + now := time.Now() + tk.MustExec(`set @@tidb_read_staleness="-1s"`) + // query will use stale read + c.Assert(failpoint.Enable("github.com/pingcap/tidb/expression/injectNow", fmt.Sprintf(`return(%d)`, now.Unix())), IsNil) + c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/assertStaleTSO", fmt.Sprintf(`return(%d)`, now.Unix()-1)), IsNil) + c.Assert(tk.MustQuery("select * from t10;").Rows(), HasLen, 1) + c.Assert(failpoint.Disable("github.com/pingcap/tidb/executor/assertStaleTSO"), IsNil) + c.Assert(failpoint.Disable("github.com/pingcap/tidb/expression/injectNow"), IsNil) + // begin transaction won't be affected by read staleness + tk.MustExec("begin") + tk.MustExec("insert into t10(id) values (2);") + tk.MustExec("commit") + tk.MustExec("insert into t10(id) values (3);") + // query will still use staleness read + c.Assert(failpoint.Enable("github.com/pingcap/tidb/expression/injectNow", fmt.Sprintf(`return(%d)`, now.Unix())), IsNil) + c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/assertStaleTSO", fmt.Sprintf(`return(%d)`, now.Unix()-1)), IsNil) + c.Assert(tk.MustQuery("select * from t10").Rows(), HasLen, 1) + c.Assert(failpoint.Disable("github.com/pingcap/tidb/executor/assertStaleTSO"), IsNil) + c.Assert(failpoint.Disable("github.com/pingcap/tidb/expression/injectNow"), IsNil) + // assert stale read is not exist after empty the variable + tk.MustExec(`set @@tidb_read_staleness=""`) + c.Assert(tk.MustQuery("select * from t10").Rows(), HasLen, 3) +} diff --git a/expression/aggregation/aggregation_test.go b/expression/aggregation/aggregation_test.go index 059405cae9e84..3985b2eb67b2a 100644 --- a/expression/aggregation/aggregation_test.go +++ b/expression/aggregation/aggregation_test.go @@ -16,8 +16,8 @@ package aggregation import ( "math" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" @@ -26,77 +26,77 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/mock" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testAggFuncSuit{}) - -type testAggFuncSuit struct { +type mockAggFuncSuite struct { ctx sessionctx.Context rows []chunk.Row nullRow chunk.Row } -func generateRowData() []chunk.Row { - rows := make([]chunk.Row, 0, 5050) +func createAggFuncSuite() (s *mockAggFuncSuite) { + s = new(mockAggFuncSuite) + s.ctx = mock.NewContext() + s.ctx.GetSessionVars().GlobalVarsAccessor = variable.NewMockGlobalAccessor() + s.rows = make([]chunk.Row, 0, 5050) for i := 1; i <= 100; i++ { for j := 0; j < i; j++ { - rows = append(rows, chunk.MutRowFromDatums(types.MakeDatums(i)).ToRow()) + s.rows = append(s.rows, chunk.MutRowFromDatums(types.MakeDatums(i)).ToRow()) } } - return rows -} - -func (s *testAggFuncSuit) SetUpSuite(c *C) { - s.ctx = mock.NewContext() - s.ctx.GetSessionVars().GlobalVarsAccessor = variable.NewMockGlobalAccessor() - s.rows = generateRowData() s.nullRow = chunk.MutRowFromDatums([]types.Datum{{}}).ToRow() + return } -func (s *testAggFuncSuit) TestAvg(c *C) { +func TestAvg(t *testing.T) { + t.Parallel() + s := createAggFuncSuite() col := &expression.Column{ Index: 0, RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, false) - c.Assert(err, IsNil) + require.NoError(t, err) avgFunc := desc.GetAggFunc(ctx) evalCtx := avgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := avgFunc.GetResult(evalCtx) - c.Assert(result.IsNull(), IsTrue) + require.True(t, result.IsNull()) for _, row := range s.rows { err := avgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) } result = avgFunc.GetResult(evalCtx) needed := types.NewDecFromStringForTest("67.000000000000000000000000000000") - c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) + require.True(t, result.GetMysqlDecimal().Compare(needed) == 0) err = avgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) - c.Assert(err, IsNil) + require.NoError(t, err) result = avgFunc.GetResult(evalCtx) - c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) + require.True(t, result.GetMysqlDecimal().Compare(needed) == 0) desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, true) - c.Assert(err, IsNil) + require.NoError(t, err) distinctAvgFunc := desc.GetAggFunc(ctx) evalCtx = distinctAvgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) for _, row := range s.rows { err := distinctAvgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) } result = distinctAvgFunc.GetResult(evalCtx) needed = types.NewDecFromStringForTest("50.500000000000000000000000000000") - c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) + require.True(t, result.GetMysqlDecimal().Compare(needed) == 0) partialResult := distinctAvgFunc.GetPartialResult(evalCtx) - c.Assert(partialResult[0].GetInt64(), Equals, int64(100)) + require.Equal(t, int64(100), partialResult[0].GetInt64()) needed = types.NewDecFromStringForTest("5050") - c.Assert(partialResult[1].GetMysqlDecimal().Compare(needed) == 0, IsTrue, Commentf("%v, %v ", result.GetMysqlDecimal(), needed)) + require.Equalf(t, 0, partialResult[1].GetMysqlDecimal().Compare(needed), "%v, %v ", result.GetMysqlDecimal(), needed) } -func (s *testAggFuncSuit) TestAvgFinalMode(c *C) { +func TestAvgFinalMode(t *testing.T) { + t.Parallel() + s := createAggFuncSuite() rows := make([][]types.Datum, 0, 100) for i := 1; i <= 100; i++ { rows = append(rows, types.MakeDatums(i, types.NewDecFromInt(int64(i*i)))) @@ -111,344 +111,356 @@ func (s *testAggFuncSuit) TestAvgFinalMode(c *C) { RetType: types.NewFieldType(mysql.TypeNewDecimal), } aggFunc, err := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{cntCol, sumCol}, false) - c.Assert(err, IsNil) + require.NoError(t, err) aggFunc.Mode = FinalMode avgFunc := aggFunc.GetAggFunc(ctx) evalCtx := avgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) for _, row := range rows { err := avgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, chunk.MutRowFromDatums(row).ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) } result := avgFunc.GetResult(evalCtx) needed := types.NewDecFromStringForTest("67.000000000000000000000000000000") - c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) + require.True(t, result.GetMysqlDecimal().Compare(needed) == 0) } -func (s *testAggFuncSuit) TestSum(c *C) { +func TestSum(t *testing.T) { + t.Parallel() + s := createAggFuncSuite() col := &expression.Column{ Index: 0, RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, false) - c.Assert(err, IsNil) + require.NoError(t, err) sumFunc := desc.GetAggFunc(ctx) evalCtx := sumFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := sumFunc.GetResult(evalCtx) - c.Assert(result.IsNull(), IsTrue) + require.True(t, result.IsNull()) for _, row := range s.rows { err := sumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) } result = sumFunc.GetResult(evalCtx) needed := types.NewDecFromStringForTest("338350") - c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) + require.True(t, result.GetMysqlDecimal().Compare(needed) == 0) err = sumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) - c.Assert(err, IsNil) + require.NoError(t, err) result = sumFunc.GetResult(evalCtx) - c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) + require.True(t, result.GetMysqlDecimal().Compare(needed) == 0) partialResult := sumFunc.GetPartialResult(evalCtx) - c.Assert(partialResult[0].GetMysqlDecimal().Compare(needed) == 0, IsTrue) + require.True(t, partialResult[0].GetMysqlDecimal().Compare(needed) == 0) desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, true) - c.Assert(err, IsNil) + require.NoError(t, err) distinctSumFunc := desc.GetAggFunc(ctx) evalCtx = distinctSumFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) for _, row := range s.rows { err := distinctSumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) } result = distinctSumFunc.GetResult(evalCtx) needed = types.NewDecFromStringForTest("5050") - c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) + require.True(t, result.GetMysqlDecimal().Compare(needed) == 0) } -func (s *testAggFuncSuit) TestBitAnd(c *C) { +func TestBitAnd(t *testing.T) { + t.Parallel() + s := createAggFuncSuite() col := &expression.Column{ Index: 0, RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitAnd, []expression.Expression{col}, false) - c.Assert(err, IsNil) + require.NoError(t, err) bitAndFunc := desc.GetAggFunc(ctx) evalCtx := bitAndFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := bitAndFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(math.MaxUint64)) + require.Equal(t, uint64(math.MaxUint64), result.GetUint64()) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() err = bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitAndFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) err = bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitAndFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) row = chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() err = bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitAndFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) row = chunk.MutRowFromDatums(types.MakeDatums(3)).ToRow() err = bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitAndFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) row = chunk.MutRowFromDatums(types.MakeDatums(2)).ToRow() err = bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitAndFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(0)) + require.Equal(t, uint64(0), result.GetUint64()) partialResult := bitAndFunc.GetPartialResult(evalCtx) - c.Assert(partialResult[0].GetUint64(), Equals, uint64(0)) + require.Equal(t, uint64(0), partialResult[0].GetUint64()) // test bit_and( decimal ) col.RetType = types.NewFieldType(mysql.TypeNewDecimal) bitAndFunc.ResetContext(s.ctx.GetSessionVars().StmtCtx, evalCtx) result = bitAndFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(math.MaxUint64)) + require.Equal(t, uint64(math.MaxUint64), result.GetUint64()) var dec types.MyDecimal err = dec.FromString([]byte("1.234")) - c.Assert(err, IsNil) + require.NoError(t, err) row = chunk.MutRowFromDatums(types.MakeDatums(&dec)).ToRow() err = bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitAndFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) err = dec.FromString([]byte("3.012")) - c.Assert(err, IsNil) + require.NoError(t, err) row = chunk.MutRowFromDatums(types.MakeDatums(&dec)).ToRow() err = bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitAndFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) err = dec.FromString([]byte("2.12345678")) - c.Assert(err, IsNil) + require.NoError(t, err) row = chunk.MutRowFromDatums(types.MakeDatums(&dec)).ToRow() err = bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitAndFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(0)) + require.Equal(t, uint64(0), result.GetUint64()) } -func (s *testAggFuncSuit) TestBitOr(c *C) { +func TestBitOr(t *testing.T) { + t.Parallel() + s := createAggFuncSuite() col := &expression.Column{ Index: 0, RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitOr, []expression.Expression{col}, false) - c.Assert(err, IsNil) + require.NoError(t, err) bitOrFunc := desc.GetAggFunc(ctx) evalCtx := bitOrFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := bitOrFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(0)) + require.Equal(t, uint64(0), result.GetUint64()) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() err = bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitOrFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) err = bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitOrFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) row = chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() err = bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitOrFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) row = chunk.MutRowFromDatums(types.MakeDatums(3)).ToRow() err = bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitOrFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(3)) + require.Equal(t, uint64(3), result.GetUint64()) row = chunk.MutRowFromDatums(types.MakeDatums(2)).ToRow() err = bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitOrFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(3)) + require.Equal(t, uint64(3), result.GetUint64()) partialResult := bitOrFunc.GetPartialResult(evalCtx) - c.Assert(partialResult[0].GetUint64(), Equals, uint64(3)) + require.Equal(t, uint64(3), partialResult[0].GetUint64()) // test bit_or( decimal ) col.RetType = types.NewFieldType(mysql.TypeNewDecimal) bitOrFunc.ResetContext(s.ctx.GetSessionVars().StmtCtx, evalCtx) result = bitOrFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(0)) + require.Equal(t, uint64(0), result.GetUint64()) var dec types.MyDecimal err = dec.FromString([]byte("12.234")) - c.Assert(err, IsNil) + require.NoError(t, err) row = chunk.MutRowFromDatums(types.MakeDatums(&dec)).ToRow() err = bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitOrFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(12)) + require.Equal(t, uint64(12), result.GetUint64()) err = dec.FromString([]byte("1.012")) - c.Assert(err, IsNil) + require.NoError(t, err) row = chunk.MutRowFromDatums(types.MakeDatums(&dec)).ToRow() err = bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitOrFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(13)) + require.Equal(t, uint64(13), result.GetUint64()) err = dec.FromString([]byte("15.12345678")) - c.Assert(err, IsNil) + require.NoError(t, err) row = chunk.MutRowFromDatums(types.MakeDatums(&dec)).ToRow() err = bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitOrFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(15)) + require.Equal(t, uint64(15), result.GetUint64()) err = dec.FromString([]byte("16.00")) - c.Assert(err, IsNil) + require.NoError(t, err) row = chunk.MutRowFromDatums(types.MakeDatums(&dec)).ToRow() err = bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitOrFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(31)) + require.Equal(t, uint64(31), result.GetUint64()) } -func (s *testAggFuncSuit) TestBitXor(c *C) { +func TestBitXor(t *testing.T) { + t.Parallel() + s := createAggFuncSuite() col := &expression.Column{ Index: 0, RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitXor, []expression.Expression{col}, false) - c.Assert(err, IsNil) + require.NoError(t, err) bitXorFunc := desc.GetAggFunc(ctx) evalCtx := bitXorFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := bitXorFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(0)) + require.Equal(t, uint64(0), result.GetUint64()) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() err = bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitXorFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) err = bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitXorFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) row = chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() err = bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitXorFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(0)) + require.Equal(t, uint64(0), result.GetUint64()) row = chunk.MutRowFromDatums(types.MakeDatums(3)).ToRow() err = bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitXorFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(3)) + require.Equal(t, uint64(3), result.GetUint64()) row = chunk.MutRowFromDatums(types.MakeDatums(2)).ToRow() err = bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitXorFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) partialResult := bitXorFunc.GetPartialResult(evalCtx) - c.Assert(partialResult[0].GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), partialResult[0].GetUint64()) // test bit_xor( decimal ) col.RetType = types.NewFieldType(mysql.TypeNewDecimal) bitXorFunc.ResetContext(s.ctx.GetSessionVars().StmtCtx, evalCtx) result = bitXorFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(0)) + require.Equal(t, uint64(0), result.GetUint64()) var dec types.MyDecimal err = dec.FromString([]byte("1.234")) - c.Assert(err, IsNil) + require.NoError(t, err) row = chunk.MutRowFromDatums(types.MakeDatums(&dec)).ToRow() err = bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitXorFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) err = dec.FromString([]byte("1.012")) - c.Assert(err, IsNil) + require.NoError(t, err) row = chunk.MutRowFromDatums(types.MakeDatums(&dec)).ToRow() err = bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitXorFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(0)) + require.Equal(t, uint64(0), result.GetUint64()) err = dec.FromString([]byte("2.12345678")) - c.Assert(err, IsNil) + require.NoError(t, err) row = chunk.MutRowFromDatums(types.MakeDatums(&dec)).ToRow() err = bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = bitXorFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(2)) + require.Equal(t, uint64(2), result.GetUint64()) } -func (s *testAggFuncSuit) TestCount(c *C) { +func TestCount(t *testing.T) { + t.Parallel() + s := createAggFuncSuite() col := &expression.Column{ Index: 0, RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, false) - c.Assert(err, IsNil) + require.NoError(t, err) countFunc := desc.GetAggFunc(ctx) evalCtx := countFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := countFunc.GetResult(evalCtx) - c.Assert(result.GetInt64(), Equals, int64(0)) + require.Equal(t, int64(0), result.GetInt64()) for _, row := range s.rows { err := countFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) } result = countFunc.GetResult(evalCtx) - c.Assert(result.GetInt64(), Equals, int64(5050)) + require.Equal(t, int64(5050), result.GetInt64()) err = countFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) - c.Assert(err, IsNil) + require.NoError(t, err) result = countFunc.GetResult(evalCtx) - c.Assert(result.GetInt64(), Equals, int64(5050)) + require.Equal(t, int64(5050), result.GetInt64()) partialResult := countFunc.GetPartialResult(evalCtx) - c.Assert(partialResult[0].GetInt64(), Equals, int64(5050)) + require.Equal(t, int64(5050), partialResult[0].GetInt64()) desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, true) - c.Assert(err, IsNil) + require.NoError(t, err) distinctCountFunc := desc.GetAggFunc(ctx) evalCtx = distinctCountFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) for _, row := range s.rows { err := distinctCountFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) } result = distinctCountFunc.GetResult(evalCtx) - c.Assert(result.GetInt64(), Equals, int64(100)) + require.Equal(t, int64(100), result.GetInt64()) } -func (s *testAggFuncSuit) TestConcat(c *C) { +func TestConcat(t *testing.T) { + t.Parallel() + s := createAggFuncSuite() col := &expression.Column{ Index: 0, RetType: types.NewFieldType(mysql.TypeLonglong), @@ -459,52 +471,54 @@ func (s *testAggFuncSuit) TestConcat(c *C) { } ctx := mock.NewContext() desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, false) - c.Assert(err, IsNil) + require.NoError(t, err) concatFunc := desc.GetAggFunc(ctx) evalCtx := concatFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := concatFunc.GetResult(evalCtx) - c.Assert(result.IsNull(), IsTrue) + require.True(t, result.IsNull()) row := chunk.MutRowFromDatums(types.MakeDatums(1, "x")) err = concatFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) result = concatFunc.GetResult(evalCtx) - c.Assert(result.GetString(), Equals, "1") + require.Equal(t, "1", result.GetString()) row.SetDatum(0, types.NewIntDatum(2)) err = concatFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) result = concatFunc.GetResult(evalCtx) - c.Assert(result.GetString(), Equals, "1x2") + require.Equal(t, "1x2", result.GetString()) row.SetDatum(0, types.NewDatum(nil)) err = concatFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) result = concatFunc.GetResult(evalCtx) - c.Assert(result.GetString(), Equals, "1x2") + require.Equal(t, "1x2", result.GetString()) partialResult := concatFunc.GetPartialResult(evalCtx) - c.Assert(partialResult[0].GetString(), Equals, "1x2") + require.Equal(t, "1x2", partialResult[0].GetString()) desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, true) - c.Assert(err, IsNil) + require.NoError(t, err) distinctConcatFunc := desc.GetAggFunc(ctx) evalCtx = distinctConcatFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) row.SetDatum(0, types.NewIntDatum(1)) err = distinctConcatFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) result = distinctConcatFunc.GetResult(evalCtx) - c.Assert(result.GetString(), Equals, "1") + require.Equal(t, "1", result.GetString()) row.SetDatum(0, types.NewIntDatum(1)) err = distinctConcatFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) result = distinctConcatFunc.GetResult(evalCtx) - c.Assert(result.GetString(), Equals, "1") + require.Equal(t, "1", result.GetString()) } -func (s *testAggFuncSuit) TestFirstRow(c *C) { +func TestFirstRow(t *testing.T) { + t.Parallel() + s := createAggFuncSuite() col := &expression.Column{ Index: 0, RetType: types.NewFieldType(mysql.TypeLonglong), @@ -512,26 +526,28 @@ func (s *testAggFuncSuit) TestFirstRow(c *C) { ctx := mock.NewContext() desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) - c.Assert(err, IsNil) + require.NoError(t, err) firstRowFunc := desc.GetAggFunc(ctx) evalCtx := firstRowFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() err = firstRowFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result := firstRowFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) row = chunk.MutRowFromDatums(types.MakeDatums(2)).ToRow() err = firstRowFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) - c.Assert(err, IsNil) + require.NoError(t, err) result = firstRowFunc.GetResult(evalCtx) - c.Assert(result.GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), result.GetUint64()) partialResult := firstRowFunc.GetPartialResult(evalCtx) - c.Assert(partialResult[0].GetUint64(), Equals, uint64(1)) + require.Equal(t, uint64(1), partialResult[0].GetUint64()) } -func (s *testAggFuncSuit) TestMaxMin(c *C) { +func TestMaxMin(t *testing.T) { + t.Parallel() + s := createAggFuncSuite() col := &expression.Column{ Index: 0, RetType: types.NewFieldType(mysql.TypeLonglong), @@ -539,58 +555,58 @@ func (s *testAggFuncSuit) TestMaxMin(c *C) { ctx := mock.NewContext() desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncMax, []expression.Expression{col}, false) - c.Assert(err, IsNil) + require.NoError(t, err) maxFunc := desc.GetAggFunc(ctx) desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncMin, []expression.Expression{col}, false) - c.Assert(err, IsNil) + require.NoError(t, err) minFunc := desc.GetAggFunc(ctx) maxEvalCtx := maxFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) minEvalCtx := minFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := maxFunc.GetResult(maxEvalCtx) - c.Assert(result.IsNull(), IsTrue) + require.True(t, result.IsNull()) result = minFunc.GetResult(minEvalCtx) - c.Assert(result.IsNull(), IsTrue) + require.True(t, result.IsNull()) row := chunk.MutRowFromDatums(types.MakeDatums(2)) err = maxFunc.Update(maxEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) result = maxFunc.GetResult(maxEvalCtx) - c.Assert(result.GetInt64(), Equals, int64(2)) + require.Equal(t, int64(2), result.GetInt64()) err = minFunc.Update(minEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) result = minFunc.GetResult(minEvalCtx) - c.Assert(result.GetInt64(), Equals, int64(2)) + require.Equal(t, int64(2), result.GetInt64()) row.SetDatum(0, types.NewIntDatum(3)) err = maxFunc.Update(maxEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) result = maxFunc.GetResult(maxEvalCtx) - c.Assert(result.GetInt64(), Equals, int64(3)) + require.Equal(t, int64(3), result.GetInt64()) err = minFunc.Update(minEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) result = minFunc.GetResult(minEvalCtx) - c.Assert(result.GetInt64(), Equals, int64(2)) + require.Equal(t, int64(2), result.GetInt64()) row.SetDatum(0, types.NewIntDatum(1)) err = maxFunc.Update(maxEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) result = maxFunc.GetResult(maxEvalCtx) - c.Assert(result.GetInt64(), Equals, int64(3)) + require.Equal(t, int64(3), result.GetInt64()) err = minFunc.Update(minEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) result = minFunc.GetResult(minEvalCtx) - c.Assert(result.GetInt64(), Equals, int64(1)) + require.Equal(t, int64(1), result.GetInt64()) row.SetDatum(0, types.NewDatum(nil)) err = maxFunc.Update(maxEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) result = maxFunc.GetResult(maxEvalCtx) - c.Assert(result.GetInt64(), Equals, int64(3)) + require.Equal(t, int64(3), result.GetInt64()) err = minFunc.Update(minEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) - c.Assert(err, IsNil) + require.NoError(t, err) result = minFunc.GetResult(minEvalCtx) - c.Assert(result.GetInt64(), Equals, int64(1)) + require.Equal(t, int64(1), result.GetInt64()) partialResult := minFunc.GetPartialResult(minEvalCtx) - c.Assert(partialResult[0].GetInt64(), Equals, int64(1)) + require.Equal(t, int64(1), partialResult[0].GetInt64()) } diff --git a/expression/builtin_time.go b/expression/builtin_time.go index de2cb3e516fc0..b877fb6cd5d23 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -2537,6 +2537,16 @@ func (c *nowFunctionClass) getFunction(ctx sessionctx.Context, args []Expression return sig, nil } +// GetStmtTimestamp directly calls getTimeZone with timezone +func GetStmtTimestamp(ctx sessionctx.Context) (time.Time, error) { + tz := getTimeZone(ctx) + tVal, err := getStmtTimestamp(ctx) + if err != nil { + return tVal, err + } + return tVal.In(tz), nil +} + func evalNowWithFsp(ctx sessionctx.Context, fsp int8) (types.Time, bool, error) { nowTs, err := getStmtTimestamp(ctx) if err != nil { @@ -7205,6 +7215,11 @@ func (b *builtinTiDBBoundedStalenessSig) evalTime(row chunk.Row) (types.Time, bo return types.NewTime(types.FromGoTime(calAppropriateTime(minTime, maxTime, getMinSafeTime(b.ctx, timeZone))), mysql.TypeDatetime, 3), false, nil } +// GetMinSafeTime get minSafeTime +func GetMinSafeTime(sessionCtx sessionctx.Context) time.Time { + return getMinSafeTime(sessionCtx, getTimeZone(sessionCtx)) +} + func getMinSafeTime(sessionCtx sessionctx.Context, timeZone *time.Location) time.Time { var minSafeTS uint64 txnScope := config.GetTxnScopeFromConfig() @@ -7222,6 +7237,11 @@ func getMinSafeTime(sessionCtx sessionctx.Context, timeZone *time.Location) time return oracle.GetTimeFromTS(minSafeTS).In(timeZone) } +// CalAppropriateTime directly calls calAppropriateTime +func CalAppropriateTime(minTime, maxTime, minSafeTime time.Time) time.Time { + return calAppropriateTime(minTime, maxTime, minSafeTime) +} + // For a SafeTS t and a time range [t1, t2]: // 1. If t < t1, we will use t1 as the result, // and with it, a read request may fail because it's an unreached SafeTS. diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 5a9410f3a1a86..979cb0f6c83d9 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -2571,6 +2571,16 @@ func calculateTsExpr(sctx sessionctx.Context, asOfClause *ast.AsOfClause) (uint6 return oracle.GoTimeToTS(tsTime), nil } +func calculateTsWithReadStaleness(sctx sessionctx.Context, readStaleness time.Duration) (uint64, error) { + nowVal, err := expression.GetStmtTimestamp(sctx) + if err != nil { + return 0, err + } + tsVal := nowVal.Add(readStaleness) + minTsVal := expression.GetMinSafeTime(sctx) + return oracle.GoTimeToTS(expression.CalAppropriateTime(tsVal, nowVal, minTsVal)), nil +} + func collectVisitInfoFromRevokeStmt(sctx sessionctx.Context, vi []visitInfo, stmt *ast.RevokeStmt) ([]visitInfo, error) { // To use REVOKE, you must have the GRANT OPTION privilege, // and you must have the privileges that you are granting. diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 73b3f052b6530..63126b612a35e 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -1540,15 +1540,16 @@ func (p *preprocessor) handleAsOfAndReadTS(node *ast.AsOfClause) { // If the statement is in auto-commit mode, we will check whether there exists read_ts, if exists, // we will directly use it. The txnScope will be defined by the zone label, if it is not set, we will use // global txnScope directly. - ts := p.ctx.GetSessionVars().TxnReadTS.UseTxnReadTS() - if ts > 0 { + readTS := p.ctx.GetSessionVars().TxnReadTS.UseTxnReadTS() + readStaleness := p.ctx.GetSessionVars().ReadStaleness + var ts uint64 + switch { + case readTS > 0: + ts = readTS if node != nil { p.err = ErrAsOf.FastGenWithCause("can't use select as of while already set transaction as of") return } - // it means we meet following case: - // 1. set transaction read only as of timestamp ts - // 2. select statement if !p.initedLastSnapshotTS { p.SnapshotTSEvaluator = func(sessionctx.Context) (uint64, error) { return ts, nil @@ -1556,8 +1557,11 @@ func (p *preprocessor) handleAsOfAndReadTS(node *ast.AsOfClause) { p.LastSnapshotTS = ts p.setStalenessReturn() } - } - if node != nil { + case readTS == 0 && node != nil: + // If we didn't use read_ts, and node isn't nil, it means we use 'select table as of timestamp ... ' + // for stale read + // It means we meet following case: + // select statement with as of timestamp ts, p.err = calculateTsExpr(p.ctx, node) if p.err != nil { return @@ -1566,8 +1570,6 @@ func (p *preprocessor) handleAsOfAndReadTS(node *ast.AsOfClause) { p.err = errors.Trace(err) return } - // It means we meet following case: - // select statement with as of timestamp if !p.initedLastSnapshotTS { p.SnapshotTSEvaluator = func(ctx sessionctx.Context) (uint64, error) { return calculateTsExpr(ctx, node) @@ -1575,6 +1577,22 @@ func (p *preprocessor) handleAsOfAndReadTS(node *ast.AsOfClause) { p.LastSnapshotTS = ts p.setStalenessReturn() } + case readTS == 0 && node == nil && readStaleness != 0: + ts, p.err = calculateTsWithReadStaleness(p.ctx, readStaleness) + if p.err != nil { + return + } + if err := sessionctx.ValidateStaleReadTS(context.Background(), p.ctx, ts); err != nil { + p.err = errors.Trace(err) + return + } + if !p.initedLastSnapshotTS { + p.SnapshotTSEvaluator = func(ctx sessionctx.Context) (uint64, error) { + return calculateTsWithReadStaleness(p.ctx, readStaleness) + } + p.LastSnapshotTS = ts + p.setStalenessReturn() + } } if p.LastSnapshotTS != ts { p.err = ErrAsOf.GenWithStack("can not set different time in the as of") diff --git a/server/server.go b/server/server.go index fa0541054b8ba..617adfe1c36bc 100644 --- a/server/server.go +++ b/server/server.go @@ -249,6 +249,12 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { } if s.cfg.Socket != "" { + + err := cleanupStaleSocket(s.cfg.Socket) + if err != nil { + return nil, errors.Trace(err) + } + if s.socket, err = net.Listen("unix", s.cfg.Socket); err != nil { return nil, errors.Trace(err) } @@ -295,6 +301,30 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { return s, nil } +func cleanupStaleSocket(socket string) error { + sockStat, err := os.Stat(socket) + if err == nil { + if sockStat.Mode().Type() != os.ModeSocket { + return fmt.Errorf( + "the specified socket file %s is a %s instead of a socket file", + socket, sockStat.Mode().String()) + } + + _, err = net.Dial("unix", socket) + if err != nil { + logutil.BgLogger().Warn("Unix socket exists and is nonfunctional, removing it", + zap.String("socket", socket), zap.Error(err)) + err = os.Remove(socket) + if err != nil { + return fmt.Errorf("failed to remove socket file %s", socket) + } + } else { + return fmt.Errorf("unix socket %s exists and is functional, not removing it", socket) + } + } + return nil +} + func setSSLVariable(ca, key, cert string) { variable.SetSysVar("have_openssl", "YES") variable.SetSysVar("have_ssl", "YES") diff --git a/server/statistics_handler_serial_test.go b/server/statistics_handler_serial_test.go new file mode 100644 index 0000000000000..7c56cf2186831 --- /dev/null +++ b/server/statistics_handler_serial_test.go @@ -0,0 +1,231 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "database/sql" + "fmt" + "io" + "os" + "testing" + "time" + + "github.com/go-sql-driver/mysql" + "github.com/gorilla/mux" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/statistics/handle" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" +) + +func TestDumpStatsAPI(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + driver := NewTiDBDriver(store) + client := newTestServerClient() + cfg := newTestConfig() + cfg.Port = client.port + cfg.Status.StatusPort = client.statusPort + cfg.Status.ReportStatus = true + + server, err := NewServer(cfg, driver) + require.NoError(t, err) + defer server.Close() + + client.port = getPortFromTCPAddr(server.listener.Addr()) + client.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) + go func() { + err := server.Run() + require.NoError(t, err) + }() + client.waitUntilServerOnline() + + dom, err := session.GetDomain(store) + require.NoError(t, err) + statsHandler := &StatsHandler{dom} + + prepareData(t, client, statsHandler) + + router := mux.NewRouter() + router.Handle("/stats/dump/{db}/{table}", statsHandler) + + resp0, err := client.fetchStatus("/stats/dump/tidb/test") + require.NoError(t, err) + defer func() { + require.NoError(t, resp0.Body.Close()) + }() + + path := "/tmp/stats.json" + fp, err := os.Create(path) + require.NoError(t, err) + require.NotNil(t, fp) + defer func() { + require.NoError(t, fp.Close()) + require.NoError(t, os.Remove(path)) + }() + + js, err := io.ReadAll(resp0.Body) + require.NoError(t, err) + _, err = fp.Write(js) + require.NoError(t, err) + checkData(t, path, client) + checkCorrelation(t, client) + + // sleep for 1 seconds to ensure the existence of tidb.test + time.Sleep(time.Second) + timeBeforeDropStats := time.Now() + snapshot := timeBeforeDropStats.Format("20060102150405") + prepare4DumpHistoryStats(t, client) + + // test dump history stats + resp1, err := client.fetchStatus("/stats/dump/tidb/test") + require.NoError(t, err) + defer func() { + require.NoError(t, resp1.Body.Close()) + }() + js, err = io.ReadAll(resp1.Body) + require.NoError(t, err) + require.Equal(t, "null", string(js)) + + path1 := "/tmp/stats_history.json" + fp1, err := os.Create(path1) + require.NoError(t, err) + require.NotNil(t, fp1) + defer func() { + require.NoError(t, fp1.Close()) + require.NoError(t, os.Remove(path1)) + }() + + resp2, err := client.fetchStatus("/stats/dump/tidb/test/" + snapshot) + require.NoError(t, err) + defer func() { + require.NoError(t, resp2.Body.Close()) + }() + js, err = io.ReadAll(resp2.Body) + require.NoError(t, err) + _, err = fp1.Write(js) + require.NoError(t, err) + checkData(t, path1, client) +} + +func prepareData(t *testing.T, client *testServerClient, statHandle *StatsHandler) { + db, err := sql.Open("mysql", client.getDSN()) + require.NoError(t, err, "Error connecting") + defer func() { + err := db.Close() + require.NoError(t, err) + }() + tk := testkit.NewDBTestKit(t, db) + + h := statHandle.do.StatsHandle() + tk.MustExec("create database tidb") + tk.MustExec("use tidb") + tk.MustExec("create table test (a int, b varchar(20))") + err = h.HandleDDLEvent(<-h.DDLEventCh()) + require.NoError(t, err) + tk.MustExec("create index c on test (a, b)") + tk.MustExec("insert test values (1, 's')") + require.NoError(t, h.DumpStatsDeltaToKV(handle.DumpAll)) + tk.MustExec("analyze table test") + tk.MustExec("insert into test(a,b) values (1, 'v'),(3, 'vvv'),(5, 'vv')") + is := statHandle.do.InfoSchema() + require.NoError(t, h.DumpStatsDeltaToKV(handle.DumpAll)) + require.NoError(t, h.Update(is)) +} + +func prepare4DumpHistoryStats(t *testing.T, client *testServerClient) { + db, err := sql.Open("mysql", client.getDSN()) + require.NoError(t, err, "Error connecting") + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + tk := testkit.NewDBTestKit(t, db) + + safePointName := "tikv_gc_safe_point" + safePointValue := "20060102-15:04:05 -0700" + safePointComment := "All versions after safe point can be accessed. (DO NOT EDIT)" + updateSafePoint := fmt.Sprintf(`INSERT INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s') + ON DUPLICATE KEY + UPDATE variable_value = '%[2]s', comment = '%[3]s'`, safePointName, safePointValue, safePointComment) + tk.MustExec(updateSafePoint) + + tk.MustExec("drop table tidb.test") + tk.MustExec("create table tidb.test (a int, b varchar(20))") +} + +func checkCorrelation(t *testing.T, client *testServerClient) { + db, err := sql.Open("mysql", client.getDSN()) + require.NoError(t, err, "Error connecting") + tk := testkit.NewDBTestKit(t, db) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + tk.MustExec("use tidb") + rows := tk.MustQuery("SELECT tidb_table_id FROM information_schema.tables WHERE table_name = 'test' AND table_schema = 'tidb'") + var tableID int64 + if rows.Next() { + err = rows.Scan(&tableID) + require.NoError(t, err) + require.False(t, rows.Next(), "unexpected data") + } else { + require.FailNow(t, "no data") + } + require.NoError(t, rows.Close()) + rows = tk.MustQuery("select correlation from mysql.stats_histograms where table_id = ? and hist_id = 1 and is_index = 0", tableID) + if rows.Next() { + var corr float64 + err = rows.Scan(&corr) + require.NoError(t, err) + require.Equal(t, float64(1), corr) + require.False(t, rows.Next(), "unexpected data") + } else { + require.FailNow(t, "no data") + } + require.NoError(t, rows.Close()) +} + +func checkData(t *testing.T, path string, client *testServerClient) { + db, err := sql.Open("mysql", client.getDSN(func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + })) + require.NoError(t, err, "Error connecting") + tk := testkit.NewDBTestKit(t, db) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + tk.MustExec("use tidb") + tk.MustExec("drop stats test") + tk.MustExec(fmt.Sprintf("load stats '%s'", path)) + + rows := tk.MustQuery("show stats_meta") + require.True(t, rows.Next(), "unexpected data") + var dbName, tableName string + var modifyCount, count int64 + var other interface{} + err = rows.Scan(&dbName, &tableName, &other, &other, &modifyCount, &count) + require.NoError(t, err) + require.Equal(t, "tidb", dbName) + require.Equal(t, "test", tableName) + require.Equal(t, int64(3), modifyCount) + require.Equal(t, int64(4), count) +} diff --git a/server/statistics_handler_test.go b/server/statistics_handler_test.go deleted file mode 100644 index d3a1af67e1c0d..0000000000000 --- a/server/statistics_handler_test.go +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package server - -import ( - "database/sql" - "fmt" - "io" - "os" - "time" - - "github.com/go-sql-driver/mysql" - "github.com/gorilla/mux" - . "github.com/pingcap/check" - "github.com/pingcap/tidb/domain" - "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/session" - "github.com/pingcap/tidb/statistics/handle" - "github.com/pingcap/tidb/store/mockstore" -) - -type testDumpStatsSuite struct { - *testServerClient - server *Server - sh *StatsHandler - store kv.Storage - domain *domain.Domain -} - -var _ = Suite(&testDumpStatsSuite{ - testServerClient: newTestServerClient(), -}) - -func (ds *testDumpStatsSuite) startServer(c *C) { - var err error - ds.store, err = mockstore.NewMockStore() - c.Assert(err, IsNil) - session.DisableStats4Test() - ds.domain, err = session.BootstrapSession(ds.store) - c.Assert(err, IsNil) - ds.domain.SetStatsUpdating(true) - tidbdrv := NewTiDBDriver(ds.store) - - cfg := newTestConfig() - cfg.Port = ds.port - cfg.Status.StatusPort = ds.statusPort - cfg.Status.ReportStatus = true - - server, err := NewServer(cfg, tidbdrv) - c.Assert(err, IsNil) - ds.port = getPortFromTCPAddr(server.listener.Addr()) - ds.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) - ds.server = server - go func() { - err := server.Run() - c.Assert(err, IsNil) - }() - ds.waitUntilServerOnline() - - do, err := session.GetDomain(ds.store) - c.Assert(err, IsNil) - ds.sh = &StatsHandler{do} -} - -func (ds *testDumpStatsSuite) stopServer(c *C) { - if ds.domain != nil { - ds.domain.Close() - } - if ds.store != nil { - ds.store.Close() - } - if ds.server != nil { - ds.server.Close() - } -} - -func (ds *testDumpStatsSuite) TestDumpStatsAPI(c *C) { - ds.startServer(c) - defer ds.stopServer(c) - ds.prepareData(c) - - router := mux.NewRouter() - router.Handle("/stats/dump/{db}/{table}", ds.sh) - - resp, err := ds.fetchStatus("/stats/dump/tidb/test") - c.Assert(err, IsNil) - defer resp.Body.Close() - - path := "/tmp/stats.json" - fp, err := os.Create(path) - c.Assert(err, IsNil) - c.Assert(fp, NotNil) - defer func() { - c.Assert(fp.Close(), IsNil) - c.Assert(os.Remove(path), IsNil) - }() - - js, err := io.ReadAll(resp.Body) - c.Assert(err, IsNil) - _, err = fp.Write(js) - c.Assert(err, IsNil) - ds.checkData(c, path) - ds.checkCorrelation(c) - - // sleep for 1 seconds to ensure the existence of tidb.test - time.Sleep(time.Second) - timeBeforeDropStats := time.Now() - snapshot := timeBeforeDropStats.Format("20060102150405") - ds.prepare4DumpHistoryStats(c) - - // test dump history stats - resp1, err := ds.fetchStatus("/stats/dump/tidb/test") - c.Assert(err, IsNil) - defer resp1.Body.Close() - js, err = io.ReadAll(resp1.Body) - c.Assert(err, IsNil) - c.Assert(string(js), Equals, "null") - - path1 := "/tmp/stats_history.json" - fp1, err := os.Create(path1) - c.Assert(err, IsNil) - c.Assert(fp1, NotNil) - defer func() { - c.Assert(fp1.Close(), IsNil) - c.Assert(os.Remove(path1), IsNil) - }() - - resp1, err = ds.fetchStatus("/stats/dump/tidb/test/" + snapshot) - c.Assert(err, IsNil) - - js, err = io.ReadAll(resp1.Body) - c.Assert(err, IsNil) - _, err = fp1.Write(js) - c.Assert(err, IsNil) - ds.checkData(c, path1) -} - -func (ds *testDumpStatsSuite) prepareData(c *C) { - db, err := sql.Open("mysql", ds.getDSN()) - c.Assert(err, IsNil, Commentf("Error connecting")) - defer func() { - err := db.Close() - c.Assert(err, IsNil) - }() - dbt := &DBTest{c, db} - - h := ds.sh.do.StatsHandle() - dbt.mustExec("create database tidb") - dbt.mustExec("use tidb") - dbt.mustExec("create table test (a int, b varchar(20))") - err = h.HandleDDLEvent(<-h.DDLEventCh()) - c.Assert(err, IsNil) - dbt.mustExec("create index c on test (a, b)") - dbt.mustExec("insert test values (1, 's')") - c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) - dbt.mustExec("analyze table test") - dbt.mustExec("insert into test(a,b) values (1, 'v'),(3, 'vvv'),(5, 'vv')") - is := ds.sh.do.InfoSchema() - c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) - c.Assert(h.Update(is), IsNil) -} - -func (ds *testDumpStatsSuite) prepare4DumpHistoryStats(c *C) { - db, err := sql.Open("mysql", ds.getDSN()) - c.Assert(err, IsNil, Commentf("Error connecting")) - defer func() { - err := db.Close() - c.Assert(err, IsNil) - }() - - dbt := &DBTest{c, db} - - safePointName := "tikv_gc_safe_point" - safePointValue := "20060102-15:04:05 -0700" - safePointComment := "All versions after safe point can be accessed. (DO NOT EDIT)" - updateSafePoint := fmt.Sprintf(`INSERT INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s') - ON DUPLICATE KEY - UPDATE variable_value = '%[2]s', comment = '%[3]s'`, safePointName, safePointValue, safePointComment) - dbt.mustExec(updateSafePoint) - - dbt.mustExec("drop table tidb.test") - dbt.mustExec("create table tidb.test (a int, b varchar(20))") -} - -func (ds *testDumpStatsSuite) checkCorrelation(c *C) { - db, err := sql.Open("mysql", ds.getDSN()) - c.Assert(err, IsNil, Commentf("Error connecting")) - dbt := &DBTest{c, db} - defer func() { - err := db.Close() - c.Assert(err, IsNil) - }() - - dbt.mustExec("use tidb") - rows := dbt.mustQuery("SELECT tidb_table_id FROM information_schema.tables WHERE table_name = 'test' AND table_schema = 'tidb'") - var tableID int64 - if rows.Next() { - err = rows.Scan(&tableID) - c.Assert(err, IsNil) - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) - } else { - dbt.Error("no data") - } - rows.Close() - rows = dbt.mustQuery("select correlation from mysql.stats_histograms where table_id = ? and hist_id = 1 and is_index = 0", tableID) - if rows.Next() { - var corr float64 - err = rows.Scan(&corr) - c.Assert(err, IsNil) - dbt.Check(corr, Equals, float64(1)) - dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) - } else { - dbt.Error("no data") - } - rows.Close() -} - -func (ds *testDumpStatsSuite) checkData(c *C, path string) { - db, err := sql.Open("mysql", ds.getDSN(func(config *mysql.Config) { - config.AllowAllFiles = true - config.Params["sql_mode"] = "''" - })) - c.Assert(err, IsNil, Commentf("Error connecting")) - dbt := &DBTest{c, db} - defer func() { - err := db.Close() - c.Assert(err, IsNil) - }() - - dbt.mustExec("use tidb") - dbt.mustExec("drop stats test") - _, err = dbt.db.Exec(fmt.Sprintf("load stats '%s'", path)) - c.Assert(err, IsNil) - - rows := dbt.mustQuery("show stats_meta") - dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) - var dbName, tableName string - var modifyCount, count int64 - var other interface{} - err = rows.Scan(&dbName, &tableName, &other, &other, &modifyCount, &count) - dbt.Check(err, IsNil) - dbt.Check(dbName, Equals, "tidb") - dbt.Check(tableName, Equals, "test") - dbt.Check(modifyCount, Equals, int64(3)) - dbt.Check(count, Equals, int64(4)) -} diff --git a/session/session_test.go b/session/session_test.go index 9426240ce7154..ab9f99821d6e4 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -5685,3 +5685,13 @@ func (s *testTiDBAsLibrary) TestMemoryLeak(c *C) { runtime.ReadMemStats(&memStat) c.Assert(memStat.HeapInuse-oldHeapInUse, Less, uint64(150*units.MiB)) } + +func (s *testSessionSuite) TestTiDBReadStaleness(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("set @@tidb_read_staleness='-5s'") + err := tk.ExecToErr("set @@tidb_read_staleness='-5'") + c.Assert(err, NotNil) + err = tk.ExecToErr("set @@tidb_read_staleness='foo'") + c.Assert(err, NotNil) + tk.MustExec("set @@tidb_read_staleness=''") +} diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 0689c4551b197..fe0a57990a752 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -114,12 +114,11 @@ type StatementContext struct { copied uint64 touched uint64 - message string - warnings []SQLWarn - errorCount uint16 - histogramsNotLoad bool - execDetails execdetails.ExecDetails - allExecDetails []*execdetails.ExecDetails + message string + warnings []SQLWarn + errorCount uint16 + execDetails execdetails.ExecDetails + allExecDetails []*execdetails.ExecDetails } // PrevAffectedRows is the affected-rows value(DDL is 0, DML is the number of affected rows). PrevAffectedRows int64 @@ -527,13 +526,6 @@ func (sc *StatementContext) AppendError(warn error) { sc.mu.Unlock() } -// SetHistogramsNotLoad sets histogramsNotLoad. -func (sc *StatementContext) SetHistogramsNotLoad() { - sc.mu.Lock() - sc.mu.histogramsNotLoad = true - sc.mu.Unlock() -} - // HandleTruncate ignores or returns the error based on the StatementContext state. func (sc *StatementContext) HandleTruncate(err error) error { // TODO: At present we have not checked whether the error can be ignored or treated as warning. diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 0c6bd9c718851..63be07c8a18f2 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -947,6 +947,9 @@ type SessionVars struct { // MPPStoreFailTTL indicates the duration that protect TiDB from sending task to a new recovered TiFlash. MPPStoreFailTTL string + // ReadStaleness indicates the staleness duration for the following query + ReadStaleness time.Duration + // cached is used to optimze the object allocation. cached struct { curr int8 diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 211e852676520..74b66b284c81d 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -852,6 +852,9 @@ var defaultSysVars = []*SysVar{ }, Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { return normalizedValue, nil }}, + {Scope: ScopeSession, Name: TiDBReadStaleness, Value: "", Hidden: false, SetSession: func(s *SessionVars, val string) error { + return setReadStaleness(s, val) + }}, {Scope: ScopeGlobal | ScopeSession, Name: TiDBAllowMPPExecution, Type: TypeBool, Value: BoolToOnOff(DefTiDBAllowMPPExecution), SetSession: func(s *SessionVars, val string) error { s.allowMPPExecution = TiDBOptOn(val) return nil diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 06e1e37fc06e3..c9d38dd5ba507 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -208,6 +208,9 @@ const ( // TiDBTxnReadTS indicates the next transaction should be staleness transaction and provide the startTS TiDBTxnReadTS = "tx_read_ts" + + // TiDBReadStaleness indicates the staleness duration for following statement + TiDBReadStaleness = "tidb_read_staleness" ) // TiDB system variable names that both in session and global scope. diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index cdf7aa678da6e..76c5b7190b9ff 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -437,6 +437,19 @@ func setTxnReadTS(s *SessionVars, sVal string) error { return err } +func setReadStaleness(s *SessionVars, sVal string) error { + if sVal == "" { + s.ReadStaleness = 0 + return nil + } + d, err := time.ParseDuration(sVal) + if err != nil { + return err + } + s.ReadStaleness = d + return nil +} + // serverGlobalVariable is used to handle variables that acts in server and global scope. type serverGlobalVariable struct { sync.Mutex diff --git a/statistics/histogram.go b/statistics/histogram.go index 9ec1663ff3efc..bd5f3e35a5c01 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -1072,7 +1072,6 @@ func (c *Column) IsInvalid(sc *stmtctx.StatementContext, collPseudo bool) bool { return true } if c.Histogram.NDV > 0 && c.notNullCount() == 0 && sc != nil { - sc.SetHistogramsNotLoad() HistogramNeededColumns.insert(tableColumnID{TableID: c.PhysicalID, ColumnID: c.Info.ID}) } return c.TotalRowCount() == 0 || (c.Histogram.NDV > 0 && c.notNullCount() == 0) diff --git a/store/batch_coprocessor_test.go b/store/batch_coprocessor_serial_test.go similarity index 51% rename from store/batch_coprocessor_test.go rename to store/batch_coprocessor_serial_test.go index 7f48a35302e6e..9022875cd9586 100644 --- a/store/batch_coprocessor_test.go +++ b/store/batch_coprocessor_serial_test.go @@ -17,30 +17,24 @@ package store import ( "context" "fmt" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/domain" - "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/mockstore/unistore" "github.com/pingcap/tidb/table" - "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/testutils" ) -type testBatchCopSuite struct { -} - -var _ = SerialSuites(&testBatchCopSuite{}) - -func newStoreWithBootstrap(tiflashNum int) (kv.Storage, *domain.Domain, error) { - store, err := mockstore.NewMockStore( +func createMockTiKVStoreOptions(tiflashNum int) []mockstore.MockTiKVStoreOption { + return []mockstore.MockTiKVStoreOption{ mockstore.WithClusterInspector(func(c testutils.Cluster) { mockCluster := c.(*unistore.Cluster) _, _, region1 := mockstore.BootstrapWithSingleStore(c) @@ -55,98 +49,82 @@ func newStoreWithBootstrap(tiflashNum int) (kv.Storage, *domain.Domain, error) { } }), mockstore.WithStoreType(mockstore.EmbedUnistore), - ) - - if err != nil { - return nil, nil, errors.Trace(err) } - - session.SetSchemaLease(0) - session.DisableStats4Test() - - dom, err := session.BootstrapSession(store) - if err != nil { - return nil, nil, err - } - - dom.SetStatsUpdating(true) - return store, dom, errors.Trace(err) } -func testGetTableByName(c *C, ctx sessionctx.Context, db, table string) table.Table { +func testGetTableByName(t *testing.T, ctx sessionctx.Context, db, table string) table.Table { dom := domain.GetDomain(ctx) // Make sure the table schema is the new schema. err := dom.Reload() - c.Assert(err, IsNil) + require.NoError(t, err) tbl, err := dom.InfoSchema().TableByName(model.NewCIStr(db), model.NewCIStr(table)) - c.Assert(err, IsNil) + require.NoError(t, err) return tbl } -func (s *testBatchCopSuite) TestStoreErr(c *C) { - store, dom, err := newStoreWithBootstrap(1) - c.Assert(err, IsNil) +func TestStoreErr(t *testing.T) { + store, clean := testkit.CreateMockStore(t, createMockTiKVStoreOptions(1)...) + defer clean() + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/infoschema/mockTiFlashStoreCount", `return(true)`)) defer func() { - dom.Close() - store.Close() + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/infoschema/mockTiFlashStoreCount")) }() - c.Assert(failpoint.Enable("github.com/pingcap/tidb/infoschema/mockTiFlashStoreCount", `return(true)`), IsNil) - defer failpoint.Disable("github.com/pingcap/tidb/infoschema/mockTiFlashStoreCount") - - tk := testkit.NewTestKit(c, store) + tk := testkit.NewTestKit(t, store) tk.MustExec("use test") tk.MustExec("create table t(a int not null, b int not null)") tk.MustExec("alter table t set tiflash replica 1") - tb := testGetTableByName(c, tk.Se, "test", "t") - err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) - c.Assert(err, IsNil) + tb := testGetTableByName(t, tk.Session(), "test", "t") + + err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + tk.MustExec("insert into t values(1,0)") tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") tk.MustExec("set @@session.tidb_allow_mpp=OFF") - c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/BatchCopCancelled", "1*return(true)"), IsNil) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/BatchCopCancelled", "1*return(true)")) err = tk.QueryToErr("select count(*) from t") - c.Assert(errors.Cause(err), Equals, context.Canceled) + require.Equal(t, context.Canceled, errors.Cause(err)) - c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/BatchCopRpcErrtiflash0", "1*return(\"tiflash0\")"), IsNil) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/BatchCopRpcErrtiflash0", "1*return(\"tiflash0\")")) tk.MustQuery("select count(*) from t").Check(testkit.Rows("1")) - c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/BatchCopRpcErrtiflash0", "return(\"tiflash0\")"), IsNil) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/BatchCopRpcErrtiflash0", "return(\"tiflash0\")")) err = tk.QueryToErr("select count(*) from t") - c.Assert(err, NotNil) + require.Error(t, err) } -func (s *testBatchCopSuite) TestStoreSwitchPeer(c *C) { - store, dom, err := newStoreWithBootstrap(2) - c.Assert(err, IsNil) +func TestStoreSwitchPeer(t *testing.T) { + store, clean := testkit.CreateMockStore(t, createMockTiKVStoreOptions(2)...) + defer clean() + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/infoschema/mockTiFlashStoreCount", `return(true)`)) defer func() { - dom.Close() - store.Close() + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/infoschema/mockTiFlashStoreCount")) }() - c.Assert(failpoint.Enable("github.com/pingcap/tidb/infoschema/mockTiFlashStoreCount", `return(true)`), IsNil) - defer failpoint.Disable("github.com/pingcap/tidb/infoschema/mockTiFlashStoreCount") - - tk := testkit.NewTestKit(c, store) + tk := testkit.NewTestKit(t, store) tk.MustExec("use test") tk.MustExec("create table t(a int not null, b int not null)") tk.MustExec("alter table t set tiflash replica 1") - tb := testGetTableByName(c, tk.Se, "test", "t") - err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) - c.Assert(err, IsNil) + tb := testGetTableByName(t, tk.Session(), "test", "t") + + err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + tk.MustExec("insert into t values(1,0)") tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") tk.MustExec("set @@session.tidb_allow_mpp=OFF") - c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/BatchCopRpcErrtiflash0", "return(\"tiflash0\")"), IsNil) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/BatchCopRpcErrtiflash0", "return(\"tiflash0\")")) tk.MustQuery("select count(*) from t").Check(testkit.Rows("1")) - c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/BatchCopRpcErrtiflash1", "return(\"tiflash1\")"), IsNil) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/BatchCopRpcErrtiflash1", "return(\"tiflash1\")")) err = tk.QueryToErr("select count(*) from t") - c.Assert(err, NotNil) - + require.Error(t, err) } diff --git a/store/main_test.go b/store/main_test.go new file mode 100644 index 0000000000000..40703aa3d2e51 --- /dev/null +++ b/store/main_test.go @@ -0,0 +1,31 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package store + +import ( + "testing" + + "github.com/pingcap/tidb/util/testbridge" + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + testbridge.WorkaroundGoCheckFlags() + opts := []goleak.Option{ + goleak.IgnoreTopFunction("go.etcd.io/etcd/pkg/logutil.(*MergeLogger).outputLoop"), + goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), + } + goleak.VerifyTestMain(m, opts...) +} diff --git a/store/mockstore/unistore/cophandler/closure_exec.go b/store/mockstore/unistore/cophandler/closure_exec.go index 29512287db53c..94c37dc71491f 100644 --- a/store/mockstore/unistore/cophandler/closure_exec.go +++ b/store/mockstore/unistore/cophandler/closure_exec.go @@ -330,10 +330,6 @@ func (e *closureExecutor) initIdxScanCtx(idxScan *tipb.IndexScan) { for i, col := range colInfos[:e.idxScanCtx.columnLen] { colIDs[col.ID] = i } - e.scanCtx.newCollationIds = colIDs - - // We don't need to decode handle here, and colIDs >= 0 always. - e.scanCtx.newCollationRd = rowcodec.NewByteDecoder(colInfos[:e.idxScanCtx.columnLen], []int64{-1}, nil, nil) } func isCountAgg(pbAgg *tipb.Aggregation) bool { @@ -515,9 +511,7 @@ type scanCtx struct { desc bool decoder *rowcodec.ChunkDecoder - newCollationRd *rowcodec.BytesDecoder - newCollationIds map[int64]int - execDetail *execDetail + execDetail *execDetail } type idxScanCtx struct { diff --git a/store/store_test.go b/store/store_test.go index 98a3acf365289..c02f87f312d93 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -17,7 +17,6 @@ package store import ( "context" "fmt" - "os" "strconv" "sync" "sync/atomic" @@ -27,8 +26,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/mockstore" - "github.com/pingcap/tidb/util/logutil" - "github.com/pingcap/tidb/util/testleak" + "github.com/stretchr/testify/require" ) const ( @@ -39,49 +37,23 @@ const ( type brokenStore struct{} -func (s *brokenStore) Open(schema string) (kv.Storage, error) { +func (s *brokenStore) Open(_ string) (kv.Storage, error) { return nil, kv.ErrTxnRetryable } -func TestT(t *testing.T) { - CustomVerboseFlag = true - logLevel := os.Getenv("log_level") - logutil.InitLogger(logutil.NewLogConfig(logLevel, logutil.DefaultLogFormat, "", logutil.EmptyFileLogConfig, false)) - TestingT(t) -} - -var _ = Suite(&testKVSuite{}) - -type testKVSuite struct { - s kv.Storage -} - -func (s *testKVSuite) SetUpSuite(c *C) { - testleak.BeforeTest() - store, err := mockstore.NewMockStore() - c.Assert(err, IsNil) - s.s = store -} - -func (s *testKVSuite) TearDownSuite(c *C) { - err := s.s.Close() - c.Assert(err, IsNil) - testleak.AfterTest(c)() -} - -func insertData(c *C, txn kv.Transaction) { +func insertData(t *testing.T, txn kv.Transaction) { for i := startIndex; i < testCount; i++ { val := encodeInt(i * indexStep) err := txn.Set(val, val) - c.Assert(err, IsNil) + require.NoError(t, err) } } -func mustDel(c *C, txn kv.Transaction) { +func mustDel(t *testing.T, txn kv.Transaction) { for i := startIndex; i < testCount; i++ { val := encodeInt(i * indexStep) err := txn.Delete(val) - c.Assert(err, IsNil) + require.NoError(t, err) } } @@ -91,22 +63,22 @@ func encodeInt(n int) []byte { func decodeInt(s []byte) int { var n int - fmt.Sscanf(string(s), "%010d", &n) + _, _ = fmt.Sscanf(string(s), "%010d", &n) return n } -func valToStr(c *C, iter kv.Iterator) string { +func valToStr(iter kv.Iterator) string { val := iter.Value() return string(val) } -func checkSeek(c *C, txn kv.Transaction) { +func checkSeek(t *testing.T, txn kv.Transaction) { for i := startIndex; i < testCount; i++ { val := encodeInt(i * indexStep) iter, err := txn.Iter(val, nil) - c.Assert(err, IsNil) - c.Assert([]byte(iter.Key()), BytesEquals, val) - c.Assert(decodeInt([]byte(valToStr(c, iter))), Equals, i*indexStep) + require.NoError(t, err) + require.Equal(t, val, []byte(iter.Key())) + require.Equal(t, i*indexStep, decodeInt([]byte(valToStr(iter)))) iter.Close() } @@ -114,306 +86,396 @@ func checkSeek(c *C, txn kv.Transaction) { for i := startIndex; i < testCount-1; i++ { val := encodeInt(i * indexStep) iter, err := txn.Iter(val, nil) - c.Assert(err, IsNil) - c.Assert([]byte(iter.Key()), BytesEquals, val) - c.Assert(valToStr(c, iter), Equals, string(val)) + require.NoError(t, err) + require.Equal(t, val, []byte(iter.Key())) + require.Equal(t, string(val), valToStr(iter)) err = iter.Next() - c.Assert(err, IsNil) - c.Assert(iter.Valid(), IsTrue) + require.NoError(t, err) + require.True(t, iter.Valid()) val = encodeInt((i + 1) * indexStep) - c.Assert([]byte(iter.Key()), BytesEquals, val) - c.Assert(valToStr(c, iter), Equals, string(val)) + require.Equal(t, val, []byte(iter.Key())) + require.Equal(t, string(val), valToStr(iter)) iter.Close() } // Non exist and beyond maximum seek test iter, err := txn.Iter(encodeInt(testCount*indexStep), nil) - c.Assert(err, IsNil) - c.Assert(iter.Valid(), IsFalse) + require.NoError(t, err) + require.False(t, iter.Valid()) // Non exist but between existing keys seek test, // it returns the smallest key that larger than the one we are seeking inBetween := encodeInt((testCount-1)*indexStep - 1) last := encodeInt((testCount - 1) * indexStep) iter, err = txn.Iter(inBetween, nil) - c.Assert(err, IsNil) - c.Assert(iter.Valid(), IsTrue) - c.Assert([]byte(iter.Key()), Not(BytesEquals), inBetween) - c.Assert([]byte(iter.Key()), BytesEquals, last) + require.NoError(t, err) + require.True(t, iter.Valid()) + require.NotEqual(t, inBetween, []byte(iter.Key())) + require.Equal(t, last, []byte(iter.Key())) iter.Close() } -func mustNotGet(c *C, txn kv.Transaction) { +func mustNotGet(t *testing.T, txn kv.Transaction) { for i := startIndex; i < testCount; i++ { s := encodeInt(i * indexStep) _, err := txn.Get(context.TODO(), s) - c.Assert(err, NotNil) + require.Error(t, err) } } -func mustGet(c *C, txn kv.Transaction) { +func mustGet(t *testing.T, txn kv.Transaction) { for i := startIndex; i < testCount; i++ { s := encodeInt(i * indexStep) val, err := txn.Get(context.TODO(), s) - c.Assert(err, IsNil) - c.Assert(string(val), Equals, string(s)) + require.NoError(t, err) + require.Equal(t, string(s), string(val)) } } -func (s *testKVSuite) TestNew(c *C) { +func TestNew(t *testing.T) { store, err := New("goleveldb://relative/path") - c.Assert(err, NotNil) - c.Assert(store, IsNil) + require.Error(t, err) + require.Nil(t, store) } -func (s *testKVSuite) TestGetSet(c *C) { - txn, err := s.s.Begin() - c.Assert(err, IsNil) +func TestGetSet(t *testing.T) { + t.Parallel() - insertData(c, txn) + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() - mustGet(c, txn) + txn, err := store.Begin() + require.NoError(t, err) + + insertData(t, txn) + + mustGet(t, txn) // Check transaction results err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) - txn, err = s.s.Begin() - c.Assert(err, IsNil) - defer txn.Commit(context.Background()) + txn, err = store.Begin() + require.NoError(t, err) + defer func() { + require.NoError(t, txn.Commit(context.Background())) + }() - mustGet(c, txn) - mustDel(c, txn) + mustGet(t, txn) + mustDel(t, txn) } -func (s *testKVSuite) TestSeek(c *C) { - txn, err := s.s.Begin() - c.Assert(err, IsNil) +func TestSeek(t *testing.T) { + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + + txn, err := store.Begin() + require.NoError(t, err) - insertData(c, txn) - checkSeek(c, txn) + insertData(t, txn) + checkSeek(t, txn) // Check transaction results err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) - txn, err = s.s.Begin() - c.Assert(err, IsNil) - defer txn.Commit(context.Background()) + txn, err = store.Begin() + require.NoError(t, err) + defer func() { + require.NoError(t, txn.Commit(context.Background())) + }() - checkSeek(c, txn) - mustDel(c, txn) + checkSeek(t, txn) + mustDel(t, txn) } -func (s *testKVSuite) TestInc(c *C) { - txn, err := s.s.Begin() - c.Assert(err, IsNil) +func TestInc(t *testing.T) { + t.Parallel() + + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + + txn, err := store.Begin() + require.NoError(t, err) key := []byte("incKey") n, err := kv.IncInt64(txn, key, 100) - c.Assert(err, IsNil) - c.Assert(n, Equals, int64(100)) + require.NoError(t, err) + require.Equal(t, int64(100), n) // Check transaction results err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) - txn, err = s.s.Begin() - c.Assert(err, IsNil) + txn, err = store.Begin() + require.NoError(t, err) n, err = kv.IncInt64(txn, key, -200) - c.Assert(err, IsNil) - c.Assert(n, Equals, int64(-100)) + require.NoError(t, err) + require.Equal(t, int64(-100), n) err = txn.Delete(key) - c.Assert(err, IsNil) + require.NoError(t, err) n, err = kv.IncInt64(txn, key, 100) - c.Assert(err, IsNil) - c.Assert(n, Equals, int64(100)) + require.NoError(t, err) + require.Equal(t, int64(100), n) err = txn.Delete(key) - c.Assert(err, IsNil) + require.NoError(t, err) err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *testKVSuite) TestDelete(c *C) { - txn, err := s.s.Begin() - c.Assert(err, IsNil) +func TestDelete(t *testing.T) { + t.Parallel() - insertData(c, txn) + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + + txn, err := store.Begin() + require.NoError(t, err) - mustDel(c, txn) + insertData(t, txn) - mustNotGet(c, txn) + mustDel(t, txn) + + mustNotGet(t, txn) err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) // Try get - txn, err = s.s.Begin() - c.Assert(err, IsNil) + txn, err = store.Begin() + require.NoError(t, err) - mustNotGet(c, txn) + mustNotGet(t, txn) // Insert again - insertData(c, txn) + insertData(t, txn) err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) // Delete all - txn, err = s.s.Begin() - c.Assert(err, IsNil) + txn, err = store.Begin() + require.NoError(t, err) - mustDel(c, txn) + mustDel(t, txn) err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) - txn, err = s.s.Begin() - c.Assert(err, IsNil) + txn, err = store.Begin() + require.NoError(t, err) - mustNotGet(c, txn) + mustNotGet(t, txn) err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *testKVSuite) TestDelete2(c *C) { - txn, err := s.s.Begin() - c.Assert(err, IsNil) +func TestDelete2(t *testing.T) { + t.Parallel() + + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + + txn, err := store.Begin() + require.NoError(t, err) val := []byte("test") - txn.Set([]byte("DATA_test_tbl_department_record__0000000001_0003"), val) - txn.Set([]byte("DATA_test_tbl_department_record__0000000001_0004"), val) - txn.Set([]byte("DATA_test_tbl_department_record__0000000002_0003"), val) - txn.Set([]byte("DATA_test_tbl_department_record__0000000002_0004"), val) + require.NoError(t, txn.Set([]byte("DATA_test_tbl_department_record__0000000001_0003"), val)) + require.NoError(t, txn.Set([]byte("DATA_test_tbl_department_record__0000000001_0004"), val)) + require.NoError(t, txn.Set([]byte("DATA_test_tbl_department_record__0000000002_0003"), val)) + require.NoError(t, txn.Set([]byte("DATA_test_tbl_department_record__0000000002_0004"), val)) err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) // Delete all - txn, err = s.s.Begin() - c.Assert(err, IsNil) + txn, err = store.Begin() + require.NoError(t, err) it, err := txn.Iter([]byte("DATA_test_tbl_department_record__0000000001_0003"), nil) - c.Assert(err, IsNil) + require.NoError(t, err) for it.Valid() { err = txn.Delete(it.Key()) - c.Assert(err, IsNil) + require.NoError(t, err) err = it.Next() - c.Assert(err, IsNil) + require.NoError(t, err) } err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) - txn, err = s.s.Begin() - c.Assert(err, IsNil) + txn, err = store.Begin() + require.NoError(t, err) it, _ = txn.Iter([]byte("DATA_test_tbl_department_record__000000000"), nil) - c.Assert(it.Valid(), IsFalse) + require.False(t, it.Valid()) err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *testKVSuite) TestSetNil(c *C) { - txn, err := s.s.Begin() - defer txn.Commit(context.Background()) - c.Assert(err, IsNil) +func TestSetNil(t *testing.T) { + t.Parallel() + + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + + txn, err := store.Begin() + defer func() { + require.NoError(t, txn.Commit(context.Background())) + }() + require.NoError(t, err) err = txn.Set([]byte("1"), nil) - c.Assert(err, NotNil) + require.Error(t, err) } -func (s *testKVSuite) TestBasicSeek(c *C) { - txn, err := s.s.Begin() - c.Assert(err, IsNil) - txn.Set([]byte("1"), []byte("1")) +func TestBasicSeek(t *testing.T) { + t.Parallel() + + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + + txn, err := store.Begin() + require.NoError(t, err) + require.NoError(t, txn.Set([]byte("1"), []byte("1"))) err = txn.Commit(context.Background()) - c.Assert(err, IsNil) - txn, err = s.s.Begin() - c.Assert(err, IsNil) - defer txn.Commit(context.Background()) + require.NoError(t, err) + txn, err = store.Begin() + require.NoError(t, err) + defer func() { + require.NoError(t, txn.Commit(context.Background())) + }() it, err := txn.Iter([]byte("2"), nil) - c.Assert(err, IsNil) - c.Assert(it.Valid(), Equals, false) - txn.Delete([]byte("1")) + require.NoError(t, err) + require.False(t, it.Valid()) + require.NoError(t, txn.Delete([]byte("1"))) } -func (s *testKVSuite) TestBasicTable(c *C) { - txn, err := s.s.Begin() - c.Assert(err, IsNil) +func TestBasicTable(t *testing.T) { + t.Parallel() + + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + + txn, err := store.Begin() + require.NoError(t, err) for i := 1; i < 5; i++ { b := []byte(strconv.Itoa(i)) - txn.Set(b, b) + require.NoError(t, txn.Set(b, b)) } err = txn.Commit(context.Background()) - c.Assert(err, IsNil) - txn, err = s.s.Begin() - c.Assert(err, IsNil) - defer txn.Commit(context.Background()) + require.NoError(t, err) + txn, err = store.Begin() + require.NoError(t, err) + defer func() { + require.NoError(t, txn.Commit(context.Background())) + }() err = txn.Set([]byte("1"), []byte("1")) - c.Assert(err, IsNil) + require.NoError(t, err) it, err := txn.Iter([]byte("0"), nil) - c.Assert(err, IsNil) - c.Assert(string(it.Key()), Equals, "1") + require.NoError(t, err) + require.Equal(t, "1", string(it.Key())) err = txn.Set([]byte("0"), []byte("0")) - c.Assert(err, IsNil) + require.NoError(t, err) it, err = txn.Iter([]byte("0"), nil) - c.Assert(err, IsNil) - c.Assert(string(it.Key()), Equals, "0") + require.NoError(t, err) + require.Equal(t, "0", string(it.Key())) err = txn.Delete([]byte("0")) - c.Assert(err, IsNil) + require.NoError(t, err) - txn.Delete([]byte("1")) + require.NoError(t, txn.Delete([]byte("1"))) it, err = txn.Iter([]byte("0"), nil) - c.Assert(err, IsNil) - c.Assert(string(it.Key()), Equals, "2") + require.NoError(t, err) + require.Equal(t, "2", string(it.Key())) err = txn.Delete([]byte("3")) - c.Assert(err, IsNil) + require.NoError(t, err) it, err = txn.Iter([]byte("2"), nil) - c.Assert(err, IsNil) - c.Assert(string(it.Key()), Equals, "2") + require.NoError(t, err) + require.Equal(t, "2", string(it.Key())) it, err = txn.Iter([]byte("3"), nil) - c.Assert(err, IsNil) - c.Assert(string(it.Key()), Equals, "4") + require.NoError(t, err) + require.Equal(t, "4", string(it.Key())) err = txn.Delete([]byte("2")) - c.Assert(err, IsNil) + require.NoError(t, err) err = txn.Delete([]byte("4")) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *testKVSuite) TestRollback(c *C) { - txn, err := s.s.Begin() - c.Assert(err, IsNil) +func TestRollback(t *testing.T) { + t.Parallel() + + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + + txn, err := store.Begin() + require.NoError(t, err) err = txn.Rollback() - c.Assert(err, IsNil) + require.NoError(t, err) - txn, err = s.s.Begin() - c.Assert(err, IsNil) + txn, err = store.Begin() + require.NoError(t, err) - insertData(c, txn) + insertData(t, txn) - mustGet(c, txn) + mustGet(t, txn) err = txn.Rollback() - c.Assert(err, IsNil) + require.NoError(t, err) - txn, err = s.s.Begin() - c.Assert(err, IsNil) - defer txn.Commit(context.Background()) + txn, err = store.Begin() + require.NoError(t, err) + defer func() { + require.NoError(t, txn.Commit(context.Background())) + }() for i := startIndex; i < testCount; i++ { _, err := txn.Get(context.TODO(), []byte(strconv.Itoa(i))) - c.Assert(err, NotNil) + require.Error(t, err) } } -func (s *testKVSuite) TestSeekMin(c *C) { +func TestSeekMin(t *testing.T) { + t.Parallel() + + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + rows := []struct { key string value string @@ -426,28 +488,36 @@ func (s *testKVSuite) TestSeekMin(c *C) { {"DATA_test_main_db_tbl_tbl_test_record__00000000000000000002_0003", "hello"}, } - txn, err := s.s.Begin() - c.Assert(err, IsNil) + txn, err := store.Begin() + require.NoError(t, err) for _, row := range rows { - txn.Set([]byte(row.key), []byte(row.value)) + require.NoError(t, txn.Set([]byte(row.key), []byte(row.value))) } it, err := txn.Iter(nil, nil) - c.Assert(err, IsNil) + require.NoError(t, err) for it.Valid() { - it.Next() + require.NoError(t, it.Next()) } it, err = txn.Iter([]byte("DATA_test_main_db_tbl_tbl_test_record__00000000000000000000"), nil) - c.Assert(err, IsNil) - c.Assert(string(it.Key()), Equals, "DATA_test_main_db_tbl_tbl_test_record__00000000000000000001") + require.NoError(t, err) + require.Equal(t, "DATA_test_main_db_tbl_tbl_test_record__00000000000000000001", string(it.Key())) for _, row := range rows { - txn.Delete([]byte(row.key)) + require.NoError(t, txn.Delete([]byte(row.key))) } } -func (s *testKVSuite) TestConditionIfNotExist(c *C) { +func TestConditionIfNotExist(t *testing.T) { + t.Parallel() + + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + var success int64 cnt := 100 b := []byte("1") @@ -456,8 +526,8 @@ func (s *testKVSuite) TestConditionIfNotExist(c *C) { for i := 0; i < cnt; i++ { go func() { defer wg.Done() - txn, err := s.s.Begin() - c.Assert(err, IsNil) + txn, err := store.Begin() + require.NoError(t, err) err = txn.Set(b, b) if err != nil { return @@ -470,38 +540,46 @@ func (s *testKVSuite) TestConditionIfNotExist(c *C) { } wg.Wait() // At least one txn can success. - c.Assert(success, Greater, int64(0)) + require.Greater(t, success, int64(0)) // Clean up - txn, err := s.s.Begin() - c.Assert(err, IsNil) + txn, err := store.Begin() + require.NoError(t, err) err = txn.Delete(b) - c.Assert(err, IsNil) + require.NoError(t, err) err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *testKVSuite) TestConditionIfEqual(c *C) { +func TestConditionIfEqual(t *testing.T) { + t.Parallel() + + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + var success int64 cnt := 100 b := []byte("1") var wg sync.WaitGroup wg.Add(cnt) - txn, err := s.s.Begin() - c.Assert(err, IsNil) - txn.Set(b, b) + txn, err := store.Begin() + require.NoError(t, err) + require.NoError(t, txn.Set(b, b)) err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) for i := 0; i < cnt; i++ { go func() { defer wg.Done() // Use txn1/err1 instead of txn/err is // to pass `go tool vet -shadow` check. - txn1, err1 := s.s.Begin() - c.Assert(err1, IsNil) - txn1.Set(b, []byte("newValue")) + txn1, err1 := store.Begin() + require.NoError(t, err1) + require.NoError(t, txn1.Set(b, []byte("newValue"))) err1 = txn1.Commit(context.Background()) if err1 == nil { atomic.AddInt64(&success, 1) @@ -509,68 +587,86 @@ func (s *testKVSuite) TestConditionIfEqual(c *C) { }() } wg.Wait() - c.Assert(success, Greater, int64(0)) + require.Greater(t, success, int64(0)) // Clean up - txn, err = s.s.Begin() - c.Assert(err, IsNil) + txn, err = store.Begin() + require.NoError(t, err) err = txn.Delete(b) - c.Assert(err, IsNil) + require.NoError(t, err) err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *testKVSuite) TestConditionUpdate(c *C) { - txn, err := s.s.Begin() - c.Assert(err, IsNil) - txn.Delete([]byte("b")) - kv.IncInt64(txn, []byte("a"), 1) +func TestConditionUpdate(t *testing.T) { + t.Parallel() + + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + + txn, err := store.Begin() + require.NoError(t, err) + require.NoError(t, txn.Delete([]byte("b"))) + _, err = kv.IncInt64(txn, []byte("a"), 1) + require.NoError(t, err) err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *testKVSuite) TestDBClose(c *C) { - c.Skip("don't know why it fails.") +func TestDBClose(t *testing.T) { + t.Skip("don't know why it fails.") + store, err := mockstore.NewMockStore() - c.Assert(err, IsNil) + require.NoError(t, err) txn, err := store.Begin() - c.Assert(err, IsNil) + require.NoError(t, err) err = txn.Set([]byte("a"), []byte("b")) - c.Assert(err, IsNil) + require.NoError(t, err) err = txn.Commit(context.Background()) - c.Assert(err, IsNil) + require.NoError(t, err) ver, err := store.CurrentVersion(kv.GlobalTxnScope) - c.Assert(err, IsNil) - c.Assert(kv.MaxVersion.Cmp(ver), Equals, 1) + require.NoError(t, err) + require.Equal(t, 1, kv.MaxVersion.Cmp(ver), Equals) snap := store.GetSnapshot(kv.MaxVersion) _, err = snap.Get(context.TODO(), []byte("a")) - c.Assert(err, IsNil) + require.NoError(t, err) txn, err = store.Begin() - c.Assert(err, IsNil) + require.NoError(t, err) err = store.Close() - c.Assert(err, IsNil) + require.NoError(t, err) _, err = store.Begin() - c.Assert(err, NotNil) + require.Error(t, err) _ = store.GetSnapshot(kv.MaxVersion) err = txn.Set([]byte("a"), []byte("b")) - c.Assert(err, IsNil) + require.NoError(t, err) err = txn.Commit(context.Background()) - c.Assert(err, NotNil) + require.Error(t, err) } -func (s *testKVSuite) TestIsolationInc(c *C) { +func TestIsolationInc(t *testing.T) { + t.Parallel() + + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + threadCnt := 4 ids := make(map[int64]struct{}, threadCnt*100) @@ -583,18 +679,18 @@ func (s *testKVSuite) TestIsolationInc(c *C) { defer wg.Done() for j := 0; j < 100; j++ { var id int64 - err := kv.RunInNewTxn(context.Background(), s.s, true, func(ctx context.Context, txn kv.Transaction) error { + err := kv.RunInNewTxn(context.Background(), store, true, func(ctx context.Context, txn kv.Transaction) error { var err1 error id, err1 = kv.IncInt64(txn, []byte("key"), 1) return err1 }) - c.Assert(err, IsNil) + require.NoError(t, err) m.Lock() _, ok := ids[id] ids[id] = struct{}{} m.Unlock() - c.Assert(ok, IsFalse) + require.False(t, ok) } }() } @@ -602,13 +698,23 @@ func (s *testKVSuite) TestIsolationInc(c *C) { wg.Wait() // delete - txn, err := s.s.Begin() - c.Assert(err, IsNil) - defer txn.Commit(context.Background()) - txn.Delete([]byte("key")) + txn, err := store.Begin() + require.NoError(t, err) + defer func() { + require.NoError(t, txn.Commit(context.Background())) + }() + require.NoError(t, txn.Delete([]byte("key"))) } -func (s *testKVSuite) TestIsolationMultiInc(c *C) { +func TestIsolationMultiInc(t *testing.T) { + t.Parallel() + + store, err := mockstore.NewMockStore() + require.NoError(t, err) + defer func() { + require.NoError(t, store.Close()) + }() + threadCnt := 4 incCnt := 100 keyCnt := 4 @@ -625,7 +731,7 @@ func (s *testKVSuite) TestIsolationMultiInc(c *C) { go func() { defer wg.Done() for j := 0; j < incCnt; j++ { - err := kv.RunInNewTxn(context.Background(), s.s, true, func(ctx context.Context, txn kv.Transaction) error { + err := kv.RunInNewTxn(context.Background(), store, true, func(ctx context.Context, txn kv.Transaction) error { for _, key := range keys { _, err1 := kv.IncInt64(txn, key, 1) if err1 != nil { @@ -635,51 +741,58 @@ func (s *testKVSuite) TestIsolationMultiInc(c *C) { return nil }) - c.Assert(err, IsNil) + require.NoError(t, err) } }() } wg.Wait() - err := kv.RunInNewTxn(context.Background(), s.s, false, func(ctx context.Context, txn kv.Transaction) error { + err = kv.RunInNewTxn(context.Background(), store, false, func(ctx context.Context, txn kv.Transaction) error { for _, key := range keys { id, err1 := kv.GetInt64(context.TODO(), txn, key) if err1 != nil { return err1 } - c.Assert(id, Equals, int64(threadCnt*incCnt)) - txn.Delete(key) + require.Equal(t, int64(threadCnt*incCnt), id) + require.NoError(t, txn.Delete(key)) } return nil }) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *testKVSuite) TestRetryOpenStore(c *C) { +func TestRetryOpenStore(t *testing.T) { + t.Parallel() begin := time.Now() - Register("dummy", &brokenStore{}) + require.NoError(t, Register("dummy", &brokenStore{})) store, err := newStoreWithRetry("dummy://dummy-store", 3) if store != nil { - defer store.Close() + defer func() { + require.NoError(t, store.Close()) + }() } - c.Assert(err, NotNil) + require.Error(t, err) elapse := time.Since(begin) - c.Assert(uint64(elapse), GreaterEqual, uint64(3*time.Second), Commentf("elapse: %s", elapse)) + require.GreaterOrEqual(t, uint64(elapse), uint64(3*time.Second)) } -func (s *testKVSuite) TestOpenStore(c *C) { - Register("open", &brokenStore{}) +func TestOpenStore(t *testing.T) { + t.Parallel() + require.NoError(t, Register("open", &brokenStore{})) store, err := newStoreWithRetry(":", 3) if store != nil { - defer store.Close() + defer func() { + require.NoError(t, store.Close()) + }() } - c.Assert(err, NotNil) + require.Error(t, err) } -func (s *testKVSuite) TestRegister(c *C) { +func TestRegister(t *testing.T) { + t.Parallel() err := Register("retry", &brokenStore{}) - c.Assert(err, IsNil) + require.NoError(t, err) err = Register("retry", &brokenStore{}) - c.Assert(err, NotNil) + require.Error(t, err) } diff --git a/testkit/dbtestkit.go b/testkit/dbtestkit.go new file mode 100644 index 0000000000000..b0039098a48ae --- /dev/null +++ b/testkit/dbtestkit.go @@ -0,0 +1,60 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !codes + +package testkit + +import ( + "database/sql" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// DBTestKit is a utility to run sql with a db connection. +type DBTestKit struct { + require *require.Assertions + assert *assert.Assertions + db *sql.DB +} + +// NewDBTestKit returns a new *DBTestKit. +func NewDBTestKit(t *testing.T, db *sql.DB) *DBTestKit { + return &DBTestKit{ + require: require.New(t), + assert: assert.New(t), + db: db, + } +} + +// MustExec query the statements and returns the result. +func (tk *DBTestKit) MustExec(sql string, args ...interface{}) sql.Result { + comment := fmt.Sprintf("sql:%s, args:%v", sql, args) + rs, err := tk.db.Exec(sql, args...) + tk.require.NoError(err, comment) + tk.require.NotNil(rs, comment) + return rs +} + +// MustQuery query the statements and returns result rows. +func (tk *DBTestKit) MustQuery(sql string, args ...interface{}) *sql.Rows { + comment := fmt.Sprintf("sql:%s, args:%v", sql, args) + rows, err := tk.db.Query(sql, args...) + tk.require.NoError(err, comment) + tk.require.NotNil(rows, comment) + return rows +} diff --git a/tools/check/check-timeout.go b/tools/check/check-timeout.go index c1857f3563b5f..5c1b9a73fd377 100644 --- a/tools/check/check-timeout.go +++ b/tools/check/check-timeout.go @@ -87,9 +87,9 @@ func init() { "testPessimisticSuite.TestPessimisticLockNonExistsKey", "testPessimisticSuite.TestSelectForUpdateNoWait", "testSessionSerialSuite.TestProcessInfoIssue22068", - "testKVSuite.TestRetryOpenStore", - "testBatchCopSuite.TestStoreErr", - "testBatchCopSuite.TestStoreSwitchPeer", + "TestRetryOpenStore", + "TestStoreErr", + "TestStoreSwitchPeer", "testSequenceSuite.TestSequenceFunction", "testSuiteP2.TestUnion", "testVectorizeSuite1.TestVectorizedBuiltinTimeFuncGenerated", diff --git a/types/binary_literal_test.go b/types/binary_literal_test.go index 2e50e48d38185..81b9fde6767a5 100644 --- a/types/binary_literal_test.go +++ b/types/binary_literal_test.go @@ -15,271 +15,280 @@ package types import ( - . "github.com/pingcap/check" + "fmt" + "testing" + "github.com/pingcap/tidb/sessionctx/stmtctx" - "github.com/pingcap/tidb/util/testleak" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testBinaryLiteralSuite{}) +func TestBinaryLiteral(t *testing.T) { + t.Run("TestTrimLeadingZeroBytes", func(t *testing.T) { + t.Parallel() + tbl := []struct { + Input []byte + Expected []byte + }{ + {[]byte{}, []byte{}}, + {[]byte{0x0}, []byte{0x0}}, + {[]byte{0x1}, []byte{0x1}}, + {[]byte{0x1, 0x0}, []byte{0x1, 0x0}}, + {[]byte{0x0, 0x1}, []byte{0x1}}, + {[]byte{0x0, 0x0, 0x0}, []byte{0x0}}, + {[]byte{0x1, 0x0, 0x0}, []byte{0x1, 0x0, 0x0}}, + {[]byte{0x0, 0x1, 0x0, 0x0, 0x1, 0x0, 0x0}, []byte{0x1, 0x0, 0x0, 0x1, 0x0, 0x0}}, + {[]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x1, 0x0, 0x0}, []byte{0x1, 0x0, 0x0, 0x1, 0x0, 0x0}}, + } + for _, item := range tbl { + b := trimLeadingZeroBytes(item.Input) + require.Equal(t, item.Expected, b, fmt.Sprintf("%#v", item)) + } + }) -type testBinaryLiteralSuite struct { -} + t.Run("TestParseBitStr", func(t *testing.T) { + t.Parallel() + tbl := []struct { + Input string + Expected []byte + IsError bool + }{ + {"b''", []byte{}, false}, + {"B''", []byte{}, false}, + {"0b''", nil, true}, + {"0b0", []byte{0x0}, false}, + {"b'0'", []byte{0x0}, false}, + {"B'0'", []byte{0x0}, false}, + {"0B0", nil, true}, + {"0b123", nil, true}, + {"b'123'", nil, true}, + {"0b'1010'", nil, true}, + {"0b0000000", []byte{0x0}, false}, + {"b'0000000'", []byte{0x0}, false}, + {"B'0000000'", []byte{0x0}, false}, + {"0b00000000", []byte{0x0}, false}, + {"b'00000000'", []byte{0x0}, false}, + {"B'00000000'", []byte{0x0}, false}, + {"0b000000000", []byte{0x0, 0x0}, false}, + {"b'000000000'", []byte{0x0, 0x0}, false}, + {"B'000000000'", []byte{0x0, 0x0}, false}, + {"0b1", []byte{0x1}, false}, + {"b'1'", []byte{0x1}, false}, + {"B'1'", []byte{0x1}, false}, + {"0b00000001", []byte{0x1}, false}, + {"b'00000001'", []byte{0x1}, false}, + {"B'00000001'", []byte{0x1}, false}, + {"0b000000010", []byte{0x0, 0x2}, false}, + {"b'000000010'", []byte{0x0, 0x2}, false}, + {"B'000000010'", []byte{0x0, 0x2}, false}, + {"0b000000001", []byte{0x0, 0x1}, false}, + {"b'000000001'", []byte{0x0, 0x1}, false}, + {"B'000000001'", []byte{0x0, 0x1}, false}, + {"0b11111111", []byte{0xFF}, false}, + {"b'11111111'", []byte{0xFF}, false}, + {"B'11111111'", []byte{0xFF}, false}, + {"0b111111111", []byte{0x1, 0xFF}, false}, + {"b'111111111'", []byte{0x1, 0xFF}, false}, + {"B'111111111'", []byte{0x1, 0xFF}, false}, + {"0b1101000011001010110110001101100011011110010000001110111011011110111001001101100011001000010000001100110011011110110111100100000011000100110000101110010", []byte("hello world foo bar"), false}, + {"b'1101000011001010110110001101100011011110010000001110111011011110111001001101100011001000010000001100110011011110110111100100000011000100110000101110010'", []byte("hello world foo bar"), false}, + {"B'1101000011001010110110001101100011011110010000001110111011011110111001001101100011001000010000001100110011011110110111100100000011000100110000101110010'", []byte("hello world foo bar"), false}, + {"0b01101000011001010110110001101100011011110010000001110111011011110111001001101100011001000010000001100110011011110110111100100000011000100110000101110010", []byte("hello world foo bar"), false}, + {"b'01101000011001010110110001101100011011110010000001110111011011110111001001101100011001000010000001100110011011110110111100100000011000100110000101110010'", []byte("hello world foo bar"), false}, + {"B'01101000011001010110110001101100011011110010000001110111011011110111001001101100011001000010000001100110011011110110111100100000011000100110000101110010'", []byte("hello world foo bar"), false}, + } + for _, item := range tbl { + b, err := ParseBitStr(item.Input) + if item.IsError { + require.Error(t, err, fmt.Sprintf("%#v", item)) + } else { + require.NoError(t, err, fmt.Sprintf("%#v", item)) + require.Equal(t, item.Expected, []byte(b), fmt.Sprintf("%#v", item)) + } + } + }) -func (s *testBinaryLiteralSuite) TestTrimLeadingZeroBytes(c *C) { - defer testleak.AfterTest(c)() - tbl := []struct { - Input []byte - Expected []byte - }{ - {[]byte{}, []byte{}}, - {[]byte{0x0}, []byte{0x0}}, - {[]byte{0x1}, []byte{0x1}}, - {[]byte{0x1, 0x0}, []byte{0x1, 0x0}}, - {[]byte{0x0, 0x1}, []byte{0x1}}, - {[]byte{0x0, 0x0, 0x0}, []byte{0x0}}, - {[]byte{0x1, 0x0, 0x0}, []byte{0x1, 0x0, 0x0}}, - {[]byte{0x0, 0x1, 0x0, 0x0, 0x1, 0x0, 0x0}, []byte{0x1, 0x0, 0x0, 0x1, 0x0, 0x0}}, - {[]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x1, 0x0, 0x0}, []byte{0x1, 0x0, 0x0, 0x1, 0x0, 0x0}}, - } - for _, t := range tbl { - b := trimLeadingZeroBytes(t.Input) - c.Assert(b, DeepEquals, t.Expected, Commentf("%#v", t)) - } -} + t.Run("TestParseBitStr", func(t *testing.T) { + t.Parallel() + b, err := ParseBitStr("") + require.Nil(t, b) + require.Contains(t, err.Error(), "invalid empty ") + }) -func (s *testBinaryLiteralSuite) TestParseBitStr(c *C) { - defer testleak.AfterTest(c)() - tbl := []struct { - Input string - Expected []byte - IsError bool - }{ - {"b''", []byte{}, false}, - {"B''", []byte{}, false}, - {"0b''", nil, true}, - {"0b0", []byte{0x0}, false}, - {"b'0'", []byte{0x0}, false}, - {"B'0'", []byte{0x0}, false}, - {"0B0", nil, true}, - {"0b123", nil, true}, - {"b'123'", nil, true}, - {"0b'1010'", nil, true}, - {"0b0000000", []byte{0x0}, false}, - {"b'0000000'", []byte{0x0}, false}, - {"B'0000000'", []byte{0x0}, false}, - {"0b00000000", []byte{0x0}, false}, - {"b'00000000'", []byte{0x0}, false}, - {"B'00000000'", []byte{0x0}, false}, - {"0b000000000", []byte{0x0, 0x0}, false}, - {"b'000000000'", []byte{0x0, 0x0}, false}, - {"B'000000000'", []byte{0x0, 0x0}, false}, - {"0b1", []byte{0x1}, false}, - {"b'1'", []byte{0x1}, false}, - {"B'1'", []byte{0x1}, false}, - {"0b00000001", []byte{0x1}, false}, - {"b'00000001'", []byte{0x1}, false}, - {"B'00000001'", []byte{0x1}, false}, - {"0b000000010", []byte{0x0, 0x2}, false}, - {"b'000000010'", []byte{0x0, 0x2}, false}, - {"B'000000010'", []byte{0x0, 0x2}, false}, - {"0b000000001", []byte{0x0, 0x1}, false}, - {"b'000000001'", []byte{0x0, 0x1}, false}, - {"B'000000001'", []byte{0x0, 0x1}, false}, - {"0b11111111", []byte{0xFF}, false}, - {"b'11111111'", []byte{0xFF}, false}, - {"B'11111111'", []byte{0xFF}, false}, - {"0b111111111", []byte{0x1, 0xFF}, false}, - {"b'111111111'", []byte{0x1, 0xFF}, false}, - {"B'111111111'", []byte{0x1, 0xFF}, false}, - {"0b1101000011001010110110001101100011011110010000001110111011011110111001001101100011001000010000001100110011011110110111100100000011000100110000101110010", []byte("hello world foo bar"), false}, - {"b'1101000011001010110110001101100011011110010000001110111011011110111001001101100011001000010000001100110011011110110111100100000011000100110000101110010'", []byte("hello world foo bar"), false}, - {"B'1101000011001010110110001101100011011110010000001110111011011110111001001101100011001000010000001100110011011110110111100100000011000100110000101110010'", []byte("hello world foo bar"), false}, - {"0b01101000011001010110110001101100011011110010000001110111011011110111001001101100011001000010000001100110011011110110111100100000011000100110000101110010", []byte("hello world foo bar"), false}, - {"b'01101000011001010110110001101100011011110010000001110111011011110111001001101100011001000010000001100110011011110110111100100000011000100110000101110010'", []byte("hello world foo bar"), false}, - {"B'01101000011001010110110001101100011011110010000001110111011011110111001001101100011001000010000001100110011011110110111100100000011000100110000101110010'", []byte("hello world foo bar"), false}, - } - for _, t := range tbl { - b, err := ParseBitStr(t.Input) - if t.IsError { - c.Assert(err, NotNil, Commentf("%#v", t)) - } else { - c.Assert(err, IsNil, Commentf("%#v", t)) - c.Assert([]byte(b), DeepEquals, t.Expected, Commentf("%#v", t)) + t.Run("TestParseHexStr", func(t *testing.T) { + t.Parallel() + tbl := []struct { + Input string + Expected []byte + IsError bool + }{ + {"x'1'", nil, true}, + {"x'01'", []byte{0x1}, false}, + {"X'01'", []byte{0x1}, false}, + {"0x1", []byte{0x1}, false}, + {"0x-1", nil, true}, + {"0X11", nil, true}, + {"x'01+'", nil, true}, + {"0x123", []byte{0x01, 0x23}, false}, + {"0x10", []byte{0x10}, false}, + {"0x4D7953514C", []byte("MySQL"), false}, + {"0x4920616D2061206C6F6E672068657820737472696E67", []byte("I am a long hex string"), false}, + {"x'4920616D2061206C6F6E672068657820737472696E67'", []byte("I am a long hex string"), false}, + {"X'4920616D2061206C6F6E672068657820737472696E67'", []byte("I am a long hex string"), false}, + {"x''", []byte{}, false}, } - } - b, err := ParseBitStr("") - c.Assert(b, IsNil) - c.Assert(err, ErrorMatches, "invalid empty .*") -} - -func (s *testBinaryLiteralSuite) TestParseHexStr(c *C) { - defer testleak.AfterTest(c)() - tbl := []struct { - Input string - Expected []byte - IsError bool - }{ - {"x'1'", nil, true}, - {"x'01'", []byte{0x1}, false}, - {"X'01'", []byte{0x1}, false}, - {"0x1", []byte{0x1}, false}, - {"0x-1", nil, true}, - {"0X11", nil, true}, - {"x'01+'", nil, true}, - {"0x123", []byte{0x01, 0x23}, false}, - {"0x10", []byte{0x10}, false}, - {"0x4D7953514C", []byte("MySQL"), false}, - {"0x4920616D2061206C6F6E672068657820737472696E67", []byte("I am a long hex string"), false}, - {"x'4920616D2061206C6F6E672068657820737472696E67'", []byte("I am a long hex string"), false}, - {"X'4920616D2061206C6F6E672068657820737472696E67'", []byte("I am a long hex string"), false}, - {"x''", []byte{}, false}, - } - for _, t := range tbl { - hex, err := ParseHexStr(t.Input) - if t.IsError { - c.Assert(err, NotNil, Commentf("%#v", t)) - } else { - c.Assert(err, IsNil, Commentf("%#v", t)) - c.Assert([]byte(hex), DeepEquals, t.Expected, Commentf("%#v", t)) + for _, item := range tbl { + hex, err := ParseHexStr(item.Input) + if item.IsError { + require.Error(t, err, fmt.Sprintf("%#v", item)) + } else { + require.NoError(t, err, fmt.Sprintf("%#v", item)) + require.Equal(t, item.Expected, []byte(hex), fmt.Sprintf("%#v", item)) + } } - } - hex, err := ParseHexStr("") - c.Assert(hex, IsNil) - c.Assert(err, ErrorMatches, "invalid empty .*") -} + }) -func (s *testBinaryLiteralSuite) TestString(c *C) { - defer testleak.AfterTest(c)() - tbl := []struct { - Input BinaryLiteral - Expected string - }{ - {BinaryLiteral{}, ""}, // Expected - {BinaryLiteral{0x0}, "0x00"}, - {BinaryLiteral{0x1}, "0x01"}, - {BinaryLiteral{0xff, 0x01}, "0xff01"}, - } - for _, t := range tbl { - str := t.Input.String() - c.Assert(str, Equals, t.Expected) - } -} + t.Run("TestParseHexStr", func(t *testing.T) { + t.Parallel() + b, err := ParseBitStr("") + require.Nil(t, b) + require.Contains(t, err.Error(), "invalid empty ") + }) -func (s *testBinaryLiteralSuite) TestToBitLiteralString(c *C) { - defer testleak.AfterTest(c)() - tbl := []struct { - Input BinaryLiteral - TrimLeadingZero bool - Expected string - }{ - {BinaryLiteral{}, true, "b''"}, - {BinaryLiteral{}, false, "b''"}, - {BinaryLiteral{0x0}, true, "b'0'"}, - {BinaryLiteral{0x0}, false, "b'00000000'"}, - {BinaryLiteral{0x0, 0x0}, true, "b'0'"}, - {BinaryLiteral{0x0, 0x0}, false, "b'0000000000000000'"}, - {BinaryLiteral{0x1}, true, "b'1'"}, - {BinaryLiteral{0x1}, false, "b'00000001'"}, - {BinaryLiteral{0xff, 0x01}, true, "b'1111111100000001'"}, - {BinaryLiteral{0xff, 0x01}, false, "b'1111111100000001'"}, - {BinaryLiteral{0x0, 0xff, 0x01}, true, "b'1111111100000001'"}, - {BinaryLiteral{0x0, 0xff, 0x01}, false, "b'000000001111111100000001'"}, - } - for _, t := range tbl { - str := t.Input.ToBitLiteralString(t.TrimLeadingZero) - c.Assert(str, Equals, t.Expected) - } -} + t.Run("TestString", func(t *testing.T) { + t.Parallel() + tbl := []struct { + Input BinaryLiteral + Expected string + }{ + {BinaryLiteral{}, ""}, // Expected + {BinaryLiteral{0x0}, "0x00"}, + {BinaryLiteral{0x1}, "0x01"}, + {BinaryLiteral{0xff, 0x01}, "0xff01"}, + } + for _, item := range tbl { + str := item.Input.String() + require.Equal(t, str, item.Expected) + } + }) -func (s *testBinaryLiteralSuite) TestToInt(c *C) { - defer testleak.AfterTest(c)() - tbl := []struct { - Input string - Expected uint64 - HasError bool - }{ - {"x''", 0, false}, - {"0x00", 0x0, false}, - {"0xff", 0xff, false}, - {"0x10ff", 0x10ff, false}, - {"0x1010ffff", 0x1010ffff, false}, - {"0x1010ffff8080", 0x1010ffff8080, false}, - {"0x1010ffff8080ff12", 0x1010ffff8080ff12, false}, - {"0x1010ffff8080ff12ff", 0xffffffffffffffff, true}, - } - sc := new(stmtctx.StatementContext) - for _, t := range tbl { - hex, err := ParseHexStr(t.Input) - c.Assert(err, IsNil) - intValue, err := hex.ToInt(sc) - if t.HasError { - c.Assert(err, NotNil) - } else { - c.Assert(err, IsNil) + t.Run("TestToBitLiteralString", func(t *testing.T) { + t.Parallel() + tbl := []struct { + Input BinaryLiteral + TrimLeadingZero bool + Expected string + }{ + {BinaryLiteral{}, true, "b''"}, + {BinaryLiteral{}, false, "b''"}, + {BinaryLiteral{0x0}, true, "b'0'"}, + {BinaryLiteral{0x0}, false, "b'00000000'"}, + {BinaryLiteral{0x0, 0x0}, true, "b'0'"}, + {BinaryLiteral{0x0, 0x0}, false, "b'0000000000000000'"}, + {BinaryLiteral{0x1}, true, "b'1'"}, + {BinaryLiteral{0x1}, false, "b'00000001'"}, + {BinaryLiteral{0xff, 0x01}, true, "b'1111111100000001'"}, + {BinaryLiteral{0xff, 0x01}, false, "b'1111111100000001'"}, + {BinaryLiteral{0x0, 0xff, 0x01}, true, "b'1111111100000001'"}, + {BinaryLiteral{0x0, 0xff, 0x01}, false, "b'000000001111111100000001'"}, } - c.Assert(intValue, Equals, t.Expected) - } -} + for _, item := range tbl { + str := item.Input.ToBitLiteralString(item.TrimLeadingZero) + require.Equal(t, item.Expected, str) + } + }) -func (s *testBinaryLiteralSuite) TestNewBinaryLiteralFromUint(c *C) { - defer testleak.AfterTest(c)() - tbl := []struct { - Input uint64 - ByteSize int - Expected []byte - }{ - {0x0, -1, []byte{0x0}}, - {0x0, 1, []byte{0x0}}, - {0x0, 2, []byte{0x0, 0x0}}, - {0x1, -1, []byte{0x1}}, - {0x1, 1, []byte{0x1}}, - {0x1, 2, []byte{0x0, 0x1}}, - {0x1, 3, []byte{0x0, 0x0, 0x1}}, - {0x10, -1, []byte{0x10}}, - {0x123, -1, []byte{0x1, 0x23}}, - {0x123, 2, []byte{0x1, 0x23}}, - {0x123, 1, []byte{0x23}}, - {0x123, 5, []byte{0x0, 0x0, 0x0, 0x1, 0x23}}, - {0x4D7953514C, -1, []byte{0x4D, 0x79, 0x53, 0x51, 0x4C}}, - {0x4D7953514C, 8, []byte{0x0, 0x0, 0x0, 0x4D, 0x79, 0x53, 0x51, 0x4C}}, - {0x4920616D2061206C, -1, []byte{0x49, 0x20, 0x61, 0x6D, 0x20, 0x61, 0x20, 0x6C}}, - {0x4920616D2061206C, 8, []byte{0x49, 0x20, 0x61, 0x6D, 0x20, 0x61, 0x20, 0x6C}}, - {0x4920616D2061206C, 5, []byte{0x6D, 0x20, 0x61, 0x20, 0x6C}}, - } - for _, t := range tbl { - hex := NewBinaryLiteralFromUint(t.Input, t.ByteSize) - c.Assert([]byte(hex), DeepEquals, t.Expected, Commentf("%#v", t)) - } + t.Run("TestToInt", func(t *testing.T) { + t.Parallel() + tbl := []struct { + Input string + Expected uint64 + HasError bool + }{ + {"x''", 0, false}, + {"0x00", 0x0, false}, + {"0xff", 0xff, false}, + {"0x10ff", 0x10ff, false}, + {"0x1010ffff", 0x1010ffff, false}, + {"0x1010ffff8080", 0x1010ffff8080, false}, + {"0x1010ffff8080ff12", 0x1010ffff8080ff12, false}, + {"0x1010ffff8080ff12ff", 0xffffffffffffffff, true}, + } + sc := new(stmtctx.StatementContext) + for _, item := range tbl { + hex, err := ParseHexStr(item.Input) + require.NoError(t, err) + intValue, err := hex.ToInt(sc) + if item.HasError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, item.Expected, intValue) + } + }) - defer func() { - r := recover() - c.Assert(r, NotNil) - }() - NewBinaryLiteralFromUint(0x123, -2) -} + t.Run("TestNewBinaryLiteralFromUint", func(t *testing.T) { + t.Parallel() + tbl := []struct { + Input uint64 + ByteSize int + Expected []byte + }{ + {0x0, -1, []byte{0x0}}, + {0x0, 1, []byte{0x0}}, + {0x0, 2, []byte{0x0, 0x0}}, + {0x1, -1, []byte{0x1}}, + {0x1, 1, []byte{0x1}}, + {0x1, 2, []byte{0x0, 0x1}}, + {0x1, 3, []byte{0x0, 0x0, 0x1}}, + {0x10, -1, []byte{0x10}}, + {0x123, -1, []byte{0x1, 0x23}}, + {0x123, 2, []byte{0x1, 0x23}}, + {0x123, 1, []byte{0x23}}, + {0x123, 5, []byte{0x0, 0x0, 0x0, 0x1, 0x23}}, + {0x4D7953514C, -1, []byte{0x4D, 0x79, 0x53, 0x51, 0x4C}}, + {0x4D7953514C, 8, []byte{0x0, 0x0, 0x0, 0x4D, 0x79, 0x53, 0x51, 0x4C}}, + {0x4920616D2061206C, -1, []byte{0x49, 0x20, 0x61, 0x6D, 0x20, 0x61, 0x20, 0x6C}}, + {0x4920616D2061206C, 8, []byte{0x49, 0x20, 0x61, 0x6D, 0x20, 0x61, 0x20, 0x6C}}, + {0x4920616D2061206C, 5, []byte{0x6D, 0x20, 0x61, 0x20, 0x6C}}, + } + for _, item := range tbl { + hex := NewBinaryLiteralFromUint(item.Input, item.ByteSize) + require.Equal(t, item.Expected, []byte(hex), fmt.Sprintf("%#v", item)) + } -func (s *testBinaryLiteralSuite) TestCompare(c *C) { - tbl := []struct { - a BinaryLiteral - b BinaryLiteral - cmp int - }{ - {BinaryLiteral{0, 0, 1}, BinaryLiteral{2}, -1}, - {BinaryLiteral{0, 1}, BinaryLiteral{0, 0, 2}, -1}, - {BinaryLiteral{0, 1}, BinaryLiteral{1}, 0}, - {BinaryLiteral{0, 2, 1}, BinaryLiteral{1, 2}, 1}, - } - for _, t := range tbl { - c.Assert(t.a.Compare(t.b), Equals, t.cmp) - } -} + defer func() { + r := recover() + require.NotNil(t, r) + }() + NewBinaryLiteralFromUint(0x123, -2) + }) + + t.Run("TestCompare", func(t *testing.T) { + t.Parallel() + tbl := []struct { + a BinaryLiteral + b BinaryLiteral + cmp int + }{ + {BinaryLiteral{0, 0, 1}, BinaryLiteral{2}, -1}, + {BinaryLiteral{0, 1}, BinaryLiteral{0, 0, 2}, -1}, + {BinaryLiteral{0, 1}, BinaryLiteral{1}, 0}, + {BinaryLiteral{0, 2, 1}, BinaryLiteral{1, 2}, 1}, + } + for _, item := range tbl { + require.Equal(t, item.cmp, item.a.Compare(item.b)) + } + }) -func (s *testBinaryLiteralSuite) TestToString(c *C) { - h, _ := NewHexLiteral("x'3A3B'") - str := h.ToString() - c.Assert(str, Equals, ":;") + t.Run("TestToString", func(t *testing.T) { + t.Parallel() + h, _ := NewHexLiteral("x'3A3B'") + str := h.ToString() + require.Equal(t, str, ":;") - b, _ := NewBitLiteral("b'00101011'") - str = b.ToString() - c.Assert(str, Equals, "+") + b, _ := NewBitLiteral("b'00101011'") + str = b.ToString() + require.Equal(t, "+", str) + }) } diff --git a/types/errors_test.go b/types/errors_test.go index 93a46618e4f94..a9da8300ee7e7 100644 --- a/types/errors_test.go +++ b/types/errors_test.go @@ -15,16 +15,15 @@ package types import ( - . "github.com/pingcap/check" - "github.com/pingcap/parser/mysql" + "testing" + "github.com/pingcap/parser/terror" + "github.com/stretchr/testify/require" ) -type testErrorSuite struct{} - -var _ = Suite(testErrorSuite{}) +func TestError(t *testing.T) { + t.Parallel() -func (s testErrorSuite) TestError(c *C) { kvErrs := []*terror.Error{ ErrInvalidDefault, ErrDataTooLong, @@ -50,8 +49,9 @@ func (s testErrorSuite) TestError(c *C) { ErrInvalidWeekModeFormat, ErrWrongValue, } + for _, err := range kvErrs { code := terror.ToSQLError(err).Code - c.Assert(code != mysql.ErrUnknown && code == uint16(err.Code()), IsTrue, Commentf("err: %v", err)) + require.Equalf(t, code, uint16(err.Code()), "err: %v", err) } } diff --git a/types/overflow_test.go b/types/overflow_test.go index dadb4848a57b8..f203d145d46fa 100644 --- a/types/overflow_test.go +++ b/types/overflow_test.go @@ -16,19 +16,15 @@ package types import ( "math" + "testing" "time" - . "github.com/pingcap/check" - "github.com/pingcap/tidb/util/testleak" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testOverflowSuite{}) +func TestAdd(t *testing.T) { + t.Parallel() -type testOverflowSuite struct { -} - -func (s *testOverflowSuite) TestAdd(c *C) { - defer testleak.AfterTest(c)() tblUint64 := []struct { lsh uint64 rsh uint64 @@ -40,12 +36,12 @@ func (s *testOverflowSuite) TestAdd(c *C) { {1, 1, 2, false}, } - for _, t := range tblUint64 { - ret, err := AddUint64(t.lsh, t.rsh) - if t.overflow { - c.Assert(err, NotNil) + for _, tt := range tblUint64 { + ret, err := AddUint64(tt.lsh, tt.rsh) + if tt.overflow { + require.Error(t, err) } else { - c.Assert(ret, Equals, t.ret) + require.Equal(t, tt.ret, ret) } } @@ -64,18 +60,18 @@ func (s *testOverflowSuite) TestAdd(c *C) { {1, -1, 0, false}, } - for _, t := range tblInt64 { - ret, err := AddInt64(t.lsh, t.rsh) - if t.overflow { - c.Assert(err, NotNil) + for _, tt := range tblInt64 { + ret, err := AddInt64(tt.lsh, tt.rsh) + if tt.overflow { + require.Error(t, err) } else { - c.Assert(ret, Equals, t.ret) + require.Equal(t, tt.ret, ret) } - ret2, err := AddDuration(time.Duration(t.lsh), time.Duration(t.rsh)) - if t.overflow { - c.Assert(err, NotNil) + ret2, err := AddDuration(time.Duration(tt.lsh), time.Duration(tt.rsh)) + if tt.overflow { + require.Error(t, err) } else { - c.Assert(ret2, Equals, time.Duration(t.ret)) + require.Equal(t, time.Duration(tt.ret), ret2) } } @@ -93,18 +89,19 @@ func (s *testOverflowSuite) TestAdd(c *C) { {1, 1, 2, false}, } - for _, t := range tblInt { - ret, err := AddInteger(t.lsh, t.rsh) - if t.overflow { - c.Assert(err, NotNil) + for _, tt := range tblInt { + ret, err := AddInteger(tt.lsh, tt.rsh) + if tt.overflow { + require.Error(t, err) } else { - c.Assert(ret, Equals, t.ret) + require.Equal(t, tt.ret, ret) } } } -func (s *testOverflowSuite) TestSub(c *C) { - defer testleak.AfterTest(c)() +func TestSub(t *testing.T) { + t.Parallel() + tblUint64 := []struct { lsh uint64 rsh uint64 @@ -119,12 +116,12 @@ func (s *testOverflowSuite) TestSub(c *C) { {1, 1, 0, false}, } - for _, t := range tblUint64 { - ret, err := SubUint64(t.lsh, t.rsh) - if t.overflow { - c.Assert(err, NotNil) + for _, tt := range tblUint64 { + ret, err := SubUint64(tt.lsh, tt.rsh) + if tt.overflow { + require.Error(t, err) } else { - c.Assert(ret, Equals, t.ret) + require.Equal(t, tt.ret, ret) } } @@ -145,12 +142,12 @@ func (s *testOverflowSuite) TestSub(c *C) { {1, 1, 0, false}, } - for _, t := range tblInt64 { - ret, err := SubInt64(t.lsh, t.rsh) - if t.overflow { - c.Assert(err, NotNil) + for _, tt := range tblInt64 { + ret, err := SubInt64(tt.lsh, tt.rsh) + if tt.overflow { + require.Error(t, err) } else { - c.Assert(ret, Equals, t.ret) + require.Equal(t, tt.ret, ret) } } @@ -169,12 +166,12 @@ func (s *testOverflowSuite) TestSub(c *C) { {1, 1, 0, false}, } - for _, t := range tblInt { - ret, err := SubUintWithInt(t.lsh, t.rsh) - if t.overflow { - c.Assert(err, NotNil) + for _, tt := range tblInt { + ret, err := SubUintWithInt(tt.lsh, tt.rsh) + if tt.overflow { + require.Error(t, err) } else { - c.Assert(ret, Equals, t.ret) + require.Equal(t, tt.ret, ret) } } @@ -192,18 +189,19 @@ func (s *testOverflowSuite) TestSub(c *C) { {1, 1, 0, false}, } - for _, t := range tblInt2 { - ret, err := SubIntWithUint(t.lsh, t.rsh) - if t.overflow { - c.Assert(err, NotNil) + for _, tt := range tblInt2 { + ret, err := SubIntWithUint(tt.lsh, tt.rsh) + if tt.overflow { + require.Error(t, err) } else { - c.Assert(ret, Equals, t.ret) + require.Equal(t, tt.ret, ret) } } } -func (s *testOverflowSuite) TestMul(c *C) { - defer testleak.AfterTest(c)() +func TestMul(t *testing.T) { + t.Parallel() + tblUint64 := []struct { lsh uint64 rsh uint64 @@ -216,12 +214,12 @@ func (s *testOverflowSuite) TestMul(c *C) { {1, 1, 1, false}, } - for _, t := range tblUint64 { - ret, err := MulUint64(t.lsh, t.rsh) - if t.overflow { - c.Assert(err, NotNil) + for _, tt := range tblUint64 { + ret, err := MulUint64(tt.lsh, tt.rsh) + if tt.overflow { + require.Error(t, err) } else { - c.Assert(ret, Equals, t.ret) + require.Equal(t, tt.ret, ret) } } @@ -243,12 +241,12 @@ func (s *testOverflowSuite) TestMul(c *C) { {1, 1, 1, false}, } - for _, t := range tblInt64 { - ret, err := MulInt64(t.lsh, t.rsh) - if t.overflow { - c.Assert(err, NotNil) + for _, tt := range tblInt64 { + ret, err := MulInt64(tt.lsh, tt.rsh) + if tt.overflow { + require.Error(t, err) } else { - c.Assert(ret, Equals, t.ret) + require.Equal(t, tt.ret, ret) } } @@ -266,18 +264,19 @@ func (s *testOverflowSuite) TestMul(c *C) { {1, 1, 1, false}, } - for _, t := range tblInt { - ret, err := MulInteger(t.lsh, t.rsh) - if t.overflow { - c.Assert(err, NotNil) + for _, tt := range tblInt { + ret, err := MulInteger(tt.lsh, tt.rsh) + if tt.overflow { + require.Error(t, err) } else { - c.Assert(ret, Equals, t.ret) + require.Equal(t, tt.ret, ret) } } } -func (s *testOverflowSuite) TestDiv(c *C) { - defer testleak.AfterTest(c)() +func TestDiv(t *testing.T) { + t.Parallel() + tblInt64 := []struct { lsh int64 rsh int64 @@ -294,12 +293,12 @@ func (s *testOverflowSuite) TestDiv(c *C) { {math.MinInt64, 2, math.MinInt64 / 2, false}, } - for _, t := range tblInt64 { - ret, err := DivInt64(t.lsh, t.rsh) - if t.overflow { - c.Assert(err, NotNil) + for _, tt := range tblInt64 { + ret, err := DivInt64(tt.lsh, tt.rsh) + if tt.overflow { + require.Error(t, err) } else { - c.Assert(ret, Equals, t.ret) + require.Equal(t, tt.ret, ret) } } @@ -316,12 +315,12 @@ func (s *testOverflowSuite) TestDiv(c *C) { {100, 20, 5, false}, } - for _, t := range tblInt { - ret, err := DivUintWithInt(t.lsh, t.rsh) - if t.overflow { - c.Assert(err, NotNil) + for _, tt := range tblInt { + ret, err := DivUintWithInt(tt.lsh, tt.rsh) + if tt.overflow { + require.Error(t, err) } else { - c.Assert(ret, Equals, t.ret) + require.Equal(t, tt.ret, ret) } } @@ -332,18 +331,18 @@ func (s *testOverflowSuite) TestDiv(c *C) { overflow bool err string }{ - {math.MinInt64, math.MaxInt64, 0, true, "*BIGINT UNSIGNED value is out of range in '\\(-9223372036854775808, 9223372036854775807\\)'"}, + {math.MinInt64, math.MaxInt64, 0, true, "^*BIGINT UNSIGNED value is out of range in '\\(-9223372036854775808, 9223372036854775807\\)'$"}, {0, 1, 0, false, ""}, {-1, math.MaxInt64, 0, false, ""}, } - for _, t := range tblInt2 { - ret, err := DivIntWithUint(t.lsh, t.rsh) - if t.overflow { - c.Assert(err, NotNil) - c.Assert(err, ErrorMatches, t.err) + for _, tt := range tblInt2 { + ret, err := DivIntWithUint(tt.lsh, tt.rsh) + if tt.overflow { + require.Error(t, err) + require.Regexp(t, tt.err, err.Error()) } else { - c.Assert(ret, Equals, t.ret) + require.Equal(t, tt.ret, ret) } } }