From 5bb3c06f53cc65cd131864ed3a7722fb9ade53a0 Mon Sep 17 00:00:00 2001 From: "Zhuomin(Charming) Liu" Date: Fri, 26 Jul 2019 14:30:40 +0800 Subject: [PATCH] executor: fix autoid doesn't handle float, double type and tiny cleanup (#11110) (#11385) --- executor/insert_common.go | 30 ++++- executor/insert_test.go | 245 ++++++++++++++++++++++++++++++++++++++ executor/update_test.go | 125 +++++++++++++++++++ executor/write.go | 6 +- 4 files changed, 399 insertions(+), 7 deletions(-) diff --git a/executor/insert_common.go b/executor/insert_common.go index 851305da981d9..70750b4ea60c2 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -15,6 +15,7 @@ package executor import ( "context" + "math" "github.com/pingcap/errors" "github.com/pingcap/parser/ast" @@ -479,12 +480,10 @@ func (e *InsertValues) adjustAutoIncrementDatum(ctx context.Context, d types.Dat d.SetNull() } if !d.IsNull() { - sc := e.ctx.GetSessionVars().StmtCtx - datum, err1 := d.ConvertTo(sc, &c.FieldType) - if e.filterErr(err1) != nil { - return types.Datum{}, err1 + recordID, err = getAutoRecordID(d, &c.FieldType, true) + if err != nil { + return types.Datum{}, err } - recordID = datum.GetInt64() } // Use the value if it's not null and not 0. if recordID != 0 { @@ -494,7 +493,6 @@ func (e *InsertValues) adjustAutoIncrementDatum(ctx context.Context, d types.Dat } e.ctx.GetSessionVars().StmtCtx.InsertID = uint64(recordID) retryInfo.AddAutoIncrementID(recordID) - d.SetAutoID(recordID, c.Flag) return d, nil } @@ -522,6 +520,26 @@ func (e *InsertValues) adjustAutoIncrementDatum(ctx context.Context, d types.Dat return casted, nil } +func getAutoRecordID(d types.Datum, target *types.FieldType, isInsert bool) (int64, error) { + var recordID int64 + + switch target.Tp { + case mysql.TypeFloat, mysql.TypeDouble: + f := d.GetFloat64() + if isInsert { + recordID = int64(math.Round(f)) + } else { + recordID = int64(f) + } + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + recordID = d.GetInt64() + default: + return 0, errors.Errorf("unexpected field type [%v]", target.Tp) + } + + return recordID, nil +} + func (e *InsertValues) handleWarning(err error) { sc := e.ctx.GetSessionVars().StmtCtx sc.AppendWarning(err) diff --git a/executor/insert_test.go b/executor/insert_test.go index e26834572f12b..4e60d258af50a 100644 --- a/executor/insert_test.go +++ b/executor/insert_test.go @@ -300,3 +300,248 @@ func (s *testSuite3) TestPartitionInsertOnDuplicate(c *C) { tk.MustExec(`insert into t2 set a=1,b=1 on duplicate key update a=1,b=1`) tk.MustQuery(`select * from t2`).Check(testkit.Rows("1 1")) } + +func (s *testSuite3) TestInsertWithAutoidSchema(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1(id int primary key auto_increment, n int);`) + tk.MustExec(`create table t2(id int unsigned primary key auto_increment, n int);`) + tk.MustExec(`create table t3(id tinyint primary key auto_increment, n int);`) + tk.MustExec(`create table t4(id int primary key, n float auto_increment, key I_n(n));`) + tk.MustExec(`create table t5(id int primary key, n float unsigned auto_increment, key I_n(n));`) + tk.MustExec(`create table t6(id int primary key, n double auto_increment, key I_n(n));`) + tk.MustExec(`create table t7(id int primary key, n double unsigned auto_increment, key I_n(n));`) + + tests := []struct { + insert string + query string + result [][]interface{} + }{ + { + `insert into t1(id, n) values(1, 1)`, + `select * from t1 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t1(n) values(2)`, + `select * from t1 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t1(n) values(3)`, + `select * from t1 where id = 3`, + testkit.Rows(`3 3`), + }, + { + `insert into t1(id, n) values(-1, 4)`, + `select * from t1 where id = -1`, + testkit.Rows(`-1 4`), + }, + { + `insert into t1(n) values(5)`, + `select * from t1 where id = 4`, + testkit.Rows(`4 5`), + }, + { + `insert into t1(id, n) values('5', 6)`, + `select * from t1 where id = 5`, + testkit.Rows(`5 6`), + }, + { + `insert into t1(n) values(7)`, + `select * from t1 where id = 6`, + testkit.Rows(`6 7`), + }, + { + `insert into t1(id, n) values(7.4, 8)`, + `select * from t1 where id = 7`, + testkit.Rows(`7 8`), + }, + { + `insert into t1(id, n) values(7.5, 9)`, + `select * from t1 where id = 8`, + testkit.Rows(`8 9`), + }, + { + `insert into t1(n) values(9)`, + `select * from t1 where id = 9`, + testkit.Rows(`9 9`), + }, + { + `insert into t2(id, n) values(1, 1)`, + `select * from t2 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t2(n) values(2)`, + `select * from t2 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t2(n) values(3)`, + `select * from t2 where id = 3`, + testkit.Rows(`3 3`), + }, + { + `insert into t3(id, n) values(1, 1)`, + `select * from t3 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t3(n) values(2)`, + `select * from t3 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t3(n) values(3)`, + `select * from t3 where id = 3`, + testkit.Rows(`3 3`), + }, + { + `insert into t3(id, n) values(-1, 4)`, + `select * from t3 where id = -1`, + testkit.Rows(`-1 4`), + }, + { + `insert into t3(n) values(5)`, + `select * from t3 where id = 4`, + testkit.Rows(`4 5`), + }, + { + `insert into t4(id, n) values(1, 1)`, + `select * from t4 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t4(id) values(2)`, + `select * from t4 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t4(id, n) values(3, -1)`, + `select * from t4 where id = 3`, + testkit.Rows(`3 -1`), + }, + { + `insert into t4(id) values(4)`, + `select * from t4 where id = 4`, + testkit.Rows(`4 3`), + }, + { + `insert into t4(id, n) values(5, 5.5)`, + `select * from t4 where id = 5`, + testkit.Rows(`5 5.5`), + }, + { + `insert into t4(id) values(6)`, + `select * from t4 where id = 6`, + testkit.Rows(`6 7`), + }, + { + `insert into t4(id, n) values(7, '7.7')`, + `select * from t4 where id = 7`, + testkit.Rows(`7 7.7`), + }, + { + `insert into t4(id) values(8)`, + `select * from t4 where id = 8`, + testkit.Rows(`8 9`), + }, + { + `insert into t4(id, n) values(9, 10.4)`, + `select * from t4 where id = 9`, + testkit.Rows(`9 10.4`), + }, + { + `insert into t4(id) values(10)`, + `select * from t4 where id = 10`, + testkit.Rows(`10 11`), + }, + { + `insert into t5(id, n) values(1, 1)`, + `select * from t5 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t5(id) values(2)`, + `select * from t5 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t5(id) values(3)`, + `select * from t5 where id = 3`, + testkit.Rows(`3 3`), + }, + { + `insert into t6(id, n) values(1, 1)`, + `select * from t6 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t6(id) values(2)`, + `select * from t6 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t6(id, n) values(3, -1)`, + `select * from t6 where id = 3`, + testkit.Rows(`3 -1`), + }, + { + `insert into t6(id) values(4)`, + `select * from t6 where id = 4`, + testkit.Rows(`4 3`), + }, + { + `insert into t6(id, n) values(5, 5.5)`, + `select * from t6 where id = 5`, + testkit.Rows(`5 5.5`), + }, + { + `insert into t6(id) values(6)`, + `select * from t6 where id = 6`, + testkit.Rows(`6 7`), + }, + { + `insert into t6(id, n) values(7, '7.7')`, + `select * from t4 where id = 7`, + testkit.Rows(`7 7.7`), + }, + { + `insert into t6(id) values(8)`, + `select * from t4 where id = 8`, + testkit.Rows(`8 9`), + }, + { + `insert into t6(id, n) values(9, 10.4)`, + `select * from t6 where id = 9`, + testkit.Rows(`9 10.4`), + }, + { + `insert into t6(id) values(10)`, + `select * from t6 where id = 10`, + testkit.Rows(`10 11`), + }, + { + `insert into t7(id, n) values(1, 1)`, + `select * from t7 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t7(id) values(2)`, + `select * from t7 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t7(id) values(3)`, + `select * from t7 where id = 3`, + testkit.Rows(`3 3`), + }, + } + + for _, tt := range tests { + tk.MustExec(tt.insert) + tk.MustQuery(tt.query).Check(tt.result) + } + +} diff --git a/executor/update_test.go b/executor/update_test.go index cc51e8dc02f08..30206d4d213c4 100644 --- a/executor/update_test.go +++ b/executor/update_test.go @@ -87,3 +87,128 @@ func (s *testUpdateSuite) TestUpdateGenColInTxn(c *C) { tk.MustQuery(`select * from t;`).Check(testkit.Rows( `1 2`)) } + +func (s *testUpdateSuite) TestUpdateWithAutoidSchema(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1(id int primary key auto_increment, n int);`) + tk.MustExec(`create table t2(id int primary key, n float auto_increment, key I_n(n));`) + tk.MustExec(`create table t3(id int primary key, n double auto_increment, key I_n(n));`) + + tests := []struct { + exec string + query string + result [][]interface{} + }{ + { + `insert into t1 set n = 1`, + `select * from t1 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `update t1 set id = id+1`, + `select * from t1 where id = 2`, + testkit.Rows(`2 1`), + }, + { + `insert into t1 set n = 2`, + `select * from t1 where id = 3`, + testkit.Rows(`3 2`), + }, + { + `update t1 set id = id + '1.1' where id = 3`, + `select * from t1 where id = 4`, + testkit.Rows(`4 2`), + }, + { + `insert into t1 set n = 3`, + `select * from t1 where id = 5`, + testkit.Rows(`5 3`), + }, + { + `update t1 set id = id + '0.5' where id = 5`, + `select * from t1 where id = 6`, + testkit.Rows(`6 3`), + }, + { + `insert into t1 set n = 4`, + `select * from t1 where id = 7`, + testkit.Rows(`7 4`), + }, + { + `insert into t2 set id = 1`, + `select * from t2 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `update t2 set n = n+1`, + `select * from t2 where id = 1`, + testkit.Rows(`1 2`), + }, + { + `insert into t2 set id = 2`, + `select * from t2 where id = 2`, + testkit.Rows(`2 3`), + }, + { + `update t2 set n = n + '2.2'`, + `select * from t2 where id = 2`, + testkit.Rows(`2 5.2`), + }, + { + `insert into t2 set id = 3`, + `select * from t2 where id = 3`, + testkit.Rows(`3 6`), + }, + { + `update t2 set n = n + '0.5' where id = 3`, + `select * from t2 where id = 3`, + testkit.Rows(`3 6.5`), + }, + { + `insert into t2 set id = 4`, + `select * from t2 where id = 4`, + testkit.Rows(`4 7`), + }, + { + `insert into t3 set id = 1`, + `select * from t3 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `update t3 set n = n+1`, + `select * from t3 where id = 1`, + testkit.Rows(`1 2`), + }, + { + `insert into t3 set id = 2`, + `select * from t3 where id = 2`, + testkit.Rows(`2 3`), + }, + { + `update t3 set n = n + '3.3'`, + `select * from t3 where id = 2`, + testkit.Rows(`2 6.3`), + }, + { + `insert into t3 set id = 3`, + `select * from t3 where id = 3`, + testkit.Rows(`3 7`), + }, + { + `update t3 set n = n + '0.5' where id = 3`, + `select * from t3 where id = 3`, + testkit.Rows(`3 7.5`), + }, + { + `insert into t3 set id = 4`, + `select * from t3 where id = 4`, + testkit.Rows(`4 8`), + }, + } + + for _, tt := range tests { + tk.MustExec(tt.exec) + tk.MustQuery(tt.query).Check(tt.result) + } +} diff --git a/executor/write.go b/executor/write.go index f93c302b479b1..2a531eeff63f0 100644 --- a/executor/write.go +++ b/executor/write.go @@ -88,7 +88,11 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu modified[i] = true // Rebase auto increment id if the field is changed. if mysql.HasAutoIncrementFlag(col.Flag) { - if err = t.RebaseAutoID(ctx, newData[i].GetInt64(), true); err != nil { + recordID, err := getAutoRecordID(newData[i], &col.FieldType, false) + if err != nil { + return false, false, 0, err + } + if err = t.RebaseAutoID(ctx, recordID, true); err != nil { return false, false, 0, err } }