Skip to content

Commit

Permalink
*: Make code cleaner for binary execute
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao committed Jul 4, 2022
1 parent fe4cb85 commit 08403f0
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 58 deletions.
7 changes: 4 additions & 3 deletions executor/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func (a ExecStmt) GetStmtNode() ast.StmtNode {
}

// PointGet short path for point exec directly from plan, keep only necessary steps
func (a *ExecStmt) PointGet(ctx context.Context, is infoschema.InfoSchema) (*recordSet, error) {
func (a *ExecStmt) PointGet(ctx context.Context) (*recordSet, error) {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span1 := span.Tracer().StartSpan("ExecStmt.PointGet", opentracing.ChildOf(span.Context()))
span1.LogKV("sql", a.OriginText())
Expand All @@ -238,7 +238,7 @@ func (a *ExecStmt) PointGet(ctx context.Context, is infoschema.InfoSchema) (*rec
sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerInShortPointGetPlan", true)
// stale read should not reach here
staleread.AssertStmtStaleness(a.Ctx, false)
sessiontxn.AssertTxnManagerInfoSchema(a.Ctx, is)
sessiontxn.AssertTxnManagerInfoSchema(a.Ctx, a.InfoSchema)
})

ctx = a.observeStmtBeginForTopSQL(ctx)
Expand All @@ -262,7 +262,7 @@ func (a *ExecStmt) PointGet(ctx context.Context, is infoschema.InfoSchema) (*rec
}
}
if a.PsStmt.Executor == nil {
b := newExecutorBuilder(a.Ctx, is, a.Ti)
b := newExecutorBuilder(a.Ctx, a.InfoSchema, a.Ti)
newExecutor := b.build(a.Plan)
if b.err != nil {
return nil, b.err
Expand Down Expand Up @@ -315,6 +315,7 @@ func (a *ExecStmt) RebuildPlan(ctx context.Context) (int64, error) {
sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerInRebuildPlan", true)
sessiontxn.AssertTxnManagerInfoSchema(a.Ctx, ret.InfoSchema)
staleread.AssertStmtStaleness(a.Ctx, ret.IsStaleness)
sessiontxn.AssertTxnManagerReadTS(a.Ctx, ret.LastSnapshotTS)
})

a.InfoSchema = sessiontxn.GetTxnManager(a.Ctx).GetTxnInfoSchema()
Expand Down
13 changes: 2 additions & 11 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessiontxn"
"github.com/pingcap/tidb/sessiontxn/staleread"
"github.com/pingcap/tidb/types"
driver "github.com/pingcap/tidb/types/parser_driver"
"github.com/pingcap/tidb/util"
Expand Down Expand Up @@ -332,27 +331,19 @@ func (e *DeallocateExec) Next(ctx context.Context, req *chunk.Chunk) error {
}

// CompileExecutePreparedStmt compiles a session Execute command to a stmt.Statement.
func CompileExecutePreparedStmt(ctx context.Context, sctx sessionctx.Context,
execStmt *ast.ExecuteStmt, is infoschema.InfoSchema, snapshotTS uint64, replicaReadScope string, args []types.Datum) (*ExecStmt, bool, bool, error) {
func CompileExecutePreparedStmt(ctx context.Context, sctx sessionctx.Context, execStmt *ast.ExecuteStmt) (*ExecStmt, bool, bool, error) {
startTime := time.Now()
defer func() {
sctx.GetSessionVars().DurationCompile = time.Since(startTime)
}()
isStaleness := snapshotTS != 0
sctx.GetSessionVars().StmtCtx.IsStaleness = isStaleness
execStmt.BinaryArgs = args
is := sessiontxn.GetTxnManager(sctx).GetTxnInfoSchema()
execPlan, names, err := planner.Optimize(ctx, sctx, execStmt, is)
if err != nil {
return nil, false, false, err
}

failpoint.Inject("assertTxnManagerInCompile", func() {
sessiontxn.RecordAssert(sctx, "assertTxnManagerInCompile", true)
sessiontxn.AssertTxnManagerInfoSchema(sctx, is)
staleread.AssertStmtStaleness(sctx, snapshotTS != 0)
if snapshotTS != 0 {
sessiontxn.AssertTxnManagerReadTS(sctx, snapshotTS)
}
})

stmt := &ExecStmt{
Expand Down
7 changes: 2 additions & 5 deletions executor/seqtest/prepared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import (
"time"

"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
Expand Down Expand Up @@ -158,10 +156,9 @@ func TestPrepared(t *testing.T) {
require.NoError(t, err)
tk.ResultSetToResult(rs, fmt.Sprintf("%v", rs)).Check(testkit.Rows())

execStmt := &ast.ExecuteStmt{ExecID: stmtID}
execStmt := &ast.ExecuteStmt{ExecID: stmtID, BinaryArgs: []types.Datum{types.NewDatum(1)}}
// Check that ast.Statement created by executor.CompileExecutePreparedStmt has query text.
stmt, _, _, err := executor.CompileExecutePreparedStmt(context.TODO(), tk.Session(), execStmt,
tk.Session().GetInfoSchema().(infoschema.InfoSchema), 0, kv.GlobalReplicaScope, []types.Datum{types.NewDatum(1)})
stmt, _, _, err := executor.CompileExecutePreparedStmt(context.TODO(), tk.Session(), execStmt)
require.NoError(t, err)
require.Equal(t, query, stmt.OriginText())

Expand Down
6 changes: 2 additions & 4 deletions session/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/store/mockstore"
Expand Down Expand Up @@ -1810,12 +1809,11 @@ func BenchmarkCompileExecutePreparedStmt(b *testing.B) {
}

args := []types.Datum{types.NewDatum(3401544)}
is := se.GetInfoSchema()

b.ResetTimer()
stmtExec := &ast.ExecuteStmt{ExecID: stmtID}
stmtExec := &ast.ExecuteStmt{ExecID: stmtID, BinaryArgs: args}
for i := 0; i < b.N; i++ {
_, _, _, err := executor.CompileExecutePreparedStmt(context.Background(), se, stmtExec, is.(infoschema.InfoSchema), 0, kv.GlobalTxnScope, args)
_, _, _, err := executor.CompileExecutePreparedStmt(context.Background(), se, stmtExec)
if err != nil {
b.Fatal(err)
}
Expand Down
63 changes: 30 additions & 33 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2235,19 +2235,20 @@ func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields
return prepareExec.ID, prepareExec.ParamCount, prepareExec.Fields, nil
}

func (s *session) preparedStmtExec(ctx context.Context,
is infoschema.InfoSchema, snapshotTS uint64,
execStmt *ast.ExecuteStmt, prepareStmt *plannercore.CachedPrepareStmt, replicaReadScope string, args []types.Datum) (sqlexec.RecordSet, error) {

func (s *session) preparedStmtExec(ctx context.Context, execStmt *ast.ExecuteStmt, prepareStmt *plannercore.CachedPrepareStmt) (sqlexec.RecordSet, error) {
failpoint.Inject("assertTxnManagerInPreparedStmtExec", func() {
sessiontxn.RecordAssert(s, "assertTxnManagerInPreparedStmtExec", true)
sessiontxn.AssertTxnManagerInfoSchema(s, is)
if snapshotTS != 0 {
sessiontxn.AssertTxnManagerReadTS(s, snapshotTS)
if prepareStmt.SnapshotTSEvaluator != nil {
staleread.AssertStmtStaleness(s, true)
ts, err := prepareStmt.SnapshotTSEvaluator(s)
if err != nil {
panic(err)
}
sessiontxn.AssertTxnManagerReadTS(s, ts)
}
})

st, tiFlashPushDown, tiFlashExchangePushDown, err := executor.CompileExecutePreparedStmt(ctx, s, execStmt, is, snapshotTS, replicaReadScope, args)
st, tiFlashPushDown, tiFlashExchangePushDown, err := executor.CompileExecutePreparedStmt(ctx, s, execStmt)
if err != nil {
return nil, err
}
Expand All @@ -2267,18 +2268,17 @@ func (s *session) preparedStmtExec(ctx context.Context,

// cachedPointPlanExec is a short path currently ONLY for cached "point select plan" execution
func (s *session) cachedPointPlanExec(ctx context.Context,
is infoschema.InfoSchema, execAst *ast.ExecuteStmt, prepareStmt *plannercore.CachedPrepareStmt, replicaReadScope string, args []types.Datum) (sqlexec.RecordSet, bool, error) {
execAst *ast.ExecuteStmt, prepareStmt *plannercore.CachedPrepareStmt) (sqlexec.RecordSet, bool, error) {

prepared := prepareStmt.PreparedAst

failpoint.Inject("assertTxnManagerInCachedPlanExec", func() {
sessiontxn.RecordAssert(s, "assertTxnManagerInCachedPlanExec", true)
sessiontxn.AssertTxnManagerInfoSchema(s, is)
// stale read should not reach here
staleread.AssertStmtStaleness(s, false)
})

execAst.BinaryArgs = args
is := sessiontxn.GetTxnManager(s).GetTxnInfoSchema()
execPlan, err := planner.OptimizeExecStmt(ctx, s, execAst, is)
if err != nil {
return nil, false, err
Expand Down Expand Up @@ -2324,7 +2324,7 @@ func (s *session) cachedPointPlanExec(ctx context.Context,
var resultSet sqlexec.RecordSet
switch execPlan.(type) {
case *plannercore.PointGetPlan:
resultSet, err = stmt.PointGet(ctx, is)
resultSet, err = stmt.PointGet(ctx)
s.txn.changeToInvalid()
case *plannercore.Update:
stmtCtx.Priority = kv.PriorityHigh
Expand All @@ -2341,9 +2341,9 @@ func (s *session) cachedPointPlanExec(ctx context.Context,
// IsCachedExecOk check if we can execute using plan cached in prepared structure
// Be careful with the short path, current precondition is ths cached plan satisfying
// IsPointGetWithPKOrUniqueKeyByAutoCommit
func (s *session) IsCachedExecOk(ctx context.Context, preparedStmt *plannercore.CachedPrepareStmt, isStaleness bool) (bool, error) {
func (s *session) IsCachedExecOk(preparedStmt *plannercore.CachedPrepareStmt) (bool, error) {
prepared := preparedStmt.PreparedAst
if prepared.CachedPlan == nil || isStaleness {
if prepared.CachedPlan == nil || staleread.IsStmtStaleness(s) {
return false, nil
}
// check auto commit
Expand Down Expand Up @@ -2396,60 +2396,57 @@ func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args [
return nil, errors.Errorf("invalid CachedPrepareStmt type")
}

var snapshotTS uint64
replicaReadScope := oracle.GlobalTxnScope
execStmt := &ast.ExecuteStmt{ExecID: stmtID, BinaryArgs: args}
if err := executor.ResetContextOfStmt(s, execStmt); err != nil {
return nil, err
}

staleReadProcessor := staleread.NewStaleReadProcessor(s)
if err = staleReadProcessor.OnExecutePreparedStmt(preparedStmt.SnapshotTSEvaluator); err != nil {
return nil, err
}

txnManager := sessiontxn.GetTxnManager(s)
if staleReadProcessor.IsStaleness() {
snapshotTS = staleReadProcessor.GetStalenessReadTS()
is := staleReadProcessor.GetStalenessInfoSchema()
replicaReadScope = config.GetTxnScopeFromConfig()
err = txnManager.EnterNewTxn(ctx, &sessiontxn.EnterNewTxnRequest{
Type: sessiontxn.EnterNewTxnWithReplaceProvider,
Provider: staleread.NewStalenessTxnContextProvider(s, snapshotTS, is),
s.sessionVars.StmtCtx.IsStaleness = true
err = sessiontxn.GetTxnManager(s).EnterNewTxn(ctx, &sessiontxn.EnterNewTxnRequest{
Type: sessiontxn.EnterNewTxnWithReplaceProvider,
Provider: staleread.NewStalenessTxnContextProvider(
s,
staleReadProcessor.GetStalenessReadTS(),
staleReadProcessor.GetStalenessInfoSchema(),
),
})

if err != nil {
return nil, err
}
}

staleness := snapshotTS > 0
executor.CountStmtNode(preparedStmt.PreparedAst.Stmt, s.sessionVars.InRestrictedSQL)
ok, err = s.IsCachedExecOk(ctx, preparedStmt, staleness)
cacheExecOk, err := s.IsCachedExecOk(preparedStmt)
if err != nil {
return nil, err
}
s.txn.onStmtStart(preparedStmt.SQLDigest.String())
defer s.txn.onStmtEnd()

execStmt := &ast.ExecuteStmt{ExecID: stmtID}
if err := executor.ResetContextOfStmt(s, execStmt); err != nil {
return nil, err
}

if err = s.onTxnManagerStmtStartOrRetry(ctx, execStmt); err != nil {
return nil, err
}
s.setRequestSource(ctx, preparedStmt.PreparedAst.StmtType, preparedStmt.PreparedAst.Stmt)
// even the txn is valid, still need to set session variable for coprocessor usage.
s.sessionVars.RequestSourceType = preparedStmt.PreparedAst.StmtType

if ok {
rs, ok, err := s.cachedPointPlanExec(ctx, txnManager.GetTxnInfoSchema(), execStmt, preparedStmt, replicaReadScope, args)
if cacheExecOk {
rs, ok, err := s.cachedPointPlanExec(ctx, execStmt, preparedStmt)
if err != nil {
return nil, err
}
if ok { // fallback to preparedStmtExec if we cannot get a valid point select plan in cachedPointPlanExec
return rs, nil
}
}
return s.preparedStmtExec(ctx, txnManager.GetTxnInfoSchema(), snapshotTS, execStmt, preparedStmt, replicaReadScope, args)
return s.preparedStmtExec(ctx, execStmt, preparedStmt)
}

func (s *session) DropPreparedStmt(stmtID uint32) error {
Expand Down
30 changes: 28 additions & 2 deletions sessiontxn/txn_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -587,13 +587,13 @@ func TestTxnContextForPrepareExecute(t *testing.T) {
}

func TestTxnContextForStaleReadInPrepare(t *testing.T) {
store, do, deferFunc := setupTxnContextTest(t)
store, _, deferFunc := setupTxnContextTest(t)
defer deferFunc()
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
se := tk.Session()

is1 := do.InfoSchema()
is1 := se.GetDomainInfoSchema()
tk.MustExec("do sleep(0.1)")
tk.MustExec("set @a=now(6)")
tk.MustExec("prepare s1 from 'select * from t1 where id=1'")
Expand Down Expand Up @@ -660,6 +660,32 @@ func TestTxnContextForStaleReadInPrepare(t *testing.T) {
doWithCheckPath(t, se, normalPathRecords, func() {
tk.MustExec("execute s3")
})
se.SetValue(sessiontxn.AssertTxnInfoSchemaKey, nil)

// stale read should not use plan cache
is2 := se.GetDomainInfoSchema()
se.SetValue(sessiontxn.AssertTxnInfoSchemaKey, nil)
tk.MustExec("set @@tx_read_ts=''")
tk.MustExec("do sleep(0.1)")
tk.MustExec("set @b=now(6)")
tk.MustExec("do sleep(0.1)")
tk.MustExec("update t1 set v=v+1 where id=1")
se.SetValue(sessiontxn.AssertTxnInfoSchemaKey, is2)
doWithCheckPath(t, se, path, func() {
rs, err := se.ExecutePreparedStmt(context.TODO(), stmtID1, nil)
require.NoError(t, err)
tk.ResultSetToResult(rs, fmt.Sprintf("%v", rs)).Check(testkit.Rows("1 12"))
})
se.SetValue(sessiontxn.AssertTxnInfoSchemaKey, nil)
tk.MustExec("set @@tx_read_ts=@b")
se.SetValue(sessiontxn.AssertTxnInfoSchemaKey, is2)
doWithCheckPath(t, se, path, func() {
rs, err := se.ExecutePreparedStmt(context.TODO(), stmtID1, nil)
require.NoError(t, err)
tk.ResultSetToResult(rs, fmt.Sprintf("%v", rs)).Check(testkit.Rows("1 11"))
})
se.SetValue(sessiontxn.AssertTxnInfoSchemaKey, nil)
tk.MustExec("set @@tx_read_ts=''")
}

func TestTxnContextPreparedStmtWithForUpdate(t *testing.T) {
Expand Down

0 comments on commit 08403f0

Please sign in to comment.