Skip to content

Commit

Permalink
remove errBadConnNoWrite and markBadConn
Browse files Browse the repository at this point in the history
  • Loading branch information
methane committed May 18, 2024
1 parent af8d793 commit db0cc0e
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 104 deletions.
75 changes: 34 additions & 41 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,12 @@ func (mc *mysqlConn) handleParams() (err error) {
return
}

func (mc *mysqlConn) markBadConn(err error) error {
if mc == nil {
return err
}
if err != errBadConnNoWrite {
return err
}
return driver.ErrBadConn
}

func (mc *mysqlConn) Begin() (driver.Tx, error) {
return mc.begin(false)
}

func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
var q string
Expand All @@ -128,7 +117,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if err == nil {
return &mysqlTx{mc}, err
}
return nil, mc.markBadConn(err)
return nil, err
}

func (mc *mysqlConn) Close() (err error) {
Expand Down Expand Up @@ -177,7 +166,6 @@ func (mc *mysqlConn) error() error {

func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
Expand Down Expand Up @@ -218,8 +206,8 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
mc.log(err)
return "", ErrInvalidConn
mc.cleanup()
return "", err
}
buf = buf[:0]
argPos := 0
Expand Down Expand Up @@ -310,7 +298,6 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin

func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
Expand All @@ -330,15 +317,15 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
copied := mc.result
return &copied, err
}
return nil, mc.markBadConn(err)
return nil, err
}

// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
handleOk := mc.clearResult()
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err)
return err
}

// Read Result
Expand Down Expand Up @@ -370,7 +357,6 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
handleOk := mc.clearResult()

if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
Expand All @@ -384,33 +370,37 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
}
query = prepared
}

// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = handleOk.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if err != nil {
return nil, err
}

if resLen == 0 {
rows.rs.done = true
// Read Result
var resLen int
resLen, err = handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}

switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
rows := new(textRows)
rows.mc = mc

// Columns
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
if resLen == 0 {
rows.rs.done = true

switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
return nil, mc.markBadConn(err)

// Columns
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
}

// Gets the value of the given MySQL System Variable
Expand Down Expand Up @@ -465,7 +455,6 @@ func (mc *mysqlConn) finish() {
// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.Load() {
mc.log(ErrInvalidConn)
return driver.ErrBadConn
}

Expand All @@ -476,7 +465,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {

handleOk := mc.clearResult()
if err = mc.writeCommandPacket(comPing); err != nil {
return mc.markBadConn(err)
return err
}

return handleOk.readResultOK()
Expand Down Expand Up @@ -682,8 +671,12 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
return nil
}

var _ driver.SessionResetter = &mysqlConn{}

// IsValid implements driver.Validator interface
// (From Go 1.15)
func (mc *mysqlConn) IsValid() bool {
return !mc.closed.Load()
}

var _ driver.Validator = &mysqlConn{}
9 changes: 5 additions & 4 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,13 @@ func TestPingMarkBadConnection(t *testing.T) {
netConn: nc,
buf: newBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
}

err := mc.Ping(context.Background())

if err != driver.ErrBadConn {
t.Errorf("expected driver.ErrBadConn, got %#v", err)
if !errors.Is(err, nc.err) {
t.Errorf("expected %v, got %#v", nc.err, err)
}
}

Expand All @@ -184,8 +185,8 @@ func TestPingErrInvalidConn(t *testing.T) {

err := mc.Ping(context.Background())

if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %#v", err)
if !errors.Is(err, nc.err) {
t.Errorf("expected %v, got %#v", nc.err, err)
}
}

Expand Down
6 changes: 0 additions & 6 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ var (
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`")
ErrBusyBuffer = errors.New("busy buffer")

// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
// to trigger a resend.
// See https://github.com/go-sql-driver/mysql/pull/302
errBadConnNoWrite = errors.New("bad connection")
)

var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
Expand Down
96 changes: 46 additions & 50 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,37 +117,32 @@ func (mc *mysqlConn) writePacket(data []byte) error {
// Write packet
if mc.writeTimeout > 0 {
if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
mc.log(err)
mc.cleanup()
return err
}
}

n, err := mc.netConn.Write(data[:4+size])
if err == nil && n == 4+size {
mc.sequence++
if size != maxPacketSize {
return nil
}
pktLen -= size
data = data[size:]
continue
}

// Handle error
if err == nil { // n != len(data)
if err != nil {
mc.cleanup()
mc.log(ErrMalformPkt)
} else {
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
if n == 0 && pktLen == len(data)-4 {
// only for the first loop iteration when nothing was written yet
return errBadConnNoWrite
}
return err
}
if n != size+4 {
mc.cleanup()
mc.log(err)
return io.ErrShortWrite
}

mc.sequence++
if size != maxPacketSize {
return nil
}
return ErrInvalidConn
pktLen -= size
data = data[size:]
continue
}
}

Expand Down Expand Up @@ -303,8 +298,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// ClientFlags [32 bit]
Expand Down Expand Up @@ -392,8 +387,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
data, err := mc.buf.takeSmallBuffer(pktLen)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// Add the auth data [EOF]
Expand All @@ -412,8 +407,8 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// Add command byte
Expand All @@ -431,8 +426,8 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// Add command byte
Expand All @@ -452,8 +447,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// Add command byte
Expand Down Expand Up @@ -522,32 +517,33 @@ func (mc *okHandler) readResultOK() error {
}

// Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html
func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
// handleOkPacket replaces both values; other cases leave the values unchanged.
mc.result.affectedRows = append(mc.result.affectedRows, 0)
mc.result.insertIds = append(mc.result.insertIds, 0)

data, err := mc.conn().readPacket()
if err == nil {
switch data[0] {

case iOK:
return 0, mc.handleOkPacket(data)
if err != nil {
return 0, err
}

case iERR:
return 0, mc.conn().handleErrorPacket(data)
switch data[0] {
case iOK:
return 0, mc.handleOkPacket(data)

case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
}
case iERR:
return 0, mc.conn().handleErrorPacket(data)

// column count
num, _, _ := readLengthEncodedInteger(data)
// ignore remaining data in the packet. see #1478.
return int(num), nil
case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
}
return 0, err

// column count
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html
num, _, _ := readLengthEncodedInteger(data)
// ignore remaining data in the packet. see #1478.
return int(num), nil
}

// Error Packet
Expand Down Expand Up @@ -994,8 +990,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
}
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// command [1 byte]
Expand Down Expand Up @@ -1193,8 +1189,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
if err = mc.buf.store(data); err != nil {
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}
}

Expand Down
Loading

0 comments on commit db0cc0e

Please sign in to comment.