diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 7f45bc0bb..d315c7937 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -3,8 +3,10 @@ package pool import ( "bufio" "context" + "crypto/tls" "net" "sync/atomic" + "syscall" "time" "github.com/redis/go-redis/v9/internal/proto" @@ -16,6 +18,9 @@ type Conn struct { usedAt int64 // atomic netConn net.Conn + // for checking the health status of the connection, it may be nil. + sysConn syscall.Conn + rd *proto.Reader bw *bufio.Writer wr *proto.Writer @@ -34,6 +39,7 @@ func NewConn(netConn net.Conn) *Conn { cn.bw = bufio.NewWriter(netConn) cn.wr = proto.NewWriter(cn.bw) cn.SetUsedAt(time.Now()) + cn.setSysConn() return cn } @@ -50,6 +56,22 @@ func (cn *Conn) SetNetConn(netConn net.Conn) { cn.netConn = netConn cn.rd.Reset(netConn) cn.bw.Reset(netConn) + cn.setSysConn() +} + +func (cn *Conn) setSysConn() { + cn.sysConn = nil + conn := cn.netConn + if conn == nil { + return + } + if tlsConn, ok := conn.(*tls.Conn); ok { + conn = tlsConn.NetConn() + } + + if sysConn, ok := conn.(syscall.Conn); ok { + cn.sysConn = sysConn + } } func (cn *Conn) Write(b []byte) (int, error) { diff --git a/internal/pool/conn_check.go b/internal/pool/conn_check.go index 07c261c2b..f28833850 100644 --- a/internal/pool/conn_check.go +++ b/internal/pool/conn_check.go @@ -3,28 +3,14 @@ package pool import ( - "crypto/tls" "errors" "io" - "net" "syscall" - "time" ) var errUnexpectedRead = errors.New("unexpected read from socket") -func connCheck(conn net.Conn) error { - // Reset previous timeout. - _ = conn.SetDeadline(time.Time{}) - - // Check if tls.Conn. - if c, ok := conn.(*tls.Conn); ok { - conn = c.NetConn() - } - sysConn, ok := conn.(syscall.Conn) - if !ok { - return nil - } +func connCheck(sysConn syscall.Conn) error { rawConn, err := sysConn.SyscallConn() if err != nil { return err diff --git a/internal/pool/conn_check_dummy.go b/internal/pool/conn_check_dummy.go index 295da1268..2d270cf56 100644 --- a/internal/pool/conn_check_dummy.go +++ b/internal/pool/conn_check_dummy.go @@ -2,8 +2,8 @@ package pool -import "net" +import "syscall" -func connCheck(conn net.Conn) error { +func connCheck(_ syscall.Conn) error { return nil } diff --git a/internal/pool/conn_check_test.go b/internal/pool/conn_check_test.go index 214993339..d19969adf 100644 --- a/internal/pool/conn_check_test.go +++ b/internal/pool/conn_check_test.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "net" "net/http/httptest" + "syscall" "time" . "github.com/bsm/ginkgo/v2" @@ -16,16 +17,20 @@ var _ = Describe("tests conn_check with real conns", func() { var ts *httptest.Server var conn net.Conn var tlsConn *tls.Conn + var sysConn syscall.Conn + var tlsSysConn syscall.Conn var err error BeforeEach(func() { ts = httptest.NewServer(nil) conn, err = net.DialTimeout(ts.Listener.Addr().Network(), ts.Listener.Addr().String(), time.Second) Expect(err).NotTo(HaveOccurred()) + sysConn = conn.(syscall.Conn) tlsTestServer := httptest.NewUnstartedServer(nil) tlsTestServer.StartTLS() tlsConn, err = tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, tlsTestServer.Listener.Addr().Network(), tlsTestServer.Listener.Addr().String(), &tls.Config{InsecureSkipVerify: true}) Expect(err).NotTo(HaveOccurred()) + tlsSysConn = tlsConn.NetConn().(syscall.Conn) }) AfterEach(func() { @@ -33,33 +38,37 @@ var _ = Describe("tests conn_check with real conns", func() { }) It("good conn check", func() { - Expect(connCheck(conn)).NotTo(HaveOccurred()) + Expect(connCheck(sysConn)).NotTo(HaveOccurred()) Expect(conn.Close()).NotTo(HaveOccurred()) - Expect(connCheck(conn)).To(HaveOccurred()) + Expect(connCheck(sysConn)).To(HaveOccurred()) }) It("good tls conn check", func() { - Expect(connCheck(tlsConn)).NotTo(HaveOccurred()) + Expect(connCheck(tlsSysConn)).NotTo(HaveOccurred()) Expect(tlsConn.Close()).NotTo(HaveOccurred()) - Expect(connCheck(tlsConn)).To(HaveOccurred()) + Expect(connCheck(tlsSysConn)).To(HaveOccurred()) }) It("bad conn check", func() { Expect(conn.Close()).NotTo(HaveOccurred()) - Expect(connCheck(conn)).To(HaveOccurred()) + Expect(connCheck(sysConn)).To(HaveOccurred()) }) It("bad tls conn check", func() { Expect(tlsConn.Close()).NotTo(HaveOccurred()) - Expect(connCheck(tlsConn)).To(HaveOccurred()) + Expect(connCheck(tlsSysConn)).To(HaveOccurred()) }) It("check conn deadline", func() { Expect(conn.SetDeadline(time.Now())).NotTo(HaveOccurred()) time.Sleep(time.Millisecond * 10) - Expect(connCheck(conn)).NotTo(HaveOccurred()) + Expect(connCheck(sysConn)).To(HaveOccurred()) + + Expect(conn.SetDeadline(time.Now().Add(time.Minute))).NotTo(HaveOccurred()) + time.Sleep(time.Millisecond * 10) + Expect(connCheck(sysConn)).NotTo(HaveOccurred()) Expect(conn.Close()).NotTo(HaveOccurred()) }) }) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 2125f3e13..9b84993cc 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -499,6 +499,8 @@ func (p *ConnPool) Close() error { return firstErr } +var zeroTime = time.Time{} + func (p *ConnPool) isHealthyConn(cn *Conn) bool { now := time.Now() @@ -509,8 +511,12 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool { return false } - if connCheck(cn.netConn) != nil { - return false + if cn.sysConn != nil { + // reset previous timeout. + _ = cn.netConn.SetDeadline(zeroTime) + if connCheck(cn.sysConn) != nil { + return false + } } cn.SetUsedAt(now)