From 3484db1f68a7b493faffc08c1897360fdd7a67f9 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 29 Jun 2024 08:36:17 +0900 Subject: [PATCH] improve error handling in writePacket (#1601) * handle error before success case. * return io.ErrShortWrite if not all bytes were written but err is nil. * return err instead of ErrInvalidConn. --- connection_test.go | 6 ++++-- packets.go | 34 +++++++++++++++++----------------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/connection_test.go b/connection_test.go index c59cb6176..6f8d2a6d7 100644 --- a/connection_test.go +++ b/connection_test.go @@ -163,6 +163,8 @@ func TestPingMarkBadConnection(t *testing.T) { netConn: nc, buf: newBuffer(nc), maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), + cfg: NewConfig(), } err := mc.Ping(context.Background()) @@ -184,8 +186,8 @@ func TestPingErrInvalidConn(t *testing.T) { err := mc.Ping(context.Background()) - if err != ErrInvalidConn { - t.Errorf("expected ErrInvalidConn, got %#v", err) + if err != nc.err { + t.Errorf("expected %#v, got %#v", nc.err, err) } } diff --git a/packets.go b/packets.go index b90b14c5c..df850fd41 100644 --- a/packets.go +++ b/packets.go @@ -124,32 +124,32 @@ func (mc *mysqlConn) writePacket(data []byte) error { } n, err := mc.netConn.Write(data[:4+size]) - if err == nil && n == 4+size { - mc.sequence++ - if size != maxPacketSize { - return nil - } - pktLen -= size - data = data[size:] - continue - } - - // Handle error - if err == nil { // n != len(data) - mc.cleanup() - mc.log(ErrMalformPkt) - } else { + if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return cerr } + mc.cleanup() if n == 0 && pktLen == len(data)-4 { // only for the first loop iteration when nothing was written yet + mc.log(err) return errBadConnNoWrite + } else { + return err } + } + if n != 4+size { + // io.Writer(b) must return a non-nil error if it cannot write len(b) bytes. + // The io.ErrShortWrite error is used to indicate that this rule has not been followed. mc.cleanup() - mc.log(err) + return io.ErrShortWrite + } + + mc.sequence++ + if size != maxPacketSize { + return nil } - return ErrInvalidConn + pktLen -= size + data = data[size:] } }