From d27440de3ff4bcac15a44957845d2b7c1042ad35 Mon Sep 17 00:00:00 2001 From: Michael Raimondi Date: Mon, 5 Nov 2018 18:13:34 -0500 Subject: [PATCH] client: set TCP_USER_TIMEOUT socket option for linux (#2307) Implements proposal A18 (https://github.com/grpc/proposal/blob/master/A18-tcp-user-timeout.md). gRPC Core issue for reference: https://github.com/grpc/grpc/issues/15889 --- internal/syscall/syscall_linux.go | 47 ++++++++++++++++++++++ internal/syscall/syscall_nonlinux.go | 18 ++++++++- internal/transport/http2_client.go | 28 ++++++++----- internal/transport/transport_test.go | 59 ++++++++++++++++++++++++++++ 4 files changed, 141 insertions(+), 11 deletions(-) diff --git a/internal/syscall/syscall_linux.go b/internal/syscall/syscall_linux.go index 1c7cef610396..43281a3e078d 100644 --- a/internal/syscall/syscall_linux.go +++ b/internal/syscall/syscall_linux.go @@ -23,7 +23,10 @@ package syscall import ( + "fmt" + "net" "syscall" + "time" "golang.org/x/sys/unix" "google.golang.org/grpc/grpclog" @@ -65,3 +68,47 @@ func CPUTimeDiff(first *Rusage, latest *Rusage) (float64, float64) { return uTimeElapsed, sTimeElapsed } + +// SetTCPUserTimeout sets the TCP user timeout on a connection's socket +func SetTCPUserTimeout(conn net.Conn, timeout time.Duration) error { + tcpconn, ok := conn.(*net.TCPConn) + if !ok { + // not a TCP connection. exit early + return nil + } + rawConn, err := tcpconn.SyscallConn() + if err != nil { + return fmt.Errorf("error getting raw connection: %v", err) + } + err = rawConn.Control(func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int(timeout/time.Millisecond)) + }) + if err != nil { + return fmt.Errorf("error setting option on socket: %v", err) + } + + return nil +} + +// GetTCPUserTimeout gets the TCP user timeout on a connection's socket +func GetTCPUserTimeout(conn net.Conn) (opt int, err error) { + tcpconn, ok := conn.(*net.TCPConn) + if !ok { + err = fmt.Errorf("conn is not *net.TCPConn. got %T", conn) + return + } + rawConn, err := tcpconn.SyscallConn() + if err != nil { + err = fmt.Errorf("error getting raw connection: %v", err) + return + } + err = rawConn.Control(func(fd uintptr) { + opt, err = syscall.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_USER_TIMEOUT) + }) + if err != nil { + err = fmt.Errorf("error getting option on socket: %v", err) + return + } + + return +} diff --git a/internal/syscall/syscall_nonlinux.go b/internal/syscall/syscall_nonlinux.go index 887c38bbd04e..61678feb0044 100644 --- a/internal/syscall/syscall_nonlinux.go +++ b/internal/syscall/syscall_nonlinux.go @@ -20,7 +20,12 @@ package syscall -import "google.golang.org/grpc/grpclog" +import ( + "net" + "time" + + "google.golang.org/grpc/grpclog" +) func init() { grpclog.Info("CPU time info is unavailable on non-linux or appengine environment.") @@ -45,3 +50,14 @@ func GetRusage() (rusage *Rusage) { func CPUTimeDiff(first *Rusage, latest *Rusage) (float64, float64) { return 0, 0 } + +// SetTCPUserTimeout is a no-op function under non-linux or appengine environments +func SetTCPUserTimeout(conn net.Conn, timeout time.Duration) error { + return nil +} + +// GetTCPUserTimeout is a no-op function under non-linux or appengine environments +// a negative return value indicates the operation is not supported +func GetTCPUserTimeout(conn net.Conn) (int, error) { + return -1, nil +} diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index a1e602658cdf..e7e881a2230d 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -36,6 +36,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" @@ -166,6 +167,21 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne conn.Close() } }(conn) + kp := opts.KeepaliveParams + // Validate keepalive parameters. + if kp.Time == 0 { + kp.Time = defaultClientKeepaliveTime + } + if kp.Timeout == 0 { + kp.Timeout = defaultClientKeepaliveTimeout + } + keepaliveEnabled := false + if kp.Time != infinity { + if err = syscall.SetTCPUserTimeout(conn, kp.Timeout); err != nil { + return nil, connectionErrorf(false, err, "transport: failed to set TCP_USER_TIMEOUT: %v", err) + } + keepaliveEnabled = true + } var ( isSecure bool authInfo credentials.AuthInfo @@ -189,14 +205,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne } isSecure = true } - kp := opts.KeepaliveParams - // Validate keepalive parameters. - if kp.Time == 0 { - kp.Time = defaultClientKeepaliveTime - } - if kp.Timeout == 0 { - kp.Timeout = defaultClientKeepaliveTimeout - } dynamicWindow := true icwz := int32(initialWindowSize) if opts.InitialConnWindowSize >= defaultWindowSize { @@ -240,6 +248,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne czData: new(channelzData), onGoAway: onGoAway, onClose: onClose, + keepaliveEnabled: keepaliveEnabled, } t.controlBuf = newControlBuffer(t.ctxDone) if opts.InitialWindowSize >= defaultWindowSize { @@ -268,8 +277,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne if channelz.IsOn() { t.channelzID = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, fmt.Sprintf("%s -> %s", t.localAddr, t.remoteAddr)) } - if t.kp.Time != infinity { - t.keepaliveEnabled = true + if t.keepaliveEnabled { go t.keepalive() } // Start the reader goroutine for incoming message. Each transport has diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 45e29be8cfb6..3911a2925d19 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -41,6 +41,7 @@ import ( "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/leakcheck" + "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" ) @@ -2317,3 +2318,61 @@ func TestHeaderTblSize(t *testing.T) { t.Fatalf("expected len(limits) = 2 within 10s, got != 2") } } + +// TestTCPUserTimeout tests that the TCP_USER_TIMEOUT socket option is set to the +// keepalive timeout, as detailed in proposal A18 +func TestTCPUserTimeout(t *testing.T) { + tests := []struct { + time time.Duration + timeout time.Duration + }{ + { + 10 * time.Second, + 10 * time.Second, + }, + { + 0, + 0, + }, + } + for _, tt := range tests { + server, client, cancel := setUpWithOptions( + t, + 0, + &ServerConfig{ + KeepaliveParams: keepalive.ServerParameters{ + Time: tt.timeout, + Timeout: tt.timeout, + }, + }, + normal, + ConnectOptions{ + KeepaliveParams: keepalive.ClientParameters{ + Time: tt.time, + Timeout: tt.timeout, + }, + }, + ) + defer cancel() + defer server.stop() + defer client.Close() + + stream, err := client.NewStream(context.Background(), &CallHdr{}) + if err != nil { + t.Fatalf("Client failed to create RPC request: %v", err) + } + client.closeStream(stream, io.EOF, true, http2.ErrCodeCancel, nil, nil, false) + + opt, err := syscall.GetTCPUserTimeout(client.conn) + if err != nil { + t.Fatalf("GetTCPUserTimeout error: %v", err) + } + if opt < 0 { + t.Skipf("skipping test on unsupported environment") + } + if timeoutMS := int(tt.timeout / time.Millisecond); timeoutMS != opt { + t.Fatalf("wrong TCP_USER_TIMEOUT set on conn. expected %d. got %d", + timeoutMS, opt) + } + } +}