From b0102c9eba67ee12b55421f80753c9b5dc9a6142 Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Thu, 23 Jan 2020 18:30:36 +0100 Subject: [PATCH] Support calling Serve multiple times on a Server (#731) You can use the following methods in the handler to find out which listener the connection is coming in on. RequestCtx.IsTLS() RequestCtx.LocalAddr() RequestCtx.Request.Header.Host() --- server.go | 26 +++++++++++++++----------- server_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/server.go b/server.go index 1c652b0c7e..7b017b24b1 100644 --- a/server.go +++ b/server.go @@ -359,8 +359,8 @@ type Server struct { writerPool sync.Pool hijackConnPool sync.Pool - // We need to know our listener so we can close it in Shutdown(). - ln net.Listener + // We need to know our listeners so we can close them in Shutdown(). + ln []net.Listener mu sync.Mutex open int32 @@ -1577,20 +1577,21 @@ func (s *Server) Serve(ln net.Listener) error { var c net.Conn var err error + maxWorkersCount := s.getConcurrency() + s.mu.Lock() { - if s.ln != nil { - s.mu.Unlock() - return ErrAlreadyServing + s.ln = append(s.ln, ln) + if s.done == nil { + s.done = make(chan struct{}) } - s.ln = ln - s.done = make(chan struct{}) + if s.concurrencyCh == nil { + s.concurrencyCh = make(chan struct{}, maxWorkersCount) + } } s.mu.Unlock() - maxWorkersCount := s.getConcurrency() - s.concurrencyCh = make(chan struct{}, maxWorkersCount) wp := &workerPool{ WorkerFunc: s.serveConn, MaxWorkersCount: maxWorkersCount, @@ -1663,8 +1664,10 @@ func (s *Server) Shutdown() error { return nil } - if err := s.ln.Close(); err != nil { - return err + for _, ln := range s.ln { + if err := ln.Close(); err != nil { + return err + } } if s.done != nil { @@ -1684,6 +1687,7 @@ func (s *Server) Shutdown() error { time.Sleep(time.Millisecond * 100) } + s.done = nil s.ln = nil return nil } diff --git a/server_test.go b/server_test.go index bf6a110f0b..d8b265bdc6 100644 --- a/server_test.go +++ b/server_test.go @@ -3047,6 +3047,50 @@ func TestShutdownErr(t *testing.T) { verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") } +func TestMultipleServe(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.Success("aaa/bbb", []byte("real response")) + }, + } + + ln1 := fasthttputil.NewInmemoryListener() + ln2 := fasthttputil.NewInmemoryListener() + + go func() { + if err := s.Serve(ln1); err != nil { + t.Errorf("unexepcted error: %s", err) + } + }() + go func() { + if err := s.Serve(ln2); err != nil { + t.Errorf("unexepcted error: %s", err) + } + }() + + conn, err := ln1.Dial() + if err != nil { + t.Fatalf("unexepcted error: %s", err) + } + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) + } + br := bufio.NewReader(conn) + verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") + + conn, err = ln2.Dial() + if err != nil { + t.Fatalf("unexepcted error: %s", err) + } + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) + } + br = bufio.NewReader(conn) + verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") +} + func TestMaxBodySizePerRequest(t *testing.T) { t.Parallel()