Skip to content

Commit

Permalink
http2: add Server.WriteByteTimeout
Browse files Browse the repository at this point in the history
Transports support a WriteByteTimeout option which sets the maximum
amount of time we can go without being able to write any bytes to
a connection. Add an equivalent option to Server for consistency.

Fixes golang/go#61777

Change-Id: Iaa8a69dfc403906eb224829320f901e5a6a5c429
Reviewed-on: https://go-review.googlesource.com/c/net/+/601496
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
  • Loading branch information
neild committed Sep 23, 2024
1 parent 3c333c0 commit 541dbe5
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 28 deletions.
5 changes: 2 additions & 3 deletions http2/connframes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package http2

import (
"bytes"
"context"
"io"
"net/http"
"os"
Expand Down Expand Up @@ -295,7 +294,7 @@ func (tf *testConnFramer) wantClosed() {
if err == nil {
tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr)
}
if err == context.DeadlineExceeded {
if err == os.ErrDeadlineExceeded {
tf.t.Fatalf("connection is not closed; want it to be")
}
}
Expand All @@ -306,7 +305,7 @@ func (tf *testConnFramer) wantIdle() {
if err == nil {
tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr)
}
if err != context.DeadlineExceeded {
if err != os.ErrDeadlineExceeded {
tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err)
}
}
Expand Down
53 changes: 46 additions & 7 deletions http2/http2.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"sort"
Expand Down Expand Up @@ -237,13 +238,19 @@ func (cw closeWaiter) Wait() {
// Its buffered writer is lazily allocated as needed, to minimize
// idle memory usage with many connections.
type bufferedWriter struct {
_ incomparable
w io.Writer // immutable
bw *bufio.Writer // non-nil when data is buffered
_ incomparable
group synctestGroupInterface // immutable
conn net.Conn // immutable
bw *bufio.Writer // non-nil when data is buffered
byteTimeout time.Duration // immutable, WriteByteTimeout
}

func newBufferedWriter(w io.Writer) *bufferedWriter {
return &bufferedWriter{w: w}
func newBufferedWriter(group synctestGroupInterface, conn net.Conn, timeout time.Duration) *bufferedWriter {
return &bufferedWriter{
group: group,
conn: conn,
byteTimeout: timeout,
}
}

// bufWriterPoolBufferSize is the size of bufio.Writer's
Expand All @@ -270,7 +277,7 @@ func (w *bufferedWriter) Available() int {
func (w *bufferedWriter) Write(p []byte) (n int, err error) {
if w.bw == nil {
bw := bufWriterPool.Get().(*bufio.Writer)
bw.Reset(w.w)
bw.Reset((*bufferedWriterTimeoutWriter)(w))
w.bw = bw
}
return w.bw.Write(p)
Expand All @@ -288,6 +295,38 @@ func (w *bufferedWriter) Flush() error {
return err
}

type bufferedWriterTimeoutWriter bufferedWriter

func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) {
return writeWithByteTimeout(w.group, w.conn, w.byteTimeout, p)
}

// writeWithByteTimeout writes to conn.
// If more than timeout passes without any bytes being written to the connection,
// the write fails.
func writeWithByteTimeout(group synctestGroupInterface, conn net.Conn, timeout time.Duration, p []byte) (n int, err error) {
if timeout <= 0 {
return conn.Write(p)
}
for {
var now time.Time
if group == nil {
now = time.Now()
} else {
now = group.Now()
}
conn.SetWriteDeadline(now.Add(timeout))
nn, err := conn.Write(p[n:])
n += nn
if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
// Either we finished the write, made no progress, or hit the deadline.
// Whichever it is, we're done now.
conn.SetWriteDeadline(time.Time{})
return n, err
}
}
}

func mustUint31(v int32) uint32 {
if v < 0 || v > 2147483647 {
panic("out of range")
Expand Down
12 changes: 11 additions & 1 deletion http2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ type Server struct {
// If zero or negative, there is no timeout.
IdleTimeout time.Duration

// WriteByteTimeout is the timeout after which a connection will be
// closed if no data can be written to it. The timeout begins when data is
// available to write, and is extended whenever any bytes are written.
// If zero or negative, there is no timeout.
WriteByteTimeout time.Duration

// MaxUploadBufferPerConnection is the size of the initial flow
// control window for each connections. The HTTP/2 spec does not
// allow this to be smaller than 65535 or larger than 2^32-1.
Expand Down Expand Up @@ -446,7 +452,7 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
conn: c,
baseCtx: baseCtx,
remoteAddrStr: c.RemoteAddr().String(),
bw: newBufferedWriter(c),
bw: newBufferedWriter(s.group, c, s.WriteByteTimeout),
handler: opts.handler(),
streams: make(map[uint32]*stream),
readFrameCh: make(chan readFrameResult),
Expand Down Expand Up @@ -1320,6 +1326,10 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
sc.writingFrame = false
sc.writingFrameAsync = false

if res.err != nil {
sc.conn.Close()
}

wr := res.wr

if writeEndsStream(wr.write) {
Expand Down
32 changes: 32 additions & 0 deletions http2/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4674,3 +4674,35 @@ func TestServerSetReadWriteDeadlineRace(t *testing.T) {
}
resp.Body.Close()
}

func TestServerWriteByteTimeout(t *testing.T) {
const timeout = 1 * time.Second
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
w.Write(make([]byte, 100))
}, func(s *Server) {
s.WriteByteTimeout = timeout
})
st.greet()

st.cc.(*synctestNetConn).SetReadBufferSize(1) // write one byte at a time
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeader(),
EndStream: true,
EndHeaders: true,
})

// Read a few bytes, staying just under WriteByteTimeout.
for i := 0; i < 10; i++ {
st.advance(timeout - 1)
if n, err := st.cc.Read(make([]byte, 1)); n != 1 || err != nil {
t.Fatalf("read %v: %v, %v; want 1, nil", i, n, err)
}
}

// Wait for WriteByteTimeout.
// The connection should close.
st.advance(1 * time.Second) // timeout after writing one byte
st.advance(1 * time.Second) // timeout after failing to write any more bytes
st.wantClosed()
}
24 changes: 7 additions & 17 deletions http2/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"net/http"
"net/http/httptrace"
"net/textproto"
"os"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -499,6 +498,7 @@ func (cs *clientStream) closeReqBodyLocked() {
}

type stickyErrWriter struct {
group synctestGroupInterface
conn net.Conn
timeout time.Duration
err *error
Expand All @@ -508,22 +508,9 @@ func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
if *sew.err != nil {
return 0, *sew.err
}
for {
if sew.timeout != 0 {
sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout))
}
nn, err := sew.conn.Write(p[n:])
n += nn
if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) {
// Keep extending the deadline so long as we're making progress.
continue
}
if sew.timeout != 0 {
sew.conn.SetWriteDeadline(time.Time{})
}
*sew.err = err
return n, err
}
n, err = writeWithByteTimeout(sew.group, sew.conn, sew.timeout, p)
*sew.err = err
return n, err
}

// noCachedConnError is the concrete type of ErrNoCachedConn, which
Expand Down Expand Up @@ -792,10 +779,12 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1),
}
var group synctestGroupInterface
if t.transportTestHooks != nil {
t.markNewGoroutine()
t.transportTestHooks.newclientconn(cc)
c = cc.tconn
group = t.group
}
if VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
Expand All @@ -807,6 +796,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
// TODO: adjust this writer size to account for frame size +
// MTU + crypto/tls record padding.
cc.bw = bufio.NewWriter(stickyErrWriter{
group: group,
conn: c,
timeout: t.WriteByteTimeout,
err: &cc.werr,
Expand Down

0 comments on commit 541dbe5

Please sign in to comment.