diff --git a/pkg/transport/listener_tls.go b/pkg/transport/listener_tls.go index 53e6a104890..ecc12454899 100644 --- a/pkg/transport/listener_tls.go +++ b/pkg/transport/listener_tls.go @@ -15,7 +15,9 @@ package transport import ( + "context" "crypto/tls" + "crypto/x509" "fmt" "net" "sync" @@ -40,11 +42,16 @@ func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) { if err != nil { return nil, err } + + hf := tlsinfo.HandshakeFailure + if hf == nil { + hf = func(*tls.Conn, error) {} + } tlsl := &tlsListener{ Listener: tls.NewListener(l, tlscfg), connc: make(chan net.Conn), donec: make(chan struct{}), - handshakeFailure: tlsinfo.HandshakeFailure, + handshakeFailure: hf, } go tlsl.acceptLoop() return tlsl, nil @@ -66,9 +73,9 @@ func (l *tlsListener) acceptLoop() { var pendingMu sync.Mutex pending := make(map[net.Conn]struct{}) - stopc := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) defer func() { - close(stopc) + cancel() pendingMu.Lock() for c := range pending { c.Close() @@ -104,32 +111,58 @@ func (l *tlsListener) acceptLoop() { delete(pending, conn) pendingMu.Unlock() if herr != nil { - if l.handshakeFailure != nil { - l.handshakeFailure(tlsConn, herr) - } + l.handshakeFailure(tlsConn, herr) return } st := tlsConn.ConnectionState() if len(st.PeerCertificates) > 0 { cert := st.PeerCertificates[0] - if len(cert.IPAddresses) > 0 || len(cert.DNSNames) > 0 { - addr := tlsConn.RemoteAddr().String() - h, _, herr := net.SplitHostPort(addr) - if herr != nil || cert.VerifyHostname(h) != nil { - return - } + addr := tlsConn.RemoteAddr().String() + if cerr := checkCert(ctx, cert, addr); cerr != nil { + l.handshakeFailure(tlsConn, cerr) + return } } select { case l.connc <- tlsConn: conn = nil - case <-stopc: + case <-ctx.Done(): } }() } } +func checkCert(ctx context.Context, cert *x509.Certificate, remoteAddr string) error { + h, _, herr := net.SplitHostPort(remoteAddr) + if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 { + return nil + } + if herr != nil { + return herr + } + if len(cert.IPAddresses) > 0 { + if cerr := cert.VerifyHostname(h); cerr != nil && len(cert.DNSNames) == 0 { + return cerr + } + } + if len(cert.DNSNames) > 0 { + for _, dns := range cert.DNSNames { + addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns) + if lerr != nil { + continue + } + for _, addr := range addrs { + if addr == h { + return nil + } + } + } + return fmt.Errorf("tls: %q does not match any of DNSNames %q", h, cert.DNSNames) + } + return nil +} + func (l *tlsListener) Close() error { err := l.Listener.Close() <-l.donec