Skip to content

Commit

Permalink
Merge pull request #12702 from hexfusion/add-so
Browse files Browse the repository at this point in the history
*: add support for socket options
  • Loading branch information
gyuho authored Mar 9, 2021
2 parents 792b7f5 + 5b49fb4 commit 6fd85af
Show file tree
Hide file tree
Showing 17 changed files with 463 additions and 51 deletions.
72 changes: 66 additions & 6 deletions pkg/transport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package transport

import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
Expand All @@ -39,18 +40,66 @@ import (

// NewListener creates a new listner.
func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) {
if l, err = newListener(addr, scheme); err != nil {
return nil, err
}
return wrapTLS(scheme, tlsinfo, l)
return newListener(addr, scheme, WithTLSInfo(tlsinfo))
}

// NewListenerWithOpts creates a new listener which accpets listener options.
func NewListenerWithOpts(addr, scheme string, opts ...ListenerOption) (net.Listener, error) {
return newListener(addr, scheme, opts...)
}

func newListener(addr string, scheme string) (net.Listener, error) {
func newListener(addr, scheme string, opts ...ListenerOption) (net.Listener, error) {
if scheme == "unix" || scheme == "unixs" {
// unix sockets via unix://laddr
return NewUnixListener(addr)
}
return net.Listen("tcp", addr)

lnOpts := newListenOpts(opts...)

switch {
case lnOpts.IsSocketOpts():
// new ListenConfig with socket options.
config, err := newListenConfig(lnOpts.socketOpts)
if err != nil {
return nil, err
}
lnOpts.ListenConfig = config
// check for timeout
fallthrough
case lnOpts.IsTimeout(), lnOpts.IsSocketOpts():
// timeout listener with socket options.
ln, err := lnOpts.ListenConfig.Listen(context.TODO(), "tcp", addr)
if err != nil {
return nil, err
}
lnOpts.Listener = &rwTimeoutListener{
Listener: ln,
readTimeout: lnOpts.readTimeout,
writeTimeout: lnOpts.writeTimeout,
}
case lnOpts.IsTimeout():
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
lnOpts.Listener = &rwTimeoutListener{
Listener: ln,
readTimeout: lnOpts.readTimeout,
writeTimeout: lnOpts.writeTimeout,
}
default:
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
lnOpts.Listener = ln
}

// only skip if not passing TLSInfo
if lnOpts.skipTLSInfoCheck && !lnOpts.IsTLS() {
return lnOpts.Listener, nil
}
return wrapTLS(scheme, lnOpts.tlsInfo, lnOpts.Listener)
}

func wrapTLS(scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) {
Expand All @@ -63,6 +112,17 @@ func wrapTLS(scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, err
return newTLSListener(l, tlsinfo, checkSAN)
}

func newListenConfig(sopts *SocketOpts) (net.ListenConfig, error) {
lc := net.ListenConfig{}
if sopts != nil {
ctls := getControls(sopts)
if len(ctls) > 0 {
lc.Control = ctls.Control
}
}
return lc, nil
}

type TLSInfo struct {
// CertFile is the _server_ cert, it will also be used as a _client_ certificate if ClientCertFile is empty
CertFile string
Expand Down
76 changes: 76 additions & 0 deletions pkg/transport/listener_opts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package transport

import (
"net"
"time"
)

type ListenerOptions struct {
Listener net.Listener
ListenConfig net.ListenConfig

socketOpts *SocketOpts
tlsInfo *TLSInfo
skipTLSInfoCheck bool
writeTimeout time.Duration
readTimeout time.Duration
}

func newListenOpts(opts ...ListenerOption) *ListenerOptions {
lnOpts := &ListenerOptions{}
lnOpts.applyOpts(opts)
return lnOpts
}

func (lo *ListenerOptions) applyOpts(opts []ListenerOption) {
for _, opt := range opts {
opt(lo)
}
}

// IsTimeout returns true if the listener has a read/write timeout defined.
func (lo *ListenerOptions) IsTimeout() bool { return lo.readTimeout != 0 || lo.writeTimeout != 0 }

// IsSocketOpts returns true if the listener options includes socket options.
func (lo *ListenerOptions) IsSocketOpts() bool {
if lo.socketOpts == nil {
return false
}
return lo.socketOpts.ReusePort == true || lo.socketOpts.ReuseAddress == true
}

// IsTLS returns true if listner options includes TLSInfo.
func (lo *ListenerOptions) IsTLS() bool {
if lo.tlsInfo == nil {
return false
}
return lo.tlsInfo.Empty() == false
}

// ListenerOption are options which can be applied to the listener.
type ListenerOption func(*ListenerOptions)

// WithTimeout allows for a read or write timeout to be applied to the listener.
func WithTimeout(read, write time.Duration) ListenerOption {
return func(lo *ListenerOptions) {
lo.writeTimeout = write
lo.readTimeout = read
}
}

// WithSocketOpts defines socket options that will be applied to the listener.
func WithSocketOpts(s *SocketOpts) ListenerOption {
return func(lo *ListenerOptions) { lo.socketOpts = s }
}

// WithTLSInfo adds TLS credentials to the listener.
func WithTLSInfo(t *TLSInfo) ListenerOption {
return func(lo *ListenerOptions) { lo.tlsInfo = t }
}

// WithSkipTLSInfoCheck when true a transport can be created with an https scheme
// without passing TLSInfo, circumventing not presented error. Skipping this check
// also requires that TLSInfo is not passed.
func WithSkipTLSInfoCheck(skip bool) ListenerOption {
return func(lo *ListenerOptions) { lo.skipTLSInfoCheck = skip }
}
173 changes: 173 additions & 0 deletions pkg/transport/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,161 @@ func TestNewListenerTLSInfo(t *testing.T) {
testNewListenerTLSInfoAccept(t, *tlsInfo)
}

func TestNewListenerWithOpts(t *testing.T) {
tlsInfo, del, err := createSelfCert()
if err != nil {
t.Fatalf("unable to create cert: %v", err)
}
defer del()

tests := map[string]struct {
opts []ListenerOption
scheme string
expectedErr bool
}{
"https scheme no TLSInfo": {
opts: []ListenerOption{},
expectedErr: true,
scheme: "https",
},
"https scheme no TLSInfo with skip check": {
opts: []ListenerOption{WithSkipTLSInfoCheck(true)},
expectedErr: false,
scheme: "https",
},
"https scheme empty TLSInfo with skip check": {
opts: []ListenerOption{
WithSkipTLSInfoCheck(true),
WithTLSInfo(&TLSInfo{}),
},
expectedErr: false,
scheme: "https",
},
"https scheme empty TLSInfo no skip check": {
opts: []ListenerOption{
WithTLSInfo(&TLSInfo{}),
},
expectedErr: true,
scheme: "https",
},
"https scheme with TLSInfo and skip check": {
opts: []ListenerOption{
WithSkipTLSInfoCheck(true),
WithTLSInfo(tlsInfo),
},
expectedErr: false,
scheme: "https",
},
}
for testName, test := range tests {
t.Run(testName, func(t *testing.T) {
ln, err := NewListenerWithOpts("127.0.0.1:0", test.scheme, test.opts...)
if ln != nil {
defer ln.Close()
}
if test.expectedErr && err == nil {
t.Fatalf("expected error")
}
if !test.expectedErr && err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
}
}

func TestNewListenerWithSocketOpts(t *testing.T) {
tlsInfo, del, err := createSelfCert()
if err != nil {
t.Fatalf("unable to create cert: %v", err)
}
defer del()

tests := map[string]struct {
opts []ListenerOption
scheme string
expectedErr bool
}{
"nil socketopts": {
opts: []ListenerOption{WithSocketOpts(nil)},
expectedErr: true,
scheme: "http",
},
"empty socketopts": {
opts: []ListenerOption{WithSocketOpts(&SocketOpts{})},
expectedErr: true,
scheme: "http",
},

"reuse address": {
opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReuseAddress: true})},
scheme: "http",
expectedErr: true,
},
"reuse address with TLS": {
opts: []ListenerOption{
WithSocketOpts(&SocketOpts{ReuseAddress: true}),
WithTLSInfo(tlsInfo),
},
scheme: "https",
expectedErr: true,
},
"reuse address and port": {
opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReuseAddress: true, ReusePort: true})},
scheme: "http",
expectedErr: false,
},
"reuse address and port with TLS": {
opts: []ListenerOption{
WithSocketOpts(&SocketOpts{ReuseAddress: true, ReusePort: true}),
WithTLSInfo(tlsInfo),
},
scheme: "https",
expectedErr: false,
},
"reuse port with TLS and timeout": {
opts: []ListenerOption{
WithSocketOpts(&SocketOpts{ReusePort: true}),
WithTLSInfo(tlsInfo),
WithTimeout(5*time.Second, 5*time.Second),
},
scheme: "https",
expectedErr: false,
},
"reuse port with https scheme and no TLSInfo skip check": {
opts: []ListenerOption{
WithSocketOpts(&SocketOpts{ReusePort: true}),
WithSkipTLSInfoCheck(true),
},
scheme: "https",
expectedErr: false,
},
"reuse port": {
opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReusePort: true})},
scheme: "http",
expectedErr: false,
},
}
for testName, test := range tests {
t.Run(testName, func(t *testing.T) {
ln, err := NewListenerWithOpts("127.0.0.1:0", test.scheme, test.opts...)
if err != nil {
t.Fatalf("unexpected NewListenerWithSocketOpts error: %v", err)
}
defer ln.Close()
ln2, err := NewListenerWithOpts(ln.Addr().String(), test.scheme, test.opts...)
if ln2 != nil {
ln2.Close()
}
if test.expectedErr && err == nil {
t.Fatalf("expected error")
}
if !test.expectedErr && err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
}
}

func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
ln, err := NewListener("127.0.0.1:0", "https", &tlsInfo)
if err != nil {
Expand Down Expand Up @@ -401,3 +556,21 @@ func TestIsClosedConnError(t *testing.T) {
t.Fatalf("expect true, got false (%v)", err)
}
}

func TestSocktOptsEmpty(t *testing.T) {
tests := []struct {
sopts SocketOpts
want bool
}{
{SocketOpts{}, true},
{SocketOpts{ReuseAddress: true, ReusePort: false}, false},
{SocketOpts{ReusePort: true}, false},
}

for i, tt := range tests {
got := tt.sopts.Empty()
if tt.want != got {
t.Errorf("#%d: result of Empty() incorrect: want=%t got=%t", i, tt.want, got)
}
}
}
Loading

0 comments on commit 6fd85af

Please sign in to comment.