Skip to content

Commit

Permalink
Merge pull request #698 from iangudger/master
Browse files Browse the repository at this point in the history
Fix connection leak on conn.ssl or conn.startup failure
  • Loading branch information
maddyblue authored Jan 23, 2018
2 parents 27ea5d9 + 5253e15 commit 61fe37a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
10 changes: 10 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,15 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
if err != nil {
return nil, err
}

// cn.ssl and cn.startup panic on error. Make sure we don't leak cn.c.
panicking := true
defer func() {
if panicking {
cn.c.Close()
}
}()

cn.ssl(o)
cn.buf = bufio.NewReader(cn.c)
cn.startup(o)
Expand All @@ -347,6 +356,7 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
err = cn.c.SetDeadline(time.Time{})
}
panicking = false
return cn, err
}

Expand Down
58 changes: 56 additions & 2 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func forceBinaryParameters() bool {
}
}

func openTestConnConninfo(conninfo string) (*sql.DB, error) {
func testConninfo(conninfo string) string {
defaultTo := func(envvar string, value string) {
if os.Getenv(envvar) == "" {
os.Setenv(envvar, value)
Expand All @@ -43,8 +43,11 @@ func openTestConnConninfo(conninfo string) (*sql.DB, error) {
!strings.HasPrefix(conninfo, "postgresql://") {
conninfo = conninfo + " binary_parameters=yes"
}
return conninfo
}

return sql.Open("postgres", conninfo)
func openTestConnConninfo(conninfo string) (*sql.DB, error) {
return sql.Open("postgres", testConninfo(conninfo))
}

func openTestConn(t Fatalistic) *sql.DB {
Expand Down Expand Up @@ -637,6 +640,57 @@ func TestErrorDuringStartup(t *testing.T) {
}
}

type testConn struct {
closed bool
net.Conn
}

func (c *testConn) Close() error {
c.closed = true
return c.Conn.Close()
}

type testDialer struct {
conns []*testConn
}

func (d *testDialer) Dial(ntw, addr string) (net.Conn, error) {
c, err := net.Dial(ntw, addr)
if err != nil {
return nil, err
}
tc := &testConn{Conn: c}
d.conns = append(d.conns, tc)
return tc, nil
}

func (d *testDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
c, err := net.DialTimeout(ntw, addr, timeout)
if err != nil {
return nil, err
}
tc := &testConn{Conn: c}
d.conns = append(d.conns, tc)
return tc, nil
}

func TestErrorDuringStartupClosesConn(t *testing.T) {
// Don't use the normal connection setup, this is intended to
// blow up in the startup packet from a non-existent user.
var d testDialer
c, err := DialOpen(&d, testConninfo("user=thisuserreallydoesntexist"))
if err == nil {
c.Close()
t.Fatal("expected dial error")
}
if len(d.conns) != 1 {
t.Fatalf("got len(d.conns) = %d, want = %d", len(d.conns), 1)
}
if !d.conns[0].closed {
t.Error("connection leaked")
}
}

func TestBadConn(t *testing.T) {
var err error

Expand Down

0 comments on commit 61fe37a

Please sign in to comment.