Skip to content

Commit

Permalink
pdutil: fix retry reusing body reader (#48312)
Browse files Browse the repository at this point in the history
close #48307
  • Loading branch information
D3Hunter committed Nov 6, 2023
1 parent 0d9a4ce commit e2d3047
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 19 deletions.
25 changes: 15 additions & 10 deletions br/pkg/pdutil/pd.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,21 +130,21 @@ var (

// pdHTTPRequest defines the interface to send a request to pd and return the result in bytes.
type pdHTTPRequest func(ctx context.Context, addr string, prefix string,
cli *http.Client, method string, body io.Reader) ([]byte, error)
cli *http.Client, method string, body []byte) ([]byte, error)

// pdRequest is a func to send an HTTP to pd and return the result bytes.
func pdRequest(
ctx context.Context,
addr string, prefix string,
cli *http.Client, method string, body io.Reader) ([]byte, error) {
cli *http.Client, method string, body []byte) ([]byte, error) {
_, respBody, err := pdRequestWithCode(ctx, addr, prefix, cli, method, body)
return respBody, err
}

func pdRequestWithCode(
ctx context.Context,
addr string, prefix string,
cli *http.Client, method string, body io.Reader) (int, []byte, error) {
cli *http.Client, method string, body []byte) (int, []byte, error) {
u, err := url.Parse(addr)
if err != nil {
return 0, nil, errors.Trace(err)
Expand All @@ -154,10 +154,13 @@ func pdRequestWithCode(
req *http.Request
resp *http.Response
)
if body == nil {
body = []byte("")
}
count := 0
// the total retry duration: 120*1 = 2min
for {
req, err = http.NewRequestWithContext(ctx, method, reqURL, body)
req, err = http.NewRequestWithContext(ctx, method, reqURL, bytes.NewBuffer(body))
if err != nil {
return 0, nil, errors.Trace(err)
}
Expand All @@ -184,6 +187,8 @@ func pdRequestWithCode(
(err != nil && !common.IsRetryableError(err)) {
break
}
log.Warn("request failed, will retry later",
zap.String("url", reqURL), zap.Int("retry-count", count), zap.Error(err))
if resp != nil {
_ = resp.Body.Close()
}
Expand Down Expand Up @@ -434,7 +439,7 @@ func (p *PdController) doPauseSchedulers(ctx context.Context,
removedSchedulers := make([]string, 0, len(schedulers))
for _, scheduler := range schedulers {
for _, addr := range p.getAllPDAddrs() {
_, err = post(ctx, addr, pdapi.SchedulerByName(scheduler), p.cli, http.MethodPost, bytes.NewBuffer(body))
_, err = post(ctx, addr, pdapi.SchedulerByName(scheduler), p.cli, http.MethodPost, body)
if err == nil {
removedSchedulers = append(removedSchedulers, scheduler)
break
Expand Down Expand Up @@ -516,7 +521,7 @@ func (p *PdController) resumeSchedulerWith(ctx context.Context, schedulers []str
}
for _, scheduler := range schedulers {
for _, addr := range p.getAllPDAddrs() {
_, err = post(ctx, addr, pdapi.SchedulerByName(scheduler), p.cli, http.MethodPost, bytes.NewBuffer(body))
_, err = post(ctx, addr, pdapi.SchedulerByName(scheduler), p.cli, http.MethodPost, body)
if err == nil {
break
}
Expand Down Expand Up @@ -605,7 +610,7 @@ func (p *PdController) doUpdatePDScheduleConfig(
return errors.Trace(err)
}
_, e := post(ctx, addr, prefix,
p.cli, http.MethodPost, bytes.NewBuffer(reqData))
p.cli, http.MethodPost, reqData)
if e == nil {
return nil
}
Expand Down Expand Up @@ -861,7 +866,7 @@ func (p *PdController) RecoverBaseAllocID(ctx context.Context, id uint64) error
})
var err error
for _, addr := range p.getAllPDAddrs() {
_, e := pdRequest(ctx, addr, pdapi.BaseAllocID, p.cli, http.MethodPost, bytes.NewBuffer(reqData))
_, e := pdRequest(ctx, addr, pdapi.BaseAllocID, p.cli, http.MethodPost, reqData)
if e != nil {
log.Warn("failed to recover base alloc id", zap.String("addr", addr), zap.Error(e))
err = e
Expand All @@ -885,7 +890,7 @@ func (p *PdController) ResetTS(ctx context.Context, ts uint64) error {
})
var err error
for _, addr := range p.getAllPDAddrs() {
code, _, e := pdRequestWithCode(ctx, addr, pdapi.ResetTS, p.cli, http.MethodPost, bytes.NewBuffer(reqData))
code, _, e := pdRequestWithCode(ctx, addr, pdapi.ResetTS, p.cli, http.MethodPost, reqData)
if e != nil {
// for pd version <= 6.2, if the given ts < current ts of pd, pd returns StatusForbidden.
// it's not an error for br
Expand Down Expand Up @@ -962,7 +967,7 @@ func (p *PdController) CreateOrUpdateRegionLabelRule(ctx context.Context, rule L
addrs := p.getAllPDAddrs()
for i, addr := range addrs {
_, lastErr = pdRequest(ctx, addr, pdapi.RegionLabelRule,
p.cli, http.MethodPost, bytes.NewBuffer(reqData))
p.cli, http.MethodPost, reqData)
if lastErr == nil {
return nil
}
Expand Down
19 changes: 10 additions & 9 deletions br/pkg/pdutil/pd_serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package pdutil

import (
"bytes"
"context"
"encoding/hex"
"encoding/json"
Expand Down Expand Up @@ -32,7 +31,7 @@ func TestScheduler(t *testing.T) {
defer cancel()

scheduler := "balance-leader-scheduler"
mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) {
return nil, errors.New("failed")
}
schedulerPauseCh := make(chan struct{})
Expand Down Expand Up @@ -67,7 +66,7 @@ func TestScheduler(t *testing.T) {
_, err = pdController.listSchedulersWith(ctx, mock)
require.EqualError(t, err, "failed")

mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) {
return []byte(`["` + scheduler + `"]`), nil
}

Expand All @@ -87,7 +86,7 @@ func TestScheduler(t *testing.T) {
func TestGetClusterVersion(t *testing.T) {
pdController := &PdController{addrs: []string{"", ""}} // two endpoints
counter := 0
mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) {
counter++
if counter <= 1 {
return nil, errors.New("mock error")
Expand All @@ -100,7 +99,7 @@ func TestGetClusterVersion(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "test", respString)

mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) {
return nil, errors.New("mock error")
}
_, err = pdController.getClusterVersionWith(ctx, mock)
Expand Down Expand Up @@ -130,7 +129,7 @@ func TestRegionCount(t *testing.T) {
require.Equal(t, 3, len(regions.Regions))

mock := func(
_ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader,
_ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte,
) ([]byte, error) {
query := fmt.Sprintf("%s/%s", addr, prefix)
u, e := url.Parse(query)
Expand Down Expand Up @@ -181,6 +180,9 @@ func TestPDRequestRetry(t *testing.T) {
count := 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count++
bytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
require.Equal(t, "test", string(bytes))
if count <= pdRequestRetryTime-1 {
w.WriteHeader(http.StatusGatewayTimeout)
return
Expand All @@ -196,8 +198,7 @@ func TestPDRequestRetry(t *testing.T) {
cli.Transport.(*http.Transport).DisableKeepAlives = true

taddr := ts.URL
body := bytes.NewBuffer([]byte("test"))
_, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodPost, body)
_, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodPost, []byte("test"))
require.NoError(t, reqErr)
ts.Close()
count = 0
Expand Down Expand Up @@ -269,7 +270,7 @@ func TestStoreInfo(t *testing.T) {
},
}
mock := func(
_ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader,
_ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte,
) ([]byte, error) {
require.Equal(t,
fmt.Sprintf("http://mock%s", pdapi.StoreByID(1)),
Expand Down

0 comments on commit e2d3047

Please sign in to comment.