From 9f6d8755e44564b74017198cb312052f524af68d Mon Sep 17 00:00:00 2001 From: georgehao Date: Thu, 8 Dec 2022 09:51:13 +0800 Subject: [PATCH] optimize some format (#392) Co-authored-by: haohongfan1 --- pkg/datasource/sql/conn.go | 3 +- pkg/datasource/sql/conn_at.go | 10 ++-- pkg/datasource/sql/conn_at_test.go | 12 ++--- pkg/datasource/sql/conn_xa.go | 4 +- pkg/datasource/sql/conn_xa_test.go | 12 ++--- pkg/datasource/sql/connector.go | 6 +-- pkg/datasource/sql/connector_test.go | 4 +- pkg/datasource/sql/driver.go | 3 +- pkg/datasource/sql/exec/executor.go | 63 ++++++++++------------- pkg/datasource/sql/exec/hook.go | 17 +++--- pkg/datasource/sql/exec/xa/executor_xa.go | 30 +++++------ pkg/datasource/sql/stmt.go | 24 ++++----- pkg/datasource/sql/tx.go | 6 +-- pkg/datasource/sql/tx_at.go | 2 +- pkg/datasource/sql/tx_xa.go | 2 +- pkg/datasource/sql/types/executor.go | 20 +++---- pkg/datasource/sql/types/types.go | 36 ++++++------- 17 files changed, 113 insertions(+), 141 deletions(-) diff --git a/pkg/datasource/sql/conn.go b/pkg/datasource/sql/conn.go index ea9f25a1e..3f224c991 100644 --- a/pkg/datasource/sql/conn.go +++ b/pkg/datasource/sql/conn.go @@ -29,7 +29,6 @@ import ( // by multiple goroutines. // // Conn is assumed to be stateful. - type Conn struct { res *DBResource txCtx *types.TransactionContext @@ -135,7 +134,7 @@ func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) { return nil, driver.ErrSkip } - executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TxType, query) + executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransactionMode, query) if err != nil { return nil, err } diff --git a/pkg/datasource/sql/conn_at.go b/pkg/datasource/sql/conn_at.go index 94fb8fe8b..5e2edd54c 100644 --- a/pkg/datasource/sql/conn_at.go +++ b/pkg/datasource/sql/conn_at.go @@ -53,7 +53,7 @@ func (c *ATConn) QueryContext(ctx context.Context, query string, args []driver.N } ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types.ExecResult, error) { - executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TxType, query) + executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransactionMode, query) if err != nil { return nil, err } @@ -89,7 +89,7 @@ func (c *ATConn) ExecContext(ctx context.Context, query string, args []driver.Na } ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types.ExecResult, error) { - executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TxType, query) + executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransactionMode, query) if err != nil { return nil, err } @@ -130,7 +130,7 @@ func (c *ATConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, if tm.IsGlobalTx(ctx) { c.txCtx.XID = tm.GetXID(ctx) - c.txCtx.TxType = types.ATMode + c.txCtx.TransactionMode = types.ATMode } tx, err := c.Conn.BeginTx(ctx, opts) @@ -149,7 +149,7 @@ func (c *ATConn) createOnceTxContext(ctx context.Context) bool { c.txCtx.DBType = c.res.dbType c.txCtx.ResourceID = c.res.resourceID c.txCtx.XID = tm.GetXID(ctx) - c.txCtx.TxType = types.ATMode + c.txCtx.TransactionMode = types.ATMode c.txCtx.GlobalLockRequire = true } @@ -162,7 +162,7 @@ func (c *ATConn) createNewTxOnExecIfNeed(ctx context.Context, f func() (types.Ex err error ) - if c.txCtx.TxType != types.Local && c.autoCommit { + if c.txCtx.TransactionMode != types.Local && c.autoCommit { tx, err = c.BeginTx(ctx, driver.TxOptions{Isolation: driver.IsolationLevel(gosql.LevelDefault)}) if err != nil { return nil, err diff --git a/pkg/datasource/sql/conn_at_test.go b/pkg/datasource/sql/conn_at_test.go index ca868923c..017f0d587 100644 --- a/pkg/datasource/sql/conn_at_test.go +++ b/pkg/datasource/sql/conn_at_test.go @@ -87,14 +87,14 @@ func TestATConn_ExecContext(t *testing.T) { beforeHook := func(_ context.Context, execCtx *types.ExecContext) { t.Logf("on exec xid=%s", execCtx.TxCtx.XID) assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XID) - assert.Equal(t, types.ATMode, execCtx.TxCtx.TxType) + assert.Equal(t, types.ATMode, execCtx.TxCtx.TransactionMode) } mi.before = beforeHook var comitCnt int32 beforeCommit := func(tx *Tx) { atomic.AddInt32(&comitCnt, 1) - assert.Equal(t, types.ATMode, tx.tranCtx.TxType) + assert.Equal(t, types.ATMode, tx.tranCtx.TransactionMode) } ti.beforeCommit = beforeCommit @@ -112,7 +112,7 @@ func TestATConn_ExecContext(t *testing.T) { t.Run("not xid", func(t *testing.T) { mi.before = func(_ context.Context, execCtx *types.ExecContext) { assert.Equal(t, "", execCtx.TxCtx.XID) - assert.Equal(t, types.Local, execCtx.TxCtx.TxType) + assert.Equal(t, types.Local, execCtx.TxCtx.TransactionMode) } var comitCnt int32 @@ -149,7 +149,7 @@ func TestATConn_BeginTx(t *testing.T) { mi.before = func(_ context.Context, execCtx *types.ExecContext) { assert.Equal(t, "", execCtx.TxCtx.XID) - assert.Equal(t, types.Local, execCtx.TxCtx.TxType) + assert.Equal(t, types.Local, execCtx.TxCtx.TransactionMode) } var comitCnt int32 @@ -175,7 +175,7 @@ func TestATConn_BeginTx(t *testing.T) { mi.before = func(_ context.Context, execCtx *types.ExecContext) { assert.Equal(t, "", execCtx.TxCtx.XID) - assert.Equal(t, types.Local, execCtx.TxCtx.TxType) + assert.Equal(t, types.Local, execCtx.TxCtx.TransactionMode) } var comitCnt int32 @@ -203,7 +203,7 @@ func TestATConn_BeginTx(t *testing.T) { mi.before = func(_ context.Context, execCtx *types.ExecContext) { assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XID) - assert.Equal(t, types.ATMode, execCtx.TxCtx.TxType) + assert.Equal(t, types.ATMode, execCtx.TxCtx.TransactionMode) } var comitCnt int32 diff --git a/pkg/datasource/sql/conn_xa.go b/pkg/datasource/sql/conn_xa.go index 3f8ba9212..7d1b66f7d 100644 --- a/pkg/datasource/sql/conn_xa.go +++ b/pkg/datasource/sql/conn_xa.go @@ -73,7 +73,7 @@ func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, c.txCtx.TxOpt = opts if tm.IsGlobalTx(ctx) { - c.txCtx.TxType = types.XAMode + c.txCtx.TransactionMode = types.XAMode c.txCtx.XID = tm.GetXID(ctx) } @@ -92,7 +92,7 @@ func (c *XAConn) createOnceTxContext(ctx context.Context) bool { c.txCtx = types.NewTxCtx() c.txCtx.DBType = c.res.dbType c.txCtx.XID = tm.GetXID(ctx) - c.txCtx.TxType = types.XAMode + c.txCtx.TransactionMode = types.XAMode } return onceTx diff --git a/pkg/datasource/sql/conn_xa_test.go b/pkg/datasource/sql/conn_xa_test.go index b06fce24f..ee04412cd 100644 --- a/pkg/datasource/sql/conn_xa_test.go +++ b/pkg/datasource/sql/conn_xa_test.go @@ -138,14 +138,14 @@ func TestXAConn_ExecContext(t *testing.T) { before := func(_ context.Context, execCtx *types.ExecContext) { t.Logf("on exec xid=%s", execCtx.TxCtx.XID) assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XID) - assert.Equal(t, types.XAMode, execCtx.TxCtx.TxType) + assert.Equal(t, types.XAMode, execCtx.TxCtx.TransactionMode) } mi.before = before var comitCnt int32 beforeCommit := func(tx *Tx) { atomic.AddInt32(&comitCnt, 1) - assert.Equal(t, tx.tranCtx.TxType, types.XAMode) + assert.Equal(t, tx.tranCtx.TransactionMode, types.XAMode) } ti.beforeCommit = beforeCommit @@ -164,7 +164,7 @@ func TestXAConn_ExecContext(t *testing.T) { t.Run("not xid", func(t *testing.T) { before := func(_ context.Context, execCtx *types.ExecContext) { assert.Equal(t, "", execCtx.TxCtx.XID) - assert.Equal(t, types.Local, execCtx.TxCtx.TxType) + assert.Equal(t, types.Local, execCtx.TxCtx.TransactionMode) } mi.before = before @@ -203,7 +203,7 @@ func TestXAConn_BeginTx(t *testing.T) { mi.before = func(_ context.Context, execCtx *types.ExecContext) { assert.Equal(t, "", execCtx.TxCtx.XID) - assert.Equal(t, types.Local, execCtx.TxCtx.TxType) + assert.Equal(t, types.Local, execCtx.TxCtx.TransactionMode) } var comitCnt int32 @@ -229,7 +229,7 @@ func TestXAConn_BeginTx(t *testing.T) { mi.before = func(_ context.Context, execCtx *types.ExecContext) { assert.Equal(t, "", execCtx.TxCtx.XID) - assert.Equal(t, types.Local, execCtx.TxCtx.TxType) + assert.Equal(t, types.Local, execCtx.TxCtx.TransactionMode) } var comitCnt int32 @@ -257,7 +257,7 @@ func TestXAConn_BeginTx(t *testing.T) { mi.before = func(_ context.Context, execCtx *types.ExecContext) { assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XID) - assert.Equal(t, types.XAMode, execCtx.TxCtx.TxType) + assert.Equal(t, types.XAMode, execCtx.TxCtx.TransactionMode) } var comitCnt int32 diff --git a/pkg/datasource/sql/connector.go b/pkg/datasource/sql/connector.go index 1eee50761..36d861a5b 100644 --- a/pkg/datasource/sql/connector.go +++ b/pkg/datasource/sql/connector.go @@ -29,7 +29,7 @@ import ( type seataATConnector struct { *seataConnector - transType types.TransactionType + transType types.TransactionMode } func (c *seataATConnector) Connect(ctx context.Context) (driver.Conn, error) { @@ -53,7 +53,7 @@ func (c *seataATConnector) Driver() driver.Driver { type seataXAConnector struct { *seataConnector - transType types.TransactionType + transType types.TransactionMode } func (c *seataXAConnector) Connect(ctx context.Context) (driver.Conn, error) { @@ -88,7 +88,7 @@ func (c *seataXAConnector) Driver() driver.Driver { // If a Connector implements io.Closer, the sql package's DB.Close // method will call Close and return error (if any). type seataConnector struct { - transType types.TransactionType + transType types.TransactionMode conf *seataServerConfig res *DBResource once sync.Once diff --git a/pkg/datasource/sql/connector_test.go b/pkg/datasource/sql/connector_test.go index 9223724d2..950e81337 100644 --- a/pkg/datasource/sql/connector_test.go +++ b/pkg/datasource/sql/connector_test.go @@ -82,7 +82,7 @@ func Test_seataATConnector_Connect(t *testing.T) { atConn, ok := conn.(*ATConn) assert.True(t, ok, "need return seata at connection") - assert.True(t, atConn.txCtx.TxType == types.Local, "init need local tx") + assert.True(t, atConn.txCtx.TransactionMode == types.Local, "init need local tx") } func initMockXaConnector(t *testing.T, ctrl *gomock.Controller, db *sql.DB, f initConnectorFunc) driver.Connector { @@ -126,5 +126,5 @@ func Test_seataXAConnector_Connect(t *testing.T) { xaConn, ok := conn.(*XAConn) assert.True(t, ok, "need return seata xa connection") - assert.True(t, xaConn.txCtx.TxType == types.Local, "init need local tx") + assert.True(t, xaConn.txCtx.TransactionMode == types.Local, "init need local tx") } diff --git a/pkg/datasource/sql/driver.go b/pkg/datasource/sql/driver.go index 04b79b33b..01e3c2a00 100644 --- a/pkg/datasource/sql/driver.go +++ b/pkg/datasource/sql/driver.go @@ -47,6 +47,7 @@ func init() { target: mysql.MySQLDriver{}, }, }) + sql.Register(SeataXAMySQLDriver, &seataXADriver{ seataDriver: &seataDriver{ transType: types.XAMode, @@ -96,7 +97,7 @@ func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err erro } type seataDriver struct { - transType types.TransactionType + transType types.TransactionMode target driver.Driver } diff --git a/pkg/datasource/sql/exec/executor.go b/pkg/datasource/sql/exec/executor.go index 2c2cff60e..ed604aa13 100644 --- a/pkg/datasource/sql/exec/executor.go +++ b/pkg/datasource/sql/exec/executor.go @@ -22,6 +22,7 @@ import ( "database/sql/driver" "fmt" + "github.com/mitchellh/copystructure" "github.com/pkg/errors" "github.com/seata/seata-go/pkg/datasource/sql/parser" @@ -30,8 +31,6 @@ import ( "github.com/seata/seata-go/pkg/datasource/sql/undo/builder" "github.com/seata/seata-go/pkg/tm" "github.com/seata/seata-go/pkg/util/log" - - "github.com/mitchellh/copystructure" ) func init() { @@ -39,13 +38,12 @@ func init() { undo.RegisterUndoLogBuilder(types.MultiExecutor, builder.GetMySQLMultiUndoLogBuilder) } -// executorSolts var ( executorSoltsAT = make(map[types.DBType]map[types.ExecutorType]func() SQLExecutor) executorSoltsXA = make(map[types.DBType]func() SQLExecutor) ) -// RegisterATExecutor +// RegisterATExecutor AT executor func RegisterATExecutor(dt types.DBType, et types.ExecutorType, builder func() SQLExecutor) { if _, ok := executorSoltsAT[dt]; !ok { executorSoltsAT[dt] = make(map[types.ExecutorType]func() SQLExecutor) @@ -58,7 +56,7 @@ func RegisterATExecutor(dt types.DBType, et types.ExecutorType, builder func() S } } -// RegisterXAExecutor +// RegisterXAExecutor XA executor func RegisterXAExecutor(dt types.DBType, builder func() SQLExecutor) { executorSoltsXA[dt] = func() SQLExecutor { return &BaseExecutor{ex: builder()} @@ -71,41 +69,37 @@ type ( CallbackWithValue func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) SQLExecutor interface { - // Interceptors Interceptors(interceptors []SQLHook) - // Exec ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) - // Exec ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error) } ) -// BuildExecutor -func BuildExecutor(dbType types.DBType, txType types.TransactionType, query string) (SQLExecutor, error) { - parseCtx, err := parser.DoParser(query) +// BuildExecutor use db type and transaction type to build an executor. the executor can +// add custom hook, and intercept the user's business sql to generate the undo log. +func BuildExecutor(dbType types.DBType, transactionMode types.TransactionMode, query string) (SQLExecutor, error) { + parseContext, err := parser.DoParser(query) if err != nil { return nil, err } hooks := make([]SQLHook, 0, 4) hooks = append(hooks, commonHook...) - hooks = append(hooks, hookSolts[parseCtx.SQLType]...) - hooks = append(hooks, commonHook...) + hooks = append(hooks, hookSolts[parseContext.SQLType]...) - if txType == types.XAMode { + if transactionMode == types.XAMode { e := executorSoltsXA[dbType]() e.Interceptors(hooks) return e, nil } - if txType == types.ATMode { - e := executorSoltsAT[dbType][parseCtx.ExecutorType]() + if transactionMode == types.ATMode { + e := executorSoltsAT[dbType][parseContext.ExecutorType]() e.Interceptors(hooks) return e, nil } factories, ok := executorSoltsAT[dbType] - if !ok { log.Debugf("%s not found executor factories, return default Executor", dbType.String()) e := &BaseExecutor{} @@ -113,10 +107,10 @@ func BuildExecutor(dbType types.DBType, txType types.TransactionType, query stri return e, nil } - supplier, ok := factories[parseCtx.ExecutorType] + supplier, ok := factories[parseContext.ExecutorType] if !ok { log.Debugf("%s not found executor for %s, return default Executor", - dbType.String(), parseCtx.ExecutorType) + dbType.String(), parseContext.ExecutorType) e := &BaseExecutor{} e.Interceptors(hooks) return e, nil @@ -128,19 +122,17 @@ func BuildExecutor(dbType types.DBType, txType types.TransactionType, query stri } type BaseExecutor struct { - is []SQLHook - ex SQLExecutor + hooks []SQLHook + ex SQLExecutor } -// Interceptors -func (e *BaseExecutor) Interceptors(interceptors []SQLHook) { - e.is = interceptors +func (e *BaseExecutor) Interceptors(hooks []SQLHook) { + e.hooks = hooks } -// ExecWithNamedValue func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) { - for i := range e.is { - _ = e.is[i].Before(ctx, execCtx) + for _, hook := range e.hooks { + hook.Before(ctx, execCtx) } var ( @@ -167,8 +159,8 @@ func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Ex } defer func() { - for i := range e.is { - _ = e.is[i].After(ctx, execCtx) + for _, hook := range e.hooks { + hook.After(ctx, execCtx) } }() @@ -210,10 +202,9 @@ func (e *BaseExecutor) prepareUndoLog(ctx context.Context, execCtx *types.ExecCo return undoLogManager.FlushUndoLog(execCtx.TxCtx, execCtx.Conn) } -// ExecWithValue func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error) { - for i := range e.is { - e.is[i].Before(ctx, execCtx) + for _, hook := range e.hooks { + hook.Before(ctx, execCtx) } var ( @@ -232,8 +223,8 @@ func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecCon } defer func() { - for i := range e.is { - _ = e.is[i].After(ctx, execCtx) + for _, hook := range e.hooks { + hook.After(ctx, execCtx) } }() @@ -257,7 +248,7 @@ func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecCon return result, err } -func (h *BaseExecutor) beforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) { +func (e *BaseExecutor) beforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) { if !tm.IsGlobalTx(ctx) { return nil, nil } @@ -279,7 +270,7 @@ func (h *BaseExecutor) beforeImage(ctx context.Context, execCtx *types.ExecConte } // After -func (h *BaseExecutor) afterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { +func (e *BaseExecutor) afterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { if !tm.IsGlobalTx(ctx) { return nil, nil } diff --git a/pkg/datasource/sql/exec/hook.go b/pkg/datasource/sql/exec/hook.go index a9c47a486..a2d703cc0 100644 --- a/pkg/datasource/sql/exec/hook.go +++ b/pkg/datasource/sql/exec/hook.go @@ -29,7 +29,7 @@ var ( hookSolts = map[types.SQLType][]SQLHook{} ) -// RegisCommonHook not goroutine safe +// RegisterCommonHook not goroutine safe func RegisterCommonHook(hook SQLHook) { commonHook = append(commonHook, hook) } @@ -40,13 +40,16 @@ func CleanCommonHook() { // RegisterHook not goroutine safe func RegisterHook(hook SQLHook) { - _, ok := hookSolts[hook.Type()] + sqlType := hook.Type() + if sqlType == types.SQLTypeUnknown { + return + } + _, ok := hookSolts[sqlType] if !ok { - hookSolts[hook.Type()] = make([]SQLHook, 0, 4) + hookSolts[sqlType] = make([]SQLHook, 0, 4) } - - hookSolts[hook.Type()] = append(hookSolts[hook.Type()], hook) + hookSolts[sqlType] = append(hookSolts[sqlType], hook) } // SQLHook SQL execution front and back interceptor @@ -55,10 +58,6 @@ func RegisterHook(hook SQLHook) { // case 3. SQL black and white list type SQLHook interface { Type() types.SQLType - - // Before Before(ctx context.Context, execCtx *types.ExecContext) error - - // After After(ctx context.Context, execCtx *types.ExecContext) error } diff --git a/pkg/datasource/sql/exec/xa/executor_xa.go b/pkg/datasource/sql/exec/xa/executor_xa.go index fb9da4b2c..e7a0c3c6c 100644 --- a/pkg/datasource/sql/exec/xa/executor_xa.go +++ b/pkg/datasource/sql/exec/xa/executor_xa.go @@ -24,28 +24,26 @@ import ( "github.com/seata/seata-go/pkg/datasource/sql/types" ) -// todo -// 完善XA prepare -// +// XAExecutor The XA transaction manager. type XAExecutor struct { - is []exec.SQLHook - ex exec.SQLExecutor + hooks []exec.SQLHook + ex exec.SQLExecutor } -// Interceptors -func (e *XAExecutor) Interceptors(interceptors []exec.SQLHook) { - e.is = interceptors +// Interceptors set xa executor hooks +func (e *XAExecutor) Interceptors(hooks []exec.SQLHook) { + e.hooks = hooks } // ExecWithNamedValue func (e *XAExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { - for i := range e.is { - e.is[i].Before(ctx, execCtx) + for _, hook := range e.hooks { + hook.Before(ctx, execCtx) } defer func() { - for i := range e.is { - e.is[i].After(ctx, execCtx) + for _, hook := range e.hooks { + hook.After(ctx, execCtx) } }() @@ -58,13 +56,13 @@ func (e *XAExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Exec // ExecWithValue func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithValue) (types.ExecResult, error) { - for i := range e.is { - e.is[i].Before(ctx, execCtx) + for _, hook := range e.hooks { + hook.Before(ctx, execCtx) } defer func() { - for i := range e.is { - e.is[i].After(ctx, execCtx) + for _, hook := range e.hooks { + hook.After(ctx, execCtx) } }() diff --git a/pkg/datasource/sql/stmt.go b/pkg/datasource/sql/stmt.go index 4522e4d92..9e5a84ed8 100644 --- a/pkg/datasource/sql/stmt.go +++ b/pkg/datasource/sql/stmt.go @@ -26,15 +26,11 @@ import ( ) type Stmt struct { - conn *Conn - // res - res *DBResource - // txCtx + conn *Conn + res *DBResource txCtx *types.TransactionContext - // query query string - // stmt - stmt driver.Stmt + stmt driver.Stmt } // Close closes the statement. @@ -67,7 +63,7 @@ func (s *Stmt) NumInput() int { // // Deprecated: Drivers should implement StmtQueryContext instead (or additionally). func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { - executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TxType, s.query) + executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransactionMode, s.query) if err != nil { return nil, err } @@ -94,10 +90,8 @@ func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { return ret.GetRows(), nil } -// StmtQueryContext enhances the Stmt interface by providing Query with context. -// QueryContext executes a query that may return rows, such as a -// SELECT. -// +// QueryContext StmtQueryContext enhances the Stmt interface by providing Query with context. +// QueryContext executes a query that may return rows, such as a SELECT. // QueryContext must honor the context timeout and return when it is canceled. func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { stmt, ok := s.stmt.(driver.StmtQueryContext) @@ -105,7 +99,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv return nil, driver.ErrSkip } - executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TxType, s.query) + executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransactionMode, s.query) if err != nil { return nil, err } @@ -138,7 +132,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv // Deprecated: Drivers should implement StmtExecContext instead (or additionally). func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { // in transaction, need run Executor - executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TxType, s.query) + executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransactionMode, s.query) if err != nil { return nil, err } @@ -173,7 +167,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive } // in transaction, need run Executor - executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TxType, s.query) + executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransactionMode, s.query) if err != nil { return nil, err } diff --git a/pkg/datasource/sql/tx.go b/pkg/datasource/sql/tx.go index b73d7b4fa..296c0b561 100644 --- a/pkg/datasource/sql/tx.go +++ b/pkg/datasource/sql/tx.go @@ -156,11 +156,11 @@ func (tx *Tx) register(ctx *types.TransactionContext) error { } request := rm.BranchRegisterParam{ Xid: ctx.XID, - BranchType: ctx.TxType.GetBranchType(), + BranchType: ctx.TransactionMode.BranchType(), ResourceId: ctx.ResourceID, LockKeys: lockKey, } - dataSourceManager := datasource.GetDataSourceManager(ctx.TxType.GetBranchType()) + dataSourceManager := datasource.GetDataSourceManager(ctx.TransactionMode.BranchType()) branchId, err := dataSourceManager.BranchRegister(context.Background(), request) if err != nil { log.Infof("Failed to report branch status: %s", err.Error()) @@ -181,7 +181,7 @@ func (tx *Tx) report(success bool) error { BranchId: int64(tx.tranCtx.BranchID), Status: status, } - dataSourceManager := datasource.GetDataSourceManager(tx.tranCtx.TxType.GetBranchType()) + dataSourceManager := datasource.GetDataSourceManager(tx.tranCtx.TransactionMode.BranchType()) if dataSourceManager == nil { return errors.New("get dataSourceManager failed") } diff --git a/pkg/datasource/sql/tx_at.go b/pkg/datasource/sql/tx_at.go index 04c73f6be..b8e2777a2 100644 --- a/pkg/datasource/sql/tx_at.go +++ b/pkg/datasource/sql/tx_at.go @@ -43,7 +43,7 @@ func (tx *ATTx) Rollback() error { originTx := tx.tx - if originTx.tranCtx.OpenGlobalTrsnaction() && originTx.tranCtx.IsBranchRegistered() { + if originTx.tranCtx.OpenGlobalTransaction() && originTx.tranCtx.IsBranchRegistered() { originTx.report(false) } } diff --git a/pkg/datasource/sql/tx_xa.go b/pkg/datasource/sql/tx_xa.go index d9e602ee7..5e6dea178 100644 --- a/pkg/datasource/sql/tx_xa.go +++ b/pkg/datasource/sql/tx_xa.go @@ -37,7 +37,7 @@ func (tx *XATx) Rollback() error { originTx := tx.tx - if originTx.tranCtx.OpenGlobalTrsnaction() && originTx.tranCtx.IsBranchRegistered() { + if originTx.tranCtx.OpenGlobalTransaction() && originTx.tranCtx.IsBranchRegistered() { originTx.report(false) } } diff --git a/pkg/datasource/sql/types/executor.go b/pkg/datasource/sql/types/executor.go index 362d43cfb..2a4b52818 100644 --- a/pkg/datasource/sql/types/executor.go +++ b/pkg/datasource/sql/types/executor.go @@ -25,9 +25,6 @@ import ( seatabytes "github.com/seata/seata-go/pkg/util/bytes" ) -// ExecutorType -// -//go:generate stringer -type=ExecutorType type ExecutorType int32 const ( @@ -45,18 +42,13 @@ const ( ) type ParseContext struct { - // SQLType - SQLType SQLType - // ExecutorType + SQLType SQLType ExecutorType ExecutorType - // InsertStmt - InsertStmt *ast.InsertStmt - // UpdateStmt - UpdateStmt *ast.UpdateStmt - SelectStmt *ast.SelectStmt - // DeleteStmt - DeleteStmt *ast.DeleteStmt - MultiStmt []*ParseContext + InsertStmt *ast.InsertStmt + UpdateStmt *ast.UpdateStmt + SelectStmt *ast.SelectStmt + DeleteStmt *ast.DeleteStmt + MultiStmt []*ParseContext } func (p *ParseContext) HasValidStmt() bool { diff --git a/pkg/datasource/sql/types/types.go b/pkg/datasource/sql/types/types.go index 5a393c16b..39066db35 100644 --- a/pkg/datasource/sql/types/types.go +++ b/pkg/datasource/sql/types/types.go @@ -27,11 +27,9 @@ import ( "github.com/google/uuid" ) -//go:generate stringer -type=DBType type DBType int16 type ( - // DBType // BranchPhase BranchPhase int8 // IndexType index type @@ -102,24 +100,24 @@ func ParseDBType(driverName string) DBType { } } -// TransactionType -type TransactionType int8 +type TransactionMode int8 const ( - _ TransactionType = iota + _ TransactionMode = iota Local XAMode ATMode ) -func (t TransactionType) GetBranchType() branch.BranchType { - if t == XAMode { +func (t TransactionMode) BranchType() branch.BranchType { + switch t { + case XAMode: return branch.BranchTypeXA - } - if t == ATMode { + case ATMode: return branch.BranchTypeAT + default: + return branch.BranchTypeUnknow } - return branch.BranchTypeUnknow } // TransactionContext seata-go‘s context of transaction @@ -132,8 +130,8 @@ type TransactionContext struct { DBType DBType // TxOpt transaction option TxOpt driver.TxOptions - // TxType transaction mode, eg. XA/AT - TxType TransactionType + // TransactionMode transaction mode, eg. XA/AT + TransactionMode TransactionMode // ResourceID resource id, database-table ResourceID string // BranchID transaction branch unique id @@ -167,16 +165,16 @@ type ExecContext struct { func NewTxCtx() *TransactionContext { return &TransactionContext{ - LockKeys: make(map[string]struct{}, 0), - TxType: Local, - LocalTransID: uuid.New().String(), - RoundImages: &RoundRecordImage{}, + LockKeys: make(map[string]struct{}, 0), + TransactionMode: Local, + LocalTransID: uuid.New().String(), + RoundImages: &RoundRecordImage{}, } } // HasUndoLog func (t *TransactionContext) HasUndoLog() bool { - return t.TxType == ATMode && !t.RoundImages.IsEmpty() + return t.TransactionMode == ATMode && !t.RoundImages.IsEmpty() } // HasLockKey @@ -184,8 +182,8 @@ func (t *TransactionContext) HasLockKey() bool { return len(t.LockKeys) != 0 } -func (t *TransactionContext) OpenGlobalTrsnaction() bool { - return t.TxType != Local +func (t *TransactionContext) OpenGlobalTransaction() bool { + return t.TransactionMode != Local } func (t *TransactionContext) IsBranchRegistered() bool {