From 97140d3fb4d49af9f31e2e3608852e1b944caf01 Mon Sep 17 00:00:00 2001 From: tangenta Date: Wed, 15 Dec 2021 12:14:34 +0800 Subject: [PATCH] server: support decoding prepared string args to character_set_client (#30723) --- server/conn.go | 10 ++++++++++ server/conn_stmt.go | 12 ++++++++++-- server/conn_stmt_test.go | 15 ++++++++++++++- server/util.go | 26 ++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 3 deletions(-) diff --git a/server/conn.go b/server/conn.go index 49a42fe54bb94..cf6cf3b9c1b84 100644 --- a/server/conn.go +++ b/server/conn.go @@ -189,6 +189,7 @@ type clientConn struct { authPlugin string // default authentication plugin isUnixSocket bool // connection is Unix Socket file rsEncoder *resultEncoder // rsEncoder is used to encode the string result to different charsets. + inputDecoder *inputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8. socketCredUID uint32 // UID from the other end of the Unix Socket // mu is used for cancelling the execution of current transaction. mu struct { @@ -931,6 +932,15 @@ func (cc *clientConn) initResultEncoder(ctx context.Context) { cc.rsEncoder = newResultEncoder(chs) } +func (cc *clientConn) initInputEncoder(ctx context.Context) { + chs, err := variable.GetSessionOrGlobalSystemVar(cc.ctx.GetSessionVars(), variable.CharacterSetClient) + if err != nil { + chs = "" + logutil.Logger(ctx).Warn("get character_set_client system variable failed", zap.Error(err)) + } + cc.inputDecoder = newInputDecoder(chs) +} + // initConnect runs the initConnect SQL statement if it has been specified. // The semantics are MySQL compatible. func (cc *clientConn) initConnect(ctx context.Context) error { diff --git a/server/conn_stmt.go b/server/conn_stmt.go index 980de55c6c896..07e4699d52738 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -46,6 +46,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/charset" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" plannercore "github.com/pingcap/tidb/planner/core" @@ -167,6 +168,8 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e paramTypes []byte paramValues []byte ) + cc.initInputEncoder(ctx) + defer cc.inputDecoder.clean() numParams := stmt.NumParams() args := make([]types.Datum, numParams) if numParams > 0 { @@ -194,7 +197,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e paramValues = data[pos+1:] } - err = parseExecArgs(cc.ctx.GetSessionVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues) + err = parseExecArgs(cc.ctx.GetSessionVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder) stmt.Reset() if err != nil { return errors.Annotate(err, cc.preparedStmt2String(stmtID)) @@ -310,7 +313,8 @@ func parseStmtFetchCmd(data []byte) (uint32, uint32, error) { return stmtID, fetchSize, nil } -func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte) (err error) { +func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams [][]byte, + nullBitmap, paramTypes, paramValues []byte, enc *inputDecoder) (err error) { pos := 0 var ( tmp interface{} @@ -318,6 +322,9 @@ func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams n int isNull bool ) + if enc == nil { + enc = newInputDecoder(charset.CharsetUTF8) + } for i := 0; i < len(args); i++ { // if params had received via ComStmtSendLongData, use them directly. @@ -543,6 +550,7 @@ func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams } if !isNull { + v = enc.decodeInput(v) tmp = string(hack.String(v)) } else { tmp = nil diff --git a/server/conn_stmt_test.go b/server/conn_stmt_test.go index 512093e85098a..cd63aea7e66bf 100644 --- a/server/conn_stmt_test.go +++ b/server/conn_stmt_test.go @@ -197,12 +197,25 @@ func TestParseExecArgs(t *testing.T) { }, } for _, tt := range tests { - err := parseExecArgs(&stmtctx.StatementContext{}, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues) + err := parseExecArgs(&stmtctx.StatementContext{}, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues, nil) require.Truef(t, terror.ErrorEqual(err, tt.err), "err %v", err) require.Equal(t, tt.expect, tt.args.args[0].GetValue()) } } +func TestParseExecArgsAndEncode(t *testing.T) { + dt := make([]types.Datum, 1) + err := parseExecArgs(&stmtctx.StatementContext{}, + dt, + [][]byte{nil}, + []byte{0x0}, + []byte{mysql.TypeVarchar, 0}, + []byte{4, 178, 226, 202, 212}, + newInputDecoder("gbk")) + require.NoError(t, err) + require.Equal(t, "测试", dt[0].GetValue()) +} + func TestParseStmtFetchCmd(t *testing.T) { tests := []struct { arg []byte diff --git a/server/util.go b/server/util.go index 6a8cbad8386c5..4f3b8b04f29a8 100644 --- a/server/util.go +++ b/server/util.go @@ -290,6 +290,32 @@ func dumpBinaryRow(buffer []byte, columns []*ColumnInfo, row chunk.Row, d *resul return buffer, nil } +type inputDecoder struct { + encoding *charset.Encoding + + buffer []byte +} + +func newInputDecoder(chs string) *inputDecoder { + return &inputDecoder{ + encoding: charset.NewEncoding(chs), + buffer: nil, + } +} + +// clean prevents the inputDecoder from holding too much memory. +func (i *inputDecoder) clean() { + i.buffer = nil +} + +func (i *inputDecoder) decodeInput(src []byte) []byte { + result, err := i.encoding.Decode(i.buffer, src) + if err != nil { + return src + } + return result +} + type resultEncoder struct { // chsName and encoding are unchanged after the initialization from // session variable @@character_set_results.