Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: support decoding prepared string args to character_set_client #30723

Merged
merged 5 commits into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ type clientConn struct {
authPlugin string // default authentication plugin
isUnixSocket bool // connection is Unix Socket file
rsEncoder *resultEncoder // rsEncoder is used to encode the string result to different charsets.
inputDecoder *inputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8.
socketCredUID uint32 // UID from the other end of the Unix Socket
// mu is used for cancelling the execution of current transaction.
mu struct {
Expand Down Expand Up @@ -964,6 +965,15 @@ func (cc *clientConn) initResultEncoder(ctx context.Context) {
cc.rsEncoder = newResultEncoder(chs)
}

func (cc *clientConn) initInputEncoder(ctx context.Context) {
chs, err := variable.GetSessionOrGlobalSystemVar(cc.ctx.GetSessionVars(), variable.CharacterSetClient)
if err != nil {
chs = ""
logutil.Logger(ctx).Warn("get character_set_client system variable failed", zap.Error(err))
}
cc.inputDecoder = newInputDecoder(chs)
}

// initConnect runs the initConnect SQL statement if it has been specified.
// The semantics are MySQL compatible.
func (cc *clientConn) initConnect(ctx context.Context) error {
Expand Down
12 changes: 10 additions & 2 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/charset"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
plannercore "github.com/pingcap/tidb/planner/core"
Expand Down Expand Up @@ -167,6 +168,8 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
paramTypes []byte
paramValues []byte
)
cc.initInputEncoder(ctx)
defer cc.inputDecoder.clean()
numParams := stmt.NumParams()
args := make([]types.Datum, numParams)
if numParams > 0 {
Expand Down Expand Up @@ -194,7 +197,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
paramValues = data[pos+1:]
}

err = parseExecArgs(cc.ctx.GetSessionVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues)
err = parseExecArgs(cc.ctx.GetSessionVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder)
stmt.Reset()
if err != nil {
return errors.Annotate(err, cc.preparedStmt2String(stmtID))
Expand Down Expand Up @@ -310,14 +313,18 @@ func parseStmtFetchCmd(data []byte) (uint32, uint32, error) {
return stmtID, fetchSize, nil
}

func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte) (err error) {
func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams [][]byte,
nullBitmap, paramTypes, paramValues []byte, enc *inputDecoder) (err error) {
pos := 0
var (
tmp interface{}
v []byte
n int
isNull bool
)
if enc == nil {
tangenta marked this conversation as resolved.
Show resolved Hide resolved
enc = newInputDecoder(charset.CharsetUTF8)
}

for i := 0; i < len(args); i++ {
// if params had received via ComStmtSendLongData, use them directly.
Expand Down Expand Up @@ -543,6 +550,7 @@ func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams
}

if !isNull {
v = enc.decodeInput(v)
tmp = string(hack.String(v))
} else {
tmp = nil
Expand Down
15 changes: 14 additions & 1 deletion server/conn_stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,25 @@ func TestParseExecArgs(t *testing.T) {
},
}
for _, tt := range tests {
err := parseExecArgs(&stmtctx.StatementContext{}, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues)
err := parseExecArgs(&stmtctx.StatementContext{}, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues, nil)
require.Truef(t, terror.ErrorEqual(err, tt.err), "err %v", err)
require.Equal(t, tt.expect, tt.args.args[0].GetValue())
}
}

func TestParseExecArgsAndEncode(t *testing.T) {
dt := make([]types.Datum, 1)
err := parseExecArgs(&stmtctx.StatementContext{},
dt,
[][]byte{nil},
[]byte{0x0},
[]byte{mysql.TypeVarchar, 0},
[]byte{4, 178, 226, 202, 212},
newInputDecoder("gbk"))
require.NoError(t, err)
require.Equal(t, "测试", dt[0].GetValue())
}

func TestParseStmtFetchCmd(t *testing.T) {
tests := []struct {
arg []byte
Expand Down
26 changes: 26 additions & 0 deletions server/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,32 @@ func dumpBinaryRow(buffer []byte, columns []*ColumnInfo, row chunk.Row, d *resul
return buffer, nil
}

type inputDecoder struct {
encoding *charset.Encoding

buffer []byte
}

func newInputDecoder(chs string) *inputDecoder {
return &inputDecoder{
encoding: charset.NewEncoding(chs),
buffer: nil,
}
}

// clean prevents the inputDecoder from holding too much memory.
func (i *inputDecoder) clean() {
i.buffer = nil
}

func (i *inputDecoder) decodeInput(src []byte) []byte {
result, err := i.encoding.Decode(i.buffer, src)
if err != nil {
return src
}
return result
}

type resultEncoder struct {
// chsName and encoding are unchanged after the initialization from
// session variable @@character_set_results.
Expand Down