From ddb7b36cb6e37206d4e7f810c2b4723fd58a4500 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Fri, 18 Aug 2023 15:36:39 +0800 Subject: [PATCH] *: Add system variable `tidb_session_alias` to log a custom field `session_alias` in session log (#46072) close pingcap/tidb#46071 --- parser/model/model.go | 8 ++++ planner/core/preprocess.go | 52 +++++++++------------- server/BUILD.bazel | 1 + server/conn.go | 18 +++++++- session/test/variable/BUILD.bazel | 2 +- session/test/variable/variable_test.go | 21 +++++++++ sessionctx/variable/session.go | 3 ++ sessionctx/variable/sysvar.go | 20 +++++++++ sessionctx/variable/tidb_vars.go | 2 + util/logutil/BUILD.bazel | 3 ++ util/logutil/log.go | 60 ++++++++++++++++++-------- util/logutil/log_test.go | 40 ++++++++++++++++- util/logutil/main_test.go | 2 + util/tracing/BUILD.bazel | 2 + util/tracing/util.go | 22 ++++++++++ util/tracing/util_test.go | 20 +++++++++ util/util.go | 12 ++++++ util/util_test.go | 18 ++++++++ 18 files changed, 251 insertions(+), 55 deletions(-) diff --git a/parser/model/model.go b/parser/model/model.go index 64cda27eb4d53..1402afabd7983 100644 --- a/parser/model/model.go +++ b/parser/model/model.go @@ -2167,3 +2167,11 @@ func (s WindowRepeatType) String() string { return "" } } + +// TraceInfo is the information for trace. +type TraceInfo struct { + // ConnectionID is the id of the connection + ConnectionID uint64 `json:"connection_id"` + // SessionAlias is the alias of session + SessionAlias string `json:"session_alias"` +} diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 7b4cdb58c1aba..b231939ce3081 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -785,32 +785,32 @@ func (p *preprocessor) checkSetOprSelectList(stmt *ast.SetOprSelectList) { } func (p *preprocessor) checkCreateDatabaseGrammar(stmt *ast.CreateDatabaseStmt) { - if isIncorrectName(stmt.Name.L) { + if util.IsInCorrectIdentifierName(stmt.Name.L) { p.err = dbterror.ErrWrongDBName.GenWithStackByArgs(stmt.Name) } } func (p *preprocessor) checkAlterDatabaseGrammar(stmt *ast.AlterDatabaseStmt) { // for 'ALTER DATABASE' statement, database name can be empty to alter default database. - if isIncorrectName(stmt.Name.L) && !stmt.AlterDefaultDatabase { + if util.IsInCorrectIdentifierName(stmt.Name.L) && !stmt.AlterDefaultDatabase { p.err = dbterror.ErrWrongDBName.GenWithStackByArgs(stmt.Name) } } func (p *preprocessor) checkDropDatabaseGrammar(stmt *ast.DropDatabaseStmt) { - if isIncorrectName(stmt.Name.L) { + if util.IsInCorrectIdentifierName(stmt.Name.L) { p.err = dbterror.ErrWrongDBName.GenWithStackByArgs(stmt.Name) } } func (p *preprocessor) checkFlashbackTableGrammar(stmt *ast.FlashBackTableStmt) { - if isIncorrectName(stmt.NewName) { + if util.IsInCorrectIdentifierName(stmt.NewName) { p.err = dbterror.ErrWrongTableName.GenWithStackByArgs(stmt.NewName) } } func (p *preprocessor) checkFlashbackDatabaseGrammar(stmt *ast.FlashBackDatabaseStmt) { - if isIncorrectName(stmt.NewName) { + if util.IsInCorrectIdentifierName(stmt.NewName) { p.err = dbterror.ErrWrongDBName.GenWithStackByArgs(stmt.NewName) } } @@ -874,7 +874,7 @@ func (p *preprocessor) checkCreateTableGrammar(stmt *ast.CreateTableStmt) { } } tName := stmt.Table.Name.String() - if isIncorrectName(tName) { + if util.IsInCorrectIdentifierName(tName) { p.err = dbterror.ErrWrongTableName.GenWithStackByArgs(tName) return } @@ -938,7 +938,7 @@ func (p *preprocessor) checkCreateTableGrammar(stmt *ast.CreateTableStmt) { if stmt.Partition != nil { for _, def := range stmt.Partition.Definitions { pName := def.Name.String() - if isIncorrectName(pName) { + if util.IsInCorrectIdentifierName(pName) { p.err = dbterror.ErrWrongPartitionName.GenWithStackByArgs() return } @@ -948,12 +948,12 @@ func (p *preprocessor) checkCreateTableGrammar(stmt *ast.CreateTableStmt) { func (p *preprocessor) checkCreateViewGrammar(stmt *ast.CreateViewStmt) { vName := stmt.ViewName.Name.String() - if isIncorrectName(vName) { + if util.IsInCorrectIdentifierName(vName) { p.err = dbterror.ErrWrongTableName.GenWithStackByArgs(vName) return } for _, col := range stmt.Cols { - if isIncorrectName(col.String()) { + if util.IsInCorrectIdentifierName(col.String()) { p.err = dbterror.ErrWrongColumnName.GenWithStackByArgs(col) return } @@ -1014,7 +1014,7 @@ func (p *preprocessor) checkDropTableGrammar(stmt *ast.DropTableStmt) { func (p *preprocessor) checkDropTemporaryTableGrammar(stmt *ast.DropTableStmt) { currentDB := model.NewCIStr(p.sctx.GetSessionVars().CurrentDB) for _, t := range stmt.Tables { - if isIncorrectName(t.Name.String()) { + if util.IsInCorrectIdentifierName(t.Name.String()) { p.err = dbterror.ErrWrongTableName.GenWithStackByArgs(t.Name.String()) return } @@ -1045,7 +1045,7 @@ func (p *preprocessor) checkDropTemporaryTableGrammar(stmt *ast.DropTableStmt) { func (p *preprocessor) checkDropTableNames(tables []*ast.TableName) { for _, t := range tables { - if isIncorrectName(t.Name.String()) { + if util.IsInCorrectIdentifierName(t.Name.String()) { p.err = dbterror.ErrWrongTableName.GenWithStackByArgs(t.Name.String()) return } @@ -1118,7 +1118,7 @@ func checkColumnOptions(isTempTable bool, ops []*ast.ColumnOption) (int, error) func (p *preprocessor) checkCreateIndexGrammar(stmt *ast.CreateIndexStmt) { tName := stmt.Table.Name.String() - if isIncorrectName(tName) { + if util.IsInCorrectIdentifierName(tName) { p.err = dbterror.ErrWrongTableName.GenWithStackByArgs(tName) return } @@ -1152,12 +1152,12 @@ func (p *preprocessor) checkRenameTableGrammar(stmt *ast.RenameTableStmt) { } func (p *preprocessor) checkRenameTable(oldTable, newTable string) { - if isIncorrectName(oldTable) { + if util.IsInCorrectIdentifierName(oldTable) { p.err = dbterror.ErrWrongTableName.GenWithStackByArgs(oldTable) return } - if isIncorrectName(newTable) { + if util.IsInCorrectIdentifierName(newTable) { p.err = dbterror.ErrWrongTableName.GenWithStackByArgs(newTable) return } @@ -1182,7 +1182,7 @@ func (p *preprocessor) checkRepairTableGrammar(stmt *ast.RepairTableStmt) { func (p *preprocessor) checkAlterTableGrammar(stmt *ast.AlterTableStmt) { tName := stmt.Table.Name.String() - if isIncorrectName(tName) { + if util.IsInCorrectIdentifierName(tName) { p.err = dbterror.ErrWrongTableName.GenWithStackByArgs(tName) return } @@ -1190,7 +1190,7 @@ func (p *preprocessor) checkAlterTableGrammar(stmt *ast.AlterTableStmt) { for _, spec := range specs { if spec.NewTable != nil { ntName := spec.NewTable.Name.String() - if isIncorrectName(ntName) { + if util.IsInCorrectIdentifierName(ntName) { p.err = dbterror.ErrWrongTableName.GenWithStackByArgs(ntName) return } @@ -1217,7 +1217,7 @@ func (p *preprocessor) checkAlterTableGrammar(stmt *ast.AlterTableStmt) { } case ast.AlterTableAddStatistics, ast.AlterTableDropStatistics: statsName := spec.Statistics.StatsName - if isIncorrectName(statsName) { + if util.IsInCorrectIdentifierName(statsName) { msg := fmt.Sprintf("Incorrect statistics name: %s", statsName) p.err = ErrInternal.GenWithStack(msg) return @@ -1225,7 +1225,7 @@ func (p *preprocessor) checkAlterTableGrammar(stmt *ast.AlterTableStmt) { case ast.AlterTableAddPartitions: for _, def := range spec.PartDefinitions { pName := def.Name.String() - if isIncorrectName(pName) { + if util.IsInCorrectIdentifierName(pName) { p.err = dbterror.ErrWrongPartitionName.GenWithStackByArgs() return } @@ -1334,7 +1334,7 @@ func checkReferInfoForTemporaryTable(tableMetaInfo *model.TableInfo) error { func checkColumn(colDef *ast.ColumnDef) error { // Check column name. cName := colDef.Name.Name.String() - if isIncorrectName(cName) { + if util.IsInCorrectIdentifierName(cName) { return dbterror.ErrWrongColumnName.GenWithStackByArgs(cName) } @@ -1457,18 +1457,6 @@ func isInvalidDefaultValue(colDef *ast.ColumnDef) bool { return false } -// isIncorrectName checks if the identifier is incorrect. -// See https://dev.mysql.com/doc/refman/5.7/en/identifiers.html -func isIncorrectName(name string) bool { - if len(name) == 0 { - return true - } - if name[len(name)-1] == ' ' { - return true - } - return false -} - // checkContainDotColumn checks field contains the table name. // for example :create table t (c1.c2 int default null). func (p *preprocessor) checkContainDotColumn(stmt *ast.CreateTableStmt) { @@ -1683,7 +1671,7 @@ func (p *preprocessor) resolveAlterTableStmt(node *ast.AlterTableStmt) { func (p *preprocessor) resolveCreateSequenceStmt(stmt *ast.CreateSequenceStmt) { sName := stmt.Name.Name.String() - if isIncorrectName(sName) { + if util.IsInCorrectIdentifierName(sName) { p.err = dbterror.ErrWrongTableName.GenWithStackByArgs(sName) return } diff --git a/server/BUILD.bazel b/server/BUILD.bazel index 590485e673fe1..1b8e22be32874 100644 --- a/server/BUILD.bazel +++ b/server/BUILD.bazel @@ -36,6 +36,7 @@ go_library( "//parser/ast", "//parser/auth", "//parser/charset", + "//parser/model", "//parser/mysql", "//parser/terror", "//planner/core", diff --git a/server/conn.go b/server/conn.go index 94d39eda35656..0d5ce53e3b137 100644 --- a/server/conn.go +++ b/server/conn.go @@ -67,6 +67,7 @@ import ( "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" plannercore "github.com/pingcap/tidb/planner/core" @@ -977,17 +978,30 @@ func (cc *clientConn) Run(ctx context.Context) { close(cc.quit) }() + parentCtx := ctx + var traceInfo *model.TraceInfo // Usually, client connection status changes between [dispatching] <=> [reading]. // When some event happens, server may notify this client connection by setting // the status to special values, for example: kill or graceful shutdown. // The client connection would detect the events when it fails to change status // by CAS operation, it would then take some actions accordingly. for { + sessVars := cc.ctx.GetSessionVars() + if alias := sessVars.SessionAlias; traceInfo == nil || traceInfo.SessionAlias != alias { + // We should reset the context trace info when traceInfo not inited or session alias changed. + traceInfo = &model.TraceInfo{ + ConnectionID: cc.connectionID, + SessionAlias: alias, + } + ctx = logutil.WithSessionAlias(parentCtx, sessVars.SessionAlias) + ctx = tracing.ContextWithTraceInfo(ctx, traceInfo) + } + // Close connection between txn when we are going to shutdown server. // Note the current implementation when shutting down, for an idle connection, the connection may block at readPacket() // consider provider a way to close the connection directly after sometime if we can not read any data. if cc.server.inShutdownMode.Load() { - if !cc.ctx.GetSessionVars().InTxn() { + if !sessVars.InTxn() { return } } @@ -1216,7 +1230,7 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { defer task.End() trace.Log(ctx, "sql", lc.String()) - ctx = logutil.WithTraceLogger(ctx, cc.connectionID) + ctx = logutil.WithTraceLogger(ctx, tracing.TraceInfoFromContext(ctx)) taskID := *(*uint64)(unsafe.Pointer(task)) ctx = pprof.WithLabels(ctx, pprof.Labels("trace", strconv.FormatUint(taskID, 10))) diff --git a/session/test/variable/BUILD.bazel b/session/test/variable/BUILD.bazel index 9c0c49db0ceec..2edfbf857341e 100644 --- a/session/test/variable/BUILD.bazel +++ b/session/test/variable/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "variable_test.go", ], flaky = True, - shard_count = 21, + shard_count = 22, deps = [ "//config", "//kv", diff --git a/session/test/variable/variable_test.go b/session/test/variable/variable_test.go index 6c9624f629d95..00a5089a28704 100644 --- a/session/test/variable/variable_test.go +++ b/session/test/variable/variable_test.go @@ -608,3 +608,24 @@ func TestSysdateIsNow(t *testing.T) { tk.MustQuery("show variables like '%tidb_sysdate_is_now%'").Check(testkit.Rows("tidb_sysdate_is_now ON")) require.True(t, tk.Session().GetSessionVars().SysdateIsNow) } + +func TestSessionAlias(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustQuery("select @@tidb_session_alias").Check(testkit.Rows("")) + // normal set + tk.MustExec("set @@tidb_session_alias='alias123'") + tk.MustQuery("select @@tidb_session_alias").Check(testkit.Rows("alias123")) + // set a long value + val := "0123456789012345678901234567890123456789012345678901234567890123456789" + tk.MustExec("set @@tidb_session_alias=?", val) + tk.MustQuery("select @@tidb_session_alias").Check(testkit.Rows(val[:64])) + // an invalid value + err := tk.ExecToErr("set @@tidb_session_alias='abc '") + require.EqualError(t, err, "[variable:1231]Incorrect value for variable @@tidb_session_alias 'abc '") + tk.MustQuery("select @@tidb_session_alias").Check(testkit.Rows(val[:64])) + // reset to empty + tk.MustExec("set @@tidb_session_alias=''") + tk.MustQuery("select @@tidb_session_alias").Check(testkit.Rows("")) +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index f9eb3714dee67..99018080e4c75 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -1534,6 +1534,9 @@ type SessionVars struct { // When set to true, skip missing partition stats and continue to merge other partition stats to global stats. // When set to false, give up merging partition stats to global stats. SkipMissingPartitionStats bool + + // SessionAlias is the identifier of the session + SessionAlias string } // GetOptimizerFixControlMap returns the specified value of the optimizer fix control. diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index bcf1f4efe8337..eeb51efbe9a57 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -38,6 +38,7 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" _ "github.com/pingcap/tidb/types/parser_driver" // for parser driver + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/gctuner" "github.com/pingcap/tidb/util/logutil" @@ -2792,6 +2793,25 @@ var defaultSysVars = []*SysVar{ }, GetGlobal: func(ctx context.Context, vars *SessionVars) (string, error) { return BoolToOnOff(EnableCheckConstraint.Load()), nil }}, + {Scope: ScopeSession, Name: TiDBSessionAlias, Value: "", Type: TypeStr, + Validation: func(s *SessionVars, normalizedValue string, originalValue string, _ ScopeFlag) (string, error) { + if len(normalizedValue) > 64 { + s.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenWithStackByArgs(TiDBSessionAlias, originalValue)) + normalizedValue = normalizedValue[:64] + } + + if len(normalizedValue) > 0 && util.IsInCorrectIdentifierName(normalizedValue) { + return "", ErrWrongValueForVar.GenWithStack("Incorrect value for variable @@%s '%s'", TiDBSessionAlias, normalizedValue) + } + + return normalizedValue, nil + }, + SetSession: func(vars *SessionVars, s string) error { + vars.SessionAlias = s + return nil + }, GetSession: func(vars *SessionVars) (string, error) { + return vars.SessionAlias, nil + }}, } func setTiFlashComputeDispatchPolicy(s *SessionVars, val string) error { diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 89159050c2f5e..381c86c39daa6 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -1083,6 +1083,8 @@ const ( // When set to true, skip missing partition stats and continue to merge other partition stats to global stats. // When set to false, give up merging partition stats to global stats. TiDBSkipMissingPartitionStats = "tidb_skip_missing_partition_stats" + // TiDBSessionAlias indicates the alias of a session which is used for tracing. + TiDBSessionAlias = "tidb_session_alias" ) // TiDB intentional limits diff --git a/util/logutil/BUILD.bazel b/util/logutil/BUILD.bazel index 5b3bf9d7a8da1..c032019d3c608 100644 --- a/util/logutil/BUILD.bazel +++ b/util/logutil/BUILD.bazel @@ -10,6 +10,7 @@ go_library( importpath = "github.com/pingcap/tidb/util/logutil", visibility = ["//visibility:public"], deps = [ + "//parser/model", "@com_github_golang_protobuf//proto", "@com_github_grpc_ecosystem_go_grpc_middleware//logging/zap", "@com_github_opentracing_opentracing_go//:opentracing-go", @@ -35,7 +36,9 @@ go_test( flaky = True, deps = [ "//kv", + "//parser/model", "//testkit/testsetup", + "@com_github_google_uuid//:uuid", "@com_github_pingcap_kvproto//pkg/metapb", "@com_github_pingcap_log//:log", "@com_github_stretchr_testify//require", diff --git a/util/logutil/log.go b/util/logutil/log.go index 0e48e5ce64678..b01ea86987fd9 100644 --- a/util/logutil/log.go +++ b/util/logutil/log.go @@ -27,6 +27,7 @@ import ( tlog "github.com/opentracing/opentracing-go/log" "github.com/pingcap/errors" "github.com/pingcap/log" + "github.com/pingcap/tidb/parser/model" "github.com/tikv/client-go/v2/tikv" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -201,43 +202,56 @@ func BgLogger() *zap.Logger { // WithConnID attaches connId to context. func WithConnID(ctx context.Context, connID uint64) context.Context { - var logger *zap.Logger - if ctxLogger, ok := ctx.Value(CtxLogKey).(*zap.Logger); ok { - logger = ctxLogger - } else { - logger = log.L() - } - return context.WithValue(ctx, CtxLogKey, logger.With(zap.Uint64("conn", connID))) + return WithFields(ctx, zap.Uint64("conn", connID)) +} + +// WithSessionAlias attaches session_alias to context +func WithSessionAlias(ctx context.Context, alias string) context.Context { + return WithFields(ctx, zap.String("session_alias", alias)) } // WithCategory attaches category to context. func WithCategory(ctx context.Context, category string) context.Context { - var logger *zap.Logger - if ctxLogger, ok := ctx.Value(CtxLogKey).(*zap.Logger); ok { - logger = ctxLogger - } else { - logger = log.L() + return WithFields(ctx, zap.String("category", category)) +} + +func fieldsFromTraceInfo(info *model.TraceInfo) []zap.Field { + if info == nil { + return nil + } + + fields := make([]zap.Field, 0, 2) + if info.ConnectionID != 0 { + fields = append(fields, zap.Uint64("conn", info.ConnectionID)) } - return context.WithValue(ctx, CtxLogKey, logger.With(zap.String("category", category))) + + if info.SessionAlias != "" { + fields = append(fields, zap.String("session_alias", info.SessionAlias)) + } + + return fields } // WithTraceLogger attaches trace identifier to context -func WithTraceLogger(ctx context.Context, connID uint64) context.Context { +func WithTraceLogger(ctx context.Context, info *model.TraceInfo) context.Context { var logger *zap.Logger if ctxLogger, ok := ctx.Value(CtxLogKey).(*zap.Logger); ok { logger = ctxLogger } else { logger = log.L() } - return context.WithValue(ctx, CtxLogKey, wrapTraceLogger(ctx, connID, logger)) + return context.WithValue(ctx, CtxLogKey, wrapTraceLogger(ctx, info, logger)) } -func wrapTraceLogger(ctx context.Context, connID uint64, logger *zap.Logger) *zap.Logger { +func wrapTraceLogger(ctx context.Context, info *model.TraceInfo, logger *zap.Logger) *zap.Logger { return logger.WithOptions(zap.WrapCore(func(core zapcore.Core) zapcore.Core { tl := &traceLog{ctx: ctx} // cfg.Format == "", never return error enc, _ := log.NewTextEncoder(&log.Config{}) - traceCore := log.NewTextCore(enc, tl, tl).With([]zapcore.Field{zap.Uint64("conn", connID)}) + traceCore := log.NewTextCore(enc, tl, tl) + if fields := fieldsFromTraceInfo(info); len(fields) > 0 { + traceCore = traceCore.With(fields) + } return zapcore.NewTee(traceCore, core) })) } @@ -261,13 +275,23 @@ func (*traceLog) Sync() error { // WithKeyValue attaches key/value to context. func WithKeyValue(ctx context.Context, key, value string) context.Context { + return WithFields(ctx, zap.String(key, value)) +} + +// WithFields attaches key/value to context. +func WithFields(ctx context.Context, fields ...zap.Field) context.Context { var logger *zap.Logger if ctxLogger, ok := ctx.Value(CtxLogKey).(*zap.Logger); ok { logger = ctxLogger } else { logger = log.L() } - return context.WithValue(ctx, CtxLogKey, logger.With(zap.String(key, value))) + + if len(fields) > 0 { + logger = logger.With(fields...) + } + + return context.WithValue(ctx, CtxLogKey, logger) } // TraceEventKey presents the TraceEventKey in span log. diff --git a/util/logutil/log_test.go b/util/logutil/log_test.go index 1985ae43ed0e5..975c80d16200d 100644 --- a/util/logutil/log_test.go +++ b/util/logutil/log_test.go @@ -17,17 +17,37 @@ package logutil import ( "bufio" "context" + "fmt" "io" "os" "runtime" "testing" + "github.com/google/uuid" "github.com/pingcap/log" + "github.com/pingcap/tidb/parser/model" "github.com/stretchr/testify/require" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) +func TestFieldsFromTraceInfo(t *testing.T) { + fields := fieldsFromTraceInfo(nil) + require.Equal(t, 0, len(fields)) + + fields = fieldsFromTraceInfo(&model.TraceInfo{}) + require.Equal(t, 0, len(fields)) + + fields = fieldsFromTraceInfo(&model.TraceInfo{ConnectionID: 1}) + require.Equal(t, []zap.Field{zap.Uint64("conn", 1)}, fields) + + fields = fieldsFromTraceInfo(&model.TraceInfo{SessionAlias: "alias123"}) + require.Equal(t, []zap.Field{zap.String("session_alias", "alias123")}, fields) + + fields = fieldsFromTraceInfo(&model.TraceInfo{ConnectionID: 1, SessionAlias: "alias123"}) + require.Equal(t, []zap.Field{zap.Uint64("conn", 1), zap.String("session_alias", "alias123")}, fields) +} + func TestZapLoggerWithKeys(t *testing.T) { if runtime.GOOS == "windows" { // Skip this test on windows for two reason: @@ -36,7 +56,7 @@ func TestZapLoggerWithKeys(t *testing.T) { t.Skip("skip on windows") } - fileCfg := FileLogConfig{log.FileLogConfig{Filename: "zap_log", MaxSize: 4096}} + fileCfg := FileLogConfig{log.FileLogConfig{Filename: fmt.Sprintf("zap_log_%s", uuid.NewString()), MaxSize: 4096}} conf := NewLogConfig("info", DefaultLogFormat, "", fileCfg, false) err := InitLogger(conf) require.NoError(t, err) @@ -46,11 +66,27 @@ func TestZapLoggerWithKeys(t *testing.T) { err = os.Remove(fileCfg.Filename) require.NoError(t, err) + conf = NewLogConfig("info", DefaultLogFormat, "", fileCfg, false) + err = InitLogger(conf) + require.NoError(t, err) + ctx = WithConnID(context.Background(), connID) + ctx = WithSessionAlias(ctx, "alias123") + testZapLogger(ctx, t, fileCfg.Filename, zapLogWithTraceInfoPattern) + err = os.Remove(fileCfg.Filename) + require.NoError(t, err) + + err = InitLogger(conf) + require.NoError(t, err) + ctx1 := WithFields(context.Background(), zap.Int64("conn", 123), zap.String("session_alias", "alias456")) + testZapLogger(ctx1, t, fileCfg.Filename, zapLogWithTraceInfoPattern) + err = os.Remove(fileCfg.Filename) + require.NoError(t, err) + err = InitLogger(conf) require.NoError(t, err) key := "ctxKey" val := "ctxValue" - ctx1 := WithKeyValue(context.Background(), key, val) + ctx1 = WithKeyValue(context.Background(), key, val) testZapLogger(ctx1, t, fileCfg.Filename, zapLogWithKeyValPatternByCtx) err = os.Remove(fileCfg.Filename) require.NoError(t, err) diff --git a/util/logutil/main_test.go b/util/logutil/main_test.go index f2a143eee0c23..0f6b850e693a9 100644 --- a/util/logutil/main_test.go +++ b/util/logutil/main_test.go @@ -25,6 +25,8 @@ const ( // zapLogPatern is used to match the zap log format, such as the following log: // [2019/02/13 15:56:05.385 +08:00] [INFO] [log_test.go:167] ["info message"] [conn=conn1] ["str key"=val] ["int key"=123] zapLogWithConnIDPattern = `\[\d\d\d\d/\d\d/\d\d \d\d:\d\d:\d\d.\d\d\d\ (\+|-)\d\d:\d\d\] \[(FATAL|ERROR|WARN|INFO|DEBUG)\] \[([\w_%!$@.,+~-]+|\\.)+:\d+\] \[.*\] \[conn=.*\] (\[.*=.*\]).*\n` + // [2019/02/13 15:56:05.385 +08:00] [INFO] [log_test.go:167] ["info message"] [conn=conn1] [session_alias=alias] ["str key"=val] ["int key"=123] + zapLogWithTraceInfoPattern = `\[\d\d\d\d/\d\d/\d\d \d\d:\d\d:\d\d.\d\d\d\ (\+|-)\d\d:\d\d\] \[(FATAL|ERROR|WARN|INFO|DEBUG)\] \[([\w_%!$@.,+~-]+|\\.)+:\d+\] \[.*\] \[conn=.*\] \[session_alias=.*\] (\[.*=.*\]).*\n` // [2019/02/13 15:56:05.385 +08:00] [INFO] [log_test.go:167] ["info message"] [ctxKey=ctxKey1] ["str key"=val] ["int key"=123] zapLogWithKeyValPatternByCtx = `\[\d\d\d\d/\d\d/\d\d \d\d:\d\d:\d\d.\d\d\d\ (\+|-)\d\d:\d\d\] \[(FATAL|ERROR|WARN|INFO|DEBUG)\] \[([\w_%!$@.,+~-]+|\\.)+:\d+\] \[.*\] \[ctxKey=.*\] (\[.*=.*\]).*\n` // [2019/02/13 15:56:05.385 +08:00] [INFO] [log_test.go:167] ["info message"] [coreKey=coreKey1] ["str key"=val] ["int key"=123] diff --git a/util/tracing/BUILD.bazel b/util/tracing/BUILD.bazel index 9f6ccac4a590e..ea2cada504ca3 100644 --- a/util/tracing/BUILD.bazel +++ b/util/tracing/BUILD.bazel @@ -9,6 +9,7 @@ go_library( importpath = "github.com/pingcap/tidb/util/tracing", visibility = ["//visibility:public"], deps = [ + "//parser/model", "@com_github_opentracing_basictracer_go//:basictracer-go", "@com_github_opentracing_opentracing_go//:opentracing-go", ], @@ -26,6 +27,7 @@ go_test( embed = [":tracing"], flaky = True, deps = [ + "//parser/model", "//testkit/testsetup", "@com_github_opentracing_basictracer_go//:basictracer-go", "@com_github_opentracing_opentracing_go//:opentracing-go", diff --git a/util/tracing/util.go b/util/tracing/util.go index 924e2eb039f44..e26a15340cd96 100644 --- a/util/tracing/util.go +++ b/util/tracing/util.go @@ -20,11 +20,16 @@ import ( "github.com/opentracing/basictracer-go" "github.com/opentracing/opentracing-go" + "github.com/pingcap/tidb/parser/model" ) // TiDBTrace is set as Baggage on traces which are used for tidb tracing. const TiDBTrace = "tr" +type sqlTracingCtxKeyType struct{} + +var sqlTracingCtxKey = sqlTracingCtxKeyType{} + // A CallbackRecorder immediately invokes itself on received trace spans. type CallbackRecorder func(sp basictracer.RawSpan) @@ -110,3 +115,20 @@ func (r Region) End() { } r.Region.End() } + +// TraceInfoFromContext returns the `model.TraceInfo` in context +func TraceInfoFromContext(ctx context.Context) *model.TraceInfo { + val := ctx.Value(sqlTracingCtxKey) + if info, ok := val.(*model.TraceInfo); ok { + return info + } + return nil +} + +// ContextWithTraceInfo creates a new `model.TraceInfo` for context +func ContextWithTraceInfo(ctx context.Context, info *model.TraceInfo) context.Context { + if info == nil { + return ctx + } + return context.WithValue(ctx, sqlTracingCtxKey, info) +} diff --git a/util/tracing/util_test.go b/util/tracing/util_test.go index 119f2777017b9..3af7ac5b99fa4 100644 --- a/util/tracing/util_test.go +++ b/util/tracing/util_test.go @@ -20,6 +20,7 @@ import ( "github.com/opentracing/basictracer-go" "github.com/opentracing/opentracing-go" + "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/util/tracing" "github.com/stretchr/testify/require" ) @@ -130,3 +131,22 @@ func TestTreeRelationship(t *testing.T) { require.Equal(t, collectedSpans[1].Context.SpanID, collectedSpans[2].ParentSpanID) } } + +func TestTraceInfoFromContext(t *testing.T) { + ctx := context.Background() + // get info from a non-tracing context + require.Nil(t, tracing.TraceInfoFromContext(ctx)) + // ContextWithTraceInfo with a nil info will return the original context + require.Equal(t, ctx, tracing.ContextWithTraceInfo(ctx, nil)) + // create a context with trace info + ctx, cancel := context.WithCancel(context.WithValue(ctx, "val1", "a")) + ctx = tracing.ContextWithTraceInfo(ctx, &model.TraceInfo{ConnectionID: 12345, SessionAlias: "alias1"}) + // new context should have the same value as the original one + info := tracing.TraceInfoFromContext(ctx) + require.Equal(t, uint64(12345), info.ConnectionID) + require.Equal(t, "alias1", info.SessionAlias) + require.Equal(t, "a", ctx.Value("val1")) + require.NoError(t, ctx.Err()) + cancel() + require.Error(t, ctx.Err()) +} diff --git a/util/util.go b/util/util.go index 271fa32730784..69882cf3aed19 100644 --- a/util/util.go +++ b/util/util.go @@ -283,3 +283,15 @@ func ReadLines(reader *bufio.Reader, count int, maxLineSize int) ([][]byte, erro } return lines, nil } + +// IsInCorrectIdentifierName checks if the identifier is incorrect. +// See https://dev.mysql.com/doc/refman/5.7/en/identifiers.html +func IsInCorrectIdentifierName(name string) bool { + if len(name) == 0 { + return true + } + if name[len(name)-1] == ' ' { + return true + } + return false +} diff --git a/util/util_test.go b/util/util_test.go index 18f5aa8173060..71cbe8e18f84d 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -94,3 +94,21 @@ line3`)) require.Equal(t, io.EOF, err) require.Len(t, line, 0) } + +func TestIsInCorrectIdentifierName(t *testing.T) { + tests := []struct { + name string + input string + correct bool + }{ + {"Empty identifier", "", true}, + {"Ending space", "test ", true}, + {"Correct identifier", "test", false}, + {"Other correct Identifier", "aaa --\n\txyz", false}, + } + + for _, tc := range tests { + got := IsInCorrectIdentifierName(tc.input) + require.Equalf(t, tc.correct, got, "IsInCorrectIdentifierName(%v) != %v", tc.name, tc.correct) + } +}