Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PipelineClient name #994

Merged
merged 8 commits into from
Mar 15, 2021
37 changes: 37 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2120,6 +2120,13 @@ type PipelineClient struct {
// Address of the host to connect to.
Addr string

// PipelineClient name. Used in User-Agent request header.
Name string

// NoDefaultUserAgentHeader when set to true, causes the default
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool

// The maximum number of concurrent connections to the Addr.
//
// A single connection is used by default.
Expand Down Expand Up @@ -2225,6 +2232,8 @@ type pipelineConnClient struct {
noCopy noCopy //nolint:unused,structcheck

Addr string
Name string
NoDefaultUserAgentHeader bool
MaxPendingRequests int
MaxBatchDelay time.Duration
Dial DialFunc
Expand All @@ -2248,6 +2257,7 @@ type pipelineConnClient struct {

tlsConfigLock sync.Mutex
tlsConfig *tls.Config
clientName atomic.Value
}

type pipelineWork struct {
Expand Down Expand Up @@ -2316,6 +2326,11 @@ func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline t
req.URI().DisablePathNormalizing = true
}

userAgentOld := req.Header.UserAgent()
if len(userAgentOld) == 0 {
req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...)
}

w := acquirePipelineWork(&c.workPool, timeout)
w.respCopy.Header.disableNormalizing = c.DisableHeaderNamesNormalizing
w.req = &w.reqCopy
Expand Down Expand Up @@ -2380,6 +2395,11 @@ func (c *pipelineConnClient) Do(req *Request, resp *Response) error {
req.URI().DisablePathNormalizing = true
}

userAgentOld := req.Header.UserAgent()
if len(userAgentOld) == 0 {
req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...)
}

w := acquirePipelineWork(&c.workPool, 0)
w.req = req
if resp != nil {
Expand Down Expand Up @@ -2459,6 +2479,8 @@ func (c *PipelineClient) getConnClientUnlocked() *pipelineConnClient {
func (c *PipelineClient) newConnClient() *pipelineConnClient {
cc := &pipelineConnClient{
Addr: c.Addr,
Name: c.Name,
NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader,
MaxPendingRequests: c.MaxPendingRequests,
MaxBatchDelay: c.MaxBatchDelay,
Dial: c.Dial,
Expand Down Expand Up @@ -2770,6 +2792,21 @@ func (c *pipelineConnClient) PendingRequests() int {
return n
}

func (c *pipelineConnClient) getClientName() []byte {
v := c.clientName.Load()
var clientName []byte
if v == nil {
clientName = []byte(c.Name)
if len(clientName) == 0 && !c.NoDefaultUserAgentHeader {
clientName = defaultUserAgent
}
c.clientName.Store(clientName)
} else {
clientName = v.([]byte)
}
return clientName
}

var errPipelineConnStopped = errors.New("pipeline connection has been stopped")

func acquirePipelineWork(pool *sync.Pool, timeout time.Duration) *pipelineWork {
Expand Down
50 changes: 50 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,56 @@ func TestCloseIdleConnections(t *testing.T) {
}
}

func TestPipelineClientSetUserAgent(t *testing.T) {
t.Parallel()

testPipelineClientSetUserAgent(t, 0)
}

func TestPipelineClientSetUserAgentTimeout(t *testing.T) {
t.Parallel()

testPipelineClientSetUserAgent(t, time.Second)
}

func testPipelineClientSetUserAgent(t *testing.T, timeout time.Duration) {
ln := fasthttputil.NewInmemoryListener()

userAgentSeen := ""
s := &Server{
Handler: func(ctx *RequestCtx) {
userAgentSeen = string(ctx.UserAgent())
},
}
go s.Serve(ln) //nolint:errcheck

userAgent := "I'm not fasthttp"
c := &HostClient{
Name: userAgent,
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
res := AcquireResponse()

req.SetRequestURI("http://example.com")

var err error
if timeout <= 0 {
err = c.Do(req, res)
} else {
err = c.DoTimeout(req, res, timeout)
}

if err != nil {
t.Fatal(err)
}
if userAgentSeen != userAgent {
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, userAgent)
}
}

func TestPipelineClientIssue832(t *testing.T) {
t.Parallel()

Expand Down